Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJDecisionTreeInterface"
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.2.2"
version = "0.2.3"

[deps]
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Expand All @@ -10,15 +10,16 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
DecisionTree = "0.10"
DecisionTree = "0.11"
MLJModelInterface = "1.4"
Tables = "1.6"
julia = "1.6"

[extras]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CategoricalArrays", "MLJBase", "Test"]
test = ["CategoricalArrays", "MLJBase", "StableRNGs", "Test"]
6 changes: 4 additions & 2 deletions src/MLJDecisionTreeInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ end

MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic
n_iter::Int = 10::(_ ≥ 1)
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
end

function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y)
Expand All @@ -165,8 +166,8 @@ function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y)
classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
integers_seen = MMI.int(classes_seen)

stumps, coefs = DT.build_adaboost_stumps(yplain, Xmatrix,
m.n_iter)
stumps, coefs =
DT.build_adaboost_stumps(yplain, Xmatrix, m.n_iter, rng=m.rng)
cache = nothing
report = NamedTuple()
return (stumps, coefs, classes_seen, integers_seen), cache, report
Expand Down Expand Up @@ -586,6 +587,7 @@ Train the machine with `fit!(mach, rows=...)`.

- `n_iter=10`: number of iterations of AdaBoost

- `rng=Random.GLOBAL_RNG`: random number generator or seed

# Operations

Expand Down
56 changes: 38 additions & 18 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@ using Test
import CategoricalArrays
import CategoricalArrays.categorical
using MLJBase
using StableRNGs
using Random
Random.seed!(1234)

stable_rng() = StableRNGs.StableRNG(123)

# load code to be tested:
import DecisionTree
using MLJDecisionTreeInterface

# get some test data:
X, y = @load_iris

baretree = DecisionTreeClassifier()
baretree = DecisionTreeClassifier(rng=stable_rng())

baretree.max_depth = 1
fitresult, cache, report = MLJBase.fit(baretree, 2, X, y);
Expand Down Expand Up @@ -50,13 +53,17 @@ using Random: seed!
seed!(0)

n,m = 10^3, 5;
raw_features = rand(n,m);
weights = rand(-1:1,m);
raw_features = rand(stable_rng(), n,m);
weights = rand(stable_rng(), -1:1,m);
labels = raw_features * weights;
features = MLJBase.table(raw_features);

R1Tree = DecisionTreeRegressor(min_samples_leaf=5, merge_purity_threshold=0.1)
R2Tree = DecisionTreeRegressor(min_samples_split=5)
R1Tree = DecisionTreeRegressor(
min_samples_leaf=5,
merge_purity_threshold=0.1,
rng=stable_rng(),
)
R2Tree = DecisionTreeRegressor(min_samples_split=5, rng=stable_rng())
model1, = MLJBase.fit(R1Tree,1, features, labels)

vals1 = MLJBase.predict(R1Tree,model1,features)
Expand All @@ -75,11 +82,15 @@ vals2 = MLJBase.predict(R2Tree, model2, features)
## TEST ON ORDINAL FEATURES OTHER THAN CONTINUOUS

N = 20
X = (x1=rand(N), x2=categorical(rand("abc", N), ordered=true), x3=collect(1:N))
X = (
x1=rand(stable_rng(),N),
x2=categorical(rand(stable_rng(), "abc", N), ordered=true),
x3=collect(1:N),
)
yfinite = X.x2
ycont = float.(X.x3)

rgs = DecisionTreeRegressor()
rgs = DecisionTreeRegressor(rng=stable_rng())
fitresult, _, _ = MLJBase.fit(rgs, 1, X, ycont)
@test rms(predict(rgs, fitresult, X), ycont) < 1.5

Expand All @@ -90,10 +101,10 @@ fitresult, _, _ = MLJBase.fit(clf, 1, X, yfinite)

# -- Ensemble

rfc = RandomForestClassifier()
abs = AdaBoostStumpClassifier()
rfc = RandomForestClassifier(rng=stable_rng())
abs = AdaBoostStumpClassifier(rng=stable_rng())

X, y = MLJBase.make_blobs(100, 3; rng=555)
X, y = MLJBase.make_blobs(100, 3; rng=stable_rng())

m = machine(rfc, X, y)
fit!(m)
Expand All @@ -103,19 +114,21 @@ m = machine(abs, X, y)
fit!(m)
@test accuracy(predict_mode(m, X), y) > 0.95

X, y = MLJBase.make_regression(rng=5124)
rfr = RandomForestRegressor()
X, y = MLJBase.make_regression(rng=stable_rng())
rfr = RandomForestRegressor(rng=stable_rng())
m = machine(rfr, X, y)
fit!(m)
@test rms(predict(m, X), y) < 0.4

N = 10
function reproducibility(model, X, y, loss)
model.rng = 123
model.n_subfeatures = 1
if !(model isa AdaBoostStumpClassifier)
model.n_subfeatures = 1
end
mach = machine(model, X, y)
train, test = partition(eachindex(y), 0.7)
errs = map(1:N) do i
model.rng = stable_rng()
fit!(mach, rows=train, force=true, verbosity=0)
yhat = predict(mach, rows=test)
loss(yhat, y[test]) |> mean
Expand All @@ -124,14 +137,21 @@ function reproducibility(model, X, y, loss)
end

@testset "reporoducibility" begin
X, y = make_blobs();
X, y = make_blobs(rng=stable_rng());
loss = BrierLoss()
for model in [DecisionTreeClassifier(), RandomForestClassifier()]
for model in [
DecisionTreeClassifier(),
RandomForestClassifier(),
AdaBoostStumpClassifier(),
]
@test reproducibility(model, X, y, loss)
end
X, y = make_regression();
X, y = make_regression(rng=stable_rng());
loss = LPLoss(p=2)
for model in [DecisionTreeRegressor(), RandomForestRegressor()]
for model in [
DecisionTreeRegressor(),
RandomForestRegressor(),
]
@test reproducibility(model, X, y, loss)
end
end