The below code has been developed with reference from Josh Gordon's Machine Learning recepies video lectures

In [1]:
def split_data(root,condition):
    #true_rows=pd.DataFrame()
    #false_rows=pd.DataFrame()
    true_rows=root[condition]
    false_rows=root[~condition]
    #print(true_rows)
    #print(false_rows)
    return true_rows,false_rows

In [2]:
def gini(root):
    labels=root[root.columns.values[-1]]
    labels_unique=labels.unique()
    labels_counts=labels.value_counts()
    impurity=1
    for lbl in labels_unique:
        p=labels_counts[lbl]/len(labels)
        impurity=impurity-(p**2)
    return impurity

In [3]:
def calculate_gain(current_impurity,true_rows,false_rows):
    p=float(len(true_rows))/float(len(true_rows)+len(false_rows))
    gain=current_impurity-p*gini(true_rows)-(1-p)*gini(false_rows)
    return gain

In [4]:
def build_split_condition(root,val,root_column):
    if val.isnumeric():
        split_readable=root_column+'>='+val
        split_condition=root[root_column]>=int(val)
    else:
        split_readable=root_column+'=='+val
        split_condition=root[root_column]==val
    return split_condition, split_readable            

In [5]:
def select_best_split(root):
    best_gain=0
    root_columns=root.columns.values
    features_len=len(root_columns)-1
    rows_count=root.shape[0]
    best_split=''
    best_split_readable=''
    #true_rows=pd.DataFrame()
    #false_rows=pd.DataFrame()
    current_impurity=gini(root)
    for i in range(features_len):
        for j in range(rows_count):
            val=str(root.iloc[j,i])
            split_condition, split_readable=build_split_condition(root,val,root_columns[i])
            true_rows,false_rows=split_data(root,split_condition)
            info_gain=calculate_gain(current_impurity,true_rows,false_rows)
            if best_gain<info_gain:
                best_gain=info_gain
                best_split=split_condition
                best_split_readable=split_readable      
    return best_gain,best_split,best_split_readable

In [6]:
class Leaf:
    def __init__(self, rows):
        self.predictions = pd.DataFrame(rows[rows.columns.values[-1]].value_counts())

In [7]:
class node:
    def __init__(self,
                 question,
                 true_branch,
                 false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

In [8]:
def build_tree(root):
    #split_true=pd.DataFrame()
    #split_false=pd.DataFrame()
    info_gain,split,split_readable=select_best_split(root)
    #print(split_readable)
    if info_gain==0:
        return Leaf(root)
    split_true, split_false=split_data(root,split)
    #print(split_true)
    #print(split_false)
    true_branch=build_tree(split_true)
    false_branch=build_tree(split_false)
    return node (split_readable,true_branch,false_branch)

In [9]:
def print_tree(node):
    
    if isinstance(node, Leaf):
        print ("Predict", node.predictions)
        return
    
    print (str(node.question))

    print ('--> True:')
    print_tree(node.true_branch)

    print ('--> False:')
    print_tree(node.false_branch)

In [10]:
import pandas as pd
data=[['Yellow',3,'Apple'],['Yellow',3,'Lemon'],['Green',3,'Apple'],['Red',1,'Grapes'],['Red',1,'Grapes'],['Black',2,'Jamun'],['Black',2,'Jamun']]
train=pd.DataFrame(data,columns=['Color','Diameter','Fruit'])
print(train)
tree=build_tree(train)

#Print the Tree
print("TREE:")
print_tree(tree)

    Color  Diameter   Fruit
0  Yellow         3   Apple
1  Yellow         3   Lemon
2   Green         3   Apple
3     Red         1  Grapes
4     Red         1  Grapes
5   Black         2   Jamun
6   Black         2   Jamun
TREE:
Color==Red
--> True:
Predict         Fruit
Grapes      2
--> False:
Color==Black
--> True:
Predict        Fruit
Jamun      2
--> False:
Color==Yellow
--> True:
Predict        Fruit
Lemon      1
Apple      1
--> False:
Predict        Fruit
Apple      1
