Skip to content

Commit

Permalink
implement changes in docs - first attempt with passing tests
Browse files Browse the repository at this point in the history
add forgotten file
  • Loading branch information
ablaom committed Sep 10, 2020
1 parent e482296 commit 6f83bca
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 155 deletions.
4 changes: 4 additions & 0 deletions src/MLJTuning.jl
Expand Up @@ -9,6 +9,9 @@ export TunedModel
# defined in strategies/:
export Explicit, Grid, RandomSearch

# defined in selection_heuristics/:
export OptimizePrimaryAggregatedMeasurement

# defined in learning_curves.jl:
export learning_curve!, learning_curve

Expand Down Expand Up @@ -37,6 +40,7 @@ const DEFAULT_N = 10 # for when `default_n` is not implemented

include("utilities.jl")
include("tuning_strategy_interface.jl")
include("selection_heuristics.jl")
include("tuned_models.jl")
include("range_methods.jl")
include("strategies/explicit.jl")
Expand Down
19 changes: 19 additions & 0 deletions src/selection_heuristics.jl
@@ -0,0 +1,19 @@
abstract type SelectionHeuristic end


## OPTIMIZE AGGREGATED MEASURE

struct OptimizePrimaryAggregatedMeasurement <: SelectionHeuristic end

function best(heuristic::OptimizePrimaryAggregatedMeasurement, history)
measurements = [h.measurement[1] for h in history]
measure = first(history).measure[1]
if orientation(measure) == :score
measurements = -measurements
end
best_index = argmin(measurements)
return history[best_index]
end

MLJTuning.supports_heuristic(::Any, ::OptimizePrimaryAggregatedMeasurement) =
true
10 changes: 2 additions & 8 deletions src/strategies/grid.jl
Expand Up @@ -160,14 +160,8 @@ MLJTuning.models!(tuning::Grid,
verbosity) =
state.models[_length(history) + 1:end]

function tuning_report(tuning::Grid, history, state)

plotting = plotting_report(state.fields, state.parameter_scales, history)

# todo: remove collects?
return (history=history, plotting=plotting)

end
tuning_report(tuning::Grid, history, state) =
(plotting = plotting_report(state.fields, state.parameter_scales, history),)

function default_n(tuning::Grid, user_range)

Expand Down
2 changes: 1 addition & 1 deletion src/strategies/random_search.jl
Expand Up @@ -141,6 +141,6 @@ function tuning_report(tuning::RandomSearch, history, field_sampler_pairs)

plotting = plotting_report(fields, parameter_scales, history)

return (history=history, plotting=plotting)
return (plotting=plotting,)

end

0 comments on commit 6f83bca

Please sign in to comment.