diff --git a/src/tuning.jl b/src/tuning.jl index e170aad45..a1eca5000 100644 --- a/src/tuning.jl +++ b/src/tuning.jl @@ -1,29 +1,42 @@ abstract type TuningStrategy <: MLJ.MLJType end -mutable struct Grid <: TuningStrategy - resolution::Int - parallel::Bool +""" +$TYPEDEF + +Grid object for a grid search tuning strategy. +""" +@with_kw mutable struct Grid <: TuningStrategy + resolution::Int = 10 + parallel::Bool = true end -Grid(;resolution=10, parallel=true) = Grid(resolution, parallel) MLJBase.show_as_constructed(::Type{<:Grid}) = true +""" +$TYPEDEF +Container for a deterministic tuning strategy. +""" mutable struct DeterministicTunedModel{T,M<:Deterministic} <: MLJ.Deterministic model::M - tuning::T # tuning strategy + tuning::T # tuning strategy resampling # resampling strategy measure operation ranges::Union{Vector,ParamRange} minimize::Bool full_report::Bool - train_best::Bool + train_best::Bool end +""" +$TYPEDEF + +Container for a probabilistic tuning strategy. +""" mutable struct ProbabilisticTunedModel{T,M<:Probabilistic} <: MLJ.Probabilistic model::M - tuning::T # tuning strategy + tuning::T # tuning strategy resampling # resampling strategy measure operation @@ -76,13 +89,13 @@ function TunedModel(;model=nothing, minimize=true, full_report=true, train_best=true) - + !isempty(ranges) || error("You need to specify ranges=... ") model !== nothing || error("You need to specify model=... ") message = clean!(model) isempty(message) || @info message - + if model isa Deterministic return DeterministicTunedModel(model, tuning, resampling, measure, operation, ranges, minimize, full_report, train_best) @@ -111,13 +124,13 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y end ranges isa Vector{<:ParamRange} || - error("ranges must be a ParamRange object or a vector of "* + error("ranges must be a ParamRange object or a vector of " * "ParamRange objects. ") - - + + parameter_names = [string(r.field) for r in ranges] scales = [scale(r) for r in ranges] - + # the mutating model: clone = deepcopy(tuned_model.model) @@ -137,7 +150,7 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y elseif range isa MLJ.NumericRange MLJ.iterator(range, tuned_model.tuning.resolution) else - throw(TypeError(:iterator, "", MLJ.ParamRange, rrange)) + throw(TypeError(:iterator, "", MLJ.ParamRange, rrange)) end end @@ -154,10 +167,8 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y # initialize search for best model: best_model = deepcopy(tuned_model.model) - best_measurement = - tuned_model.minimize ? Inf : -Inf - s = - tuned_model.minimize ? 1 : -1 + best_measurement = ifelse(tuned_model.minimize, Inf, -Inf) + s = ifelse(tuned_model.minimize, 1, -1) # evaluate all the models using specified resampling: # TODO: parallelize! @@ -165,13 +176,13 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y meter = Progress(N+1, dt=0, desc="Iterating over a $N-point grid: ", barglyphs=BarGlyphs("[=> ]"), barlen=25, color=:yellow) verbosity != 1 || next!(meter) - for i in 1:N + for i in 1:N verbosity != 1 || next!(meter) A_row = Tuple(A[i,:]) - # new_params = copy(nested_iterators, A_row) + # new_params = copy(nested_iterators, A_row) # mutate `clone` (the model to which `resampler` points): for k in 1:n_iterators @@ -181,19 +192,17 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y if verbosity == 2 fit!(resampling_machine, verbosity=0) - else fit!(resampling_machine, verbosity=verbosity-1) end e = mean(evaluate(resampling_machine)) if verbosity > 1 - text = reduce(*, ["$(parameter_names[j])=$(A_row[j]) \t" - for j in 1:length(A_row)]) + text = prod("$(parameter_names[j])=$(A_row[j]) \t" for j in 1:length(A_row)) text *= "measurement=$e" println(text) end - + if s*(best_measurement - e) > 0 best_model = deepcopy(clone) best_measurement = e @@ -203,7 +212,7 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y # models[i] = deepcopy(clone) measurements[i] = e end - + end if tuned_model.train_best @@ -234,11 +243,11 @@ function MLJBase.fit(tuned_model::EitherTunedModel{Grid,M}, verbosity::Int, X, y measurements=[best_measurement, ][1:0], # empty vector best_measurement=best_measurement) end - + cache = nothing - + return fitresult, cache, report - + end MLJBase.fitted_params(::EitherTunedModel, fitresult) = (best_model=fitresult.model,) @@ -268,7 +277,7 @@ MLJBase.target_scitype_union(::Type{<:ProbabilisticTunedModel{T,M}}) where {T,M} MLJBase.input_is_multivariate(::Type{<:ProbabilisticTunedModel{T,M}}) where {T,M} = MLJBase.input_is_multivariate(M) -## LEARNING CURVES +## LEARNING CURVES """ curve = learning_curve!(mach; resolution=30, resampling=Holdout(), measure=rms, operation=predict, range=nothing, n=1) @@ -310,9 +319,9 @@ plot(curves.parameter_values, curves.measurements, xlab=curves.parameter_name) """ function learning_curve!(mach::Machine{<:Supervised}; - resolution=30, - resampling=Holdout(), - measure=rms, operation=predict, range=nothing, verbosity=1, n=1) + resolution=30, resampling=Holdout(), + measure=rms, operation=predict, + range=nothing, verbosity=1, n=1) range !== nothing || error("No param range specified. Use range=... ") @@ -328,14 +337,10 @@ function learning_curve!(mach::Machine{<:Supervised}; parameter_name=report.parameter_names[1] parameter_scale=report.parameter_scales[1] parameter_values=[report.parameter_values[:, 1]...] - measurements_ = - n == 1 ? [measurements...] : measurements - + measurements_ = (n == 1) ? [measurements...] : measurements + return (parameter_name=parameter_name, parameter_scale=parameter_scale, parameter_values=parameter_values, measurements = measurements_) end - - -