Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLJ Integration #226

Merged
merged 66 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
1b963d1
Clean up printing in `Options`
MilesCranmer Jul 3, 2023
1c241e0
Working MLJModelInterface integration
MilesCranmer Jul 3, 2023
3cea91f
Return more report keys
MilesCranmer Jul 3, 2023
6239790
Formatting
MilesCranmer Jul 3, 2023
82a8c0b
Move selection functions
MilesCranmer Jul 4, 2023
7df1309
Have user pass full `Options()` to MLJ model
MilesCranmer Jul 4, 2023
3b7e3a7
Fix import order of MLJ
MilesCranmer Jul 4, 2023
32b2003
Deprecate `return_state` in options
MilesCranmer Jul 4, 2023
a298e13
Formatting
MilesCranmer Jul 4, 2023
edfb424
Automatically forward all option kwargs to MLJ
MilesCranmer Jul 4, 2023
6cc15d3
Store `Options` in fitresult to avoid recreation
MilesCranmer Jul 4, 2023
afe571f
Fix incorrect deprecated kwarg
MilesCranmer Jul 4, 2023
b70fbfd
Rename import to `MMI`
MilesCranmer Jul 4, 2023
0be9e29
Add MLJ metadata
MilesCranmer Jul 4, 2023
5047355
Add test of MLJ interface
MilesCranmer Jul 4, 2023
0509ee1
Turn off `variable_names == keys` behavior
MilesCranmer Jul 4, 2023
1ccf2c1
Fix MLJ output for multi-out
MilesCranmer Jul 4, 2023
d54f515
Formatting
MilesCranmer Jul 4, 2023
05426d1
Get iterative model fits to work with MLJ
MilesCranmer Jul 4, 2023
7f746fc
Expand `target_scitype` to include tables
MilesCranmer Jul 4, 2023
4d7a102
Specify input types in `Options` constructor
MilesCranmer Jul 4, 2023
19da06b
Missing parentheses
MilesCranmer Jul 4, 2023
d4d1e67
Only continuous input allowed
MilesCranmer Jul 4, 2023
a4b43d4
Split into MultiSRRegressor for multi-output
MilesCranmer Jul 4, 2023
8265113
Test MultiSRRegressor
MilesCranmer Jul 4, 2023
14834bb
Undo typing of operators
MilesCranmer Jul 4, 2023
03ac6c7
Other typing fixes
MilesCranmer Jul 4, 2023
f519d63
Allow `npopulations` to be nothing
MilesCranmer Jul 4, 2023
945bbc8
Automatically get column names using MMI.schema
MilesCranmer Jul 4, 2023
8cb7625
Fix `getcolnames`
MilesCranmer Jul 4, 2023
75ff44b
Update src/MLJInterface.jl
MilesCranmer Jul 4, 2023
2af66b7
Rename to MultitargetSRRegressor
MilesCranmer Jul 4, 2023
bab79e2
Formatting
MilesCranmer Jul 4, 2023
5680300
Ensure matrices respect column-major
MilesCranmer Jul 4, 2023
af441c4
Brute force way to get column names
MilesCranmer Jul 4, 2023
734e34b
Copy the correct way
MilesCranmer Jul 4, 2023
2edbcb4
Copy the correct way using permutedims
MilesCranmer Jul 4, 2023
e567a24
Safety check for `SRRegressor`
MilesCranmer Jul 4, 2023
08456ff
Fix copying in prediction as well
MilesCranmer Jul 4, 2023
0e95937
Update src/MLJInterface.jl
MilesCranmer Jul 4, 2023
f96751c
Fix `variable_names` and pass to `report` as well
MilesCranmer Jul 4, 2023
50bfaec
Add variable names test
MilesCranmer Jul 4, 2023
47fdb69
Allow normal matrix input to models
MilesCranmer Jul 4, 2023
7a16f7c
Add test of predictions
MilesCranmer Jul 4, 2023
2222347
Add tests for helpful error messages
MilesCranmer Jul 4, 2023
6139bba
Use `metadata_pkg` and `metadata_model` instead
MilesCranmer Jul 4, 2023
e585bd4
Avoid `ndims` for better portability
MilesCranmer Jul 4, 2023
38758dc
Missing `predict` import
MilesCranmer Jul 4, 2023
bd70964
Clean up `metadata_pkg`
MilesCranmer Jul 4, 2023
b2d8542
Avoid `size` when table is passed
MilesCranmer Jul 4, 2023
a8baf7b
Automatically flatten `y` if needed
MilesCranmer Jul 4, 2023
d417ef0
Generate MLJ docstrings
MilesCranmer Jul 4, 2023
79ba7a6
Document `selection_method`
MilesCranmer Jul 4, 2023
09fb9b5
Fix deprecated description parameter
MilesCranmer Jul 4, 2023
db5be13
Store schema in fitresult for consistent output
MilesCranmer Jul 4, 2023
ad88b44
Missing `schema` return value
MilesCranmer Jul 4, 2023
b20270d
Revert storage of schema in `fitresult`
MilesCranmer Jul 4, 2023
6bc8bc4
Always validate variable names
MilesCranmer Jul 4, 2023
3fbb4c3
Correct behavior for multitarget output
MilesCranmer Jul 4, 2023
dc6533e
Add detailed docstrings for MLJ interface
MilesCranmer Jul 4, 2023
e8a8398
Give example of printing equations in example
MilesCranmer Jul 4, 2023
2d9cc48
Fix incorrect string interpolation
MilesCranmer Jul 4, 2023
c2a6c1f
Add `Report` section to docstrings
MilesCranmer Jul 5, 2023
66382d7
Make sample weights work for MultitargetSRRegressor
MilesCranmer Jul 5, 2023
0fdcb56
Fix PySR compatibility in `Options` constructor
MilesCranmer Jul 5, 2023
be60f13
Bump version with MLJ integration
MilesCranmer Jul 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand All @@ -33,6 +35,8 @@ DynamicExpressions = "0.9"
JSON3 = "1"
LineSearches = "7"
LossFunctions = "0.6, 0.7, 0.8, 0.10"
MLJModelInterface = "1.5, 1.6, 1.7, 1.8"
MacroTools = "0.4, 0.5"
Optim = "0.19, 1.1"
Pkg = "1"
PrecompileTools = "1"
Expand All @@ -47,10 +51,12 @@ julia = "1.6"
[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets", "ForwardDiff", "LinearAlgebra", "SymbolicUtils", "Zygote"]
test = ["Test", "SafeTestsets", "ForwardDiff", "LinearAlgebra", "MLJBase", "MLJTestInterface", "SymbolicUtils", "Zygote"]
40 changes: 40 additions & 0 deletions src/HallOfFame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,44 @@ function string_dominating_pareto_curve(
return output
end

function format_hall_of_fame(
hof::HallOfFame{T,L}, options, baseline_loss::L
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
ZERO_POINT = L(1e-10)

dominating = calculate_pareto_frontier(hof)
cur_loss = baseline_loss
last_loss = cur_loss
last_complexity = 0

trees = [member.tree for member in dominating]
losses = [member.loss for member in dominating]
complexities = [compute_complexity(member, options) for member in dominating]
scores = Array{L}(undef, length(dominating))

for i in 1:length(dominating)
complexity = complexities[i]
cur_loss = losses[i]
delta_c = complexity - last_complexity
delta_l_mse = log(abs(cur_loss / last_loss) + ZERO_POINT)

scores[i] = -delta_l_mse / delta_c
last_loss = cur_loss
last_complexity = complexity
end
return (; trees=trees, scores=scores, losses=losses, complexities=complexities)
end
function format_hall_of_fame(
hof::AH, options, baseline_loss
) where {T,L,H<:HallOfFame{T,L},AH<:AbstractVector{H}}
outs = [format_hall_of_fame(h, options, baseline_loss) for h in hof]
return (;
trees=[out.trees for out in outs],
scores=[out.scores for out in outs],
losses=[out.losses for out in outs],
complexities=[out.complexities for out in outs],
)
end
# TODO: Re-use this in `string_dominating_pareto_curve`

end
322 changes: 322 additions & 0 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
module MLJInterfaceModule

using Optim: Optim
import MLJModelInterface as MMI
import DynamicExpressions: eval_tree_array, string_tree, Node
import LossFunctions: SupervisedLoss
import ..CoreModule: Options, Dataset, MutationWeights, LOSS_TYPE
import ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS
import ..ComplexityModule: compute_complexity
import ..HallOfFameModule: HallOfFame, calculate_pareto_frontier, format_hall_of_fame
#! format: off
import ..equation_search
#! format: on

abstract type AbstractSRRegressor <: MMI.Deterministic end

# TODO: To reduce code re-use, we could forward these defaults from
# `equation_search`, similar to what we do for `Options`.

"""Generate an `SRRegressor` struct containing all the fields in `Options`."""
function modelexpr(model_name::Symbol)
struct_def = :(Base.@kwdef mutable struct $(model_name) <: AbstractSRRegressor
niterations::Int = 10
parallelism::Symbol = :multithreading
numprocs::Union{Int,Nothing} = nothing
procs::Union{Vector{Int},Nothing} = nothing
addprocs_function::Union{Function,Nothing} = nothing
runtests::Bool = true
loss_type::Type = Nothing
selection_method::Function = choose_best
end)
fields = last(last(struct_def.args).args).args

# Add everything from `Options` constructor directly to struct:
for (i, option) in enumerate(DEFAULT_OPTIONS)
insert!(fields, i, Expr(:(=), option.args...))
end

# We also need to create the `get_options` function, based on this:
constructor = :(Options(;))
constructor_fields = last(constructor.args).args
for option in DEFAULT_OPTIONS
symb = getsymb(first(option.args))
push!(constructor_fields, Expr(:kw, symb, Expr(:(.), :m, Core.QuoteNode(symb))))
end

return quote
$struct_def
function get_options(m::$(model_name))
return $constructor
end
end
end
function getsymb(ex::Symbol)
return ex
end
function getsymb(ex::Expr)
for arg in ex.args
isa(arg, Symbol) && return arg
s = getsymb(arg)
isa(s, Symbol) && return s
end
return nothing
end

"""Get an equivalent `Options()` object for a particular regressor."""
function get_options(::AbstractSRRegressor) end

eval(modelexpr(:SRRegressor))
eval(modelexpr(:MultitargetSRRegressor))

# Cleaning already taken care of by `Options` and `equation_search`
function full_report(m::AbstractSRRegressor, fitresult)
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
_, hof = fitresult.state
# TODO: Adjust baseline loss
formatted = format_hall_of_fame(hof, fitresult.options, 1.0)
equation_strings = get_equation_strings_for(
m, formatted.trees, fitresult.options, fitresult.variable_names
)
best_idx = dispatch_selection_for(
m, formatted.trees, formatted.losses, formatted.scores, formatted.complexities
)
return (;
best_idx=best_idx,
equations=formatted.trees,
equation_strings=equation_strings,
losses=formatted.losses,
complexities=formatted.complexities,
scores=formatted.scores,
)
end

MMI.clean!(::AbstractSRRegressor) = ""

# TODO: Enable `verbosity` being passed to `equation_search`
function MMI.fit(m::AbstractSRRegressor, verbosity, X, y, w=nothing)
return MMI.update(m, verbosity, (; state=nothing), nothing, X, y, w)
end
function MMI.update(
m::AbstractSRRegressor, verbosity, old_fitresult, old_cache, X, y, w=nothing
)
options = get(old_fitresult, :options, get_options(m))
X_t, variable_names = get_matrix_and_colnames(X)
y_t, y_variable_names = format_input_for(m, y)
search_state = equation_search(
X_t,
y_t;
niterations=m.niterations,
weights=w,
variable_names=variable_names,
options=options,
parallelism=m.parallelism,
numprocs=m.numprocs,
procs=m.procs,
addprocs_function=m.addprocs_function,
runtests=m.runtests,
saved_state=old_fitresult.state,
return_state=true,
loss_type=m.loss_type,
)
fitresult = (;
state=search_state,
options=options,
variable_names=variable_names,
y_variable_names=y_variable_names,
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
)
return (fitresult, nothing, full_report(m, fitresult))
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
end
function get_matrix_and_colnames(X)
sch = MMI.istable(X) ? MMI.schema(X) : nothing
Xm_t = MMI.matrix(X; transpose=true)
colnames = if sch === nothing
[map(i -> "x$(i)", axes(Xm_t, 1))...]
else
[string.(sch.names)...]
end
return Xm_t, colnames
end

function format_input_for(::SRRegressor, y)
@assert(
!(MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1)),
"For multi-output regression, please use `MultitargetSRRegressor`."
)
return vec(y), nothing
end
function format_input_for(::MultitargetSRRegressor, y)
@assert(
MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1),
"For single-output regression, please use `SRRegressor`."
)
return get_matrix_and_colnames(y)
end
function validate_variable_names(variable_names, fitresult)
@assert(
variable_names == fitresult.variable_names,
"Variable names do not match fitted regressor."
)
end

function MMI.fitted_params(m::AbstractSRRegressor, fitresult)
report = full_report(m, fitresult)
return (;
best_idx=report.best_idx,
equations=report.equations,
equation_strings=report.equation_strings,
)
end
function MMI.predict(m::SRRegressor, fitresult, Xnew)
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
params = MMI.fitted_params(m, fitresult)
Xnew_t, variable_names = get_matrix_and_colnames(Xnew)
validate_variable_names(variable_names, fitresult)
eq = params.equations[params.best_idx]
out, flag = eval_tree_array(eq, Xnew_t, fitresult.options)
!flag && error("Detected a NaN in evaluating expression.")
return out
end
function MMI.predict(m::MultitargetSRRegressor, fitresult, Xnew)
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
params = MMI.fitted_params(m, fitresult)
Xnew_t, variable_names = get_matrix_and_colnames(Xnew)
validate_variable_names(variable_names, fitresult)
equations = params.equations
best_idx = params.best_idx
outs = [
let (out, flag) = eval_tree_array(eq[i], Xnew_t, fitresult.options)
!flag && error("Detected a NaN in evaluating expression.")
out
end for (i, eq) in zip(best_idx, equations)
]
return reduce(hcat, outs)
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
end

function get_equation_strings_for(::SRRegressor, trees, options, variable_names)
return (t -> string_tree(t, options; variable_names=variable_names)).(trees)
end
function get_equation_strings_for(::MultitargetSRRegressor, trees, options, variable_names)
return [
(t -> string_tree(t, options; variable_names=variable_names)).(ts) for ts in trees
]
end

function choose_best(; trees, losses::Vector{L}, scores, complexities) where {L<:LOSS_TYPE}
# Same as in PySR:
# https://github.com/MilesCranmer/PySR/blob/e74b8ad46b163c799908b3aa4d851cf8457c79ef/pysr/sr.py#L2318-L2332
# threshold = 1.5 * minimum_loss
# Then, we get max score of those below the threshold.
threshold = 1.5 * minimum(losses)
return argmax([
(losses[i] <= threshold) ? scores[i] : typemin(L) for i in eachindex(losses)
])
end

function dispatch_selection_for(m::SRRegressor, trees, losses, scores, complexities)
return m.selection_method(;
trees=trees, losses=losses, scores=scores, complexities=complexities
)::Integer
end
function dispatch_selection_for(
m::MultitargetSRRegressor, trees, losses, scores, complexities
)
return [
m.selection_method(;
trees=trees[i], losses=losses[i], scores=scores[i], complexities=complexities[i]
)::Integer for i in eachindex(trees)
]
end

MMI.metadata_pkg(
AbstractSRRegressor;
name="SymbolicRegression",
uuid="8254be44-1295-4e6a-a16d-46603ac705cb",
url="https://github.com/MilesCranmer/SymbolicRegression.jl",
julia=true,
license="Apache-2.0",
is_wrapper=false,
)

MMI.metadata_model(
SRRegressor;
input_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}},
target_scitype=AbstractVector{<:MMI.Continuous},
supports_weights=true,
reports_feature_importances=false,
load_path="SymbolicRegression.MLJInterfaceModule.SRRegressor",
human_name="Symbolic Regression via Evolutionary Search",
)
MMI.metadata_model(
MultitargetSRRegressor;
input_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}},
target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}},
supports_weights=true,
reports_feature_importances=false,
load_path="SymbolicRegression.MLJInterfaceModule.MultitargetSRRegressor",
human_name="Multi-Target Symbolic Regression via Evolutionary Search",
)

function tag_with_docstring(model_name::Symbol, description::String)
docstring = """$(MMI.doc_header(eval(model_name)))

# Arguments
"""

# TODO: These ones are copied (or written) manually:
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
append_arguments = """- `niterations::Int=10`: The number of iterations to perform the search.
More iterations will improve the results.
- `parallelism=:multithreading`: What parallelism mode to use.
The options are `:multithreading`, `:multiprocessing`, and `:serial`.
By default, multithreading will be used. Multithreading uses less memory,
but multiprocessing can handle multi-node compute. If using `:multithreading`
mode, the number of threads available to julia are used. If using
`:multiprocessing`, `numprocs` processes will be created dynamically if
`procs` is unset. If you have already allocated processes, pass them
to the `procs` argument and they will be used.
You may also pass a string instead of a symbol, like `"multithreading"`.
- `numprocs::Union{Int, Nothing}=nothing`: The number of processes to use,
if you want `equation_search` to set this up automatically. By default
this will be `4`, but can be any number (you should pick a number <=
the number of cores available).
- `procs::Union{Vector{Int}, Nothing}=nothing`: If you have set up
a distributed run manually with `procs = addprocs()` and `@everywhere`,
pass the `procs` to this keyword argument.
- `addprocs_function::Union{Function, Nothing}=nothing`: If using multiprocessing
(`parallelism=:multithreading`), and are not passing `procs` manually,
then they will be allocated dynamically using `addprocs`. However,
you may also pass a custom function to use instead of `addprocs`.
This function should take a single positional argument,
which is the number of processes to use, as well as the `lazy` keyword argument.
For example, if set up on a slurm cluster, you could pass
`addprocs_function = addprocs_slurm`, which will set up slurm processes.
- `runtests::Bool=true`: Whether to run (quick) tests before starting the
search, to see if there will be any problems during the equation search
related to the host environment.
- `loss_type::Type=Nothing`: If you would like to use a different type
for the loss than for the data you passed, specify the type here.
Note that if you pass complex data `::Complex{L}`, then the loss
type will automatically be set to `L`.
- `selection_method::Function`: Function to selection expression from
the Pareto frontier for use in `predict`. See `SymbolicRegression.MLJInterfaceModule.choose_best`
for an example. This function should return a single integer specifying
the index of the expression to use. By default, `choose_best` maximizes
the score (a pound-for-pound rating) of expressions reaching the threshold
of 1.5x the minimum loss. To fix the index at `5`, you could just write `Returns(5)`.
"""

# Remove common indentation:
docstring = replace(docstring, r"^ " => "")
extra_arguments = replace(append_arguments, r"^ " => "")

# Add parameter descriptions:
docstring = docstring * OPTION_DESCRIPTIONS
docstring = docstring * extra_arguments
return quote
@doc $docstring $model_name
end
end

#! format: off
eval(tag_with_docstring(:SRRegressor, "Symbolic Regression via Evolutionary Search"))
eval(tag_with_docstring(:MultitargetSRRegressor, "Multi-Target Symbolic Regression via Evolutionary Search"))
#! format: on

end