Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graphviz #52

Merged
merged 6 commits into from
Apr 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions stree/Splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,28 @@ def make_predictor(self):
except IndexError:
self._class = None

def graph(self):
"""
Return a string representing the node in graphviz format
"""
output = ""
count_values = np.unique(self._y, return_counts=True)
if self.is_leaf():
output += (
f'N{id(self)} [shape=box style=filled label="'
f"class={self._class} impurity={self._impurity:.3f} "
f'classes={count_values[0]} samples={count_values[1]}"];\n'
)
else:
output += (
f'N{id(self)} [label="#features={len(self._features)} '
f"classes={count_values[0]} samples={count_values[1]} "
f'({sum(count_values[1])})" fontcolor=black];\n'
)
output += f"N{id(self)} -> N{id(self.get_up())} [color=black];\n"
output += f"N{id(self)} -> N{id(self.get_down())} [color=black];\n"
return output

def __str__(self) -> str:
count_values = np.unique(self._y, return_counts=True)
if self.is_leaf():
Expand Down
17 changes: 17 additions & 0 deletions stree/Strees.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,23 @@ def __iter__(self) -> Siterator:
tree = None
return Siterator(tree)

def graph(self, title="") -> str:
"""Graphviz code representing the tree

Returns
-------
str
graphviz code
"""
output = (
"digraph STree {\nlabel=<STree "
f"{title}>\nfontsize=30\nfontcolor=blue\nlabelloc=t\n"
)
for node in self:
output += node.graph()
output += "}\n"
return output

def __str__(self) -> str:
"""String representation of the tree

Expand Down
2 changes: 1 addition & 1 deletion stree/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.2.3"
__version__ = "1.2.4"
59 changes: 59 additions & 0 deletions stree/tests/Stree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def test_predict_feature_dimensions(self):

# Tests of score
def test_score_binary(self):
"""Check score for binary classification."""
X, y = load_dataset(self._random_state)
accuracies = [
0.9506666666666667,
Expand All @@ -380,6 +381,7 @@ def test_score_binary(self):
self.assertAlmostEqual(accuracy_expected, accuracy_score)

def test_score_max_features(self):
"""Check score using max_features."""
X, y = load_dataset(self._random_state)
clf = Stree(
kernel="liblinear",
Expand All @@ -391,6 +393,7 @@ def test_score_max_features(self):
self.assertAlmostEqual(0.9453333333333334, clf.score(X, y))

def test_bogus_splitter_parameter(self):
"""Check that bogus splitter parameter raises exception."""
clf = Stree(splitter="duck")
with self.assertRaises(ValueError):
clf.fit(*load_dataset())
Expand Down Expand Up @@ -446,6 +449,7 @@ def test_multiclass_classifier_integrity(self):
self.assertListEqual([47], resdn[1].tolist())

def test_score_multiclass_rbf(self):
"""Test score for multiclass classification with rbf kernel."""
X, y = load_dataset(
random_state=self._random_state,
n_classes=3,
Expand All @@ -463,6 +467,7 @@ def test_score_multiclass_rbf(self):
self.assertEqual(1.0, clf2.fit(X, y).score(X, y))

def test_score_multiclass_poly(self):
"""Test score for multiclass classification with poly kernel."""
X, y = load_dataset(
random_state=self._random_state,
n_classes=3,
Expand All @@ -484,6 +489,7 @@ def test_score_multiclass_poly(self):
self.assertEqual(1.0, clf2.fit(X, y).score(X, y))

def test_score_multiclass_liblinear(self):
"""Test score for multiclass classification with liblinear kernel."""
X, y = load_dataset(
random_state=self._random_state,
n_classes=3,
Expand All @@ -509,6 +515,7 @@ def test_score_multiclass_liblinear(self):
self.assertEqual(1.0, clf2.fit(X, y).score(X, y))

def test_score_multiclass_sigmoid(self):
"""Test score for multiclass classification with sigmoid kernel."""
X, y = load_dataset(
random_state=self._random_state,
n_classes=3,
Expand All @@ -529,6 +536,7 @@ def test_score_multiclass_sigmoid(self):
self.assertEqual(0.9662921348314607, clf2.fit(X, y).score(X, y))

def test_score_multiclass_linear(self):
"""Test score for multiclass classification with linear kernel."""
warnings.filterwarnings("ignore", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
X, y = load_dataset(
Expand Down Expand Up @@ -556,11 +564,13 @@ def test_score_multiclass_linear(self):
self.assertEqual(1.0, clf2.fit(X, y).score(X, y))

def test_zero_all_sample_weights(self):
"""Test exception raises when all sample weights are zero."""
X, y = load_dataset(self._random_state)
with self.assertRaises(ValueError):
Stree().fit(X, y, np.zeros(len(y)))

def test_mask_samples_weighted_zero(self):
"""Check that the weighted zero samples are masked."""
X = np.array(
[
[1, 1],
Expand Down Expand Up @@ -588,6 +598,7 @@ def test_mask_samples_weighted_zero(self):
self.assertEqual(model2.score(X, y, w), 1)

def test_depth(self):
"""Check depth of the tree."""
X, y = load_dataset(
random_state=self._random_state,
n_classes=3,
Expand All @@ -603,6 +614,7 @@ def test_depth(self):
self.assertEqual(4, clf.depth_)

def test_nodes_leaves(self):
"""Check number of nodes and leaves."""
X, y = load_dataset(
random_state=self._random_state,
n_classes=3,
Expand All @@ -622,6 +634,7 @@ def test_nodes_leaves(self):
self.assertEqual(6, leaves)

def test_nodes_leaves_artificial(self):
"""Check leaves of artificial dataset."""
n1 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test1")
n2 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test2")
n3 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test3")
Expand All @@ -640,12 +653,14 @@ def test_nodes_leaves_artificial(self):
self.assertEqual(2, leaves)

def test_bogus_multiclass_strategy(self):
"""Check invalid multiclass strategy."""
clf = Stree(multiclass_strategy="other")
X, y = load_wine(return_X_y=True)
with self.assertRaises(ValueError):
clf.fit(X, y)

def test_multiclass_strategy(self):
"""Check multiclass strategy."""
X, y = load_wine(return_X_y=True)
clf_o = Stree(multiclass_strategy="ovo")
clf_r = Stree(multiclass_strategy="ovr")
Expand All @@ -655,6 +670,7 @@ def test_multiclass_strategy(self):
self.assertEqual(0.9269662921348315, score_r)

def test_incompatible_hyperparameters(self):
"""Check incompatible hyperparameters."""
X, y = load_wine(return_X_y=True)
clf = Stree(kernel="liblinear", multiclass_strategy="ovo")
with self.assertRaises(ValueError):
Expand All @@ -664,5 +680,48 @@ def test_incompatible_hyperparameters(self):
clf.fit(X, y)

def test_version(self):
"""Check STree version."""
clf = Stree()
self.assertEqual(__version__, clf.version())

def test_graph(self):
"""Check graphviz representation of the tree."""
X, y = load_wine(return_X_y=True)
clf = Stree(random_state=self._random_state)

expected_head = (
"digraph STree {\nlabel=<STree >\nfontsize=30\n"
"fontcolor=blue\nlabelloc=t\n"
)
expected_tail = (
' [shape=box style=filled label="class=1 impurity=0.000 '
'classes=[1] samples=[1]"];\n}\n'
)
self.assertEqual(clf.graph(), expected_head + "}\n")
clf.fit(X, y)
computed = clf.graph()
computed_head = computed[: len(expected_head)]
num = -len(expected_tail)
computed_tail = computed[num:]
self.assertEqual(computed_head, expected_head)
self.assertEqual(computed_tail, expected_tail)

def test_graph_title(self):
X, y = load_wine(return_X_y=True)
clf = Stree(random_state=self._random_state)
expected_head = (
"digraph STree {\nlabel=<STree Sample title>\nfontsize=30\n"
"fontcolor=blue\nlabelloc=t\n"
)
expected_tail = (
' [shape=box style=filled label="class=1 impurity=0.000 '
'classes=[1] samples=[1]"];\n}\n'
)
self.assertEqual(clf.graph("Sample title"), expected_head + "}\n")
clf.fit(X, y)
computed = clf.graph("Sample title")
computed_head = computed[: len(expected_head)]
num = -len(expected_tail)
computed_tail = computed[num:]
self.assertEqual(computed_head, expected_head)
self.assertEqual(computed_tail, expected_tail)