Skip to content

Commit

Permalink
add tests for sample weights in ensembles and do some clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Nov 25, 2019
1 parent e770499 commit 721e657
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 16 deletions.
33 changes: 24 additions & 9 deletions src/ensembles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,21 @@ function get_ensemble(atom::Supervised, verbosity, n, n_patterns,
n_train, rng, progress_meter, args...)

# define generator of training rows:
ensemble_indices = (StatsBase.sample(rng, 1:n_patterns, n_train, replace=false)
for i in 1:n)
if n_train == n_patterns
# keep deterministic by avoiding re-ordering:
ensemble_indices = (1:n_patterns for i in 1:n)
else
ensemble_indices =
(StatsBase.sample(rng, 1:n_patterns, n_train, replace=false)
for i in 1:n)
end

ensemble = map(ensemble_indices) do train_rows
verbosity == 1 && next!(progress_meter)
verbosity < 2 || print("#")
atom_fitresult, atom_cache, atom_report =
fit(atom, verbosity - 1, [selectrows(arg, train_rows) for arg in args]...)
fit(atom, verbosity - 1, [selectrows(arg, train_rows) for
arg in args]...)
atom_fitresult
end
verbosity < 1 || println()
Expand Down Expand Up @@ -496,20 +504,27 @@ function fit(model::EitherEnsembleModel{Atom},
end
for k in eachindex(out_of_bag_measure)
m = out_of_bag_measure[k]
metrics[i,k] = value(m, yhat, Xtest, ytest, wtest)
end
if reports_each_observation(m)
s = aggregate(value(m, yhat, Xtest, ytest, wtest), m)
else
s = value(m, yhat, Xtest, ytest, wtest)
end
metrics[i,k] = s
end
end

# aggregate metrics across the ensembles:
aggregated_metrics = map(eachindex(out_of_bag_measure)) do k
aggregate(metrics[:,k], out_of_bag_measure[k])
end
metrics=mean(metrics, dims=1)

names = Symbol.(string.(out_of_bag_measure))
oob_estimates=NamedTuple{Tuple(names)}(Tuple(vec(metrics)))

else
oob_estimates=NamedTuple()
aggregated_metrics = missing
end

report=(oob_estimates=oob_estimates,)
report=(measures=out_of_bag_measure, oob_measurements=aggregated_metrics,)
cache = deepcopy(model)

return fitresult, cache, report
Expand Down
48 changes: 41 additions & 7 deletions test/ensembles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ atomic_weights = rand(10)
atomic_weights = atomic_weights/sum(atomic_weights)
ensemble_model.atomic_weights = atomic_weights
fitresult, cache, report = MLJ.fit(ensemble_model, 1, X, y)
predict(ensemble_model, fitresult, MLJ.selectrows(X, test))
p = predict(ensemble_model, fitresult, MLJ.selectrows(X, test))
MLJBase.info_dict(ensemble_model)
@test MLJBase.target_scitype(ensemble_model) == MLJBase.target_scitype(atom)

Expand All @@ -104,7 +104,7 @@ MLJBase.info_dict(ensemble_model)

# target is :deterministic :continuous false:
atom = MLJModels.DeterministicConstantRegressor()
Random.seed!(1234)
Random.seed!(1234)
X = MLJ.table(randn(10,3))
y = randn(10)
train, test = partition(1:length(y), 0.8);
Expand All @@ -113,8 +113,7 @@ ensemble_model.out_of_bag_measure = [MLJ.rms,MLJ.rmsp]
ensemble_model.n = 2
fitresult, cache, report = MLJ.fit(ensemble_model, 1, X, y)
# TODO: the following test fails in distributed version (because of multiple rng's ?)
@test report[:oob_estimates][1] 1.083490899041915
# @test MLJBase.output_is(ensemble_model) == MLJBase.output_is(atom)
@test abs(report.oob_measurements[1] - 1.0834) < 0.001
ensemble_model = MLJ.DeterministicEnsembleModel(atom=atom,rng=Random.MersenneTwister(1))
ensemble_model.out_of_bag_measure = MLJ.rms
ensemble_model.n = 2
Expand All @@ -138,12 +137,17 @@ d = predict(ensemble_model, fitresult, MLJ.selectrows(X, test))[1]
@test pdf(d, 's') 1/5
@test pdf(d, 'd') 1/5
@test pdf(d, 'f') 1/5
@test mode(d) == 'a'
atomic_weights = rand(10)
atomic_weights = atomic_weights/sum(atomic_weights)
ensemble_model.atomic_weights = atomic_weights
predict(ensemble_model, fitresult, MLJ.selectrows(X, test))
MLJBase.info_dict(ensemble_model)
# @test MLJBase.output_is(ensemble_model) == MLJBase.output_is(atom)
# test sample weights
w = [1,100,1,1,1]
fitresult, cache, report = MLJ.fit(ensemble_model, 1, X, y, w)
p2 = predict(ensemble_model, fitresult, MLJ.selectrows(X, test))
@test mode(p2[1] ) == 's'

# target is :probabilistic :continuous false:
atom = ConstantRegressor()
Expand Down Expand Up @@ -177,8 +181,39 @@ MLJBase.info_dict(ensemble_model)
@test EnsembleModel(atom=ConstantRegressor()) isa Probabilistic
@test EnsembleModel(atom=MLJModels.DeterministicConstantRegressor()) isa Deterministic

@testset "further test of sample weights" begin
N = 20
X = (x = rand(3N), );
y = categorical(rand("abbbc", 3N));
atom = @load KNNClassifier
ensemble_model = MLJ.ProbabilisticEnsembleModel(atom=atom,
bagging_fraction=1,
n = 5)
fitresult, cache, report = MLJ.fit(ensemble_model, 1, X, y)
@test predict_mode(ensemble_model, fitresult, (x = [0, ],))[1] == 'b'
w = map(y) do η
η == 'a' ? 100 : 1
end
fitresult, cache, report = MLJ.fit(ensemble_model, 1, X, y, w)
@test predict_mode(ensemble_model, fitresult, (x = [0, ],))[1] == 'a'

ensemble_model.rng = 1234 # always start with same rng
ensemble_model.bagging_fraction=0.6
ensemble_model.out_of_bag_measure = [BrierScore(), cross_entropy]
fitresult, cache, report = MLJ.fit(ensemble_model, 1, X, y)
m1 = report.oob_measurements[1]
fitresult, cache, report = MLJ.fit(ensemble_model, 1, X, y)
m2 = report.oob_measurements[1]
@test m1 == m2
# supplying sample weights should change the oob meausurements for
# measures that support weights:
fitresult, cache, report = MLJ.fit(ensemble_model, 1, X, y, w)
m3 = report.oob_measurements[1]
@test !(m1 m3)
end


## MACHINE TEST (INCLUDES TEST OF UPDATE)
## MACHINE TEST (INCLUDES TEST OF UPDATE)

N =100
X = (x1=rand(N), x2=rand(N), x3=rand(N))
Expand Down Expand Up @@ -207,6 +242,5 @@ ensemble_model.n = 5

@test !isnan(predict(ensemble, MLJ.selectrows(X, test))[1])


end
true

0 comments on commit 721e657

Please sign in to comment.