In [1]:
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.tree import DecisionTreeClassifier

import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings( "ignore", module = "matplotlib\..*" )

iris = load_iris()
X = iris.data
y = iris.target
clf = DecisionTreeClassifier(random_state=0)
clf = clf.fit(iris.data, iris.target)

## Decision Tree Classifier Internals

In [2]:
clf.classes_

array([0, 1, 2])

In [3]:
clf.feature_importances_

array([0.        , 0.01333333, 0.06405596, 0.92261071])

In [4]:
clf.decision_path

<bound method BaseDecisionTree.decision_path of DecisionTreeClassifier(random_state=0)>

### Tree representation

In [5]:
clf.tree_

<sklearn.tree._tree.Tree at 0x7fcd5b4610a0>

Number of nodes:

In [6]:
clf.tree_.node_count

17

Left node id ("-1" for leaf nodes):

In [7]:
clf.tree_.children_left

array([ 1, -1,  3,  4,  5, -1, -1,  8, -1, 10, -1, -1, 13, 14, -1, -1, -1],
      dtype=int64)

Right node id ("-1 for leaf nodes):

In [8]:
clf.tree_.children_right

array([ 2, -1, 12,  7,  6, -1, -1,  9, -1, 11, -1, -1, 16, 15, -1, -1, -1],
      dtype=int64)

Feature used for splitting at given nodes ("-2" means it's a leaf node):

In [9]:
clf.tree_.feature

array([ 3, -2,  3,  2,  3, -2, -2,  3, -2,  2, -2, -2,  2,  1, -2, -2, -2],
      dtype=int64)

Threshold applied to the given feature ("-2" if leaf node):

In [10]:
clf.tree_.threshold

array([ 0.80000001, -2.        ,  1.75      ,  4.95000005,  1.65000004,
       -2.        , -2.        ,  1.55000001, -2.        ,  5.45000005,
       -2.        , -2.        ,  4.85000014,  3.10000002, -2.        ,
       -2.        , -2.        ])

Number of samples per node:

In [11]:
clf.tree_.n_node_samples

array([150,  50, 100,  54,  48,  47,   1,   6,   3,   3,   2,   1,  46,
         3,   2,   1,  43], dtype=int64)

Weighted number of samples per node

In [12]:
clf.tree_.weighted_n_node_samples

array([150.,  50., 100.,  54.,  48.,  47.,   1.,   6.,   3.,   3.,   2.,
         1.,  46.,   3.,   2.,   1.,  43.])

Impurity at given node:

In [13]:
clf.tree_.impurity

array([0.66666667, 0.        , 0.5       , 0.16803841, 0.04079861,
       0.        , 0.        , 0.44444444, 0.        , 0.44444444,
       0.        , 0.        , 0.04253308, 0.44444444, 0.        ,
       0.        , 0.        ])

## node-parent representation

Associate each node with its parent:

In [14]:
node_ids = list(range(clf.tree_.node_count))
parents = [None]*clf.tree_.node_count
for parent_id, (child_left, child_right) in enumerate(zip(clf.tree_.children_left, clf.tree_.children_right)):
    if child_left >= 0:
        parents[child_left] = parent_id
    if child_right >= 0:
        parents[child_right] = parent_id

print("Nodes: ", node_ids)
print("Parents: ", parents)

Nodes:  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
Parents:  [None, 0, 0, 2, 3, 4, 4, 3, 7, 7, 9, 9, 2, 12, 13, 13, 12]


In [21]:
import plotly.express as px

px.sunburst(
    names=node_ids,
    parents=parents,
)