Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name = "DecisionTree"
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
license = "MIT"
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
version = "0.12.0"
version = "0.12.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
24 changes: 17 additions & 7 deletions src/abstract_trees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ apart from the two points mentioned.
In analogy to the type definitions of `DecisionTree`, the generic type `S` is
the type of the feature values used within a node as a threshold for the splits
between its children and `T` is the type of the classes given (these might be ids or labels).

!!! note
You may only add lacking class labels. It's not possible to overwrite existing labels
with this mechanism. In case you want add class labels, the generic type `T` must
be a subtype of `Integer`.
"""
struct InfoNode{S, T} <: AbstractTrees.AbstractNode{DecisionTree.Node{S,T}}
node :: DecisionTree.Node{S, T}
Expand Down Expand Up @@ -89,8 +94,8 @@ AbstractTrees.children(node::InfoNode) = (
AbstractTrees.children(node::InfoLeaf) = ()

"""
printnode(io::IO, node::InfoNode)
printnode(io::IO, leaf::InfoLeaf)
printnode(io::IO, node::InfoNode; sigdigits=4)
printnode(io::IO, leaf::InfoLeaf; sigdigits=4)

Write a printable representation of `node` or `leaf` to output-stream `io`.

Expand All @@ -108,23 +113,28 @@ For the condition of the form `feature < value` which gets printed in the `print
variant for `InfoNode`, the left subtree is the 'yes-branch' and the right subtree
accordingly the 'no-branch'. `AbstractTrees.print_tree` outputs the left subtree first
and then below the right subtree.

`value` gets rounded to `sigdigits` significant digits.
"""
function AbstractTrees.printnode(io::IO, node::InfoNode)
function AbstractTrees.printnode(io::IO, node::InfoNode; sigdigits=4)
featval = round(node.node.featval; sigdigits)
if :featurenames ∈ keys(node.info)
print(io, node.info.featurenames[node.node.featid], " < ", node.node.featval)
print(io, node.info.featurenames[node.node.featid], " < ", featval)
else
print(io, "Feature: ", node.node.featid, " < ", node.node.featval)
print(io, "Feature: ", node.node.featid, " < ", featval)
end
end

function AbstractTrees.printnode(io::IO, leaf::InfoLeaf)
function AbstractTrees.printnode(io::IO, leaf::InfoLeaf; sigdigits=4)
dt_leaf = leaf.leaf
matches = findall(dt_leaf.values .== dt_leaf.majority)
match_count = length(matches)
val_count = length(dt_leaf.values)
if :classlabels ∈ keys(leaf.info)
@assert dt_leaf.majority isa Integer "classes must be represented as Integers"
print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)")
else
print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)")
print(io, dt_leaf.majority isa Integer ? "Class: " : "",
dt_leaf.majority, " ($match_count/$val_count)")
end
end
19 changes: 19 additions & 0 deletions test/miscellaneous/abstract_trees_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,23 @@ end
traverse_tree(leaf::InfoLeaf) = nothing

traverse_tree(wrapped_tree)
end

@testset "abstract_trees - test misuse" begin

@info("Test misuse of `classlabel` information")

@info("Create test data - a decision tree based on the iris data set")
features, labels = load_data("iris")
features = float.(features)
labels = string.(labels)
model = DecisionTreeClassifier()
fit!(model, features, labels)

@info("Try to replace the exisitng class labels")
class_labels = unique(labels)
dtree = model.root.node
wt = DecisionTree.wrap(dtree, (classlabels = class_labels,))
@test_throws AssertionError AbstractTrees.print_tree(wt)

end