In [1]:
import numpy as np

In [2]:
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]

In [3]:
header = ["color",'Diameter','Fruit']

In [4]:
def unique_vals(data,column):
    return set([row[column] for row in data])
unique_vals(training_data,0)

{'Green', 'Red', 'Yellow'}

In [5]:
def Counter(data):
    counts = dict()
    for row in data:
        if row[-1] in counts:
            counts[row[-1]] += 1
        else:
            counts[row[-1]] = 1
    return counts
Counter(training_data)

{'Apple': 2, 'Grape': 2, 'Lemon': 1}

In [6]:
def is_numeric(value):
    return isinstance(value,int) or isinstance(value,float)

In [7]:
class Question:
    def __init__(self,column,value):
        self.column = column
        self.value = value
        
    def match(self,example):
        val = example[self.column]
        if is_numeric(val):
            return val>= self.value
        else:
            return val == self.value
    def __repr__(self):
        condition = "=="
        if is_numeric(self.value):
            condition = ">="
        return "Is {} {} {}".format(header[self.column],condition,self.value)

In [8]:
q = Question(0,'Red')
print(q)
q.match(['Red',3,'Apple'])

Is color == Red


True

In [9]:
def partition(dataset,question):
    left_rows,right_rows = [],[]
    for row in dataset:
        if question.match(row):
            left_rows.append(row)
        else:
            right_rows.append(row)
    return (left_rows,right_rows)

In [10]:
partition(training_data,q)

([['Red', 1, 'Grape'], ['Red', 1, 'Grape']],
 [['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']])

In [11]:
def gini(data):
    counts = Counter(data)
    impurity = 1
    for val in counts:
        prob_of_val = counts[val] / (len(data))
        impurity -= prob_of_val**2
    return impurity
# Impurity of training_data
gini(training_data)

0.6399999999999999

In [12]:
def infoGain(left,right,rootGini):
    p = len(left) / (len(left)+len(right))
    return rootGini - (p*gini(left) + (1-p)*gini(right))

In [15]:
left,right = partition(training_data,q)
infoGain(left,right,gini(training_data))

0.37333333333333324

In [20]:
def find_best_split(data):
    best_gain = 0
    best_question = None
    rootGini = gini(data)
    n_features = len(data[0])-1
    
    for col in range(n_features):
        values = set([row[col] for row in data])
        for val in values:
            question = Question(col,val)
            true_rows,false_rows = partition(data,question)
            
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue
            gain = infoGain(true_rows,false_rows,rootGini)
            if gain >= best_gain:
                best_gain, best_question = gain,question
    return best_gain,best_question

In [21]:
find_best_split(data=training_data)

(0.37333333333333324, Is Diameter >= 3)

In [24]:
class Leaf:
    def __init__(self,rows):
        self.prediction = Counter(rows)

In [73]:
class decisionNode:
    def __init__(self,question,true_branch,false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

In [92]:
def buildTree(data):
    gain,question = find_best_split(data)
    
    if gain == 0:
        return Leaf(data)
    
    true_rows,false_rows = partition(data,question)
    true_branch = buildTree(true_rows)
    false_branch = buildTree(false_rows)
    return decisionNode(question,true_branch,false_branch)

In [93]:
def print_tree(node,spacing = ""):
    if isinstance(node,Leaf):
        print(spacing + "Predict",node.prediction)
        return 
    print(spacing+str(node.question))
    
    print(spacing+'--> True : ')
    print_tree(node.true_branch,spacing + "  ")
    
    print(spacing + '--> False : ')
    print_tree(node.false_branch, spacing + " ")
    

In [94]:
my_tree = buildTree(training_data)

In [91]:
print_tree(my_tree)

Is Diameter >= 3
--> True : 
  Is color == Yellow
  --> True : 
    Predict {'Apple': 1, 'Lemon': 1}
  --> False : 
   Predict {'Apple': 1}
--> False : 
 Predict {'Grape': 2}


In [95]:
def classify(row,node):
    if isinstance(node,Leaf):
        return node.prediction
    if node.question.match(row):
        return classify(row,node.true_branch)
    else:
        return classify(row,node.false_branch)

In [98]:
classify(['Yelow',3],my_tree)

{'Apple': 1}