diff --git a/Project.toml b/Project.toml index 9a689a5a..c29932ca 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "DecisionTree" uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" license = "MIT" desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms" -version = "0.12.0" +version = "0.12.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index 7921f7ef..6b2ac672 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -28,6 +28,11 @@ apart from the two points mentioned. In analogy to the type definitions of `DecisionTree`, the generic type `S` is the type of the feature values used within a node as a threshold for the splits between its children and `T` is the type of the classes given (these might be ids or labels). + + !!! note + You may only add lacking class labels. It's not possible to overwrite existing labels + with this mechanism. In case you want add class labels, the generic type `T` must + be a subtype of `Integer`. """ struct InfoNode{S, T} <: AbstractTrees.AbstractNode{DecisionTree.Node{S,T}} node :: DecisionTree.Node{S, T} @@ -89,8 +94,8 @@ AbstractTrees.children(node::InfoNode) = ( AbstractTrees.children(node::InfoLeaf) = () """ - printnode(io::IO, node::InfoNode) - printnode(io::IO, leaf::InfoLeaf) + printnode(io::IO, node::InfoNode; sigdigits=4) + printnode(io::IO, leaf::InfoLeaf; sigdigits=4) Write a printable representation of `node` or `leaf` to output-stream `io`. @@ -108,23 +113,28 @@ For the condition of the form `feature < value` which gets printed in the `print variant for `InfoNode`, the left subtree is the 'yes-branch' and the right subtree accordingly the 'no-branch'. `AbstractTrees.print_tree` outputs the left subtree first and then below the right subtree. + +`value` gets rounded to `sigdigits` significant digits. """ -function AbstractTrees.printnode(io::IO, node::InfoNode) +function AbstractTrees.printnode(io::IO, node::InfoNode; sigdigits=4) + featval = round(node.node.featval; sigdigits) if :featurenames ∈ keys(node.info) - print(io, node.info.featurenames[node.node.featid], " < ", node.node.featval) + print(io, node.info.featurenames[node.node.featid], " < ", featval) else - print(io, "Feature: ", node.node.featid, " < ", node.node.featval) + print(io, "Feature: ", node.node.featid, " < ", featval) end end -function AbstractTrees.printnode(io::IO, leaf::InfoLeaf) +function AbstractTrees.printnode(io::IO, leaf::InfoLeaf; sigdigits=4) dt_leaf = leaf.leaf matches = findall(dt_leaf.values .== dt_leaf.majority) match_count = length(matches) val_count = length(dt_leaf.values) if :classlabels ∈ keys(leaf.info) + @assert dt_leaf.majority isa Integer "classes must be represented as Integers" print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)") else - print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)") + print(io, dt_leaf.majority isa Integer ? "Class: " : "", + dt_leaf.majority, " ($match_count/$val_count)") end end diff --git a/test/miscellaneous/abstract_trees_test.jl b/test/miscellaneous/abstract_trees_test.jl index a1bdd141..b17bedcb 100644 --- a/test/miscellaneous/abstract_trees_test.jl +++ b/test/miscellaneous/abstract_trees_test.jl @@ -81,4 +81,23 @@ end traverse_tree(leaf::InfoLeaf) = nothing traverse_tree(wrapped_tree) +end + +@testset "abstract_trees - test misuse" begin + + @info("Test misuse of `classlabel` information") + + @info("Create test data - a decision tree based on the iris data set") + features, labels = load_data("iris") + features = float.(features) + labels = string.(labels) + model = DecisionTreeClassifier() + fit!(model, features, labels) + + @info("Try to replace the exisitng class labels") + class_labels = unique(labels) + dtree = model.root.node + wt = DecisionTree.wrap(dtree, (classlabels = class_labels,)) + @test_throws AssertionError AbstractTrees.print_tree(wt) + end \ No newline at end of file