diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index bd22487..35ce4d4 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -25,11 +25,9 @@ Base.show(stream::IO, c::TreePrinter) = # # DECISION TREE CLASSIFIER # The following meets the MLJ standard for a `Model` docstring and is -# created without the use of interpolation so it can be used a -# template for authors of other MLJ model interfaces. The other +# created without the use of interpolation so it can be used a # template for authors of other MLJ model interfaces. The other # doc-strings, defined later, are generated using the `doc_header` # utility to automatically generate the header, another option. - MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic max_depth::Int = (-)(1)::(_ ≥ -1) min_samples_leaf::Int = 1::(_ ≥ 0) @@ -39,6 +37,7 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic post_prune::Bool = false merge_purity_threshold::Float64 = 1.0::(_ ≤ 1) display_depth::Int = 5::(_ ≥ 1) + feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end @@ -73,8 +72,8 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y) cache = nothing report = (classes_seen=classes_seen, print_tree=TreePrinter(tree), - features=features) - + features=features, + ) return fitresult, cache, report end @@ -107,6 +106,8 @@ function MMI.predict(m::DecisionTreeClassifier, fitresult, Xnew) return MMI.UnivariateFinite(classes_seen, scores) end +MMI.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true + # # RANDOM FOREST CLASSIFIER @@ -118,13 +119,21 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic n_subfeatures::Int = (-)(1)::(_ ≥ -1) n_trees::Int = 10::(_ ≥ 2) sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1) + feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y) + schema = Tables.schema(X) Xmatrix = MMI.matrix(X) yplain = MMI.int(y) + if schema === nothing + features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)] + else + features = schema.names |> collect + end + classes_seen = filter(in(unique(y)), MMI.classes(y[1])) integers_seen = MMI.int(classes_seen) @@ -138,7 +147,9 @@ function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y) m.min_purity_increase; rng=m.rng) cache = nothing - report = NamedTuple() + + report = (features=features,) + return (forest, classes_seen, integers_seen), cache, report end @@ -151,25 +162,38 @@ function MMI.predict(m::RandomForestClassifier, fitresult, Xnew) return MMI.UnivariateFinite(classes_seen, scores) end +MMI.reports_feature_importances(::Type{<:RandomForestClassifier}) = true + # # ADA BOOST STUMP CLASSIFIER MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic n_iter::Int = 10::(_ ≥ 1) + feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y) + schema = Tables.schema(X) Xmatrix = MMI.matrix(X) yplain = MMI.int(y) + if schema === nothing + features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)] + else + features = schema.names |> collect + end + + classes_seen = filter(in(unique(y)), MMI.classes(y[1])) integers_seen = MMI.int(classes_seen) stumps, coefs = DT.build_adaboost_stumps(yplain, Xmatrix, m.n_iter, rng=m.rng) cache = nothing - report = NamedTuple() + + report = (features=features,) + return (stumps, coefs, classes_seen, integers_seen), cache, report end @@ -184,6 +208,8 @@ function MMI.predict(m::AdaBoostStumpClassifier, fitresult, Xnew) return MMI.UnivariateFinite(classes_seen, scores) end +MMI.reports_feature_importances(::Type{<:AdaBoostStumpClassifier}) = true + # # DECISION TREE REGRESSOR @@ -195,11 +221,20 @@ MMI.@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic n_subfeatures::Int = 0::(_ ≥ -1) post_prune::Bool = false merge_purity_threshold::Float64 = 1.0::(0 ≤ _ ≤ 1) + feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y) + schema = Tables.schema(X) Xmatrix = MMI.matrix(X) + + if schema === nothing + features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)] + else + features = schema.names |> collect + end + tree = DT.build_tree(float(y), Xmatrix, m.n_subfeatures, m.max_depth, @@ -212,7 +247,9 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y) tree = DT.prune_tree(tree, m.merge_purity_threshold) end cache = nothing - report = NamedTuple() + + report = (features=features,) + return tree, cache, report end @@ -223,6 +260,8 @@ function MMI.predict(::DecisionTreeRegressor, tree, Xnew) return DT.apply_tree(tree, Xmatrix) end +MMI.reports_feature_importances(::Type{<:DecisionTreeRegressor}) = true + # # RANDOM FOREST REGRESSOR @@ -234,11 +273,20 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic n_subfeatures::Int = (-)(1)::(_ ≥ -1) n_trees::Int = 10::(_ ≥ 2) sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1) + feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y) + schema = Tables.schema(X) Xmatrix = MMI.matrix(X) + + if schema === nothing + features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)] + else + features = schema.names |> collect + end + forest = DT.build_forest(float(y), Xmatrix, m.n_subfeatures, m.n_trees, @@ -249,7 +297,8 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y) m.min_purity_increase, rng=m.rng) cache = nothing - report = NamedTuple() + report = (features=features,) + return forest, cache, report end @@ -260,6 +309,34 @@ function MMI.predict(::RandomForestRegressor, forest, Xnew) return DT.apply_forest(forest, Xmatrix) end +MMI.reports_feature_importances(::Type{<:RandomForestRegressor}) = true + + +# # Feature Importances + +# get actual arguments needed for importance calculation from various fitresults. +get_fitresult(m::Union{DecisionTreeClassifier, RandomForestClassifier}, fitresult) = (fitresult[1],) +get_fitresult(m::Union{DecisionTreeRegressor, RandomForestRegressor}, fitresult) = (fitresult,) +get_fitresult(m::AdaBoostStumpClassifier, fitresult)= (fitresult[1], fitresult[2]) + +function MMI.feature_importances(m::Union{DecisionTreeClassifier, RandomForestClassifier, AdaBoostStumpClassifier, DecisionTreeRegressor, RandomForestRegressor}, fitresult, report) + # generate feature importances for report + if m.feature_importance == :impurity + feature_importance_func = DT.impurity_importance + elseif m.feature_importance == :split + feature_importance_func = DT.split_importance + end + + mdl = get_fitresult(m, fitresult) + features = report.features + fi = feature_importance_func(mdl..., normalize=true) + fi_pairs = Pair.(features, fi) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + return fi_pairs +end + # # METADATA (MODEL TRAITS) @@ -379,6 +456,8 @@ Train the machine using `fit!(mach, rows=...)`. - `display_depth=5`: max depth to show when displaying the tree +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)` + - `rng=Random.GLOBAL_RNG`: random number generator or seed @@ -512,6 +591,8 @@ Train the machine with `fit!(mach, rows=...)`. - `sampling_fraction=0.7` fraction of samples to train each tree on +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)` + - `rng=Random.GLOBAL_RNG`: random number generator or seed @@ -587,6 +668,8 @@ Train the machine with `fit!(mach, rows=...)`. - `n_iter=10`: number of iterations of AdaBoost +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)` + - `rng=Random.GLOBAL_RNG`: random number generator or seed # Operations @@ -678,6 +761,8 @@ Train the machine with `fit!(mach, rows=...)`. - `merge_purity_threshold=1.0`: (post-pruning) merge leaves having combined purity `>= merge_purity_threshold` +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)` + - `rng=Random.GLOBAL_RNG`: random number generator or seed @@ -760,6 +845,8 @@ Train the machine with `fit!(mach, rows=...)`. - `sampling_fraction=0.7` fraction of samples to train each tree on +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)` + - `rng=Random.GLOBAL_RNG`: random number generator or seed diff --git a/test/runtests.jl b/test/runtests.jl index 89d207d..14553f2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,7 @@ yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3)) @test Set(report.classes_seen) == Set(levels(y)) @test report.print_tree(2) === nothing # :-( @test report.features == [:sepal_length, :sepal_width, :petal_length, :petal_width] + fp = fitted_params(baretree, fitresult) @test Set([:tree, :encoding, :features]) == Set(keys(fp)) @test fp.features == report.features @@ -155,3 +156,81 @@ end @test reproducibility(model, X, y, loss) end end + + +@testset "feature importance defined" begin + for model ∈ [ + DecisionTreeClassifier(), + RandomForestClassifier(), + AdaBoostStumpClassifier(), + DecisionTreeRegressor(), + RandomForestRegressor(), + ] + + @test reports_feature_importances(model) == true + end +end + + + +@testset "impurity importance" begin + + X, y = MLJBase.make_blobs(100, 3; rng=stable_rng()) + + for model ∈ [ + DecisionTreeClassifier(), + RandomForestClassifier(), + AdaBoostStumpClassifier(), + ] + m = machine(model, X, y) + fit!(m) + rpt = MLJBase.report(m) + fi = MLJBase.feature_importances(model, m.fitresult, rpt) + @test size(fi,1) == 3 + end + + + X, y = make_regression(100,3; rng=stable_rng()); + for model in [ + DecisionTreeRegressor(), + RandomForestRegressor(), + ] + m = machine(model, X, y) + fit!(m) + rpt = MLJBase.report(m) + fi = MLJBase.feature_importances(model, m.fitresult, rpt) + @test size(fi,1) == 3 + end +end + + +@testset "split importance" begin + X, y = MLJBase.make_blobs(100, 3; rng=stable_rng()) + + for model ∈ [ + DecisionTreeClassifier(feature_importance=:split), + RandomForestClassifier(feature_importance=:split), + AdaBoostStumpClassifier(feature_importance=:split), + ] + m = machine(model, X, y) + fit!(m) + rpt = MLJBase.report(m) + fi = MLJBase.feature_importances(model, m.fitresult, rpt) + @test size(fi,1) == 3 + end + + + X, y = make_regression(100,3; rng=stable_rng()); + for model in [ + DecisionTreeRegressor(feature_importance=:split), + RandomForestRegressor(feature_importance=:split), + ] + m = machine(model, X, y) + fit!(m) + rpt = MLJBase.report(m) + fi = MLJBase.feature_importances(model, m.fitresult, rpt) + @test size(fi,1) == 3 + end +end + +