Skip to content

Commit

Permalink
Store Options in fitresult to avoid recreation
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jul 4, 2023
1 parent ef9c89f commit 560fd78
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
52 changes: 28 additions & 24 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,31 @@ import ..equation_search

abstract type AbstractSRRegressor <: MLJModelInterface.Deterministic end

const sr_regressor_base = :(Base.@kwdef mutable struct SRRegressor <: AbstractSRRegressor
niterations::Int = 10
parallelism::Symbol = :multithreading
numprocs::Union{Int,Nothing} = nothing
procs::Union{Vector{Int},Nothing} = nothing
addprocs_function::Union{Function,Nothing} = nothing
runtests::Bool = true
loss_type::Type = Nothing
selection_method::Function = choose_best
end)

"""Generate an`SRRegressor` struct containing all the fields in `Options`."""
const sr_regressor_template =
:(Base.@kwdef mutable struct SRRegressor <: AbstractSRRegressor
niterations::Int = 10
parallelism::Symbol = :multithreading
numprocs::Union{Int,Nothing} = nothing
procs::Union{Vector{Int},Nothing} = nothing
addprocs_function::Union{Function,Nothing} = nothing
runtests::Bool = true
loss_type::Type = Nothing
selection_method::Function = choose_best
end)
# TODO: To reduce code re-use, we could forward these defaults from
# `equation_search`, similar to what we do for `Options`.

"""Generate an `SRRegressor` struct containing all the fields in `Options`."""
function modelexpr()
# MLJModelInterface.@mlj_model
struct_def = copy(sr_regressor_base)
struct_def = deepcopy(sr_regressor_template)
fields = last(last(struct_def.args).args).args
i = 1

# Add everything from `Options` constructor directly to struct:
for option in DEFAULT_OPTIONS
for (i, option) in enumerate(DEFAULT_OPTIONS)
insert!(fields, i, Expr(:(=), option.args...))
i += 1
end

# We also need to create the `get_options` function, based on this:
constructor = :(Options(;))
constructor_fields = last(constructor.args).args
for option in DEFAULT_OPTIONS
Expand Down Expand Up @@ -69,10 +71,10 @@ eval(modelexpr())

# Cleaning already taken care of by `Options` and `equation_search`
function full_report(m::AbstractSRRegressor, fitresult)
_, hof = fitresult
_, hof = fitresult.state
# TODO: Adjust baseline loss
formatted = format_hall_of_fame(hof, get_options(m), 1.0)
equation_strings = get_equation_strings(formatted.trees, get_options(m))
formatted = format_hall_of_fame(hof, fitresult.options, 1.0)
equation_strings = get_equation_strings(formatted.trees, fitresult.options)
best_idx = dispatch_selection(
m.selection_method,
formatted.trees,
Expand All @@ -95,13 +97,14 @@ end
MLJModelInterface.clean!(::AbstractSRRegressor) = ""

function MLJModelInterface.fit(m::AbstractSRRegressor, verbosity, X, y, w=nothing)
fitresult = equation_search(
options = get_options(m)
search_state = equation_search(
X,
y;
niterations=m.niterations,
weights=w,
variable_names=nothing,
options=get_options(m),
options=options,
parallelism=m.parallelism,
numprocs=m.numprocs,
procs=m.procs,
Expand All @@ -111,6 +114,7 @@ function MLJModelInterface.fit(m::AbstractSRRegressor, verbosity, X, y, w=nothin
return_state=true,
loss_type=m.loss_type,
)
fitresult = (; state=search_state, options=options)
return (fitresult, nothing, full_report(m, fitresult))
end
function MLJModelInterface.fitted_params(m::AbstractSRRegressor, fitresult)
Expand All @@ -127,14 +131,14 @@ function MLJModelInterface.predict(m::AbstractSRRegressor, fitresult, Xnew)
best_idx = params.best_idx
if isa(best_idx, Vector)
outs = [
let out, flag = eval_tree_array(eq[i], Xnew, get_options(m))
let out, flag = eval_tree_array(eq[i], Xnew, fitresult.options)
!flag && error("Detected a NaN in evaluating expression.")
out
end for (i, eq) in zip(best_idx, equations)
]
return reduce(hcat, outs)
else
out, flag = eval_tree_array(equations[best_idx], Xnew, get_options(m))
out, flag = eval_tree_array(equations[best_idx], Xnew, fitresult.options)
!flag && error("Detected a NaN in evaluating expression.")
return out
end
Expand Down
1 change: 0 additions & 1 deletion src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,5 +185,4 @@ function _save_kwargs(log_variable::Symbol, fdef::Expr)
end
end


end

0 comments on commit 560fd78

Please sign in to comment.