From 8be26e14332c5fc8b7c72cf3e145dd80689731ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Sat, 3 Dec 2022 16:37:44 +0100 Subject: [PATCH 1/9] Round thresholds in `printnode` --- src/abstract_trees.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index 7921f7ef..7b0b08a4 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -108,16 +108,19 @@ For the condition of the form `feature < value` which gets printed in the `print variant for `InfoNode`, the left subtree is the 'yes-branch' and the right subtree accordingly the 'no-branch'. `AbstractTrees.print_tree` outputs the left subtree first and then below the right subtree. + +`value` gets rounded to `sigdigits` significant digits. """ -function AbstractTrees.printnode(io::IO, node::InfoNode) +function AbstractTrees.printnode(io::IO, node::InfoNode; sigdigits=4) + featval = round(node.node.featval; sigdigits) if :featurenames ∈ keys(node.info) - print(io, node.info.featurenames[node.node.featid], " < ", node.node.featval) + print(io, node.info.featurenames[node.node.featid], " < ", featval) else - print(io, "Feature: ", node.node.featid, " < ", node.node.featval) + print(io, "Feature: ", node.node.featid, " < ", featval) end end -function AbstractTrees.printnode(io::IO, leaf::InfoLeaf) +function AbstractTrees.printnode(io::IO, leaf::InfoLeaf; sigdigits=4) dt_leaf = leaf.leaf matches = findall(dt_leaf.values .== dt_leaf.majority) match_count = length(matches) From 94b1110da7237a82cf5b1c93b828974074bdc5aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Sat, 3 Dec 2022 16:54:40 +0100 Subject: [PATCH 2/9] Clarify use of generic type `T` in nodes/leaves --- src/abstract_trees.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index 7b0b08a4..a411cab0 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -28,6 +28,10 @@ apart from the two points mentioned. In analogy to the type definitions of `DecisionTree`, the generic type `S` is the type of the feature values used within a node as a threshold for the splits between its children and `T` is the type of the classes given (these might be ids or labels). + +Please note: You may only add lacking class labels. It's not possible to overwrite +existing labels with this mechanism. In case you want add class labels, +the generic type `T` must be a subtype of `Integer`. """ struct InfoNode{S, T} <: AbstractTrees.AbstractNode{DecisionTree.Node{S,T}} node :: DecisionTree.Node{S, T} @@ -126,6 +130,7 @@ function AbstractTrees.printnode(io::IO, leaf::InfoLeaf; sigdigits=4) match_count = length(matches) val_count = length(dt_leaf.values) if :classlabels ∈ keys(leaf.info) + @assert typeof(dt_leaf.majority) <: Integer "classes must be represented as Integers" print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)") else print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)") From 7a95a638ddc47fed32fdc5e8042934e72a418506 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Sat, 3 Dec 2022 17:02:53 +0100 Subject: [PATCH 3/9] Beautify printing of leafs in `printnode` --- src/abstract_trees.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index a411cab0..3d725767 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -133,6 +133,7 @@ function AbstractTrees.printnode(io::IO, leaf::InfoLeaf; sigdigits=4) @assert typeof(dt_leaf.majority) <: Integer "classes must be represented as Integers" print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)") else - print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)") + print(io, typeof(dt_leaf.majority) <: Integer ? "Class: " : "", + dt_leaf.majority, " ($match_count/$val_count)") end end From 93bf743765494c7baa4cafda5085eaf49b6a65aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Sat, 3 Dec 2022 19:35:23 +0100 Subject: [PATCH 4/9] Update src/abstract_trees.jl Co-authored-by: Rik Huijzer --- src/abstract_trees.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index 3d725767..9a7427f8 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -130,7 +130,7 @@ function AbstractTrees.printnode(io::IO, leaf::InfoLeaf; sigdigits=4) match_count = length(matches) val_count = length(dt_leaf.values) if :classlabels ∈ keys(leaf.info) - @assert typeof(dt_leaf.majority) <: Integer "classes must be represented as Integers" + @assert dt_leaf.majority isa Integer "classes must be represented as Integers" print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)") else print(io, typeof(dt_leaf.majority) <: Integer ? "Class: " : "", From 3e829d64507f8104aa1dc16f338f49201aaea847 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Sat, 3 Dec 2022 19:35:33 +0100 Subject: [PATCH 5/9] Update src/abstract_trees.jl Co-authored-by: Rik Huijzer --- src/abstract_trees.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index 9a7427f8..933c7e96 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -133,7 +133,7 @@ function AbstractTrees.printnode(io::IO, leaf::InfoLeaf; sigdigits=4) @assert dt_leaf.majority isa Integer "classes must be represented as Integers" print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)") else - print(io, typeof(dt_leaf.majority) <: Integer ? "Class: " : "", + print(io, dt_leaf.majority isa Integer ? "Class: " : "", dt_leaf.majority, " ($match_count/$val_count)") end end From b667b920359f589dc1b67dec3a88a28dc5345bcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Sat, 3 Dec 2022 19:36:34 +0100 Subject: [PATCH 6/9] Update src/abstract_trees.jl Co-authored-by: Rik Huijzer --- src/abstract_trees.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index 933c7e96..e0d36c23 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -29,9 +29,10 @@ In analogy to the type definitions of `DecisionTree`, the generic type `S` is the type of the feature values used within a node as a threshold for the splits between its children and `T` is the type of the classes given (these might be ids or labels). -Please note: You may only add lacking class labels. It's not possible to overwrite -existing labels with this mechanism. In case you want add class labels, -the generic type `T` must be a subtype of `Integer`. + !!! note + You may only add lacking class labels. It's not possible to overwrite existing labels + with this mechanism. In case you want add class labels, the generic type `T` must + be a subtype of `Integer`. """ struct InfoNode{S, T} <: AbstractTrees.AbstractNode{DecisionTree.Node{S,T}} node :: DecisionTree.Node{S, T} From 42fd264408684e9d9a91e79cdb932f3f68b18150 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Mon, 5 Dec 2022 12:19:25 +0100 Subject: [PATCH 7/9] Add `sigdigits` kw argument to docstring --- src/abstract_trees.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index e0d36c23..6b2ac672 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -94,8 +94,8 @@ AbstractTrees.children(node::InfoNode) = ( AbstractTrees.children(node::InfoLeaf) = () """ - printnode(io::IO, node::InfoNode) - printnode(io::IO, leaf::InfoLeaf) + printnode(io::IO, node::InfoNode; sigdigits=4) + printnode(io::IO, leaf::InfoLeaf; sigdigits=4) Write a printable representation of `node` or `leaf` to output-stream `io`. From d7fccfa2621b81ea9d8885d0c720b11884a73451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Tue, 6 Dec 2022 11:52:34 +0100 Subject: [PATCH 8/9] Add test of misuse of `wrap` function --- test/miscellaneous/abstract_trees_test.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/miscellaneous/abstract_trees_test.jl b/test/miscellaneous/abstract_trees_test.jl index a1bdd141..b17bedcb 100644 --- a/test/miscellaneous/abstract_trees_test.jl +++ b/test/miscellaneous/abstract_trees_test.jl @@ -81,4 +81,23 @@ end traverse_tree(leaf::InfoLeaf) = nothing traverse_tree(wrapped_tree) +end + +@testset "abstract_trees - test misuse" begin + + @info("Test misuse of `classlabel` information") + + @info("Create test data - a decision tree based on the iris data set") + features, labels = load_data("iris") + features = float.(features) + labels = string.(labels) + model = DecisionTreeClassifier() + fit!(model, features, labels) + + @info("Try to replace the exisitng class labels") + class_labels = unique(labels) + dtree = model.root.node + wt = DecisionTree.wrap(dtree, (classlabels = class_labels,)) + @test_throws AssertionError AbstractTrees.print_tree(wt) + end \ No newline at end of file From 9062bd15868715d8f40a10ac3655e0d73e523236 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 7 Dec 2022 10:25:59 +1300 Subject: [PATCH 9/9] bump 0.12.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9a689a5a..c29932ca 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "DecisionTree" uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" license = "MIT" desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms" -version = "0.12.0" +version = "0.12.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"