From 1206d3523e3f4c6ab2e0ffaf696f70396207e595 Mon Sep 17 00:00:00 2001 From: john-waczak Date: Tue, 9 Aug 2022 09:36:31 -0500 Subject: [PATCH 01/11] added feature_importances to report for DecisionTreeClassifier --- src/MLJDecisionTreeInterface.jl | 10 +++++++++- test/runtests.jl | 5 ++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index bd22487..30e7148 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -71,9 +71,17 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y) fitresult = (tree, classes_seen, integers_seen, features) cache = nothing + + # generate feature importances for report + fi = DecisionTree.impurity_importance(tree) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + report = (classes_seen=classes_seen, print_tree=TreePrinter(tree), - features=features) + features=features, + feature_importances=fi_pairs) return fitresult, cache, report end diff --git a/test/runtests.jl b/test/runtests.jl index 89d207d..768dfe3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,10 +41,13 @@ yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3)) @test MLJBase.classes(yyhat[1]) == MLJBase.classes(y[1]) # check report and fitresult fields: -@test Set([:classes_seen, :print_tree, :features]) == Set(keys(report)) +@test Set([:classes_seen, :print_tree, :features, :feature_importances]) == Set(keys(report)) @test Set(report.classes_seen) == Set(levels(y)) @test report.print_tree(2) === nothing # :-( @test report.features == [:sepal_length, :sepal_width, :petal_length, :petal_width] +# check feature_importances +@test size(report.features,1) == size(report.feature_importances, 1) + fp = fitted_params(baretree, fitresult) @test Set([:tree, :encoding, :features]) == Set(keys(fp)) @test fp.features == report.features From d4451ef069365b9207689acdfd3d04d38b4eb7fd Mon Sep 17 00:00:00 2001 From: john-waczak Date: Tue, 9 Aug 2022 09:59:47 -0500 Subject: [PATCH 02/11] added feature_importances for RandomForestClassifier --- src/MLJDecisionTreeInterface.jl | 16 +++++++++++++++- test/runtests.jl | 4 ++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index 30e7148..910c7fd 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -130,9 +130,16 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic end function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y) + schema = Tables.schema(X) Xmatrix = MMI.matrix(X) yplain = MMI.int(y) + if schema === nothing + features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)] + else + features = schema.names |> collect + end + classes_seen = filter(in(unique(y)), MMI.classes(y[1])) integers_seen = MMI.int(classes_seen) @@ -146,7 +153,14 @@ function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y) m.min_purity_increase; rng=m.rng) cache = nothing - report = NamedTuple() + + fi = DecisionTree.impurity_importance(forest) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + report = (feature_importances=fi_pairs,) + # report = NamedTuple() return (forest, classes_seen, integers_seen), cache, report end diff --git a/test/runtests.jl b/test/runtests.jl index 768dfe3..f450013 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -112,6 +112,10 @@ X, y = MLJBase.make_blobs(100, 3; rng=stable_rng()) m = machine(rfc, X, y) fit!(m) @test accuracy(predict_mode(m, X), y) > 0.95 +# check feature_importances +rpt = MLJBase.report(m) +@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature + m = machine(abs, X, y) fit!(m) From c077eaab73a9007727de97564420456de3f44005 Mon Sep 17 00:00:00 2001 From: john-waczak Date: Tue, 9 Aug 2022 11:07:39 -0500 Subject: [PATCH 03/11] added feature_importances for DecisionTreeRegressor and RandomForestRegressor --- src/MLJDecisionTreeInterface.jl | 53 +++++++++++++++++++++++++++++++-- test/runtests.jl | 22 ++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index 910c7fd..c1e39e4 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -182,16 +182,31 @@ MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic end function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y) + schema = Tables.schema(X) Xmatrix = MMI.matrix(X) yplain = MMI.int(y) + if schema === nothing + features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)] + else + features = schema.names |> collect + end + + classes_seen = filter(in(unique(y)), MMI.classes(y[1])) integers_seen = MMI.int(classes_seen) stumps, coefs = DT.build_adaboost_stumps(yplain, Xmatrix, m.n_iter, rng=m.rng) cache = nothing - report = NamedTuple() + + fi = DecisionTree.impurity_importance(stumps, coefs) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + report = (feature_importances=fi_pairs,) + # report = NamedTuple() return (stumps, coefs, classes_seen, integers_seen), cache, report end @@ -221,7 +236,15 @@ MMI.@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic end function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y) + schema = Tables.schema(X) Xmatrix = MMI.matrix(X) + + if schema === nothing + features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)] + else + features = schema.names |> collect + end + tree = DT.build_tree(float(y), Xmatrix, m.n_subfeatures, m.max_depth, @@ -234,7 +257,14 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y) tree = DT.prune_tree(tree, m.merge_purity_threshold) end cache = nothing - report = NamedTuple() + + fi = DecisionTree.impurity_importance(tree) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + report = (feature_importances=fi_pairs,) + # report = NamedTuple() return tree, cache, report end @@ -260,7 +290,15 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic end function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y) + schema = Tables.schema(X) Xmatrix = MMI.matrix(X) + + if schema === nothing + features = [Symbol("x$j") for j in 1:size(Xmatrix, 2)] + else + features = schema.names |> collect + end + forest = DT.build_forest(float(y), Xmatrix, m.n_subfeatures, m.n_trees, @@ -271,7 +309,16 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y) m.min_purity_increase, rng=m.rng) cache = nothing - report = NamedTuple() + + + fi = DecisionTree.impurity_importance(forest) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + report = (feature_importances=fi_pairs,) + # report = NamedTuple() + return forest, cache, report end diff --git a/test/runtests.jl b/test/runtests.jl index f450013..d4d4b0b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -120,6 +120,28 @@ rpt = MLJBase.report(m) m = machine(abs, X, y) fit!(m) @test accuracy(predict_mode(m, X), y) > 0.95 +# check feature_importances +rpt = MLJBase.report(m) +@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature + + + + +# test DecisionTreeRegressor and RandomForestRegressor +X, y = make_regression(100,3; rng=stable_rng()); +dtr = DecisionTreeRegressor(rng=stable_rng()) +rfr = RandomForestRegressor(rng=stable_rng()) + +m = machine(dtr, X, y) +fit!(m) +rpt = MLJBase.report(m) +@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature + +m = machine(rfr, X, y) +fit!(m) +rpt = MLJBase.report(m) +@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature + X, y = MLJBase.make_regression(rng=stable_rng()) rfr = RandomForestRegressor(rng=stable_rng()) From 5b5067411930444c42431bc76fa94d4704c0f346 Mon Sep 17 00:00:00 2001 From: john-waczak Date: Wed, 10 Aug 2022 09:16:45 -0500 Subject: [PATCH 04/11] added ability to choose from :none, :impurity, :split for importances --- src/MLJDecisionTreeInterface.jl | 106 +++++++++++++++++++++----------- test/runtests.jl | 33 +++++++++- 2 files changed, 101 insertions(+), 38 deletions(-) diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index c1e39e4..5e054e6 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -22,6 +22,13 @@ Base.show(stream::IO, c::TreePrinter) = print(stream, "TreePrinter object (call with display depth)") +# add a new variable to struct a la +const feature_importance_options = Dict(:impurity => DecisionTree.impurity_importance, + :split => DecisionTree.split_importance, + # :permutation => DecisionTree.permutation_importance, + ) + + # # DECISION TREE CLASSIFIER # The following meets the MLJ standard for a `Model` docstring and is @@ -29,7 +36,6 @@ Base.show(stream::IO, c::TreePrinter) = # template for authors of other MLJ model interfaces. The other # doc-strings, defined later, are generated using the `doc_header` # utility to automatically generate the header, another option. - MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic max_depth::Int = (-)(1)::(_ ≥ -1) min_samples_leaf::Int = 1::(_ ≥ 0) @@ -39,9 +45,13 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic post_prune::Bool = false merge_purity_threshold::Float64 = 1.0::(_ ≤ 1) display_depth::Int = 5::(_ ≥ 1) + feature_importance::Symbol = :impurity::(_ ∈ (:none, :impurity, :split)) #:permutation)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end + + + function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y) schema = Tables.schema(X) Xmatrix = MMI.matrix(X) @@ -73,15 +83,22 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y) cache = nothing # generate feature importances for report - fi = DecisionTree.impurity_importance(tree) - fi_pairs = collect(Dict(zip(features, fi))) - # sort descending - sort!(fi_pairs, by= x->-x[2]) - - report = (classes_seen=classes_seen, - print_tree=TreePrinter(tree), - features=features, - feature_importances=fi_pairs) + if m.feature_importance != :none + fi = feature_importance_options[m.feature_importance](tree) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + report = (classes_seen=classes_seen, + print_tree=TreePrinter(tree), + features=features, + feature_importances=fi_pairs) + else + report = (classes_seen=classes_seen, + print_tree=TreePrinter(tree), + features=features, + ) + end return fitresult, cache, report end @@ -126,6 +143,7 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic n_subfeatures::Int = (-)(1)::(_ ≥ -1) n_trees::Int = 10::(_ ≥ 2) sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1) + feature_importance::Symbol = :impurity::(_ ∈ (:none, :impurity, :split)) #, :permutation)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end @@ -154,13 +172,17 @@ function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y) rng=m.rng) cache = nothing - fi = DecisionTree.impurity_importance(forest) - fi_pairs = collect(Dict(zip(features, fi))) - # sort descending - sort!(fi_pairs, by= x->-x[2]) + if m.feature_importance != :none + fi = feature_importance_options[m.feature_importance](forest) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + report = (feature_importances=fi_pairs,) + else + report = NamedTuple() + end - report = (feature_importances=fi_pairs,) - # report = NamedTuple() return (forest, classes_seen, integers_seen), cache, report end @@ -178,6 +200,7 @@ end MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic n_iter::Int = 10::(_ ≥ 1) + feature_importance::Symbol = :impurity::(_ ∈ (:none, :impurity, :split)) # , :permutation)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end @@ -200,13 +223,16 @@ function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y) DT.build_adaboost_stumps(yplain, Xmatrix, m.n_iter, rng=m.rng) cache = nothing - fi = DecisionTree.impurity_importance(stumps, coefs) - fi_pairs = collect(Dict(zip(features, fi))) - # sort descending - sort!(fi_pairs, by= x->-x[2]) + if m.feature_importance != :none + fi = feature_importance_options[m.feature_importance](stumps, coefs) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + report = (feature_importances=fi_pairs,) + else + report = NamedTuple() + end - report = (feature_importances=fi_pairs,) - # report = NamedTuple() return (stumps, coefs, classes_seen, integers_seen), cache, report end @@ -232,6 +258,7 @@ MMI.@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic n_subfeatures::Int = 0::(_ ≥ -1) post_prune::Bool = false merge_purity_threshold::Float64 = 1.0::(0 ≤ _ ≤ 1) + feature_importance::Symbol = :impurity::(_ ∈ (:none, :impurity, :split)) # , :permutation)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end @@ -258,13 +285,17 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y) end cache = nothing - fi = DecisionTree.impurity_importance(tree) - fi_pairs = collect(Dict(zip(features, fi))) - # sort descending - sort!(fi_pairs, by= x->-x[2]) + if m.feature_importance != :none + fi = feature_importance_options[m.feature_importance](tree) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + report = (feature_importances=fi_pairs,) + else + report = NamedTuple() + end + - report = (feature_importances=fi_pairs,) - # report = NamedTuple() return tree, cache, report end @@ -286,6 +317,7 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic n_subfeatures::Int = (-)(1)::(_ ≥ -1) n_trees::Int = 10::(_ ≥ 2) sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1) + feature_importance::Symbol = :impurity::(_ ∈ (:none, :impurity, :split)) #, :permutation)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end @@ -310,15 +342,15 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y) rng=m.rng) cache = nothing - - fi = DecisionTree.impurity_importance(forest) - fi_pairs = collect(Dict(zip(features, fi))) - # sort descending - sort!(fi_pairs, by= x->-x[2]) - - report = (feature_importances=fi_pairs,) - # report = NamedTuple() - + if m.feature_importance != :none + fi = feature_importance_options[m.feature_importance](forest) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + report = (feature_importances=fi_pairs,) + else + report = NamedTuple() + end return forest, cache, report end diff --git a/test/runtests.jl b/test/runtests.jl index d4d4b0b..29c94d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -142,7 +142,6 @@ fit!(m) rpt = MLJBase.report(m) @test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature - X, y = MLJBase.make_regression(rng=stable_rng()) rfr = RandomForestRegressor(rng=stable_rng()) m = machine(rfr, X, y) @@ -184,3 +183,35 @@ end @test reproducibility(model, X, y, loss) end end + + +# try testing model output for different feature_importance options +# 1. :none +X, y = make_regression(100,3; rng=stable_rng()); +dtr = DecisionTreeRegressor(feature_importance=:none, rng=stable_rng()) +rfr = RandomForestRegressor(feature_importance=:none, rng=stable_rng()) + +m = machine(dtr, X, y) +fit!(m) +rpt = MLJBase.report(m) +@test isempty(rpt) + +m = machine(rfr, X, y) +fit!(m) +rpt = MLJBase.report(m) +@test isempty(rpt) + + +# 2. :split +dtr = DecisionTreeRegressor(feature_importance=:split, rng=stable_rng()) +rfr = RandomForestRegressor(feature_importance=:split, rng=stable_rng()) + +m = machine(dtr, X, y) +fit!(m) +rpt = MLJBase.report(m) +@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature + +m = machine(rfr, X, y) +fit!(m) +rpt = MLJBase.report(m) +@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature From 0413626dd608e4678f303358cd0ca69bfb4cc95d Mon Sep 17 00:00:00 2001 From: John Waczak Date: Wed, 10 Aug 2022 17:15:18 -0500 Subject: [PATCH 05/11] updated to remove :none option, set reports_feature_importances to true, and define MMI.feature_importances function --- src/MLJDecisionTreeInterface.jl | 189 +++++++++++++++++++++----------- test/runtests.jl | 114 ++++++++++++------- 2 files changed, 194 insertions(+), 109 deletions(-) diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index 5e054e6..018806f 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -22,11 +22,11 @@ Base.show(stream::IO, c::TreePrinter) = print(stream, "TreePrinter object (call with display depth)") -# add a new variable to struct a la -const feature_importance_options = Dict(:impurity => DecisionTree.impurity_importance, - :split => DecisionTree.split_importance, - # :permutation => DecisionTree.permutation_importance, - ) +# # add a new variable to struct a la +# const feature_importance_options = Dict(:impurity => DecisionTree.impurity_importance, +# :split => DecisionTree.split_importance, +# # :permutation => DecisionTree.permutation_importance, +# ) # # DECISION TREE CLASSIFIER @@ -45,7 +45,7 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic post_prune::Bool = false merge_purity_threshold::Float64 = 1.0::(_ ≤ 1) display_depth::Int = 5::(_ ≥ 1) - feature_importance::Symbol = :impurity::(_ ∈ (:none, :impurity, :split)) #:permutation)) + feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end @@ -81,25 +81,10 @@ function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y) fitresult = (tree, classes_seen, integers_seen, features) cache = nothing - - # generate feature importances for report - if m.feature_importance != :none - fi = feature_importance_options[m.feature_importance](tree) - fi_pairs = collect(Dict(zip(features, fi))) - # sort descending - sort!(fi_pairs, by= x->-x[2]) - - report = (classes_seen=classes_seen, - print_tree=TreePrinter(tree), - features=features, - feature_importances=fi_pairs) - else - report = (classes_seen=classes_seen, - print_tree=TreePrinter(tree), - features=features, - ) - end - + report = (classes_seen=classes_seen, + print_tree=TreePrinter(tree), + features=features, + ) return fitresult, cache, report end @@ -132,6 +117,27 @@ function MMI.predict(m::DecisionTreeClassifier, fitresult, Xnew) return MMI.UnivariateFinite(classes_seen, scores) end +MMI.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true + +function MMI.feature_importances(m::DecisionTreeClassifier, fitresult, report) + # generate feature importances for report + if m.feature_importance == :impurity + feature_importance_func = DecisionTree.impurity_importance + elseif m.feature_importance == :split + feature_importance_func = DecisionTree.split_importance + end + + tree = fitresult[1] + features = fitresult[end] + fi = feature_importance_func(tree) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + return fi_pairs +end + + # # RANDOM FOREST CLASSIFIER @@ -143,7 +149,7 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic n_subfeatures::Int = (-)(1)::(_ ≥ -1) n_trees::Int = 10::(_ ≥ 2) sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1) - feature_importance::Symbol = :impurity::(_ ∈ (:none, :impurity, :split)) #, :permutation)) + feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end @@ -172,16 +178,7 @@ function MMI.fit(m::RandomForestClassifier, verbosity::Int, X, y) rng=m.rng) cache = nothing - if m.feature_importance != :none - fi = feature_importance_options[m.feature_importance](forest) - fi_pairs = collect(Dict(zip(features, fi))) - # sort descending - sort!(fi_pairs, by= x->-x[2]) - - report = (feature_importances=fi_pairs,) - else - report = NamedTuple() - end + report = (features=features,) return (forest, classes_seen, integers_seen), cache, report end @@ -195,12 +192,34 @@ function MMI.predict(m::RandomForestClassifier, fitresult, Xnew) return MMI.UnivariateFinite(classes_seen, scores) end +MMI.reports_feature_importances(::Type{<:RandomForestClassifier}) = true + +function MMI.feature_importances(m::RandomForestClassifier, fitresult, report) + # generate feature importances for report + if m.feature_importance == :impurity + feature_importance_func = DecisionTree.impurity_importance + elseif m.feature_importance == :split + feature_importance_func = DecisionTree.split_importance + end + + forest = fitresult[1] + features = report.features + fi = feature_importance_func(forest) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + return fi_pairs +end + + + # # ADA BOOST STUMP CLASSIFIER MMI.@mlj_model mutable struct AdaBoostStumpClassifier <: MMI.Probabilistic n_iter::Int = 10::(_ ≥ 1) - feature_importance::Symbol = :impurity::(_ ∈ (:none, :impurity, :split)) # , :permutation)) + feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end @@ -223,15 +242,7 @@ function MMI.fit(m::AdaBoostStumpClassifier, verbosity::Int, X, y) DT.build_adaboost_stumps(yplain, Xmatrix, m.n_iter, rng=m.rng) cache = nothing - if m.feature_importance != :none - fi = feature_importance_options[m.feature_importance](stumps, coefs) - fi_pairs = collect(Dict(zip(features, fi))) - # sort descending - sort!(fi_pairs, by= x->-x[2]) - report = (feature_importances=fi_pairs,) - else - report = NamedTuple() - end + report = (features=features,) return (stumps, coefs, classes_seen, integers_seen), cache, report end @@ -247,6 +258,27 @@ function MMI.predict(m::AdaBoostStumpClassifier, fitresult, Xnew) return MMI.UnivariateFinite(classes_seen, scores) end +MMI.reports_feature_importances(::Type{<:AdaBoostStumpClassifier}) = true + +function MMI.feature_importances(m::AdaBoostStumpClassifier, fitresult, report) + # generate feature importances for report + if m.feature_importance == :impurity + feature_importance_func = DecisionTree.impurity_importance + elseif m.feature_importance == :split + feature_importance_func = DecisionTree.split_importance + end + + stumps = fitresult[1] + coefs = fitresult[2] + features = report.features + fi = feature_importance_func(stumps, coefs) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + return fi_pairs +end + # # DECISION TREE REGRESSOR @@ -258,7 +290,7 @@ MMI.@mlj_model mutable struct DecisionTreeRegressor <: MMI.Deterministic n_subfeatures::Int = 0::(_ ≥ -1) post_prune::Bool = false merge_purity_threshold::Float64 = 1.0::(0 ≤ _ ≤ 1) - feature_importance::Symbol = :impurity::(_ ∈ (:none, :impurity, :split)) # , :permutation)) + feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end @@ -285,16 +317,7 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, X, y) end cache = nothing - if m.feature_importance != :none - fi = feature_importance_options[m.feature_importance](tree) - fi_pairs = collect(Dict(zip(features, fi))) - # sort descending - sort!(fi_pairs, by= x->-x[2]) - report = (feature_importances=fi_pairs,) - else - report = NamedTuple() - end - + report = (features=features,) return tree, cache, report end @@ -306,6 +329,26 @@ function MMI.predict(::DecisionTreeRegressor, tree, Xnew) return DT.apply_tree(tree, Xmatrix) end +MMI.reports_feature_importances(::Type{<:DecisionTreeRegressor}) = true + +function MMI.feature_importances(m::DecisionTreeRegressor, fitresult, report) + # generate feature importances for report + if m.feature_importance == :impurity + feature_importance_func = DecisionTree.impurity_importance + elseif m.feature_importance == :split + feature_importance_func = DecisionTree.split_importance + end + + tree = fitresult + features = report.features + fi = feature_importance_func(tree) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + return fi_pairs +end + # # RANDOM FOREST REGRESSOR @@ -317,7 +360,7 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic n_subfeatures::Int = (-)(1)::(_ ≥ -1) n_trees::Int = 10::(_ ≥ 2) sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1) - feature_importance::Symbol = :impurity::(_ ∈ (:none, :impurity, :split)) #, :permutation)) + feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split)) rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end @@ -341,16 +384,8 @@ function MMI.fit(m::RandomForestRegressor, verbosity::Int, X, y) m.min_purity_increase, rng=m.rng) cache = nothing + report = (features=features,) - if m.feature_importance != :none - fi = feature_importance_options[m.feature_importance](forest) - fi_pairs = collect(Dict(zip(features, fi))) - # sort descending - sort!(fi_pairs, by= x->-x[2]) - report = (feature_importances=fi_pairs,) - else - report = NamedTuple() - end return forest, cache, report end @@ -361,6 +396,26 @@ function MMI.predict(::RandomForestRegressor, forest, Xnew) return DT.apply_forest(forest, Xmatrix) end +MMI.reports_feature_importances(::Type{<:RandomForestRegressor}) = true + +function MMI.feature_importances(m::RandomForestRegressor, fitresult, report) + # generate feature importances for report + if m.feature_importance == :impurity + feature_importance_func = DecisionTree.impurity_importance + elseif m.feature_importance == :split + feature_importance_func = DecisionTree.split_importance + end + + forest = fitresult + features = report.features + fi = feature_importance_func(forest) + fi_pairs = collect(Dict(zip(features, fi))) + # sort descending + sort!(fi_pairs, by= x->-x[2]) + + return fi_pairs +end + # # METADATA (MODEL TRAITS) diff --git a/test/runtests.jl b/test/runtests.jl index 29c94d0..20c7c2d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,12 +41,10 @@ yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3)) @test MLJBase.classes(yyhat[1]) == MLJBase.classes(y[1]) # check report and fitresult fields: -@test Set([:classes_seen, :print_tree, :features, :feature_importances]) == Set(keys(report)) +@test Set([:classes_seen, :print_tree, :features]) == Set(keys(report)) @test Set(report.classes_seen) == Set(levels(y)) @test report.print_tree(2) === nothing # :-( @test report.features == [:sepal_length, :sepal_width, :petal_length, :petal_width] -# check feature_importances -@test size(report.features,1) == size(report.feature_importances, 1) fp = fitted_params(baretree, fitresult) @test Set([:tree, :encoding, :features]) == Set(keys(fp)) @@ -113,34 +111,34 @@ m = machine(rfc, X, y) fit!(m) @test accuracy(predict_mode(m, X), y) > 0.95 # check feature_importances -rpt = MLJBase.report(m) -@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature +# rpt = MLJBase.report(m) +# @test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature m = machine(abs, X, y) fit!(m) @test accuracy(predict_mode(m, X), y) > 0.95 # check feature_importances -rpt = MLJBase.report(m) -@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature +# rpt = MLJBase.report(m) +# @test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature # test DecisionTreeRegressor and RandomForestRegressor -X, y = make_regression(100,3; rng=stable_rng()); -dtr = DecisionTreeRegressor(rng=stable_rng()) -rfr = RandomForestRegressor(rng=stable_rng()) +# X, y = make_regression(100,3; rng=stable_rng()); +# dtr = DecisionTreeRegressor(rng=stable_rng()) +# rfr = RandomForestRegressor(rng=stable_rng()) -m = machine(dtr, X, y) -fit!(m) -rpt = MLJBase.report(m) -@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature +# m = machine(dtr, X, y) +# fit!(m) +# rpt = MLJBase.report(m) +# @test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature -m = machine(rfr, X, y) -fit!(m) -rpt = MLJBase.report(m) -@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature +# m = machine(rfr, X, y) +# fit!(m) +# rpt = MLJBase.report(m) +# @test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature X, y = MLJBase.make_regression(rng=stable_rng()) rfr = RandomForestRegressor(rng=stable_rng()) @@ -185,33 +183,65 @@ end end -# try testing model output for different feature_importance options -# 1. :none -X, y = make_regression(100,3; rng=stable_rng()); -dtr = DecisionTreeRegressor(feature_importance=:none, rng=stable_rng()) -rfr = RandomForestRegressor(feature_importance=:none, rng=stable_rng()) +@testset "impurity importance" begin + X, y = MLJBase.make_blobs(100, 3; rng=stable_rng()) -m = machine(dtr, X, y) -fit!(m) -rpt = MLJBase.report(m) -@test isempty(rpt) + for model ∈ [ + DecisionTreeClassifier(), + RandomForestClassifier(), + AdaBoostStumpClassifier(), + ] + m = machine(model, X, y) + fit!(m) + rpt = MLJBase.report(m) + fi = MLJBase.feature_importances(model, m.fitresult, rpt) + @test size(fi,1) == 3 + end -m = machine(rfr, X, y) -fit!(m) -rpt = MLJBase.report(m) -@test isempty(rpt) + X, y = make_regression(100,3; rng=stable_rng()); + for model in [ + DecisionTreeRegressor(), + RandomForestRegressor(), + ] + m = machine(model, X, y) + fit!(m) + rpt = MLJBase.report(m) + fi = MLJBase.feature_importances(model, m.fitresult, rpt) + @test size(fi,1) == 3 + end +end -# 2. :split -dtr = DecisionTreeRegressor(feature_importance=:split, rng=stable_rng()) -rfr = RandomForestRegressor(feature_importance=:split, rng=stable_rng()) -m = machine(dtr, X, y) -fit!(m) -rpt = MLJBase.report(m) -@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature +@testset "split importance" begin + X, y = MLJBase.make_blobs(100, 3; rng=stable_rng()) + + for model ∈ [ + DecisionTreeClassifier(feature_importance=:split), + RandomForestClassifier(feature_importance=:split), + AdaBoostStumpClassifier(feature_importance=:split), + ] + m = machine(model, X, y) + fit!(m) + rpt = MLJBase.report(m) + fi = MLJBase.feature_importances(model, m.fitresult, rpt) + + println(fi) + @test size(fi,1) == 3 + end + + + X, y = make_regression(100,3; rng=stable_rng()); + for model in [ + DecisionTreeRegressor(feature_importance=:split), + RandomForestRegressor(feature_importance=:split), + ] + m = machine(model, X, y) + fit!(m) + rpt = MLJBase.report(m) + fi = MLJBase.feature_importances(model, m.fitresult, rpt) + @test size(fi,1) == 3 + end +end + -m = machine(rfr, X, y) -fit!(m) -rpt = MLJBase.report(m) -@test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature From 70d146008da5dc0a1f8c0384978643a35b4f28bc Mon Sep 17 00:00:00 2001 From: John Waczak Date: Wed, 10 Aug 2022 17:43:48 -0500 Subject: [PATCH 06/11] removed print statement and updated docstrings --- src/MLJDecisionTreeInterface.jl | 30 ++++++++++++++++++++---------- test/runtests.jl | 2 -- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index 018806f..406ee2f 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -122,9 +122,9 @@ MMI.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true function MMI.feature_importances(m::DecisionTreeClassifier, fitresult, report) # generate feature importances for report if m.feature_importance == :impurity - feature_importance_func = DecisionTree.impurity_importance + feature_importance_func = DT.impurity_importance elseif m.feature_importance == :split - feature_importance_func = DecisionTree.split_importance + feature_importance_func = DT.split_importance end tree = fitresult[1] @@ -197,9 +197,9 @@ MMI.reports_feature_importances(::Type{<:RandomForestClassifier}) = true function MMI.feature_importances(m::RandomForestClassifier, fitresult, report) # generate feature importances for report if m.feature_importance == :impurity - feature_importance_func = DecisionTree.impurity_importance + feature_importance_func = DT.impurity_importance elseif m.feature_importance == :split - feature_importance_func = DecisionTree.split_importance + feature_importance_func = DT.split_importance end forest = fitresult[1] @@ -263,9 +263,9 @@ MMI.reports_feature_importances(::Type{<:AdaBoostStumpClassifier}) = true function MMI.feature_importances(m::AdaBoostStumpClassifier, fitresult, report) # generate feature importances for report if m.feature_importance == :impurity - feature_importance_func = DecisionTree.impurity_importance + feature_importance_func = DT.impurity_importance elseif m.feature_importance == :split - feature_importance_func = DecisionTree.split_importance + feature_importance_func = DT.split_importance end stumps = fitresult[1] @@ -334,9 +334,9 @@ MMI.reports_feature_importances(::Type{<:DecisionTreeRegressor}) = true function MMI.feature_importances(m::DecisionTreeRegressor, fitresult, report) # generate feature importances for report if m.feature_importance == :impurity - feature_importance_func = DecisionTree.impurity_importance + feature_importance_func = DT.impurity_importance elseif m.feature_importance == :split - feature_importance_func = DecisionTree.split_importance + feature_importance_func = DT.split_importance end tree = fitresult @@ -401,9 +401,9 @@ MMI.reports_feature_importances(::Type{<:RandomForestRegressor}) = true function MMI.feature_importances(m::RandomForestRegressor, fitresult, report) # generate feature importances for report if m.feature_importance == :impurity - feature_importance_func = DecisionTree.impurity_importance + feature_importance_func = DT.impurity_importance elseif m.feature_importance == :split - feature_importance_func = DecisionTree.split_importance + feature_importance_func = DT.split_importance end forest = fitresult @@ -535,6 +535,8 @@ Train the machine using `fit!(mach, rows=...)`. - `display_depth=5`: max depth to show when displaying the tree +- `feature_importance`: method to use for computing feature importances + - `rng=Random.GLOBAL_RNG`: random number generator or seed @@ -668,6 +670,8 @@ Train the machine with `fit!(mach, rows=...)`. - `sampling_fraction=0.7` fraction of samples to train each tree on +- `feature_importance`: method to use for computing feature importances + - `rng=Random.GLOBAL_RNG`: random number generator or seed @@ -743,6 +747,8 @@ Train the machine with `fit!(mach, rows=...)`. - `n_iter=10`: number of iterations of AdaBoost +- `feature_importance`: method to use for computing feature importances + - `rng=Random.GLOBAL_RNG`: random number generator or seed # Operations @@ -834,6 +840,8 @@ Train the machine with `fit!(mach, rows=...)`. - `merge_purity_threshold=1.0`: (post-pruning) merge leaves having combined purity `>= merge_purity_threshold` +- `feature_importance`: method to use for computing feature importances + - `rng=Random.GLOBAL_RNG`: random number generator or seed @@ -916,6 +924,8 @@ Train the machine with `fit!(mach, rows=...)`. - `sampling_fraction=0.7` fraction of samples to train each tree on +- `feature_importance`: method to use for computing feature importances + - `rng=Random.GLOBAL_RNG`: random number generator or seed diff --git a/test/runtests.jl b/test/runtests.jl index 20c7c2d..cdb1a9e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -225,8 +225,6 @@ end fit!(m) rpt = MLJBase.report(m) fi = MLJBase.feature_importances(model, m.fitresult, rpt) - - println(fi) @test size(fi,1) == 3 end From bbf074acb7483c11bd1699e0b627002b3223e567 Mon Sep 17 00:00:00 2001 From: John Waczak Date: Wed, 10 Aug 2022 18:17:58 -0500 Subject: [PATCH 07/11] remove hash map in favor of Pair function --- src/MLJDecisionTreeInterface.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index 406ee2f..28b862e 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -130,7 +130,7 @@ function MMI.feature_importances(m::DecisionTreeClassifier, fitresult, report) tree = fitresult[1] features = fitresult[end] fi = feature_importance_func(tree) - fi_pairs = collect(Dict(zip(features, fi))) + fi_pairs = Pair.(features, fi) # sort descending sort!(fi_pairs, by= x->-x[2]) @@ -205,7 +205,7 @@ function MMI.feature_importances(m::RandomForestClassifier, fitresult, report) forest = fitresult[1] features = report.features fi = feature_importance_func(forest) - fi_pairs = collect(Dict(zip(features, fi))) + fi_pairs = Pair.(features, fi) # sort descending sort!(fi_pairs, by= x->-x[2]) @@ -272,7 +272,7 @@ function MMI.feature_importances(m::AdaBoostStumpClassifier, fitresult, report) coefs = fitresult[2] features = report.features fi = feature_importance_func(stumps, coefs) - fi_pairs = collect(Dict(zip(features, fi))) + fi_pairs = Pair.(features, fi) # sort descending sort!(fi_pairs, by= x->-x[2]) @@ -342,7 +342,7 @@ function MMI.feature_importances(m::DecisionTreeRegressor, fitresult, report) tree = fitresult features = report.features fi = feature_importance_func(tree) - fi_pairs = collect(Dict(zip(features, fi))) + fi_pairs = Pair.(features, fi) # sort descending sort!(fi_pairs, by= x->-x[2]) @@ -409,7 +409,7 @@ function MMI.feature_importances(m::RandomForestRegressor, fitresult, report) forest = fitresult features = report.features fi = feature_importance_func(forest) - fi_pairs = collect(Dict(zip(features, fi))) + fi_pairs = Pair.(features, fi) # sort descending sort!(fi_pairs, by= x->-x[2]) From eae1f3fb68900d50e4886c999b8f3bb5cb2938cc Mon Sep 17 00:00:00 2001 From: John Waczak Date: Wed, 10 Aug 2022 21:19:14 -0500 Subject: [PATCH 08/11] apply normlize=true to importances --- src/MLJDecisionTreeInterface.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index 28b862e..f1fdc44 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -129,7 +129,7 @@ function MMI.feature_importances(m::DecisionTreeClassifier, fitresult, report) tree = fitresult[1] features = fitresult[end] - fi = feature_importance_func(tree) + fi = feature_importance_func(tree, normalize=true) fi_pairs = Pair.(features, fi) # sort descending sort!(fi_pairs, by= x->-x[2]) @@ -204,7 +204,7 @@ function MMI.feature_importances(m::RandomForestClassifier, fitresult, report) forest = fitresult[1] features = report.features - fi = feature_importance_func(forest) + fi = feature_importance_func(forest, normalize=true) fi_pairs = Pair.(features, fi) # sort descending sort!(fi_pairs, by= x->-x[2]) @@ -271,7 +271,7 @@ function MMI.feature_importances(m::AdaBoostStumpClassifier, fitresult, report) stumps = fitresult[1] coefs = fitresult[2] features = report.features - fi = feature_importance_func(stumps, coefs) + fi = feature_importance_func(stumps, coefs, normalize=true) fi_pairs = Pair.(features, fi) # sort descending sort!(fi_pairs, by= x->-x[2]) @@ -341,7 +341,7 @@ function MMI.feature_importances(m::DecisionTreeRegressor, fitresult, report) tree = fitresult features = report.features - fi = feature_importance_func(tree) + fi = feature_importance_func(tree, normalize=true) fi_pairs = Pair.(features, fi) # sort descending sort!(fi_pairs, by= x->-x[2]) @@ -408,7 +408,7 @@ function MMI.feature_importances(m::RandomForestRegressor, fitresult, report) forest = fitresult features = report.features - fi = feature_importance_func(forest) + fi = feature_importance_func(forest, normalize=true) fi_pairs = Pair.(features, fi) # sort descending sort!(fi_pairs, by= x->-x[2]) From b8e136a756661a8eafde360ee6a7e9a8e0975b74 Mon Sep 17 00:00:00 2001 From: John Waczak Date: Wed, 10 Aug 2022 21:46:28 -0500 Subject: [PATCH 09/11] reduced code repetition --- src/MLJDecisionTreeInterface.jl | 96 +++++---------------------------- 1 file changed, 12 insertions(+), 84 deletions(-) diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index f1fdc44..acbc4b6 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -32,8 +32,7 @@ Base.show(stream::IO, c::TreePrinter) = # # DECISION TREE CLASSIFIER # The following meets the MLJ standard for a `Model` docstring and is -# created without the use of interpolation so it can be used a -# template for authors of other MLJ model interfaces. The other +# created without the use of interpolation so it can be used a # template for authors of other MLJ model interfaces. The other # doc-strings, defined later, are generated using the `doc_header` # utility to automatically generate the header, another option. MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic @@ -49,9 +48,6 @@ MMI.@mlj_model mutable struct DecisionTreeClassifier <: MMI.Probabilistic rng::Union{AbstractRNG,Integer} = GLOBAL_RNG end - - - function MMI.fit(m::DecisionTreeClassifier, verbosity::Int, X, y) schema = Tables.schema(X) Xmatrix = MMI.matrix(X) @@ -119,25 +115,6 @@ end MMI.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true -function MMI.feature_importances(m::DecisionTreeClassifier, fitresult, report) - # generate feature importances for report - if m.feature_importance == :impurity - feature_importance_func = DT.impurity_importance - elseif m.feature_importance == :split - feature_importance_func = DT.split_importance - end - - tree = fitresult[1] - features = fitresult[end] - fi = feature_importance_func(tree, normalize=true) - fi_pairs = Pair.(features, fi) - # sort descending - sort!(fi_pairs, by= x->-x[2]) - - return fi_pairs -end - - # # RANDOM FOREST CLASSIFIER @@ -194,26 +171,6 @@ end MMI.reports_feature_importances(::Type{<:RandomForestClassifier}) = true -function MMI.feature_importances(m::RandomForestClassifier, fitresult, report) - # generate feature importances for report - if m.feature_importance == :impurity - feature_importance_func = DT.impurity_importance - elseif m.feature_importance == :split - feature_importance_func = DT.split_importance - end - - forest = fitresult[1] - features = report.features - fi = feature_importance_func(forest, normalize=true) - fi_pairs = Pair.(features, fi) - # sort descending - sort!(fi_pairs, by= x->-x[2]) - - return fi_pairs -end - - - # # ADA BOOST STUMP CLASSIFIER @@ -260,25 +217,6 @@ end MMI.reports_feature_importances(::Type{<:AdaBoostStumpClassifier}) = true -function MMI.feature_importances(m::AdaBoostStumpClassifier, fitresult, report) - # generate feature importances for report - if m.feature_importance == :impurity - feature_importance_func = DT.impurity_importance - elseif m.feature_importance == :split - feature_importance_func = DT.split_importance - end - - stumps = fitresult[1] - coefs = fitresult[2] - features = report.features - fi = feature_importance_func(stumps, coefs, normalize=true) - fi_pairs = Pair.(features, fi) - # sort descending - sort!(fi_pairs, by= x->-x[2]) - - return fi_pairs -end - # # DECISION TREE REGRESSOR @@ -331,24 +269,6 @@ end MMI.reports_feature_importances(::Type{<:DecisionTreeRegressor}) = true -function MMI.feature_importances(m::DecisionTreeRegressor, fitresult, report) - # generate feature importances for report - if m.feature_importance == :impurity - feature_importance_func = DT.impurity_importance - elseif m.feature_importance == :split - feature_importance_func = DT.split_importance - end - - tree = fitresult - features = report.features - fi = feature_importance_func(tree, normalize=true) - fi_pairs = Pair.(features, fi) - # sort descending - sort!(fi_pairs, by= x->-x[2]) - - return fi_pairs -end - # # RANDOM FOREST REGRESSOR @@ -398,7 +318,15 @@ end MMI.reports_feature_importances(::Type{<:RandomForestRegressor}) = true -function MMI.feature_importances(m::RandomForestRegressor, fitresult, report) + +# # Feature Importances + +# get actual arguments needed for importance calculation from various fitresults. +get_fitresult(m::Union{DecisionTreeClassifier, RandomForestClassifier}, fitresult) = (fitresult[1],) +get_fitresult(m::Union{DecisionTreeRegressor, RandomForestRegressor}, fitresult) = (fitresult,) +get_fitresult(m::AdaBoostStumpClassifier, fitresult)= (fitresult[1], fitresult[2]) + +function MMI.feature_importances(m::Union{DecisionTreeClassifier, RandomForestClassifier, AdaBoostStumpClassifier, DecisionTreeRegressor, RandomForestRegressor}, fitresult, report) # generate feature importances for report if m.feature_importance == :impurity feature_importance_func = DT.impurity_importance @@ -406,9 +334,9 @@ function MMI.feature_importances(m::RandomForestRegressor, fitresult, report) feature_importance_func = DT.split_importance end - forest = fitresult + mdl = get_fitresult(m, fitresult) features = report.features - fi = feature_importance_func(forest, normalize=true) + fi = feature_importance_func(mdl..., normalize=true) fi_pairs = Pair.(features, fi) # sort descending sort!(fi_pairs, by= x->-x[2]) From 75b8f7bd3536b850b5411fcfb8c6e8a99ebeed40 Mon Sep 17 00:00:00 2001 From: John Waczak Date: Thu, 11 Aug 2022 08:07:22 -0500 Subject: [PATCH 10/11] cleaned up comments and added new test for feature importance existence --- src/MLJDecisionTreeInterface.jl | 7 ------ test/runtests.jl | 41 +++++++++++++-------------------- 2 files changed, 16 insertions(+), 32 deletions(-) diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index acbc4b6..28e20f2 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -22,13 +22,6 @@ Base.show(stream::IO, c::TreePrinter) = print(stream, "TreePrinter object (call with display depth)") -# # add a new variable to struct a la -# const feature_importance_options = Dict(:impurity => DecisionTree.impurity_importance, -# :split => DecisionTree.split_importance, -# # :permutation => DecisionTree.permutation_importance, -# ) - - # # DECISION TREE CLASSIFIER # The following meets the MLJ standard for a `Model` docstring and is diff --git a/test/runtests.jl b/test/runtests.jl index cdb1a9e..14553f2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -110,35 +110,10 @@ X, y = MLJBase.make_blobs(100, 3; rng=stable_rng()) m = machine(rfc, X, y) fit!(m) @test accuracy(predict_mode(m, X), y) > 0.95 -# check feature_importances -# rpt = MLJBase.report(m) -# @test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature - m = machine(abs, X, y) fit!(m) @test accuracy(predict_mode(m, X), y) > 0.95 -# check feature_importances -# rpt = MLJBase.report(m) -# @test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature - - - - -# test DecisionTreeRegressor and RandomForestRegressor -# X, y = make_regression(100,3; rng=stable_rng()); -# dtr = DecisionTreeRegressor(rng=stable_rng()) -# rfr = RandomForestRegressor(rng=stable_rng()) - -# m = machine(dtr, X, y) -# fit!(m) -# rpt = MLJBase.report(m) -# @test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature - -# m = machine(rfr, X, y) -# fit!(m) -# rpt = MLJBase.report(m) -# @test size(rpt.feature_importances, 1) == 3 # make sure we get an importance for each feature X, y = MLJBase.make_regression(rng=stable_rng()) rfr = RandomForestRegressor(rng=stable_rng()) @@ -183,7 +158,23 @@ end end +@testset "feature importance defined" begin + for model ∈ [ + DecisionTreeClassifier(), + RandomForestClassifier(), + AdaBoostStumpClassifier(), + DecisionTreeRegressor(), + RandomForestRegressor(), + ] + + @test reports_feature_importances(model) == true + end +end + + + @testset "impurity importance" begin + X, y = MLJBase.make_blobs(100, 3; rng=stable_rng()) for model ∈ [ From 0b3b19fdfcb028e7e8476886f537a0f3d6181566 Mon Sep 17 00:00:00 2001 From: John Waczak Date: Thu, 11 Aug 2022 08:39:45 -0500 Subject: [PATCH 11/11] updated docstrings with options --- src/MLJDecisionTreeInterface.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/MLJDecisionTreeInterface.jl b/src/MLJDecisionTreeInterface.jl index 28e20f2..35ce4d4 100644 --- a/src/MLJDecisionTreeInterface.jl +++ b/src/MLJDecisionTreeInterface.jl @@ -456,7 +456,7 @@ Train the machine using `fit!(mach, rows=...)`. - `display_depth=5`: max depth to show when displaying the tree -- `feature_importance`: method to use for computing feature importances +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)` - `rng=Random.GLOBAL_RNG`: random number generator or seed @@ -591,7 +591,7 @@ Train the machine with `fit!(mach, rows=...)`. - `sampling_fraction=0.7` fraction of samples to train each tree on -- `feature_importance`: method to use for computing feature importances +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)` - `rng=Random.GLOBAL_RNG`: random number generator or seed @@ -668,7 +668,7 @@ Train the machine with `fit!(mach, rows=...)`. - `n_iter=10`: number of iterations of AdaBoost -- `feature_importance`: method to use for computing feature importances +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)` - `rng=Random.GLOBAL_RNG`: random number generator or seed @@ -761,7 +761,7 @@ Train the machine with `fit!(mach, rows=...)`. - `merge_purity_threshold=1.0`: (post-pruning) merge leaves having combined purity `>= merge_purity_threshold` -- `feature_importance`: method to use for computing feature importances +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)` - `rng=Random.GLOBAL_RNG`: random number generator or seed @@ -845,7 +845,7 @@ Train the machine with `fit!(mach, rows=...)`. - `sampling_fraction=0.7` fraction of samples to train each tree on -- `feature_importance`: method to use for computing feature importances +- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)` - `rng=Random.GLOBAL_RNG`: random number generator or seed