## Imports, Configurations

In [None]:
import collections

import pydotplus

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split


%matplotlib inline

In [None]:
TREE_IMAGE_PATH = 'tree.png'
RANDOM_STATE = 51

In [None]:
def export_tree(model, local_image_path):
    # Export as dot file
    dot_data = export_graphviz(model, out_file=None, 
                    feature_names = iris.feature_names,
                    class_names = iris.target_names,
                    rounded = True, proportion = False, 
                    precision = 2, filled = True)


    graph = pydotplus.graph_from_dot_data(dot_data)

    colors = ('turquoise', 'orange')
    edges = collections.defaultdict(list)

    for edge in graph.get_edge_list():
        edges[edge.get_source()].append(int(edge.get_destination()))

    for edge in edges:
        edges[edge].sort()    
        for i in range(2):
            dest = graph.get_node(str(edges[edge][i]))[0]
            dest.set_fillcolor(colors[i])
    graph.write_png(local_image_path)

## Load Data and Split Train/Test

In [None]:
iris = load_iris()

X_train, X_test, y_train, y_test = train_test_split(
    iris.data, iris.target, test_size=0.2, random_state = RANDOM_STATE)

## Creation and Training of the Decision Tree Model

In [None]:
max_tree_depth = 67
model = DecisionTreeClassifier(max_depth=max_tree_depth)
model.fit(X_train, y_train)

## Score Evaluation

In [None]:
accuracy = model.score(X_test, y_test)
accuracy

## Decision Tree Visualization

In [None]:
export_tree(model, TREE_IMAGE_PATH)

# Display in python
import matplotlib.pyplot as plt
plt.figure(figsize = (14, 18))
plt.imshow(plt.imread(TREE_IMAGE_PATH))
plt.axis('off');
plt.show();