# Decision Tree

In [3]:
# Sample Data
# Format: each row is an example
# The last colum is the label
# The first two columns are features
# More features and examples can be added
# 2nd and 5th examples have same features but different labels

training_data = [
    ['Green', 3, 'Mango'],
    ['Yellow', 3, 'Mango'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
    ['Green', 1, 'Banana'],
    ['Yellow', 1, 'Banana']
]

# Colour labels
# Used in printing the trees
header = ["color", "diameter", "label"]

def unique_vals(rows, col):
    """Find the unique values for a column in the dataset."""
    return set([row[col] for row in rows])


#####
# Demo:
# unique_vals(training_data, 0)
# unique_vals(training_data, 1)
#####


def class_counts(rows):
    """Counts the number of each type of example in a dataset."""
    counts = {} # a dictionary of label-> counts.
    for row in rows:
        # in our dataset format, the label is always the last column
        label = row[-1]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts

#####
# Demo
# class_counts(training_data)
#####

def is_numeric(value):
    """Test if a value is numeric."""
    return isinstance(value, int) or isinstance(value, float)

######
# Demo:
# is_numeric(7)
# is_numeric("Red)
######

class Question:
    """A question is used to partition a dataset.
    
    This class just records a 'column number' (e.g., 0 for color) and a 
    'column value' (e.g., Green). The 'match' method is used to compare
    the feature value in an example to the4 feature value stored in the
    question. see the demo below.
    """
    
    def __init__(self, column, value):
        self.column = column
        self.value = value
        
    def match(self, example):
        # compare the feature value in an example to the
        # feature value in this question.
        val = example[self.column]
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value
        
    def __repr__(self):
        # This is just a helper method to print
        # the question in a redable format.
        condition = "="
        if is_numeric(self.value):
            condition = ">="
        return "Is %s %s %s?" % (
            header[self.column], condition, str(self.value))
    
#######

def partition(rows, question):
    """Partition dataset.
    
    for each row in the dataset, check if it matches the question. If
    so, add it to 'true rows', otherwise, add it to 'false rows'
    """
    true_rows, false_rows = [], []
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows, false_rows

#######
# Demo
# Let's partition the training data based on whether rows are Red,
# true_rows, false_rows = partition(training_data, Question(0,'Red))
# This will contain all the 'Red' rows.
# true_rows
# This will contain anything else
# false_rows
#######

def gini(rows):
    """Calculate the gini impurity for a list of rows.
    
    There are a few different ways to do this, I thought this one was
    the most coincise. See:
    https://en.wikipedia.org/wiki/Decision_tree_learningGini_impurity
    """
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl]/float(len(rows))
        impurity -= prob_of_lbl**2
    return impurity

########
# Demo
# Let's look at some example to understand how Gini impurity works.
#
# First,we'll look at a dataset with no mixing.
# no_mixing = [['Mango'],
#              ['Mango']]
# This will return 0


def info_gain(left, right, current_uncertainty):
    """Information Gain.
    
    The uncertainty of the starting node, minus the weighted impurity of
    two child nodes.
    """
    p = float(len(left))/(len(left) + len(right))
    return current_uncertainty - p * gini(left) - (1 - p) * gini(right)

#######
# Demo
# calculate the uncertainty of our training data.
# current_uncertainty = gini(training_data)
#
# How much information do we gain by partitioning on 'Green'?
# true_rows, false_rows = partition(training_data, Question(0, 'Green'))
# info_gain(true_rows, false_rows, current_uncertainty)
#
# what about if we partition on 'Red' instead?
# true_rows, false_rows = partition(training_data,(0, 'Red'))
# info_gain(true_rows, false_rows, current_uncertainty)

def find_best_split(rows):
    """Find the best question to ask by iterating over every feature / value
    and calculating the information gain. """
    best_gain = 0 # keep track of the best info gain
    best_question = # keep trck of the feature / value that produced it
    current_uncertainty = gini(rows)
    n_features = len(row[0]) - 1 #number of columns
    
    for col in range(n_features): # for each feature
        values = set([row[col] for row in rows]) # unique values in the column
        for val in values: # for each value
            question = Question(col, val)
            
            
            # try splitting the dataset
            true_rows,false_rows = partition(rows, question)