## DT Intuition:

- How splitting works?
- How to measure the impact of split
- Examples from both classification and regression setting

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
vals = np.array([['36','M','P'],
    ['32','S','U'],
    ['38','M','P'],
    ['40','S','U'],
    ['44','M','P'],
    ['56','M','P'],
    ['58','S','U'],
    ['30','S','P'],
    ['28','M','U'],
    ['26','M','U']])

In [3]:
df = pd.DataFrame(vals,columns=["Age","Gender","Profitability"])

In [4]:
df['Age']=df['Age'].astype('int')

In [5]:
import ipywidgets as widgets

In [6]:
from ipywidgets import interact, interactive, fixed, interact_manual

In [7]:
def f(x,col):
    cutoff = x
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    try:
        df[df[col]>cutoff]['Profitability'].value_counts(normalize=False).plot(kind='bar',ax=axes[0])
        df[df[col]<=cutoff]['Profitability'].value_counts(normalize=False).plot(kind='bar',ax=axes[1])
        axes[0].set_title(f'{col}>={cutoff}')
        axes[1].set_title(f'{col}<{cutoff}')
        plt.tight_layout()
    except:
        print("Can't compute")
    return plt.show()

In [8]:
interactive(f,col='Age',x=widgets.IntSlider(min=0,max=df.Age.max()))

interactive(children=(IntSlider(value=0, description='x', max=58), Text(value='Age', description='col'), Outpu…

In [9]:
def compute_node_entropy(n1,n2):
    p1 = (n1)/(n1+n2)
    p2 = (n2)/(n1+n2)
    l1 = np.log2(p1)
    l2 = np.log2(p2)
    e = p1*l1+p2*l2
    return -1*e
def compute_node_metrics_right(col,cutoff):
    p = df[df[col]>cutoff]['Profitability'].value_counts(normalize=False).get("P")
    u = df[df[col]>cutoff]['Profitability'].value_counts(normalize=False).get("U")
    return (p,u)
def compute_node_metrics_left(col,cutoff):
    p = df[df[col]<=cutoff]['Profitability'].value_counts(normalize=False).get("P")
    u = df[df[col]<=cutoff]['Profitability'].value_counts(normalize=False).get("U")
    return (p,u)
def plot_entropy(x,col):
    cutoff = x
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    try:
        df[df[col]>=cutoff]['Profitability'].value_counts(normalize=False).plot(kind='bar',ax=axes[0])
        df[df[col]<cutoff]['Profitability'].value_counts(normalize=False).plot(kind='bar',ax=axes[1])
        n1,n2 = compute_node_metrics_right(col,cutoff)
        tot1 = n1+n2
        p11 = round(n1/(tot1),2)
        e1 = compute_node_entropy(n1,n2)
        n1,n2 = compute_node_metrics_left(col,cutoff)
        e2 = compute_node_entropy(n1,n2)
        tot2 = n1+n2
        p12 = round(n1/(tot2),2)
        entropy = (tot1/(tot1+tot2))*e1+((tot2)/(tot1+tot2))*e2
        axes[0].set_title(f'{col}>={cutoff}, Entropy: {round(entropy,4)}, Pr(P)={p11}')
        axes[1].set_title(f'{col}<{cutoff}, Entropy: {round(entropy,4)}, Pr(P)={p12}')
        plt.tight_layout()
    except:
        print("Can't compute")
    return plt.show()
        

In [10]:
interact(plot_entropy,col='Age',x=widgets.IntSlider(min=0,max=df.Age.max()))

interactive(children=(IntSlider(value=0, description='x', max=58), Text(value='Age', description='col'), Outpu…

<function __main__.plot_entropy(x, col)>

### Regression Example

In [9]:
from io import StringIO
reg = pd.read_table(StringIO('''Country	Rim	Tires	Type	Price
Japan	R14	195/60	Small	11.95
Japan	R15	205/60	Medium	24.76
Germany	R15	205/60	Medium	26.9
Germany	R14	175/70	Compact	18.9
Germany	R14	195/65	Compact	24.65
Germany	R15	225/60	Medium	33.2
USA	R14	185/75	Medium	13.15
USA	R14	205/75	Large	20.225
USA	R14	205/75	Large	16.145
USA	R15	205/70	Medium	23.04''')
,sep="\t")

In [10]:
reg

Unnamed: 0,Country,Rim,Tires,Type,Price
0,Japan,R14,195/60,Small,11.95
1,Japan,R15,205/60,Medium,24.76
2,Germany,R15,205/60,Medium,26.9
3,Germany,R14,175/70,Compact,18.9
4,Germany,R14,195/65,Compact,24.65
5,Germany,R15,225/60,Medium,33.2
6,USA,R14,185/75,Medium,13.15
7,USA,R14,205/75,Large,20.225
8,USA,R14,205/75,Large,16.145
9,USA,R15,205/70,Medium,23.04


In [11]:
from pandas.plotting import table    

In [12]:
def split(x,col):
    t1 = reg[reg[col]==x]
    t2 = reg[reg[col]!=x]
    avg1 = t1['Price'].mean().round(3)
    avg2 = t2['Price'].mean().round(3)
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    axes[0].xaxis.set_visible(False)
    axes[1].xaxis.set_visible(False)
    axes[0].yaxis.set_visible(False)
    axes[1].yaxis.set_visible(False) 
    axes[0].set(frame_on=False)
    axes[1].set(frame_on=False)
    axes[0].table(cellText=t1.values, colLabels=t1.keys(), loc='top')
    axes[1].table(cellText=t2.values, colLabels=t2.keys(), loc='top')
    axes[0].text(x=axes[0].get_xticks()[0],y=axes[0].get_yticks()[-2],s=f'{col}={x}: Avg Price: {round(avg1,4)}')
    axes[1].text(x=axes[1].get_xticks()[0],y=axes[1].get_yticks()[-2],s=f'{col}!={x}: Avg Price: {round(avg2,4)}')
    plt.tight_layout()
    return plt.show()
    

In [13]:
column = "Type"
default = reg[column].unique().tolist()[0]
x=interact(split,col=column,x=widgets.Dropdown(
    options=reg[column].unique().tolist(),
    value=default,
    description=column,
    disabled=False))

interactive(children=(Dropdown(description='Type', options=('Small', 'Medium', 'Compact', 'Large'), value='Sma…

## Regression Purity Metrics

In [14]:
def process_subset(t,col,x):
    t1 = t[t[col]==x]
    t2 = t[t[col]!=x]
    avg1 = t1.Price.mean().round(3)
    avg2 = t2.Price.mean().round(3)
    t1 = t1[[col,'Price']]
    t2 = t2[[col,'Price']]
    t1['Pred']=avg1
    t2['Pred']=avg2
    t1['Error^2'] = (t1['Price']-t1['Pred'])**2
    t2['Error^2'] = (t2['Price']-t2['Pred'])**2
    t1['MSE'] = t1['Error^2'].mean().round(2)
    t2['MSE'] = t2['Error^2'].mean().round(2)
    t1['Error^2'] = t1['Error^2'].round(3)
    t2['Error^2'] = t2['Error^2'].round(3)
    mse1 = t1['MSE'].iloc[0]
    mse2 = t2['MSE'].iloc[0]
    w1 = t1.shape[0]
    w2 = t2.shape[0]
    mse = (w1/(w1+w2))*(mse1)+(w2/(w1+w2))*(mse2)
    return t1,t2,mse,mse1,mse2,w1,w2

In [15]:
def split(x,col):
    t1,t2,mse,mse1,mse2,w1,w2=process_subset(reg,col,x)
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    axes[0].xaxis.set_visible(False)
    axes[1].xaxis.set_visible(False)
    axes[0].yaxis.set_visible(False)
    axes[1].yaxis.set_visible(False) 
    axes[0].set(frame_on=False)
    axes[1].set(frame_on=False)
    axes[0].table(cellText=t1.values, colLabels=t1.keys(), loc='top')
    axes[1].table(cellText=t2.values, colLabels=t2.keys(), loc='top')
    axes[0].text(x=axes[0].get_xticks()[0],y=axes[0].get_yticks()[-2],s=f'''{col}={x}: MSE: {round(mse1,4)}, W: {w1}, Wtd MSE: {round(mse,2)}''')
    axes[1].text(x=axes[1].get_xticks()[0],y=axes[1].get_yticks()[-2],s=f'{col}!={x}: MSE: {round(mse2,4)}, W {w2}, Wtd MSE: {round(mse,2)}')
    plt.tight_layout()
    return plt.show()

In [17]:
column = "Country"
default = reg[column].unique().tolist()[0]
x=interact(split,col=column,x=widgets.Dropdown(
    options=reg[column].unique().tolist(),
    value=default,
    description=column,
    disabled=False))

interactive(children=(Dropdown(description='Country', options=('Japan', 'Germany', 'USA'), value='Japan'), Tex…