## References

* [Iris Data Set](https://archive.ics.uci.edu/ml/datasets/Iris) (UCI)

* [Visualizing a Decision Tree](https://www.youtube.com/watch?v=tNa99PG8hR8) (Josh Gordon, YouTube)

In [None]:
import matplotlib.pyplot as pp
import numpy as np
import seaborn as sb

from problem import load
from sklearn.model_selection import train_test_split as split
from sklearn.tree import DecisionTreeClassifier as Model
from sklearn.metrics import accuracy_score as metric

np.random.seed(42)
%config InlineBackend.figure_format = 'svg'

In [None]:
data_train, data_test = split(load(), test_size=0.3)

In [None]:
 def plot(correlation):
    mask = np.zeros_like(correlation, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True
    sb.heatmap(correlation, vmin=-1, vmax=1, center=0, mask=mask, cmap='coolwarm',
               square=True, linewidths=1, cbar=False)

pp.figure(figsize=(12, 2))
class_names = data_train['Species'].unique()
for i, species in enumerate(class_names):
    copy = data_train.copy()
    copy['Species'] = copy['Species'].map(lambda name: name == species).astype('bool')
    copy.rename(columns={'Species': species}, inplace=True)
    correlation = copy.corr()
    pp.subplot(1, len(class_names), i + 1)
    plot(correlation)

In [None]:
y_train = data_train.pop('Species')
x_train = data_train

model = Model()
model.fit(x_train, y_train)

print('Score: {:.4}'.format(model.score(x_train, y_train)))

In [None]:
y_test = data_test.pop('Species')
x_test = data_test
y_pred = model.predict(x_test)

print('Accuracy: {:.4}'.format(metric(y_test, y_pred)))