#### Imports

In [14]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, Image
plt.style.use('ggplot')
params = {'legend.fontsize': '18',
          'axes.labelsize': '20',
          'axes.labelweight': 'bold',
          'axes.titlesize':'20',
          'xtick.labelsize':'18',
          'ytick.labelsize':'18'}
plt.rcParams.update(params)

#### Functions

In [15]:
def gini(*args):
    """
    Calculates the gini impurity for binary class data.

    Parameters
    ----------
    *args : int
        Number of examples in class i

    Returns
    -------
    float
        The gini impurity
    """
    n = sum(args)  # total examples
    gini = 0
    for c in args:
        gini += (c / n) * (1 - (c / n))
    return gini


def split(x, y, splits):
    """
    Calculates the gini impurity for binary class data.

    Parameters
    ----------
    x : int
        Feature values
    y : int
        Corresponding target values
    splits : int
        Vector of splits to calculate gini criterion for

    Returns
    -------
    list
        List of gini impurity for each split
    """
    gini_splits = []
    for i in splits:
        mask = x < i
        gini_L = gini(sum(y[mask] == 0),
                      sum(y[mask] == 1))
        p_L = sum(mask) / len(mask)
        gini_R = gini(sum(y[~mask] == 0),
                      sum(y[~mask] == 1))
        p_R = sum(~mask) / len(mask)
        gini_splits.append(gini_L * p_L + gini_R * p_R)
        
    return gini_splits


def mk_fig():
    """
    Convenience function to plot figure canvas

    Returns
    -------
    fig, axes
        Figure and axes objects
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 7))
    axes[0].set_xlim(-1, 11)
    axes[0].set_ylim(-1, 11)
    axes[0].set_xlabel('X1')
    axes[0].set_ylabel('X2')
    axes[0].xaxis.label.set_color('#988ED5')
    axes[0].yaxis.label.set_color('#E8A2A5')
    axes[1].set_xlim(-1, 11)
    axes[1].set_ylim(0, 1)
    axes[1].set_xlabel('Feature Value')
    axes[1].set_ylabel('Gini Impurity')
    
    return fig, axes

#### Data

In [16]:
np.random.seed(3)
x1 = np.random.randint(0, 10, 10)
x2 = np.random.randint(0, 10, 10)
y = np.random.randint(0, 2, 10)

#### Splits

In [17]:
x1_unique = np.unique(x1)
x1_splits = x1_unique[:-1] + np.diff(x1_unique) / 2
x1_gini = split(x1, y, x1_splits)

x2_unique = np.unique(x2)
x2_splits = x2_unique[:-1] + np.diff(x2_unique) / 2
x2_gini = split(x2, y, x2_splits)

#### Create and save animations

##### First split

In [18]:
fig, axes = mk_fig()
mask = y == 0
j = len(x1_splits)
f = len(x1_splits) + len(x2_splits) + 1

def init():
    axes[0].scatter(x1[mask], x2[mask], s=100, c='#E24A33', label='Class 0')
    axes[0].scatter(x1[~mask], x2[~mask], s=100, c='#348ABD', label='Class 1')
    axes[0].legend(facecolor='#F0F0F0', framealpha=1)
    axes[1].plot(-1, -1, c='#988ED5', label='X1 splits')  # legend place-holders
    axes[1].plot(-1, -1, c='#E8A2A5', label='X2 splits')
    axes[1].legend(facecolor='#F0F0F0', framealpha=1)

def animate(i):
    if i <= j:  # plot x1 splits
        axes[0].vlines(x1_splits[:i], -1, 11, '#988ED5')
        axes[1].plot(x1_splits[:i], x1_gini[:i], '#988ED5', marker='o', ms=9)
    elif i < f:  # plot x2 splits
        axes[0].hlines(x2_splits[:(i-j)], -1, 11, '#E8A2A5')
        axes[1].plot(x2_splits[:(i-j)], x2_gini[:(i-j)], '#E8A2A5', marker='o', ms=9)
    else:  # highlight optimum split
        if min(x1_gini) <= min(x2_gini):
            k = np.argmin(x1_gini)
            axes[0].vlines(x1_splits[k], -1, 11, 'k', lw=3)
            axes[1].plot(x1_splits[k], x1_gini[k], 'o', mec='k', mfc='None', mew=3, ms=20)
        else:
            k = np.argmin(x1_gini)
            axes[1].plot(x1_splits[k], x1_gini[k], 'o', mec='k', mfc="None")
            axes[0].vlines(x1_splits[k], -1, 11, 'k', lw=2)
            
plt.close(fig)
ani = animation.FuncAnimation(fig,
                              animate,
                              init_func=init,
                              frames=f + 1,
                              interval=600)
ani.save('../gif/decision_tree/decision_tree_1.gif', writer='imagemagick', fps=1, dpi=150)
# HTML(ani.to_jshtml())

##### Second split

In [19]:
split_1 = 4
mask_1 = x1 > split_1

In [20]:
x1_unique = np.unique(x1[mask_1])
x1_splits = x1_unique[:-1] + np.diff(x1_unique) / 2
x1_gini = split(x1[mask_1], y[mask_1], x1_splits)

x2_unique = np.unique(x2[mask_1])
x2_splits = x2_unique[:-1] + np.diff(x2_unique) / 2
x2_gini = split(x2[mask_1], y[mask_1], x2_splits)

In [21]:
fig, axes = mk_fig()
mask = y == 0
j = len(x1_splits)
f = len(x1_splits) + len(x2_splits) + 1

def init():
    axes[0].scatter(x1[mask], x2[mask], s=100, c='#E24A33', label='Class 0')
    axes[0].scatter(x1[~mask], x2[~mask], s=100, c='#348ABD', label='Class 1')
    axes[0].legend(facecolor='#F0F0F0', framealpha=1)
    axes[1].plot(-1, -1, c='#988ED5', label='X1 splits')  # legend place-holders
    axes[1].plot(-1, -1, c='#E8A2A5', label='X2 splits')
    axes[1].legend(facecolor='#F0F0F0', framealpha=1)
    axes[0].vlines(split_1, -1, 11, 'k', lw=3)

def animate(i):
    if i <= j:  # plot x1 splits
        axes[0].vlines(x1_splits[:i], -1, 11, '#988ED5')
        axes[1].plot(x1_splits[:i], x1_gini[:i], '#988ED5', marker='o', ms=9)
    elif i < f:  # plot x2 splits
        axes[0].hlines(x2_splits[:(i-j)], split_1, 11, '#E8A2A5')
        axes[1].plot(x2_splits[:(i-j)], x2_gini[:(i-j)], '#E8A2A5', marker='o', ms=9)
    else:  # highlight optimum split
        if min(x1_gini) <= min(x2_gini):
            k = np.argmin(x1_gini)
            axes[0].vlines(x1_splits[k], -1, 11, 'k', lw=3)
            axes[1].plot(x1_splits[k], x1_gini[k], 'o', mec='k', mfc='None', mew=3, ms=20)
        else:
            k = np.argmin(x2_gini)
            axes[0].hlines(x2_splits[k], split_1, 11, 'k', lw=3)
            axes[1].plot(x2_splits[k], x2_gini[k], 'o', mec='k', mfc='None', mew=3, ms=20)
            
plt.close(fig)
ani = animation.FuncAnimation(fig,
                              animate,
                              init_func=init,
                              frames=f + 1,
                              interval=600)
ani.save('../gif/decision_tree/decision_tree_2.gif', writer='imagemagick', fps=1, dpi=150)
# HTML(ani.to_jshtml())

##### Third split

In [22]:
split_2 = 4.5
mask_2 = (x1 > split_1) & (x2 > split_2)

In [23]:
x1_unique = np.unique(x1[mask_2])
x1_splits = x1_unique[:-1] + np.diff(x1_unique) / 2
x1_gini = split(x1[mask_2], y[mask_2], x1_splits)

x2_unique = np.unique(x2[mask_2])
x2_splits = x2_unique[:-1] + np.diff(x2_unique) / 2
x2_gini = split(x2[mask_2], y[mask_2], x2_splits)

In [24]:
fig, axes = mk_fig()
mask = y == 0
j = len(x1_splits)
f = len(x1_splits) + len(x2_splits) + 1

def init():
    axes[0].scatter(x1[mask], x2[mask], s=100, c='#E24A33', label='Class 0')
    axes[0].scatter(x1[~mask], x2[~mask], s=100, c='#348ABD', label='Class 1')
    axes[0].legend(facecolor='#F0F0F0', framealpha=1)
    axes[1].plot(-1, -1, c='#988ED5', label='X1 splits')  # legend place-holders
    axes[1].plot(-1, -1, c='#E8A2A5', label='X2 splits')
    axes[1].legend(facecolor='#F0F0F0', framealpha=1)
    axes[0].vlines(split_1, -1, 11, 'k', lw=3)
    axes[0].hlines(split_2, split_1, 11, 'k', lw=3)

def animate(i):
    if i <= j:  # plot x1 splits
        axes[0].vlines(x1_splits[:i], split_2, 11, '#988ED5')
        axes[1].plot(x1_splits[:i], x1_gini[:i], '#988ED5', marker='o', ms=9)
    elif i < f:  # plot x2 splits
        axes[0].hlines(x2_splits[:(i-j)], split_1, 11, '#E8A2A5')
        axes[1].plot(x2_splits[:(i-j)], x2_gini[:(i-j)], '#E8A2A5', marker='o', ms=9)
    else:  # highlight optimum split
        if min(x1_gini) <= min(x2_gini):
            k = np.argmin(x1_gini)
            axes[0].vlines(x1_splits[k], split_2, 11, 'k', lw=3)
            axes[1].plot(x1_splits[k], x1_gini[k], 'o', mec='k', mfc='None', mew=3, ms=20)
        else:
            k = np.argmin(x2_gini)
            axes[0].hlines(x2_splits[k], split_1, 11, 'k', lw=3)
            axes[1].plot(x2_splits[k], x2_gini[k], 'o', mec='k', mfc='None', mew=3, ms=20)
            
plt.close(fig)
ani = animation.FuncAnimation(fig,
                              animate,
                              init_func=init,
                              frames=f + 1,
                              interval=600)
ani.save('../gif/decision_tree/decision_tree_3.gif', writer='imagemagick', fps=1, dpi=150)
# HTML(ani.to_jshtml())

##### Final tree

In [25]:
split_3 = 6.5

In [26]:
fig, axes = mk_fig()
mask = y == 0
j = len(x1_splits)
f = len(x1_splits) + len(x2_splits) + 1

def init():
    axes[0].scatter(x1[mask], x2[mask], s=100, c='#E24A33', label='Class 0')
    axes[0].scatter(x1[~mask], x2[~mask], s=100, c='#348ABD', label='Class 1')
    axes[0].legend(facecolor='#F0F0F0', framealpha=1)
    axes[1].plot(-1, -1, c='#988ED5', label='X1 splits')  # legend place-holders
    axes[1].plot(-1, -1, c='#E8A2A5', label='X2 splits')
    axes[1].legend(facecolor='#F0F0F0', framealpha=1)
    axes[0].vlines(split_1, -1, 11, 'k', lw=3)
    axes[0].hlines(split_2, split_1, 11, 'k', lw=3)
    axes[0].hlines(split_3, split_1, 11, 'k', lw=3)

def animate(i):
    return
            
plt.close(fig)
ani = animation.FuncAnimation(fig,
                              animate,
                              init_func=init,
                              frames=1,
                              interval=600)
ani.save('../gif/decision_tree/decision_tree_4.gif', writer='imagemagick', fps=1, dpi=150)
# HTML(ani.to_jshtml())

#### View animations

In [27]:
Image(url='../gif/decision_tree/decision_tree_1.gif')

In [28]:
Image(url='../gif/decision_tree/decision_tree_2.gif')

In [29]:
Image(url='../gif/decision_tree/decision_tree_3.gif')

In [30]:
Image(url='../gif/decision_tree/decision_tree_4.gif')