In [18]:
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.core.display import display, HTML, Javascript
import json
    
def recurse(model, node=None, rules=None, feature_names=None, feature_type=None):
    if node is None:
        node = 0
        
    if rules is None:
        rules = {'name': 'node{:d}'.format(node),
                 'rule': 'root',
                 'values': list(model.tree_.value[node].flatten())}
        
    if model.tree_.feature[node] != _tree.TREE_UNDEFINED:
        i_feature = model.tree_.feature[node]            
        threshold = model.tree_.threshold[node]
        child_left = model.tree_.children_left[node]
        child_right = model.tree_.children_right[node]
        
        if feature_names is not None:
            name = feature_names[i_feature]
        else:
            name = i_feature
                  
        if feature_type == 'boolean':
            rule_left = '{} is False'.format(name)
            rule_right = '{} is True'.format(name)
        else:
            try:
                isboolean = feature_type[i_feature]
                if isboolean:
                    rule_left = '{} is False'.format(name)
                    rule_right = '{} is True'.format(name)
                else:
                    rule_left = '{} {} {:2.2g} '.format(name, '<', threshold)
                    rule_right = '{} {} {:2.2g} '.format(name, '>', threshold)
            except:
                rule_left = '{} {} {:2.2g} '.format(name, '<', threshold)
                rule_right = '{} {} {:2.2g} '.format(name, '>', threshold)
        
        rules['children'] = [{'name': 'node{:d}'.format(child_left),
                              'rule': rule_left,
                              'values': list(model.tree_.value[child_left].flatten()),
                              'impurity': model.tree_.impurity[child_left]},
                             {'name': 'node{:d}'.format(child_right),
                              'rule': rule_right,
                              'values': list(model.tree_.value[child_right].flatten()),
                              'impurity': model.tree_.impurity[child_left]}]
        
        recurse(model, child_left, rules['children'][0], feature_names=feature_names,
                feature_type=feature_type)
        recurse(model, child_right, rules['children'][1], feature_names=feature_names,
                feature_type=feature_type)
    return rules

In [19]:
%%javascript

require.config({
    paths: {
        d3: 'https://d3js.org/d3.v4.min',
        tree: 'http://localhost:8888/files/tree'
    }
});

<IPython.core.display.Javascript object>

In [20]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import _tree

data = sns.load_dataset('iris')
X, y = data.drop('species', axis=1), data['species']

model = DecisionTreeClassifier()
model.fit(X, y)

rules = recurse(model, feature_names=X.columns)
rules['class_names'] = list(model.classes_)

In [21]:
from string import Template

html_template = Template(
"""
<style>
    .chart {
        width: 800px;
        height: 600px;
        padding: 10px;
        display: block;
    }
</style>
<div id="chart1" class="chart"></div>
<script>
require(['d3', 'tree'], function(d3){
    var data = $data;
    var chart = document.getElementById("chart1");
    plot_tree(d3, data, chart);
});
</script>
""")

html_string = html_template.substitute(data=json.dumps(rules))
HTML(html_string)

In [22]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import _tree

data = sns.load_dataset('attention')
print(data.head())
X, y = data[['solutions', 'score']], data['attention']

model = DecisionTreeClassifier()
model.fit(X, y)

rules = recurse(model, feature_names=X.columns)
rules['class_names'] = list(model.classes_)

   Unnamed: 0  subject attention  solutions  score
0           0        1   divided          1    2.0
1           1        2   divided          1    3.0
2           2        3   divided          1    3.0
3           3        4   divided          1    5.0
4           4        5   divided          1    4.0


In [23]:
html_template = Template(
"""
<style>
    .chart {
        width: 800px;
        height: 600px;
        padding: 10px;
        display: block;
    }
</style>
<div id="chart2" class="chart"></div>
<script>
require(['d3', 'tree'], function(d3){
    var data = $data;
    var chart = document.getElementById("chart2");
    plot_tree(d3, data, chart);
});
</script>
""")

html_string = html_template.substitute(data=json.dumps(rules))
HTML(html_string)

In [28]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import _tree

data = sns.load_dataset('exercise')
data['time'] = pd.to_timedelta(data['time']).apply(lambda x: x.total_seconds()/60)
data = pd.concat([data, pd.get_dummies(data['diet'], drop_first=True, prefix='diet')],
                                       axis=1)
X, y = data[['diet_low fat', 'pulse', 'time']], data['kind']

model = DecisionTreeClassifier(max_depth=4)
model.fit(X, y)

rules = recurse(model, feature_names=X.columns)
rules['class_names'] = list(model.classes_)

In [29]:
html_template = Template(
"""
<style>
    .chart {
        width: 800px;
        height: 600px;
        padding: 10px;
        display: block;
    }
</style>
<div id="chart3" class="chart"></div>
<script>
require(['d3', 'tree'], function(d3){
    var data = $data;
    var chart = document.getElementById("chart3");
    plot_tree(d3, data, chart);
});
</script>
""")

html_string = html_template.substitute(data=json.dumps(rules))
HTML(html_string)