In [82]:
training_data = [ 
    ['Green',3,'Apple'],
    ['Yellow',3,'Apple'],
    ['Red',1,'Grape'],
    ['Red',1,'Grape'],
    ['Yellow',3,'Lemon'],
]

In [83]:
header = ["Color","Diameter","Label"]

In [84]:
def unique_vals(Data,col):
    return set ([row[col] for row in Data])

In [85]:
unique_vals(training_data,0)

{'Green', 'Red', 'Yellow'}

In [86]:
def class_counts(Data):
    counts = {}
    for row in Data:
        label = row[-1]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts
       
class_counts(training_data)

{'Apple': 2, 'Grape': 2, 'Lemon': 1}

In [87]:
def is_numeric(value):
    return isinstance(value,int) or isinstance(value,float)
is_numeric(20)


True

In [88]:
class Question:
    def __init__(self,column,value): #1,3
        self.column=column
        self.value=value
        
    def match(self,example):
        val=example[self.column]
        if is_numeric(val):
            return val ==self.value 
        else:
            return val==self.value
        
    def __repr__(self):
        #This is just a helper method to print the question formed
        #the question in a readable formt.
        condition="=="
        if is_numeric(self.value):
            condition ="=="
        return "Is %s %s %s ?" %(header[self.column],condition,str(self.value))

   

In [89]:
q=Question(0,'red')
q.match(training_data[2])

False

In [142]:

def partition(rows,question):
    true_rows,false_rows=[],[]
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows , false_rows    


In [143]:
true_rows, false_rows = partition(training_data, Question(1,3))

In [144]:
false_rows

[['Red', 1, 'Grape'], ['Red', 1, 'Grape']]

In [145]:
def gini(rows):
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl]/ float(len(rows))
        impurity -= prob_of_lbl**2
    return impurity

In [146]:

gini(training_data)

0.6399999999999999

In [147]:

gini(training_data)

0.6399999999999999

In [148]:
def info_gain(left,right,current_uncertainty):
    """Information Gain.
    The uncertainty of the starting node,minus the weighted impurity of two child nodes.
    """
    
    p=float(len(left)) / (len(left) + len(right))
    q= 1 - p
   # q=float(len(right)) / (len(left) + len(right))
    return current_uncertainty - p*gini(left) - q*gini(right)


In [149]:
true_rows,false_rows = partition(training_data,Question(1,3))
       

In [150]:
print(true_rows)
info_gain(true_rows,false_rows,0.639)

[['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']]


0.37233333333333335

In [151]:
def find_best_split(rows):
    best_gain = 0
    best_question = None
    current_uncertainity = gini(rows)
    
    n_features = len(rows[0]) - 1
    for col in range(n_features):
        values = set(row[col] for row in rows)
        for val in values:
            question =Question(col, val)
            true_rows, false_rows = partition(rows, question)
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue
        
            gain = info_gain(true_rows, false_rows, current_uncertainity)
            if gain >= best_gain:
                best_gain, best_question = gain, question
    return best_gain, best_question



In [152]:
best_gain, best_question = find_best_split(training_data)
print(best_question)
print(best_gain)

Is Diameter == 3 ?
0.37333333333333324


In [121]:
class Leaf:
    def __init__(self,rows):
        self.prediction = class_counts(rows)
        

In [122]:
class Decision_Node:
    """ A Decision Node asks a question.
    
    This holds a reference to the question ,and to the two 
    """
    
    def __init__(self,question,true_branch,false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch
        

In [159]:
def build_tree(rows):
    gain, question = find_best_split(rows)
    if gain==0:
        return Leaf(rows)
    true_rows,false_rows = partition(rows,question)
    true_branch = build_tree(true_rows)
    false_branch = build_tree(false_rows)
    return Decision_Node(question,true_branch,false_branch)



In [163]:
my_tree = build_tree(training_data)
print(my_true)

<__main__.Decision_Node object at 0x000002AFEC0AA518>


In [181]:
def print_tree(node,spacing=""):
    """ World's most elegant tree printing function."""
    #base case: we have reached a leaf
    if isinstance(node,Leaf):
        print(spacing + "predict", node.prediction)
        return
    
    # Print the question at this node
    print(spacing + str(node.question))
    
    #call this function recursively on the true branch
    print(spacing + '---> True:')
    print_tree(node.true_branch, spacing + " ")
    
    #call this function recursively on the false branch
    print(spacing + '---> False:')
    print_tree(node.false_branch, spacing + " ")
    

In [182]:
print_tree(my_tree)

Is Diameter == 3 ?
---> True:
 Is Color == Yellow ?
 ---> True:
  predict {'Apple': 1, 'Lemon': 1}
 ---> False:
  predict {'Apple': 1}
---> False:
 predict {'Grape': 2}
