In [1]:
import numpy as np
from sklearn import datasets
from sklearn import tree
from sklearn.tree import _tree
import networkx as nx

## Loading a sample dataset

In [2]:
# Load iris
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Build decision tree classifier
dt = tree.DecisionTreeClassifier(criterion='entropy')
dt.fit(X, y)

DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

In [13]:
tree_ = dt.tree_

In [16]:
feature_names = list(iris.feature_names)

In [17]:
feature_name = [
    feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
    for i in tree_.feature
]

In [19]:
tree_.node_count

17

In [23]:
node_features = ['undefined' if f == _tree.TREE_UNDEFINED else feature_names[f] 
                     for f in tree_.feature] 

In [24]:
node_features

['petal length (cm)',
 'undefined',
 'petal width (cm)',
 'petal length (cm)',
 'petal width (cm)',
 'undefined',
 'undefined',
 'petal width (cm)',
 'undefined',
 'sepal length (cm)',
 'undefined',
 'undefined',
 'petal length (cm)',
 'sepal width (cm)',
 'undefined',
 'undefined',
 'undefined']

In [3]:
def tree_to_code(tree, feature_names):
    '''
    Outputs a decision tree model as a Python function

    Parameters:
    -----------
    tree: decision tree model
        The decision tree to represent as a function
    feature_names: list
        The feature names of the dataset used for building the decision tree
    '''

    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print("def tree({}):".format(", ".join(feature_names)))

    g = nx.DiGraph()
    
    def recurse(node, depth, g):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            node_name = "{}\n<=\n{:.2f}\n".format(name, threshold) + ' NODE {}'.format(node)
            g.add_node(node_name)
            print("{}if {} <= {}:".format(indent, name, threshold))
            cl_name = recurse(tree_.children_left[node], depth + 1, g)
            g.add_edge(node_name, cl_name, name='yes')
            print("{}else:".format(indent, name, threshold))
            cr_name = recurse(tree_.children_right[node], depth + 1, g)
            g.add_edge(node_name, cr_name, name='no')
        else:
            node_name = "return {}".format(tree_.value[node]) + 'NODE{}'.format(node)
            g.add_node(node_name)
            print("{}return {}".format(indent, tree_.value[node]))
        return node_name
    recurse(0, 1, g)
    relabel_dict = {}
    order_dict = {}
    for n in g.nodes():
        relabel_dict[n], order = n.split('NODE')
        order_dict[relabel_dict[n]] = int(order)
    
    return g, order_dict

In [4]:
g, order_dict = tree_to_code(dt, list(iris.feature_names))

def tree(sepal length (cm), sepal width (cm), petal length (cm), petal width (cm)):
  if petal length (cm) <= 2.450000047683716:
    return [[ 50.   0.   0.]]
  else:
    if petal width (cm) <= 1.75:
      if petal length (cm) <= 4.949999809265137:
        if petal width (cm) <= 1.6500000953674316:
          return [[  0.  47.   0.]]
        else:
          return [[ 0.  0.  1.]]
      else:
        if petal width (cm) <= 1.5499999523162842:
          return [[ 0.  0.  3.]]
        else:
          if sepal length (cm) <= 6.949999809265137:
            return [[ 0.  2.  0.]]
          else:
            return [[ 0.  0.  1.]]
    else:
      if petal length (cm) <= 4.850000381469727:
        if sepal width (cm) <= 3.0999999046325684:
          return [[ 0.  0.  2.]]
        else:
          return [[ 0.  1.  0.]]
      else:
        return [[  0.   0.  43.]]


In [5]:
def get_root(g):
    root = [node for node, deg in g.degree_iter() if deg == 2]
    if len(root) != 1:
        raise Exception('something wrong')
    else:
        return root[0]

In [9]:
def get_node_positions(g, dx=50, dy=3, root_coord=(0, 1), eps=0.5):
    '''Define node positions for the graph g that represents the decision tree.
    '''
    # Set up values and call get_node_positions_rec.
    parent = None
    node = None
    pos_dict = {}
    get_node_positions_rec(g, parent, node, pos_dict, 
                           dx=dx, dy=dy, root_coord=root_coord, eps=eps)
    return pos_dict
    

def get_node_positions_rec(g, parent, node, pos_dict, 
                           dx=1, dy=1, root_coord=(0, 1), eps=0.5):
    '''Recursively defines node positions for the graph g that represents a decision tree.
    NOTE: Do not call this function directly. Use get_node_positions instead.
    '''
    if parent is None:
        node = get_root(g)
        x, y = root_coord
    else:
        x, y = pos_dict[parent]
        y = y - dy
        edge = g.get_edge_data(parent, node)
        if edge['name'] == 'yes':
            x = x + dx
        else:
            x = x - dx
    pos_dict[node] = np.array((x, y))
    
    children = [dest for orig, dest in g.edges() if orig == node]
    for child in children:
        get_node_positions_rec(g, node, child, pos_dict, dx=dx*eps)

## Defining a layout for plotting

In [10]:
def fun_layout(g, scale=1, center=(0,0)):
    pos = get_node_positions(g)
    xy = pos.values()
    xy = np.array(list(xy))
    mean = xy.mean(axis=0)
    max_ = np.abs(xy).max(axis=0)
    xy = (xy - mean + center)*scale/max_
    i = 0
    for k, v in pos.items():
        pos[k] = xy[i]
        i += 1
    return pos

## Using bokeh to display the tree

In [11]:
import networkx as nx

from bokeh.io import show, output_file, output_notebook
from bokeh.models import Plot, Range1d, MultiLine, Circle, HoverTool, TapTool, BoxSelectTool, WheelZoomTool
from bokeh.models.graphs import from_networkx, NodesAndLinkedEdges, EdgesAndLinkedNodes
from bokeh.palettes import Spectral4

G = g
plot = Plot(plot_width=400, plot_height=400,
            x_range=Range1d(-1.1,1.1), y_range=Range1d(-1.1,1.1))
plot.title.text = "Graph Interaction Demonstration"
hover = HoverTool(tooltips=[("Name:", "@name")])
plot.add_tools(hover, TapTool(), BoxSelectTool(), WheelZoomTool())

graph_renderer = from_networkx(G, fun_layout, scale=1, center=(0,0))

graph_renderer.node_renderer.glyph = Circle(size=15, fill_color=Spectral4[0])
graph_renderer.node_renderer.selection_glyph = Circle(size=15, fill_color=Spectral4[2])
graph_renderer.node_renderer.hover_glyph = Circle(size=15, fill_color=Spectral4[1])
graph_renderer.node_renderer.data_source.data['name'] = [e.split('NODE')[0] for e in list(g.nodes())]


graph_renderer.edge_renderer.glyph = MultiLine(line_color="#CCCCCC", line_alpha=0.8, line_width=5)
graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=5)
graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=5)

graph_renderer.selection_policy = NodesAndLinkedEdges()
graph_renderer.inspection_policy = NodesAndLinkedEdges()

plot.renderers.append(graph_renderer)

#output_file("interactive_graphs.html")
output_notebook()
show(plot)

In [76]:
import bokeh.core.state

ModuleNotFoundError: No module named 'bokeh.core.state'

In [73]:
help(output_file)

Help on function output_file in module bokeh.io.output:

output_file(filename, title='Bokeh Plot', mode='cdn', root_dir=None)
    Configure the default output state to generate output saved
    to a file when :func:`show` is called.
    
    Does not change the current Document from curdoc(). File and notebook
    output may be active at the same time, so e.g., this does not clear the
    effects of ``output_notebook()``.
    
    Args:
        filename (str) : a filename for saving the HTML document
    
        title (str, optional) : a title for the HTML document (default: "Bokeh Plot")
    
        mode (str, optional) : how to include BokehJS (default: ``'cdn'``)
            One of: ``'inline'``, ``'cdn'``, ``'relative(-dev)'`` or
            ``'absolute(-dev)'``. See :class:`bokeh.resources.Resources` for more details.
    
        root_dir (str, optional) : root directory to use for 'absolute' resources. (default: None)
            This value is ignored for other resource types,

In [None]:
## 