Skip to content

Commit

Permalink
Merge dfd3cb7 into 727493d
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jul 21, 2023
2 parents 727493d + dfd3cb7 commit 3ad61e7
Show file tree
Hide file tree
Showing 23 changed files with 1,362 additions and 345 deletions.
21 changes: 14 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
version = "0.20.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
Expand All @@ -23,15 +25,12 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[weakdeps]
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[extensions]
SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"

[compat]
DynamicExpressions = "0.9"
Compat = "^4.2"
DynamicExpressions = "0.10"
DynamicQuantities = "^0.6.2"
JSON3 = "1"
LineSearches = "7"
LossFunctions = "0.6, 0.7, 0.8, 0.10"
Expand All @@ -46,8 +45,12 @@ Requires = "1"
SpecialFunctions = "0.10.1, 1, 2"
StatsBase = "0.33, 0.34"
SymbolicUtils = "0.19, ^1.0.5"
Tricks = "0.1"
julia = "1.6"

[extensions]
SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -56,7 +59,11 @@ MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets", "ForwardDiff", "LinearAlgebra", "MLJBase", "MLJTestInterface", "SymbolicUtils", "Zygote"]

[weakdeps]
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
4 changes: 2 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ MultitargetSRRegressor
equation_search(X::AbstractMatrix{T}, y::AbstractMatrix{T};
niterations::Int=10,
weights::Union{AbstractVector{T}, Nothing}=nothing,
varMap::Union{Array{String, 1}, Nothing}=nothing,
variable_names::Union{Array{String, 1}, Nothing}=nothing,
options::Options=Options(),
numprocs::Union{Int, Nothing}=nothing,
procs::Union{Array{Int, 1}, Nothing}=nothing,
Expand Down Expand Up @@ -59,7 +59,7 @@ eval_grad_tree_array(tree::Node, X::AbstractMatrix, options::Options; kws...)

```@docs
node_to_symbolic(tree::Node, options::Options;
varMap::Union{Array{String, 1}, Nothing}=nothing,
variable_names::Union{Array{String, 1}, Nothing}=nothing,
index_functions::Bool=false)
```

Expand Down
18 changes: 8 additions & 10 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,13 @@ HallOfFame(options::Options, ::Type{T}, ::Type{L}) where {T<:DATA_TYPE,L<:LOSS_T

```@docs
Dataset
Dataset(
X::AbstractMatrix{T},
y::Union{AbstractVector{T},Nothing}=nothing;
weights::Union{AbstractVector{T},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
extra::NamedTuple=NamedTuple(),
loss_type::Type=Nothing,
# Deprecated:
varMap=nothing,
) where {T<:DATA_TYPE}
Dataset(X::AbstractMatrix{T}, y::Union{AbstractVector{T},Nothing}=nothing;
weights::Union{AbstractVector{T}, Nothing}=nothing,
variable_names::Union{Array{String, 1}, Nothing}=nothing,
X_units::Union{AbstractVector, Nothing}=nothing,
y_units=nothing,
extra::NamedTuple=NamedTuple(),
loss_type::Type=Nothing,
)
update_baseline_loss!
```
18 changes: 6 additions & 12 deletions ext/SymbolicRegressionSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,29 @@ import Base: convert
if isdefined(Base, :get_extension)
using SymbolicUtils: Symbolic
import SymbolicRegression: node_to_symbolic, symbolic_to_node
import SymbolicRegression: Node, Options, deprecate_varmap
import SymbolicRegression: Node, Options
else
using ..SymbolicUtils: Symbolic
import ..SymbolicRegression: node_to_symbolic, symbolic_to_node
import ..SymbolicRegression: Node, Options, deprecate_varmap
import ..SymbolicRegression: Node, Options
end

"""
node_to_symbolic(tree::Node, options::Options; kws...)
Convert an expression to SymbolicUtils.jl form.
"""
function node_to_symbolic(
tree::Node, options::Options; variable_names=nothing, varMap=nothing, kws...
)
variable_names = deprecate_varmap(variable_names, varMap, :node_to_symbolic)
return node_to_symbolic(tree, options.operators; varMap=variable_names, kws...)
function node_to_symbolic(tree::Node, options::Options; kws...)
return node_to_symbolic(tree, options.operators; kws...)
end

"""
symbolic_to_node(eqn::Symbolic, options::Options; kws...)
Convert a SymbolicUtils.jl expression to SymbolicRegression.jl's `Node` type.
"""
function symbolic_to_node(
eqn::Symbolic, options::Options; variable_names=nothing, varMap=nothing, kws...
)
variable_names = deprecate_varmap(variable_names, varMap, :symbolic_to_node)
return symbolic_to_node(eqn, options.operators; varMap=variable_names, kws...)
function symbolic_to_node(eqn::Symbolic, options::Options; kws...)
return symbolic_to_node(eqn, options.operators; kws...)
end

function convert(::Type{Symbolic}, tree::Node, options::Options; kws...)
Expand Down
2 changes: 1 addition & 1 deletion src/ConstantOptimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function opt_func(
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
_set_constants!(x, constant_nodes)
# TODO(mcranmer): This should use score_func batching.
loss = eval_loss(tree, dataset, options)
loss = eval_loss(tree, dataset, options, false)
return loss::L
end

Expand Down
165 changes: 157 additions & 8 deletions src/Dataset.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
module DatasetModule

import DynamicQuantities:
AbstractDimensions,
Dimensions,
SymbolicDimensions,
Quantity,
uparse,
sym_uparse,
DEFAULT_DIM_BASE_TYPE

import ..UtilsModule: subscriptify
import ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE
#! format: off
import ...deprecate_varmap
Expand Down Expand Up @@ -28,6 +38,17 @@ import ...deprecate_varmap
`update_baseline_loss!`.
- `variable_names::Array{String,1}`: The names of the features,
with shape `(nfeatures,)`.
- `pretty_variable_names::Array{String,1}`: A version of `variable_names`
but for printing to the terminal (e.g., with unicode versions).
- `y_variable_name::String`: The name of the output variable.
- `X_units`: Unit information of `X`. When used, this is a vector
of `DynamicQuantities.Quantity{<:Any,<:Dimensions}` with shape `(nfeatures,)`.
- `y_units`: Unit information of `y`. When used, this is a single
`DynamicQuantities.Quantity{<:Any,<:Dimensions}`.
- `X_sym_units`: Unit information of `X`. When used, this is a vector
of `DynamicQuantities.Quantity{<:Any,<:SymbolicDimensions}` with shape `(nfeatures,)`.
- `y_sym_units`: Unit information of `y`. When used, this is a single
`DynamicQuantities.Quantity{<:Any,<:SymbolicDimensions}`.
"""
mutable struct Dataset{
T<:DATA_TYPE,
Expand All @@ -36,6 +57,10 @@ mutable struct Dataset{
AY<:Union{AbstractVector{T},Nothing},
AW<:Union{AbstractVector{T},Nothing},
NT<:NamedTuple,
XU<:Union{AbstractVector{<:Quantity},Nothing},
YU<:Union{Quantity,Nothing},
XUS<:Union{AbstractVector{<:Quantity},Nothing},
YUS<:Union{Quantity,Nothing},
}
X::AX
y::AY
Expand All @@ -48,14 +73,24 @@ mutable struct Dataset{
use_baseline::Bool
baseline_loss::L
variable_names::Array{String,1}
pretty_variable_names::Array{String,1}
y_variable_name::String
X_units::XU
y_units::YU
X_sym_units::XUS
y_sym_units::YUS
end

"""
Dataset(X::AbstractMatrix{T}, y::Union{AbstractVector{T},Nothing}=nothing;
weights::Union{AbstractVector{T}, Nothing}=nothing,
variable_names::Union{Array{String, 1}, Nothing}=nothing,
y_variable_name::Union{String,Nothing}=nothing,
extra::NamedTuple=NamedTuple(),
loss_type::Type=Nothing)
loss_type::Type=Nothing,
X_units::Union{AbstractVector, Nothing}=nothing,
y_units=nothing,
)
Construct a dataset to pass between internal functions.
"""
Expand All @@ -64,11 +99,14 @@ function Dataset(
y::Union{AbstractVector{T},Nothing}=nothing;
weights::Union{AbstractVector{T},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
y_variable_name::Union{String,Nothing}=nothing,
extra::NamedTuple=NamedTuple(),
loss_type::Type=Nothing,
loss_type::Type{Linit}=Nothing,
X_units::Union{AbstractVector,Nothing}=nothing,
y_units=nothing,
# Deprecated:
varMap=nothing,
) where {T<:DATA_TYPE}
) where {T<:DATA_TYPE,Linit}
Base.require_one_based_indexing(X)
y !== nothing && Base.require_one_based_indexing(y)
# Deprecation warning:
Expand All @@ -77,8 +115,15 @@ function Dataset(
n = size(X, BATCH_DIM)
nfeatures = size(X, FEATURE_DIM)
weighted = weights !== nothing
if variable_names === nothing
variable_names = ["x$(i)" for i in 1:nfeatures]
(variable_names, pretty_variable_names) = if variable_names === nothing
(["x$(i)" for i in 1:nfeatures], ["x$(subscriptify(i))" for i in 1:nfeatures])
else
(variable_names, variable_names)
end
y_variable_name = if y_variable_name === nothing
("y" variable_names) ? "y" : "target"
else
y_variable_name
end
avg_y = if y === nothing
nothing
Expand All @@ -89,11 +134,46 @@ function Dataset(
sum(y) / n
end
end
loss_type = (loss_type == Nothing) ? T : loss_type
out_loss_type = (Linit === Nothing) ? T : Linit
use_baseline = true
baseline = one(loss_type)
baseline = one(out_loss_type)
D = Dimensions{DEFAULT_DIM_BASE_TYPE}
SD = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE}
y_si_units = get_units(T, D, y_units, uparse)
y_sym_units = get_units(T, SD, y_units, sym_uparse)

return Dataset{T,loss_type,typeof(X),typeof(y),typeof(weights),typeof(extra)}(
# TODO: Refactor
# This basically just ensures that if the `y` units are set,
# then the `X` units are set as well.
X_si_units = let (_X = get_units(T, D, X_units, uparse))
if _X === nothing && y_si_units !== nothing
get_units(T, D, [one(T) for _ in 1:nfeatures], uparse)
else
_X
end
end
X_sym_units = let _X = get_units(T, SD, X_units, sym_uparse)
if _X === nothing && y_sym_units !== nothing
get_units(T, SD, [one(T) for _ in 1:nfeatures], sym_uparse)
else
_X
end
end

error_on_mismatched_size(nfeatures, X_si_units)

return Dataset{
T,
out_loss_type,
typeof(X),
typeof(y),
typeof(weights),
typeof(extra),
typeof(X_si_units),
typeof(y_si_units),
typeof(X_sym_units),
typeof(y_sym_units),
}(
X,
y,
n,
Expand All @@ -105,7 +185,76 @@ function Dataset(
use_baseline,
baseline,
variable_names,
pretty_variable_names,
y_variable_name,
X_si_units,
y_si_units,
X_sym_units,
y_sym_units,
)
end
function Dataset(
X::AbstractMatrix,
y::Union{<:AbstractVector,Nothing}=nothing;
weights::Union{<:AbstractVector,Nothing}=nothing,
kws...,
)
T = promote_type(
eltype(X),
(y === nothing) ? eltype(X) : eltype(y),
(weights === nothing) ? eltype(X) : eltype(weights),
)
X = Base.Fix1(convert, T).(X)
if y !== nothing
y = Base.Fix1(convert, T).(y)
end
if weights !== nothing
weights = Base.Fix1(convert, T).(weights)
end
return Dataset(X, y; weights=weights, kws...)
end

# Base
function get_units(args...)
return error(
"Unit information must be passed as one of `AbstractDimensions`, `AbstractQuantity`, `AbstractString`, `Real`.",
)
end
function get_units(_, _, ::Nothing, ::Function)
return nothing
end
function get_units(::Type{T}, ::Type{D}, x::AbstractString, f::Function) where {T,D}
return convert(Quantity{T,D}, f(x))
end
function get_units(::Type{T}, ::Type{D}, x::Quantity, ::Function) where {T,D}
return convert(Quantity{T,D}, x)
end
function get_units(::Type{T}, ::Type{D}, x::AbstractDimensions, ::Function) where {T,D}
return convert(Quantity{T,D}, Quantity(one(T), x))
end
function get_units(::Type{T}, ::Type{D}, x::Real, ::Function) where {T,D}
return Quantity(convert(T, x), D)::Quantity{T,D}
end

# Derived
function get_units(::Type{T}, ::Type{D}, x::AbstractVector, f::Function) where {T,D}
return Quantity{T,D}[get_units(T, D, xi, f) for xi in x]
end

function error_on_mismatched_size(_, ::Nothing)
return nothing
end
function error_on_mismatched_size(nfeatures, X_units::AbstractVector)
if nfeatures != length(X_units)
error(
"Number of features ($(nfeatures)) does not match number of units ($(length(X_units)))",
)
end
return nothing
end

function has_units(dataset::Dataset)
return dataset.X_units !== nothing || dataset.y_units !== nothing
end

end

0 comments on commit 3ad61e7

Please sign in to comment.