Skip to content

Commit

Permalink
Merge c97f594 into d8923be
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Oct 20, 2020
2 parents d8923be + c97f594 commit a68b065
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,7 +1,7 @@
name = "MLJTuning"
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.5.2"
version = "0.5.3"

[deps]
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Expand Down
15 changes: 7 additions & 8 deletions src/selection_heuristics.jl
Expand Up @@ -39,20 +39,19 @@ with measures that are neither `:loss` nor `:score` are reset to zero.
struct NaiveSelection <: SelectionHeuristic
weights::Union{Nothing, Vector{Real}}
end
NaiveSelection(; weights=nothing) =
NaiveSelection(weights)

function NaiveSelection(; weights=nothing)
if weights isa Vector
all(x -> x >= 0, weights) ||
error("`weights` must be non-negative. ")
end
return NaiveSelection(weights)
end

function best(heuristic::NaiveSelection, history)
first_entry = history[1]
measures = first_entry.measure
weights = measure_adjusted_weights(heuristic.weights, measures)
measurements = [weights'*(h.measurement) for h in history]
measure = first(history).measure[1]
if orientation(measure) == :score
measurements = -measurements

end
best_index = argmin(measurements)
return history[best_index]
end
Expand Down
34 changes: 34 additions & 0 deletions test/selection_heuristics.jl
@@ -1,5 +1,39 @@
using .Models

measures = [accuracy, confmat, misclassification_rate]
@test MLJTuning.measure_adjusted_weights([2, 3, 4], measures) == [-2, 0, 4]
@test MLJTuning.measure_adjusted_weights(nothing, measures) == [-1, 0, 0]
@test_throws(DimensionMismatch,
MLJTuning.measure_adjusted_weights([2, 3], measures))

@testset "losses/scores get minimized/maximimized" begin
bad_model = KNNClassifier(K=100)
good_model = KNNClassifier(K=5)

am = [accuracy, misclassification_rate]
ma = [misclassification_rate, accuracy]

# scores when `weights=nothing`
history = [(model=bad_model, measure=am, measurement=[0, 1]),
(model=good_model, measure=am, measurement=[1, 0])]
@test MLJTuning.best(NaiveSelection(), history).model == good_model

# losses when `weights=nothing`
history = [(model=bad_model, measure=ma, measurement=[1, 0]),
(model=good_model, measure=ma, measurement=[0, 1])]
@test MLJTuning.best(NaiveSelection(), history).model == good_model

# mixed case favouring the score:
weights = [2, 1]
history = [(model=bad_model, measure=am, measurement=[0, 0]),
(model=good_model, measure=am, measurement=[1, 1])]
heuristic = NaiveSelection(weights=weights)
@test MLJTuning.best(heuristic, history).model == good_model

# mixed case favouring the loss:
weights = [1, 2]
history = [(model=bad_model, measure=am, measurement=[1, 1]),
(model=good_model, measure=am, measurement=[0, 0])]
heuristic = NaiveSelection(weights=weights)
@test MLJTuning.best(heuristic, history).model == good_model
end

0 comments on commit a68b065

Please sign in to comment.