## References

* [Iris Data Set](https://archive.ics.uci.edu/ml/datasets/Iris) (UCI)

* [Visualizing a Decision Tree](https://www.youtube.com/watch?v=tNa99PG8hR8) (Josh Gordon, YouTube)

In [None]:
import matplotlib.pyplot as pp
import numpy as np
import seaborn as sb

from graphviz import Source
from problem import load_split_train_test
from problem import split_feature_target
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from sklearn.metrics import accuracy_score

In [None]:
np.random.seed(42)
%config InlineBackend.figure_format = 'svg'

In [None]:
data_train, data_test = load_split_train_test()

In [None]:
 def plot(correlation):
    mask = np.zeros_like(correlation, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True
    sb.heatmap(correlation, vmin=-1, vmax=1, center=0, mask=mask, cmap='coolwarm',
               square=True, linewidths=1, cbar=False)

pp.figure(figsize=(12, 2))
class_names = data_train['Species'].unique()
for i, species in enumerate(class_names):
    copy = data_train.copy()
    copy['Species'] = copy['Species'].map(lambda name: name == species).astype('bool')
    copy.rename(columns={'Species': species}, inplace=True)
    correlation = copy.corr()
    pp.subplot(1, len(class_names), i + 1)
    plot(correlation)

In [None]:
x_train, y_train = split_feature_target(data_train)

model = DecisionTreeClassifier(max_depth=3)
model.fit(x_train, y_train)

print('Score: {:.4}'.format(model.score(x_train, y_train)))

In [None]:
x_test, y_test = split_feature_target(data_test)
y_pred = model.predict(x_test)

print('Accuracy: {:.4}'.format(accuracy_score(y_test, y_pred)))

In [None]:
Source(export_graphviz(model,
                       out_file=None,
                       feature_names=x_train.columns,
                       class_names=y_train.cat.categories))