In [86]:
from __future__ import print_function

In [87]:
feature_names = ["height" , "hair-length" , "voice-pitch"]

In [88]:
training_data = [ 
    
      # height , hair-length , voice-pitch
      [180, 15,0,'man'],                                                              
      [167, 42,1,'woman'],                                                              
      [136, 35,1,'woman'],                                                              
      [174, 15,0,'man'],                                                              
      [141, 28,1,'woman']
]                                                              

# labels
Y = ['man', 'woman', 'woman', 'man', 'woman']

In [89]:
def gen_counts(rows):
    counts ={}
    
    for e in rows:
        
        e_label = e[-1]
        if e_label not in counts:
            counts[e_label] = 1
        else:
            counts[e_label] += 1
            
            
    return counts


In [90]:
def gini(data):
    """Calculate the gini index"""
    counts = gen_counts(data)
    impurity = 1
    for d in counts:
        prob = counts[d]/len(data)
        impurity -= prob**2
    return impurity
    

In [91]:
gini(Y)

0.0

In [92]:
gen_counts(training_data)

{'man': 2, 'woman': 3}

In [93]:
class Question:
    
    def __init__(self,column,value):
        self.column = column;
        self.value = value;
        
    def match(self,example):
        """Check if value matches"""
        val = example[self.column]
        return val == self.value
    
    def __repr__(self):
        """Compose a question"""
        return "Is %s <= %s "%(feature_names[self.column], str(self.value))
        
    

In [94]:
def partition(data,question):
    true_rows,false_rows = [],[]
    for row in data:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
            
    return true_rows,false_rows
    

In [95]:
# partition based on question " Is Voice Pitch equal to 0 ? "
partition(training_data,Question(2,0))

([[180, 15, 0, 'man'], [174, 15, 0, 'man']],
 [[167, 42, 1, 'woman'], [136, 35, 1, 'woman'], [141, 28, 1, 'woman']])

In [96]:
def information_gain(true_rows,false_rows,data_uncertainity):
    p = len(true_rows)/(len(true_rows)+len(false_rows))
    return data_uncertainity-(p*gini(true_rows))-((1-p)*gini(false_rows))
    

In [97]:
data_uncertainity = gini(training_data)


In [98]:
print(data_uncertainity)

0.48


In [99]:

t,f = partition(training_data,Question(0,180))
information_gain(t,f,data_uncertainity)

0.17999999999999994

In [100]:
def find_best_split(training_data):
    best_gain = 0
    best_question = None
    data_uncertainity = gini(training_data)
    num_features = len(training_data[0])-1
    
    for col in range(num_features):
        
        values = set([row[col] for row in training_data])
        
        for value in values:
            
            question = Question(col,value)
            
            t_rows,f_rows = partition(training_data,question)
            
            if len(t_rows)==0 or len(f_rows)==0:
                continue
            
            gain = information_gain(t_rows,f_rows,data_uncertainity)
            
            if gain > best_gain:
                best_gain = gain
                best_question = question
                
    return best_gain,best_question
        

In [101]:
find_best_split(training_data)

(0.48, Is hair-length <= 15 )

In [102]:
class Leaf:
    
    def __init__(self,rows):
        self.predictions = gen_counts(rows)

In [103]:
class Decision_Node:
    
    def __init__(self,question,true_branch,false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

In [104]:
def build_tree(training_data):
    
    best_gain , best_question = find_best_split(training_data)
   
    if best_gain == 0:
        return Leaf(training_data)
    
    true_rows , false_rows = partition(training_data,best_question)
    
    true_branch = build_tree(true_rows)
    
    false_branch = build_tree(false_rows)
    
    return Decision_Node(best_question,true_branch,false_branch)
    

In [105]:
def print_tree(node):
    if isinstance(node,Leaf):
        print(" " + "Predict", node.predictions)
        return
        
    print(" " + str(node.question))
    
    print(" " + "--> True")
    print_tree(node.true_branch)
    
    print(" " + "--> False")
    print_tree(node.false_branch)
    

In [106]:
tree = build_tree(training_data)

In [107]:
print_tree(tree)

 Is hair-length <= 15 
 --> True
 Predict {'man': 2}
 --> False
 Predict {'woman': 3}
