**Building a Binary Decision Tree from Data**

Suppose we have a dataset with some predictor variables $x_1,x_2,\ldots,x_k$ and binary response variable $Y.$ For example, for the mortgage dataset we have predictors (location, principal, interest rate, credit score) and we want to predict the result (default, non-default). Our datset is _flat_/_rectangular_, with N rows and $k+1$ columns, with one column for each variable and one row for each _observation_ (mortgage loan).

We wish to use these data to build a decision tree in which the functions at the nodes are functions of the the predictor variables. 

Assume $Y$ takes the value 0 or 1. 

The predictor variables can be categorical or continuous.

**Recursive Description of the Algorithm**

The algorithm for building the tree has a recursive definition.

We begin by creating a root node with the entire dataset as attached to that node.

We start at the root node.

Whenever we visit a node, we compute and store at the node the following information about the dataset attached to the node:

a) the number of observations in the dataset attached to that node, and 

b) the proportion of observations in each class (Y=0 or 1)

Next, we take one of the following actions:

1) Find a function (splitting function) that splits/partitions the data into two pieces. The two pieces should look different in the sense that one piece tends to have a different proportions of observations with Y=1, and the pieces are each sufficiently large. Call these pieces left piece and right piece. If such a _split_ can be found, we attach the splitting function to the node, spawn two children of the current node, attach piece \#1 to the left child  and piece \#2 to the right child, and visit each of those children. 

or

2) Determine that a splitting function cannot be found so the current node becomes a leaf node (no children). 

The splitting function should be a function of $x_1,x_2,...,x_k$ that returns a value of "left" or "right". Typically, this function is taken to be a function of only one of the $x_i$'s and 

a) for a continuous variable this is a function of the form:  if $xi < c$ return("left") else return("right)

b) for a categorical variable, this function takes the form: if $x \in I$ return("left") else return("right") where $I$ is a subset of the values that the variable can take.

**How to Classify/Predict $Y$ for a New Observation**

Given a new observation with predictor variables $x1,x2,\ldots,xk$ we start at the root node and for each node we visit, we do one of the following:

a) if the node has chilren, apply the splitting function at the current node to determine which child node to visit next, or

b) if the current node is a leaf node, return the proportion $p_1$ of observations with $Y=1$ at that node

Finally, we predict $Y=1$ if $p_1$ exceeds some pre-determined threshold.

**Splitting Criterion - how to find a good splitting function**

We need a criterion for deciding on a good splitting function. There are several possibilities. We focus here on the Gini index.

Given a categorical variable taking K possible values and a set of data for that variable with proportions $p_1,p_2,\ldots,p_K$ of values in each category we define the Gini index by

$$ G = \sum_{i=1}^k p_i(1-p_i)$$

This number has the following interpretation. If we pick a data point at random, and classify it as class 1 with probability $p_1,$ class 2 with probability $p_2,$ etc., $G$ is the probability of incorrectly classifying that observation.

$G$ is a measure of _impurity_ of the dataset with regard to the class variable, if one of the $p_i$ is one and the others are zero (perfect _purity_) we get $G=0.$ In in the case of a binary class variable, with $p_1=p_2=1/2$ we get $G =1/2.$ 

When we split out dataset into two pieces, we would like the two child datasets to be as pure as possible so we try to minimize the quantity

$$ N_{left} G_{left} + N_{right} G_{right}$$

that is, the weighted sum of the impurities of the child datasets weighted by the number of observations in the datasets.

It is typical to require for splitting that the size of each child dataaet be above some pre-determined threshold.

In [1]:
import pandas as pd
mdata=pd.read_csv("mortgage_data.csv")
print(type(mdata))
mdata.head()

<class 'pandas.core.frame.DataFrame'>


Unnamed: 0,location,princ,irate,cscore,result
0,suburban,358,7.0,728,default
1,suburban,637,7.25,675,default
2,suburban,303,7.25,645,non-default
3,suburban,397,7.25,609,non-default
4,suburban,420,7.75,669,default


In [2]:
mdata.tail()

Unnamed: 0,location,princ,irate,cscore,result
9859,suburban,769,7.75,586,non-default
9860,suburban,451,7.25,684,non-default
9861,suburban,410,7.0,702,non-default
9862,suburban,851,7.0,774,non-default
9863,suburban,260,7.5,657,default


In [3]:
mdata.shape

(9864, 5)

In [4]:
#
# Create a Y variable - Y=1 for default Y=0 for non-default
#
def f(row):
    if row["result"]=="default":
        return(1)
    else:
        return(0)
mdata["Y"]=mdata.apply(f,axis=1)

In [5]:
mdata["Y"][3]

0

In [6]:
mdata.head()

Unnamed: 0,location,princ,irate,cscore,result,Y
0,suburban,358,7.0,728,default,1
1,suburban,637,7.25,675,default,1
2,suburban,303,7.25,645,non-default,0
3,suburban,397,7.25,609,non-default,0
4,suburban,420,7.75,669,default,1


In [7]:
mdata["location"].value_counts()

suburban    5347
urban       2423
rural       2094
Name: location, dtype: int64

**Evaluate quality of a split**

Let's write code to evaluate quality of an example of a splitting function.

That code should take a pandas data frame and a function as arguments.

If a split would produce nodes with sizes below some threshold, we return a value so large that it can't reduce Gini coefficient.

**Example of a splitting function**

Here, we classify a row (an observation) according to whether the **cscore** for that row exceeds some threshold.

In [8]:
def f(row):
    if row["cscore"]>620:
        return("left")
    else:
        return("right")
    

**Gini criterion**

The following function takes as an argument a function, a data frame, and a minimum node size, and calculates the Gini criterion.

If splitting produces a node that has too few observations (less than min_node_size) we return a large value so that we'll not choose this splitting function. 

In [None]:
def Gini_criterion(df,f,min_node_size):
    #
    # calculate f(row) for every row in the data frame
    # this produces a Pandas series
    #
    fvalue=df.apply(f,axis=1)
    #
    # get the series of Y's for which fvalue is "left" 
    # and the series of Y's for whcih fvalue is "right"
    #
    Yleft=df["Y"].loc[fvalue=="left"]
    Yright=df["Y"].loc[fvalue=="right"]
    #
    # compute number of obs in each side
    #
    nleft=Yleft.size
    nright=Yright.size
    #
    # if split puts too few values in a node
    # we return a value that makes it so we'd never choose this f
    #
    if nleft<min_node_size or nright<min_node_size:
        return(nleft+nright)
    
    p1left=Yleft.loc[Yleft==1].size/nleft
    p1right=Yright.loc[Yright==1].size/nright
    #
    # compute the Gini coefficient
    #
    Gini=Yleft.size*p1left*(1-p1left)+Yright.size*p1right*(1-p1right)
    return(Gini)

In [None]:
ginivalue=Gini_criterion(mdata,f,100)
print(ginivalue)

**Another example**

In [None]:
def f(row):
    if row["irate"]>7:
        return("left")
    else:
        return("right")

In [None]:
ginivalue=Gini_criterion(mdata,f,100)
print(ginivalue)

**Goal**

We want our children to be as pure as possible, and we see that Gini impurity is lower for this splitting function than the one above, so this one would be preferred. We can ask for the best possible split based on a continuous variable or a categorical variable.

For a continuous variable v we could try every possible split of the form: $v<c$ vs. $v>c$ but that might take too long to compute. Instead we try only using some quantiles  for that variable. Below, quartiles are used, but there are other options, e.g. deciles, percentiles.

If a split would produce a node with too few values, we return a huge gini value (one that can't be smaller than the current one)

In [None]:
def find_best_splitting_function_continuous_variable(data,vname,min_node_size):
    # to use e.g. deciles we change the two 4's here to 10's
    qvalues=[data[vname].quantile(i/4) for i in range(1,4)]
    minginivalue=mdata.shape[0] # Gini can't be this big
    for qvalue in qvalues:
        def f(row):
            if row[vname]<qvalue:
                return("left")
            else:
                return("right")
        ginivalue=Gini_criterion(data,f,min_node_size)
        if ginivalue<minginivalue:
            bestf=f
            bestvalue=qvalue
            minginivalue=ginivalue
    #
    # return the best function, the value and its gini value
    #
    return(bestf,bestvalue,minginivalue)  

In [None]:
f,v,g=find_best_splitting_function_continuous_variable(mdata,"cscore",100)
print(v)
print(g)
f,v,g=find_best_splitting_function_continuous_variable(mdata,"irate",100)
print(v)
print(g)


**Splits for a categorical variable**

We need a function to try all splits of a categorical variable v taking values in a set say S={1,2,3,...,K}

Here we try splitting on a given subset T of S -sending those observations with values of v in T to the left and the others to the right.

We can then iterate over all nonempty subsets T to find the minimizer of the Gini criterion - Note that this code is less than optimal because every set is tested twice - once when we send observations with values in T to the left and again when we send all observations in the complement of T to the left.

The itertools package is handy for getting all combinations of elements in a list of some size.


In [None]:
import itertools as it
L=list(it.combinations([1,2,3],2))
print(L)

**Getting all subsets**

We need a list of all ways we can split a list of values into two nonempty pieces. This is straightforward if n, the size of the list is odd, we just need to make a list of all subsets of size 1,2,...,(n-1)/2. But if n is even we don't want to check each split of a subset of size n/2 twice (once for the subset and once for its complement).

In [None]:
def find_all_set_splits(value_list):
    splits=[]
    n=len(value_list)
    m=int(n/2)
    for sz in range(1,m+1):
        combs=it.combinations(value_list,sz)
        for comb in combs:
            splits.append(list(comb))
    if 2*m<n:
        return(splits)
    #
    # even case - need to add in subsets of size n/2
    #
    combs=it.combinations(value_list,m+1)
    svalue_list=set(value_list) # by the way - sets can't contain mutable elements!!!
    for comb in combs:
        s=set(comb)
        sc=svalue_list.difference(s)
        if s not in splits and svalue_list.difference(s):
            splits.append(list(s))
    return(splits)
    

In [None]:
find_all_set_splits(['dog',"cat","bird"])

In [None]:
find_all_set_splits(["dog","cat","bird","turtle"])

In [None]:
def find_best_splitting_function_categorical_variable(data,vname,min_node_size):
    values=list(data[vname].unique())
    nvalues=len(values)
    minginivalue=data.shape[0] # Gini can't be this big
    subset_list=find_all_set_splits(values)
    for subset in subset_list:
        def f(row):
            if row[vname] in subset:
                return("left")
            else:
                return("right")
        ginivalue=Gini_criterion(data,f,min_node_size)
        if ginivalue<minginivalue:
            bestf=f
            bestsubset=subset
            minginivalue=ginivalue
    return(bestf,bestsubset,minginivalue)  

In [None]:
find_best_splitting_function_categorical_variable(mdata,"location",100)

**Finding best split using all variables (continuous & categorical)**

Now we can try all continuous *and* categorical variables looking for the best split.

The following function takes a data set, a list of continuous variables, and a list of categorical variables as input and finds the best function to split the data on.

In [None]:
def find_best_split(data,cont_vars,cat_vars,min_node_size):
    minginivalue=data.shape[0]
    for catvar in cat_vars:
        f,b,g=find_best_splitting_function_categorical_variable(data,catvar,min_node_size)
        if g<minginivalue:
            minginivalue=g
            bestvar=catvar
            bestvartype="categorical"
            bestvalue=b
            bestf=f
    for contvar in cont_vars:
        f,b,g=find_best_splitting_function_continuous_variable(data,contvar,min_node_size)
        if g<minginivalue:
            minginivalue=g
            bestvar=contvar
            bestvartype="continuous"
            bestvalue=b
            bestf=f
    return bestf,bestvar,bestvartype,bestvalue,minginivalue
find_best_split(mdata,["irate","cscore","princ"],["location"],100)

In [9]:
**Build tree recursively*

We need a function that builds a tree by starting at root and recursively splitting each node until a stopping rule kicks in.

To keep things simple, we'll stop splitting if a node has fewer than 25 observations.

Each time we split, we attach a data frame data["df"] to each new node.

As we go along, we'll attach the counts of Y=0 and Y=1 to each node.

In [10]:
import numpy as np
import pandas as pd
import itertools as it
def Gini_criterion(df,f,min_node_size):
    #
    # calculate f(row) for every row in the data frame
    # this produces a Pandas series
    #
    fvalue=df.apply(f,axis=1)
    #
    # get the series of Y's for which fvalue is "left" 
    # and the series of Y's for whcih fvalue is "right"
    #
    Yleft=df["Y"].loc[fvalue=="left"]
    Yright=df["Y"].loc[fvalue=="right"]
    #
    # compute number of obs in each side
    #
    nleft=Yleft.size
    nright=Yright.size
    #
    # if split puts too few values in a node
    # we return a value that makes it so we'd never choose this f
    #
    if nleft<min_node_size or nright<min_node_size:
        return(nleft+nright)
    
    p1left=Yleft.loc[Yleft==1].size/nleft
    p1right=Yright.loc[Yright==1].size/nright
    #
    # compute the Gini coefficient
    #
    Gini=Yleft.size*p1left*(1-p1left)+Yright.size*p1right*(1-p1right)
    return(Gini)
def find_best_splitting_function_continuous_variable(data,vname,min_node_size):
    qvalues=[data[vname].quantile(i/4) for i in range(1,4)]
    minginivalue=mdata.shape[0] # Gini can't be this big
    bestf=None
    bestvalue=None
    for qvalue in qvalues:
        def f(row):
            if row[vname]<qvalue:
                return("left")
            else:
                return("right")
        ginivalue=Gini_criterion(data,f,min_node_size)
        if ginivalue<minginivalue:
            bestf=f
            bestvalue=qvalue
            minginivalue=ginivalue
    #
    # return the best function, the value and its gini value
    #
    return(bestf,bestvalue,minginivalue)  


def find_all_set_splits(value_list):
    splits=[]
    n=len(value_list)
    m=int(n/2)
    for sz in range(1,m+1):
        combs=it.combinations(value_list,sz)
        for comb in combs:
            splits.append(list(comb))
    if 2*m<n:
        return(splits)
    #
    # even case - need to add in subsets of size n/2
    #
    combs=it.combinations(value_list,m+1)
    svalue_list=set(value_list) # by the way - sets can't contain mutable elements!!!
    for comb in combs:
        s=set(comb)
        sc=svalue_list.difference(s)
        if s not in splits and svalue_list.difference(s):
            splits.append(list(s))
    return(splits)
    
def find_best_splitting_function_categorical_variable(data,vname,min_node_size):
    values=list(data[vname].unique())
    nvalues=len(values)
    minginivalue=data.shape[0] # Gini can't be this big
    subset_list=find_all_set_splits(values)
    bestf=None
    bestsubset=None
    for subset in subset_list:
        def f(row):
            if row[vname] in subset:
                return("left")
            else:
                return("right")
        ginivalue=Gini_criterion(data,f,min_node_size)
        if ginivalue<minginivalue:
            bestf=f
            bestsubset=subset
            minginivalue=ginivalue
    return(bestf,bestsubset,minginivalue)  

def find_best_split(data,cont_vars,cat_vars,min_node_size):
    minginivalue=data.shape[0]
    bestf=None
    bestvar=None
    bestvartype=None
    bestvalue=None
    for catvar in cat_vars:
        f,b,g=find_best_splitting_function_categorical_variable(data,catvar,min_node_size)
        if g<minginivalue:
            minginivalue=g
            bestvar=catvar
            bestvartype="categorical"
            bestvalue=b
            bestf=f
    for contvar in cont_vars:
        f,b,g=find_best_splitting_function_continuous_variable(data,contvar,min_node_size)
        if g<minginivalue:
            minginivalue=g
            bestvar=contvar
            bestvartype="continuous"
            bestvalue=b
            bestf=f
    return bestf,bestvar,bestvartype,bestvalue,minginivalue

class node:
    __slots__=('parent','left_child','right_child','data')
    #
    # We instantiate a node by passing a parent (which can be None) 
    # and a dictionary
    #
    def __init__(self,parent,data):
        if parent==None:
            # making this a root node
            self.data=data
            self.data["depth"]=0
            self.parent=None
        else:
            # making this a non-root node
            self.data=data
            self.data["depth"]=parent.data["depth"]+1
            self.parent=parent
        self.left_child=None
        self.right_child=None
    def get_parent(self): # return the node's parent
        return(self.parent)
    def get_data(self):   # return the node's data
        return(self.data)
    def get_depth(self):  # return the node's depth
        return(self.data["depth"])
    def get_label(self):
        return(self.data["label"])
    def set_label(self,label):
        self.data["label"]=label
    def get_left_child(self):
        return(self.left_child)
    def get_right_child(self):
        return(self.right_child)
    def spawn_left_child(self,data):
        # create a new node n with self as parent w/ given data
        n=node(parent=self,data=data)
        #n.data=data
        n.data["depth"]=self.data["depth"]+1
        self.left_child=n
        return(n)
    def spawn_right_child(self,data):
        n=node(parent=self,data=data)
        n.data=data
        n.data["depth"]=self.data["depth"]+1
        self.right_child=n
        return(n)
    #
    # string consisting of information about node
    #
    def __str__(self):
        s="node label = "+self.data["label"]+"\n"
        if self.parent==None:
            s+="   no parent i.e. root node\n"
        else:
            s+="   parent label = " + self.parent.data["label"]+"\n"
        if self.left_child==None:
            s+="   no left child\n"
        else:
            s+="   left child label " + self.left_child.data["label"]+"\n"
        if self.right_child==None:
            s+="   no right child\n"
        else:
            s+="   right child label " + self.right_child.data["label"]+"\n"
        return(s)
    def treestr(self):
        d=self.data
        depth=d["depth"]
        G=d["gini"]
        Gstring="G: {:8.2f} ".format(G)
        Y0=d["Ycts"][0]
        Y1=d["Ycts"][1]
        Ycts_string="N: "+str(Y0+Y1)+" N0: "+str(Y0)+" "+" N1:"+str(Y1)+"\n"
        p0=Y0/(Y0+Y1)
        p1=Y1/(Y0+Y1)
        pstring="p0: {:5.4f} p1: {:5.4f}\n".format(p0,p1)
        spaces="".join(["  " for i in range(depth)])
        s=spaces+d["label"]+"\n"
        s+=spaces+Gstring+Ycts_string
        s+=spaces+pstring
        #
        # if this node has a split, include info about it
        #
        if "splitinfo" in self.data:
            splitinfo=self.data["splitinfo"]
        
        
        
        if self.left_child!=None:
            s+=self.left_child.treestr()
            s+=self.right_child.treestr()
        return(s)
    def treeprint(self):
        s=self.treestr()
        print(s)

**Recursive split node function**

In [11]:
def split_node(cnode,contvars,catvars,min_node_size):  
    cdf=cnode.data["df"]
    
    # compute Y counts in this node and store them
    N0=np.sum(1-cdf["Y"])
    N1=np.sum(cdf["Y"])
    cnode.data["Ycts"]=[N0,N1]
    
    #
    # Gini for a node is N*p(1-p) where p is prop of 1's
    # so this equalis (N0+N1)*(N0/(N0+N1)))*(N1/(N0+N1)) = N0*N1/(N0+N1)
    #
    cnode.data["gini"]=N0*N1/(N0+N1)
    
    if cnode.data["df"].shape[0]>=min_node_size:
        print("new node to try splitting: "+cnode.data["label"]+" size= "+str(cnode.data["df"].shape[0]))
        
        # find best split
        f,v,vtype,value,g=find_best_split(cnode.data["df"],contvars,catvars,min_node_size)
        
        #
        # if the split leads to a bigger gini, we don't split the node
        # so compare to gini at current node
        #
        if g>=cnode.data["gini"]:
            print("node is not split since gini not reduced")
        else:
            
            # determine which rows of current data frame go left and which go right
            child_assignment=cnode.data["df"].apply(f,axis=1)
        
            # compute counts of child nodes if we split
            nleft=np.sum(child_assignment=="left")
            nright=np.sum(child_assignment=="right")
            
            if nleft<min_node_size or nright<min_node_size:
                print("node is not split because of minimum node size constraint")
                
            else:
                
                # attach splitting function to data at this node
                splitinfo={"f":f, "vname": v, "vtype":vtype, "value":value}
                cnode.data["splitinfo"]=splitinfo
                     
                print("splitting node into sizes "+str(nleft)+" "+str(nright))
                # compute data frames to put at child nodes
                dfleft=cnode.data["df"].loc[child_assignment=="left"].copy()
                dfright=cnode.data["df"].loc[child_assignment=="right"].copy()
       
                # replace data frame indices by range
                dfleft.index=range(dfleft.shape[0])
                dfright.index=range(dfright.shape[0])
    
                # create a label 
                dataleft={"df":dfleft,"label":cnode.data["label"]+"L"}
                dataright={"df":dfright,"label":cnode.data["label"]+"R"}
                        
                # create child nodes 
                left_child=cnode.spawn_left_child(dataleft)
                right_child=cnode.spawn_right_child(dataright)
               
            
                # split child nodes
                split_node(left_child,contvars,catvars,min_node_size)
                split_node(right_child,contvars,catvars,min_node_size)

mdata=pd.read_csv("mortgage_data.csv")
#
# Create a Y variable - Y=1 for default Y=0 for non-default
#
def f(row):
    if row["result"]=="default":
        return(1)
    else:
        return(0)
mdata["Y"]=mdata.apply(f,axis=1)
rootnode=node(None,{"df":mdata,"label":""})
split_node(rootnode,["irate","cscore","princ"],["location"],500)

new node to try splitting:  size= 9864
splitting node into sizes 6781 3083
new node to try splitting: L size= 6781
splitting node into sizes 2981 3800
new node to try splitting: LL size= 2981
splitting node into sizes 627 2354
new node to try splitting: LLL size= 627
node is not split since gini not reduced
new node to try splitting: LLR size= 2354
splitting node into sizes 637 1717
new node to try splitting: LLRL size= 637
node is not split since gini not reduced
new node to try splitting: LLRR size= 1717
node is not split because of minimum node size constraint
new node to try splitting: LR size= 3800
splitting node into sizes 2850 950
new node to try splitting: LRL size= 2850
splitting node into sizes 2133 717
new node to try splitting: LRLL size= 2133
splitting node into sizes 661 1472
new node to try splitting: LRLLL size= 661
node is not split since gini not reduced
new node to try splitting: LRLLR size= 1472
node is not split because of minimum node size constraint
new node to t

In [None]:
rootnode.treeprint()

**Make the split node function a class method**

That function has been renamed to build_tree.

In [None]:
import numpy as np
import pandas as pd
import itertools as it
def Gini_criterion(df,f,min_node_size):
    #
    # calculate f(row) for every row in the data frame
    # this produces a Pandas series
    #
    fvalue=df.apply(f,axis=1)
    #
    # get the series of Y's for which fvalue is "left" 
    # and the series of Y's for whcih fvalue is "right"
    #
    Yleft=df["Y"].loc[fvalue=="left"]
    Yright=df["Y"].loc[fvalue=="right"]
    #
    # compute number of obs in each side
    #
    nleft=Yleft.size
    nright=Yright.size
    #
    # if split puts too few values in a node
    # we return a value that makes it so we'd never choose this f
    #
    if nleft<min_node_size or nright<min_node_size:
        return(nleft+nright)
    
    p1left=Yleft.loc[Yleft==1].size/nleft
    p1right=Yright.loc[Yright==1].size/nright
    #
    # compute the Gini coefficient
    #
    Gini=Yleft.size*p1left*(1-p1left)+Yright.size*p1right*(1-p1right)
    return(Gini)
def find_best_splitting_function_continuous_variable(data,vname,min_node_size):
    qvalues=[data[vname].quantile(i/4) for i in range(1,4)]
    minginivalue=mdata.shape[0] # Gini can't be this big
    bestf=None
    bestvalue=None
    for qvalue in qvalues:
        def f(row):
            if row[vname]<qvalue:
                return("left")
            else:
                return("right")
        ginivalue=Gini_criterion(data,f,min_node_size)
        if ginivalue<minginivalue:
            bestf=f
            bestvalue=qvalue
            minginivalue=ginivalue
    #
    # return the best function, the value and its gini value
    #
    return(bestf,bestvalue,minginivalue)  


def find_all_set_splits(value_list):
    splits=[]
    n=len(value_list)
    m=int(n/2)
    for sz in range(1,m+1):
        combs=it.combinations(value_list,sz)
        for comb in combs:
            splits.append(list(comb))
    if 2*m<n:
        return(splits)
    #
    # even case - need to add in subsets of size n/2
    #
    combs=it.combinations(value_list,m+1)
    svalue_list=set(value_list) # by the way - sets can't contain mutable elements!!!
    for comb in combs:
        s=set(comb)
        sc=svalue_list.difference(s)
        if s not in splits and svalue_list.difference(s):
            splits.append(list(s))
    return(splits)
    
def find_best_splitting_function_categorical_variable(data,vname,min_node_size):
    values=list(data[vname].unique())
    nvalues=len(values)
    minginivalue=data.shape[0] # Gini can't be this big
    subset_list=find_all_set_splits(values)
    bestf=None
    bestsubset=None
    for subset in subset_list:
        def f(row):
            if row[vname] in subset:
                return("left")
            else:
                return("right")
        ginivalue=Gini_criterion(data,f,min_node_size)
        if ginivalue<minginivalue:
            bestf=f
            bestsubset=subset
            minginivalue=ginivalue
    return(bestf,bestsubset,minginivalue)  

def find_best_split(data,cont_vars,cat_vars,min_node_size):
    minginivalue=data.shape[0]
    bestf=None
    bestvar=None
    bestvartype=None
    bestvalue=None
    for catvar in cat_vars:
        f,b,g=find_best_splitting_function_categorical_variable(data,catvar,min_node_size)
        if g<minginivalue:
            minginivalue=g
            bestvar=catvar
            bestvartype="categorical"
            bestvalue=b
            bestf=f
    for contvar in cont_vars:
        f,b,g=find_best_splitting_function_continuous_variable(data,contvar,min_node_size)
        if g<minginivalue:
            minginivalue=g
            bestvar=contvar
            bestvartype="continuous"
            bestvalue=b
            bestf=f
    return bestf,bestvar,bestvartype,bestvalue,minginivalue

class node:
    __slots__=('parent','left_child','right_child','data')
    #
    # We instantiate a node by passing a parent (which can be None) 
    # and a dictionary
    #
    def __init__(self,parent,data):
        if parent==None:
            # making this a root node
            self.data=data
            self.data["depth"]=0
            self.parent=None
        else:
            # making this a non-root node
            self.data=data
            self.data["depth"]=parent.data["depth"]+1
            self.parent=parent
        self.left_child=None
        self.right_child=None
    def get_parent(self): # return the node's parent
        return(self.parent)
    def get_data(self):   # return the node's data
        return(self.data)
    def get_depth(self):  # return the node's depth
        return(self.data["depth"])
    def get_label(self):
        return(self.data["label"])
    def set_label(self,label):
        self.data["label"]=label
    def get_left_child(self):
        return(self.left_child)
    def get_right_child(self):
        return(self.right_child)
    def spawn_left_child(self,data):
        # create a new node n with self as parent w/ given data
        n=node(parent=self,data=data)
        #n.data=data
        n.data["depth"]=self.data["depth"]+1
        self.left_child=n
        return(n)
    def spawn_right_child(self,data):
        n=node(parent=self,data=data)
        n.data=data
        n.data["depth"]=self.data["depth"]+1
        self.right_child=n
        return(n)
    #
    # string consisting of information about node
    #
    def __str__(self):
        s="node label = "+self.data["label"]+"\n"
        if self.parent==None:
            s+="   no parent i.e. root node\n"
        else:
            s+="   parent label = " + self.parent.data["label"]+"\n"
        if self.left_child==None:
            s+="   no left child\n"
        else:
            s+="   left child label " + self.left_child.data["label"]+"\n"
        if self.right_child==None:
            s+="   no right child\n"
        else:
            s+="   right child label " + self.right_child.data["label"]+"\n"
        
        return(s)
class node:
    __slots__=('parent','left_child','right_child','data')
    #
    # We instantiate a node by passing a parent (which can be None) 
    # and a dictionary
    #
    def __init__(self,parent,data):
        if parent==None:
            # making this a root node
            self.data=data
            self.data["depth"]=0
            self.parent=None
        else:
            # making this a non-root node
            self.data=data
            self.data["depth"]=parent.data["depth"]+1
            self.parent=parent
        self.left_child=None
        self.right_child=None
    def get_parent(self): # return the node's parent
        return(self.parent)
    def get_data(self):   # return the node's data
        return(self.data)
    def get_depth(self):  # return the node's depth
        return(self.data["depth"])
    def get_label(self):
        return(self.data["label"])
    def set_label(self,label):
        self.data["label"]=label
    def get_left_child(self):
        return(self.left_child)
    def get_right_child(self):
        return(self.right_child)
    def spawn_left_child(self,data):
        # create a new node n with self as parent w/ given data
        n=node(parent=self,data=data)
        #n.data=data
        n.data["depth"]=self.data["depth"]+1
        self.left_child=n
        return(n)
    def spawn_right_child(self,data):
        n=node(parent=self,data=data)
        n.data=data
        n.data["depth"]=self.data["depth"]+1
        self.right_child=n
        return(n)
    #
    # string consisting of information about node
    #
    def __str__(self):
        s="node label = "+self.data["label"]+"\n"
        if self.parent==None:
            s+="   no parent i.e. root node\n"
        else:
            s+="   parent label = " + self.parent.data["label"]+"\n"
        if self.left_child==None:
            s+="   no left child\n"
        else:
            s+="   left child label " + self.left_child.data["label"]+"\n"
        if self.right_child==None:
            s+="   no right child\n"
        else:
            s+="   right child label " + self.right_child.data["label"]+"\n"
        return(s)
    def treestr(self):
        d=self.data
        depth=d["depth"]
        G=d["gini"]
        Gstring="G: {:8.2f} ".format(G)
        Y0=d["Ycts"][0]
        Y1=d["Ycts"][1]
        Ycts_string="N: "+str(Y0+Y1)+" N0: "+str(Y0)+" "+" N1:"+str(Y1)+"\n"
        p0=Y0/(Y0+Y1)
        p1=Y1/(Y0+Y1)
        pstring="p0: {:5.4f} p1: {:5.4f}\n".format(p0,p1)
        spaces="".join(["  " for i in range(depth)])
        s=spaces+d["label"]+"\n"
        s+=spaces+Gstring+Ycts_string
        s+=spaces+pstring
        #
        # if this node has a split, include info about it
        #
        if "splitinfo" in self.data:
            splitinfo=self.data["splitinfo"]
        
        
        
        if self.left_child!=None:
            s+=self.left_child.treestr()
            s+=self.right_child.treestr()
        return(s)
    def treeprint(self):
        s=self.treestr()
        print(s)
    def build_tree(self,contvars,catvars,min_node_size):  
        cdf=self.data["df"]
    
        # compute Y counts in this node and store them
        N0=np.sum(1-cdf["Y"])
        N1=np.sum(cdf["Y"])
        self.data["Ycts"]=[N0,N1]
    
        #
        # Gini for a node is N*p(1-p) where p is prop of 1's
        # so this equalis (N0+N1)*(N0/(N0+N1)))*(N1/(N0+N1)) = N0*N1/(N0+N1)
        #
        self.data["gini"]=N0*N1/(N0+N1)
    
        if self.data["df"].shape[0]>=min_node_size:
            print("new node to try splitting: "+self.data["label"]+" size= "+str(self.data["df"].shape[0]))
        
            # find best split
            f,v,vtype,value,g=find_best_split(self.data["df"],contvars,catvars,min_node_size)
        
            #
            # if the split leads to a bigger gini, we don't split the node
            # so compare to gini at current node
            #
            if g>=self.data["gini"]:
                print("node is not split since gini not reduced")
            else:
            
                # determine which rows of current data frame go left and which go right
                child_assignment=self.data["df"].apply(f,axis=1)
        
                # compute counts of child nodes if we split
                nleft=np.sum(child_assignment=="left")
                nright=np.sum(child_assignment=="right")
            
                if nleft<min_node_size or nright<min_node_size:
                    print("node is not split because of minimum node size constraint")
                
                else:
                
                    # attach splitting function to data at this node
                    splitinfo={"f":f, "vname": v, "vtype":vtype, "value":value}
                    self.data["splitinfo"]=splitinfo
                     
                    print("splitting node into sizes "+str(nleft)+" "+str(nright))
                    # compute data frames to put at child nodes
                    dfleft=self.data["df"].loc[child_assignment=="left"].copy()
                    dfright=self.data["df"].loc[child_assignment=="right"].copy()
       
                    # replace data frame indices by range
                    dfleft.index=range(dfleft.shape[0])
                    dfright.index=range(dfright.shape[0])
    
                    # create a label 
                    dataleft={"df":dfleft,"label":self.data["label"]+"L"}
                    dataright={"df":dfright,"label":self.data["label"]+"R"}
                        
                    # create child nodes 
                    left_child=self.spawn_left_child(dataleft)
                    right_child=self.spawn_right_child(dataright)
               
            
                    # split child nodes
                    left_child.build_tree(contvars,catvars,min_node_size)
                    right_child.build_tree(contvars,catvars,min_node_size)

**Test the method**

We create a root node, attach a data frame to it, call the build_tree method a this node.

In [None]:
mdata=pd.read_csv("mortgage_data.csv")
#
# Create a Y variable - Y=1 for default Y=0 for non-default
#
def f(row):
    if row["result"]=="default":
        return(1)
    else:
        return(0)
mdata["Y"]=mdata.apply(f,axis=1)
rootnode=node(None,{"df":mdata,"label":""})
rootnode.build_tree(["irate","cscore","princ"],["location"],500)

In [None]:
rootnode.treeprint()

**Classify an observation**

Finally, we need to classify an observation, which is another recursive function.

This is left as an exercise.