Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 96 additions & 9 deletions src/MLJDecisionTreeInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@ Base.show(stream::IO, c::TreePrinter) =
# # DECISION TREE CLASSIFIER

# The following meets the MLJ standard for a `Model` docstring and is
# created without the use of interpolation so it can be used a
# template for authors of other MLJ model interfaces. The other
# created without the use of interpolation so it can be used a # template for authors of other MLJ model interfaces. The other
# doc-strings, defined later, are generated using the `doc_header`
# utility to automatically generate the header, another option.

MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
max_depth::Int = (-)(1)::(_ ≥ -1)
min_samples_leaf::Int = 1::(_ ≥ 0)
Expand All @@ -39,6 +37,7 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic
post_prune::Bool = false
merge_purity_threshold::Float64 = 1.0::(_ ≤ 1)
display_depth::Int = 5::(_ ≥ 1)
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
end

Expand Down Expand Up @@ -73,8 +72,8 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y)
cache = nothing
report = (classes_seen=classes_seen,
print_tree=TreePrinter(tree),
features=features)

features=features,
)
return fitresult, cache, report
end

Expand Down Expand Up @@ -107,6 +106,8 @@ function MMI.predict(m::DecisionTreeClassifier, fitresult, Xnew)
return MMI.UnivariateFinite(classes_seen, scores)
end

MMI.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true


# # RANDOM FOREST CLASSIFIER

Expand All @@ -118,13 +119,21 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
n_subfeatures::Int = (-)(1)::(_ ≥ -1)
n_trees::Int = 10::(_ ≥ 2)
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
end

function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y)
schema = Tables.schema(X)
Xmatrix = MMI.matrix(X)
yplain = MMI.int(y)

if schema === nothing
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
else
features = schema.names |> collect
end

classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
integers_seen = MMI.int(classes_seen)

Expand All @@ -138,7 +147,9 @@ function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y)
m.min_purity_increase;
rng=m.rng)
cache = nothing
report = NamedTuple()

report = (features=features,)

return (forest, classes_seen, integers_seen), cache, report
end

Expand All @@ -151,25 +162,38 @@ function MMI.predict(m::RandomForestClassifier, fitresult, Xnew)
return MMI.UnivariateFinite(classes_seen, scores)
end

MMI.reports_feature_importances(::Type{<:RandomForestClassifier}) = true


# # ADA BOOST STUMP CLASSIFIER

MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic
n_iter::Int = 10::(_ ≥ 1)
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
end

function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y)
schema = Tables.schema(X)
Xmatrix = MMI.matrix(X)
yplain = MMI.int(y)

if schema === nothing
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
else
features = schema.names |> collect
end


classes_seen = filter(in(unique(y)), MMI.classes(y[1]))
integers_seen = MMI.int(classes_seen)

stumps, coefs =
DT.build_adaboost_stumps(yplain, Xmatrix, m.n_iter, rng=m.rng)
cache = nothing
report = NamedTuple()

report = (features=features,)

return (stumps, coefs, classes_seen, integers_seen), cache, report
end

Expand All @@ -184,6 +208,8 @@ function MMI.predict(m::AdaBoostStumpClassifier, fitresult, Xnew)
return MMI.UnivariateFinite(classes_seen, scores)
end

MMI.reports_feature_importances(::Type{<:AdaBoostStumpClassifier}) = true


# # DECISION TREE REGRESSOR

Expand All @@ -195,11 +221,20 @@ MMI.@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic
n_subfeatures::Int = 0::(_ ≥ -1)
post_prune::Bool = false
merge_purity_threshold::Float64 = 1.0::(0 ≤ _ ≤ 1)
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
end

function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y)
schema = Tables.schema(X)
Xmatrix = MMI.matrix(X)

if schema === nothing
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
else
features = schema.names |> collect
end

tree = DT.build_tree(float(y), Xmatrix,
m.n_subfeatures,
m.max_depth,
Expand All @@ -212,7 +247,9 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y)
tree = DT.prune_tree(tree, m.merge_purity_threshold)
end
cache = nothing
report = NamedTuple()

report = (features=features,)

return tree, cache, report
end

Expand All @@ -223,6 +260,8 @@ function MMI.predict(::DecisionTreeRegressor, tree, Xnew)
return DT.apply_tree(tree, Xmatrix)
end

MMI.reports_feature_importances(::Type{<:DecisionTreeRegressor}) = true


# # RANDOM FOREST REGRESSOR

Expand All @@ -234,11 +273,20 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic
n_subfeatures::Int = (-)(1)::(_ ≥ -1)
n_trees::Int = 10::(_ ≥ 2)
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
end

function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y)
schema = Tables.schema(X)
Xmatrix = MMI.matrix(X)

if schema === nothing
features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)]
else
features = schema.names |> collect
end

forest = DT.build_forest(float(y), Xmatrix,
m.n_subfeatures,
m.n_trees,
Expand All @@ -249,7 +297,8 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y)
m.min_purity_increase,
rng=m.rng)
cache = nothing
report = NamedTuple()
report = (features=features,)

return forest, cache, report
end

Expand All @@ -260,6 +309,34 @@ function MMI.predict(::RandomForestRegressor, forest, Xnew)
return DT.apply_forest(forest, Xmatrix)
end

MMI.reports_feature_importances(::Type{<:RandomForestRegressor}) = true


# # Feature Importances

# get actual arguments needed for importance calculation from various fitresults.
get_fitresult(m::Union{DecisionTreeClassifier, RandomForestClassifier}, fitresult) = (fitresult[1],)
get_fitresult(m::Union{DecisionTreeRegressor, RandomForestRegressor}, fitresult) = (fitresult,)
get_fitresult(m::AdaBoostStumpClassifier, fitresult)= (fitresult[1], fitresult[2])

function MMI.feature_importances(m::Union{DecisionTreeClassifier, RandomForestClassifier, AdaBoostStumpClassifier, DecisionTreeRegressor, RandomForestRegressor}, fitresult, report)
# generate feature importances for report
if m.feature_importance == :impurity
feature_importance_func = DT.impurity_importance
elseif m.feature_importance == :split
feature_importance_func = DT.split_importance
end

mdl = get_fitresult(m, fitresult)
features = report.features
fi = feature_importance_func(mdl..., normalize=true)
fi_pairs = Pair.(features, fi)
# sort descending
sort!(fi_pairs, by= x->-x[2])

return fi_pairs
end


# # METADATA (MODEL TRAITS)

Expand Down Expand Up @@ -379,6 +456,8 @@ Train the machine using `fit!(mach, rows=...)`.

- `display_depth=5`: max depth to show when displaying the tree

- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed


Expand Down Expand Up @@ -512,6 +591,8 @@ Train the machine with `fit!(mach, rows=...)`.

- `sampling_fraction=0.7` fraction of samples to train each tree on

- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed


Expand Down Expand Up @@ -587,6 +668,8 @@ Train the machine with `fit!(mach, rows=...)`.

- `n_iter=10`: number of iterations of AdaBoost

- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed

# Operations
Expand Down Expand Up @@ -678,6 +761,8 @@ Train the machine with `fit!(mach, rows=...)`.
- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
combined purity `>= merge_purity_threshold`

- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed


Expand Down Expand Up @@ -760,6 +845,8 @@ Train the machine with `fit!(mach, rows=...)`.

- `sampling_fraction=0.7` fraction of samples to train each tree on

- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed


Expand Down
79 changes: 79 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3))
@test Set(report.classes_seen) == Set(levels(y))
@test report.print_tree(2) === nothing # :-(
@test report.features == [:sepal_length, :sepal_width, :petal_length, :petal_width]

fp = fitted_params(baretree, fitresult)
@test Set([:tree, :encoding, :features]) == Set(keys(fp))
@test fp.features == report.features
Expand Down Expand Up @@ -155,3 +156,81 @@ end
@test reproducibility(model, X, y, loss)
end
end


@testset "feature importance defined" begin
for model ∈ [
DecisionTreeClassifier(),
RandomForestClassifier(),
AdaBoostStumpClassifier(),
DecisionTreeRegressor(),
RandomForestRegressor(),
]

@test reports_feature_importances(model) == true
end
end



@testset "impurity importance" begin

X, y = MLJBase.make_blobs(100, 3; rng=stable_rng())

for model ∈ [
DecisionTreeClassifier(),
RandomForestClassifier(),
AdaBoostStumpClassifier(),
]
m = machine(model, X, y)
fit!(m)
rpt = MLJBase.report(m)
fi = MLJBase.feature_importances(model, m.fitresult, rpt)
@test size(fi,1) == 3
end


X, y = make_regression(100,3; rng=stable_rng());
for model in [
DecisionTreeRegressor(),
RandomForestRegressor(),
]
m = machine(model, X, y)
fit!(m)
rpt = MLJBase.report(m)
fi = MLJBase.feature_importances(model, m.fitresult, rpt)
@test size(fi,1) == 3
end
end


@testset "split importance" begin
X, y = MLJBase.make_blobs(100, 3; rng=stable_rng())

for model ∈ [
DecisionTreeClassifier(feature_importance=:split),
RandomForestClassifier(feature_importance=:split),
AdaBoostStumpClassifier(feature_importance=:split),
]
m = machine(model, X, y)
fit!(m)
rpt = MLJBase.report(m)
fi = MLJBase.feature_importances(model, m.fitresult, rpt)
@test size(fi,1) == 3
end


X, y = make_regression(100,3; rng=stable_rng());
for model in [
DecisionTreeRegressor(feature_importance=:split),
RandomForestRegressor(feature_importance=:split),
]
m = machine(model, X, y)
fit!(m)
rpt = MLJBase.report(m)
fi = MLJBase.feature_importances(model, m.fitresult, rpt)
@test size(fi,1) == 3
end
end