Skip to content

Commit

Permalink
Merge 6f83bca into 9ea8076
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Sep 10, 2020
2 parents 9ea8076 + 6f83bca commit 1faaf51
Show file tree
Hide file tree
Showing 11 changed files with 405 additions and 303 deletions.
355 changes: 207 additions & 148 deletions README.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/MLJTuning.jl
Expand Up @@ -9,6 +9,9 @@ export TunedModel
# defined in strategies/:
export Explicit, Grid, RandomSearch

# defined in selection_heuristics/:
export OptimizePrimaryAggregatedMeasurement

# defined in learning_curves.jl:
export learning_curve!, learning_curve

Expand Down Expand Up @@ -37,6 +40,7 @@ const DEFAULT_N = 10 # for when `default_n` is not implemented

include("utilities.jl")
include("tuning_strategy_interface.jl")
include("selection_heuristics.jl")
include("tuned_models.jl")
include("range_methods.jl")
include("strategies/explicit.jl")
Expand Down
19 changes: 19 additions & 0 deletions src/selection_heuristics.jl
@@ -0,0 +1,19 @@
abstract type SelectionHeuristic end


## OPTIMIZE AGGREGATED MEASURE

struct OptimizePrimaryAggregatedMeasurement <: SelectionHeuristic end

function best(heuristic::OptimizePrimaryAggregatedMeasurement, history)
measurements = [h.measurement[1] for h in history]
measure = first(history).measure[1]
if orientation(measure) == :score
measurements = -measurements
end
best_index = argmin(measurements)
return history[best_index]
end

MLJTuning.supports_heuristic(::Any, ::OptimizePrimaryAggregatedMeasurement) =
true
10 changes: 2 additions & 8 deletions src/strategies/grid.jl
Expand Up @@ -160,14 +160,8 @@ MLJTuning.models!(tuning::Grid,
verbosity) =
state.models[_length(history) + 1:end]

function tuning_report(tuning::Grid, history, state)

plotting = plotting_report(state.fields, state.parameter_scales, history)

# todo: remove collects?
return (history=history, plotting=plotting)

end
tuning_report(tuning::Grid, history, state) =
(plotting = plotting_report(state.fields, state.parameter_scales, history),)

function default_n(tuning::Grid, user_range)

Expand Down
2 changes: 1 addition & 1 deletion src/strategies/random_search.jl
Expand Up @@ -141,6 +141,6 @@ function tuning_report(tuning::RandomSearch, history, field_sampler_pairs)

plotting = plotting_report(fields, parameter_scales, history)

return (history=history, plotting=plotting)
return (plotting=plotting,)

end

0 comments on commit 1faaf51

Please sign in to comment.