In [345]:
#########################################

In [346]:
header =["color","diameter","label"]

In [347]:
def is_numeric(x):
    return isinstance(x,int) or isinstance(x,float)

In [348]:
training_data =[
    ['Green',3,'Apple'],
    ['Yellow',3,'Apple'],
    ['Red',1,'Grape'],
    ['Red',1,'Grape'],
    ['Yellow',3,'Lemon'],
]
testing_data=[
    ['Green',3,'Apple'],
    ['Yellow',4,'Apple'],
    ['Red',2,'Grape'],
    ['Red',1,'Grape'],
    ['Yellow',3,'Lemon'],
]

In [349]:
class Question:
    def __init__(self,column, value):
        self.column= column
        self.value=value
    def match(self, example):
        val = example[self.column]
        if (is_numeric(val)):
            return val >= self.value
        else:
            return val == self.value
    def __repr__(self):
        if is_numeric (self.value):
            condition = ">="
        else:
            condition = "=="
        return "Is %s %s %s?" %(header[self.column],condition,str(self.value))
        
    

In [350]:
Question(1,3)

Is diameter >= 3?

In [351]:
def partition(rows, question):
    true_rows, false_rows = [],[]
    #print(question)
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows, false_rows

In [352]:
def count(rows):
    counts={}
    for row in rows:
        label = row[-1]
        if label not in counts:
            counts[label]=0
        counts[label]+=1
    return counts

In [353]:
def gini_impurity(rows):
    impurity=1
    counts = count(rows)
    for label in counts:
        prob_of_label = counts[label] / float(len(rows))
        impurity -= prob_of_label**2
    return impurity

In [354]:
def info_gain(left, right, current_uncertainty):
    p = float(len(left)) / (len(left) + len(right))
    return current_uncertainty - p * gini_impurity(left) - (1 - p) * gini_impurity(right)

In [355]:
def unique_vals(rows, col):
    """Find the unique values for a column in a dataset."""
    return set([row[col] for row in rows])

In [356]:
def find_best_split(rows):
    best_gain= 0
    best_question = None
    current_uncertainty = gini_impurity(rows)
    n_features = len(rows[0]) - 1
    #print(n_features)
    for col in range(n_features):
        values = set([row[col] for row in rows])
        #print(values)
        for val in values:
            question = Question(col,val)
            #print(question)
            true_rows, false_rows = partition(rows,question)
            #if it doesn't slip we need to skip
            if len(true_rows)==0 or len(false_rows)==0:
                continue
            #In case it does split, we need to calculate the gain
            gain = info_gain(true_rows,false_rows,current_uncertainty)
#             print(gain)
#             print(question)
            if gain >= best_gain:
                best_gain= gain
                best_question = question
    return best_gain,best_question

In [357]:
class Leaf:
    def __init__(self,rows):
        self.predictions= count(rows)

In [358]:
class Decision_Node:
    def __init__(self,question,true_branch,false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

In [359]:
def build_tree(rows):
    gain,question = find_best_split(rows)
#     print(gain)
#     print(question)
    if gain == 0:
        return Leaf(rows)
    true_rows, false_rows = partition(rows,question)
    true_branch = build_tree(true_rows)
    false_branch = build_tree(false_rows)
    return Decision_Node(question,true_branch,false_branch)

In [360]:

def print_tree(node, spacing=""):
    #credits to google

    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        print (spacing + "Predict", node.predictions)
        return

    # Print the question at this node
    print (spacing + str(node.question))

    # Call this function recursively on the true branch
    print (spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")

    # Call this function recursively on the false branch
    print (spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")

In [361]:
my_tree = build_tree(training_data)

In [362]:
print_tree(my_tree)

Is diameter >= 3?
--> True:
  Is color == Yellow?
  --> True:
    Predict {'Apple': 1, 'Lemon': 1}
  --> False:
    Predict {'Apple': 1}
--> False:
  Predict {'Grape': 2}


In [363]:
def classify(row,node):
    if isinstance(node,Leaf):
        return node.predictions
    if node.question.match(row):
        return classify(row,node.true_branch)
    else:
        return classify(row,node.false_branch)
        

In [364]:
def print_leaf(counts):
    total = sum (counts.values()) *1.0
    probs ={}
    for label in counts.keys():
        probs[label] = str(int(counts[label]/total * 100)) +"%"
    return probs

In [365]:
print_leaf(classify(training_data[0],my_tree))

{'Apple': '100%'}

In [366]:
def predict(testing_data,tree):
    predictions=""
    for row in testing_data:
        predictions += str("Actual: %s. Predcited: %s \n" % (row[-1],print_leaf(classify(row,tree))))
    return predictions

In [367]:
def fit(training_data):
    return build_tree(training_data)

In [368]:
my_tree = fit(training_data)

In [369]:
predictions = predict(testing_data,my_tree)

In [370]:
print(predictions)

Actual: Apple. Predcited: {'Apple': '100%'} 
Actual: Apple. Predcited: {'Apple': '50%', 'Lemon': '50%'} 
Actual: Grape. Predcited: {'Grape': '100%'} 
Actual: Grape. Predcited: {'Grape': '100%'} 
Actual: Lemon. Predcited: {'Apple': '50%', 'Lemon': '50%'} 



In [371]:
training_data_1=[
    ["Yellow",5,"Banana"],
    ["Red",3,"Apple"],
    ["Green",3,"Apple"],
    ["Green",2,"Kiwi"],
    ["Red",1,"Grape"],
    ["Red",3,"Apple"],
    ["Red",1,"Grape"],
    ["Yellow",5,"Banana"],
    ["Yellow",3,"Apple"],
    ["Blue",1,"Grape"],
    ["Green",2,"Kiwi"]
]

In [372]:
testing_data_1=[
    ["Yellow",5,"Banana"],
    ["Green",3,"Apple"],
    ["Red",1,"Grape"],
    ["Green",2,"Kiwi"],
    ["Brown",2,"Kiwi"]
]

In [373]:
my_tree_1= fit(training_data_1)

In [374]:
predictions_1= predict(testing_data_1,my_tree_1)

In [375]:
print(predictions_1)

Actual: Banana. Predcited: {'Banana': '100%'} 
Actual: Apple. Predcited: {'Apple': '100%'} 
Actual: Grape. Predcited: {'Grape': '100%'} 
Actual: Kiwi. Predcited: {'Kiwi': '100%'} 
Actual: Kiwi. Predcited: {'Kiwi': '100%'} 

