In [1]:
training_data = [
    ['Green',3,'Apple'],
    ['Yellow',3,'Apple'],
    ['Red',1,'Grape'],
    ['Red',1,'Grape'],
    ['Yellow',3,'Lemon']
]

In [2]:
# column labels
# these are used only to print the tree
header = ['color','diameter','label']

In [3]:
def unique_values(rows,col):
    """Find the unique values for a column in a dataset"""
    return set([row[col] for row in rows])

In [4]:
unique_values(training_data,0)

{'Green', 'Red', 'Yellow'}

In [11]:
def class_counts(rows):
    """Counts the number of each type of example in a dataset"""
    counts = {} # a dictionary of label --> count
    for row in rows:
        # in the dataset format, the label is alwasys the last column
        label = row[-1]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts

In [6]:
class_counts(training_data)

{'Apple': 2, 'Grape': 2, 'Lemon': 1}

In [18]:
def is_numeric(value):
    """Test if a value is numeric"""
    return isinstance(value,int) or isinstance(value,float)

In [19]:
is_numeric(7)

True

In [9]:
class Question:
    """A Question is used to partition a dataset.

    This class just records a 'column number' (e.g., 0 for Color) and a
    'column value' (e.g., Green). The 'match' method is used to compare
    the feature value in an example to the feature value stored in the
    question. See the demo below.
    """

    def __init__(self, column, value):
        self.column = column
        self.value = value

    def match(self, example):
        # Compare the feature value in an example to the
        # feature value in this question.
        val = example[self.column]
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value

    def __repr__(self):
        # This is just a helper method to print
        # the question in a readable format.
        condition = "=="
        if is_numeric(self.value):
            condition = ">="
        return "Is %s %s %s?" % (
            header[self.column], condition, str(self.value))

In [10]:
Question(1,3)

Is diameter >= 3?

In [11]:
q = Question(0,'Green')
q

Is color == Green?

In [12]:
# Pick some data from the training set...
example = training_data[0]
# ... and see if it matches the question
q.match(example) # this will be true, since the first example is Green

True

In [13]:
example

['Green', 3, 'Apple']

In [14]:
def partition(rows, question):
    '''Partitions a dataset
    For each row in the dataset, check if it matches the question. If so,
    add it to the 'true rows', otherwise, add it to 'false rows'
    '''
    true_rows, false_rows = [], []
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows, false_rows

In [15]:
# Partition the training data based whether rows are red
true_rows, false_rows = partition(training_data, Question(0,'Red'))
# This will contain all the red rows
true_rows

[['Red', 1, 'Grape'], ['Red', 1, 'Grape']]

In [16]:
false_rows

[['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']]

In [22]:
def gini(rows):
    '''Calculate the Gini Impurity for a list of rows.
    https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity
    '''
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(rows))
        impurity -= prob_of_lbl**2
    return impurity

In [23]:
no_mixing = [['Apple'],
            ['Apple']]
some_mixing = [['Apple'],
              ['Orange']]
lots_of_mixing = [['Apple'],
                 ['Orange'],
                 ['Grape'],
                 ['Grapefruit'],
                 ['Blueberry']]

In [24]:
print(gini(no_mixing))
print(gini(some_mixing))
print(gini(lots_of_mixing))

0.0
0.5
0.7999999999999998


In [29]:
def info_gain(left, right, current_uncertainty):
    """Information gain.
    
    The uncertainty of the starting node, minus the weighted impurity of two child nodes"""
    
    p = float(len(left)) / (len(left) + len(right))
    return current_uncertainty - p * gini(left) - (1-p) * gini(right)

In [30]:
# Caclulate the uncertainty of the training data
current_uncertainty = gini(training_data)
current_uncertainty

0.6399999999999999

In [31]:
# How much information do we gain by partitioning on 'Green'
true_rows, false_rows = partition(training_data, Question(0,'Green'))
info_gain(true_rows,false_rows,current_uncertainty)

0.1399999999999999

In [32]:
# Partitioning Red
true_rows, false_rows = partition(training_data, Question(0,'Red'))
info_gain(true_rows, false_rows, current_uncertainty)

0.37333333333333324

In [33]:
# It looks like we learned more using 'Red' (.37), than 'Green' (.14)
# Why?, Look at the different splits that result, and see which one looks more unmixed

# Here, the true rows contain only grapes
true_rows

[['Red', 1, 'Grape'], ['Red', 1, 'Grape']]

In [34]:
# On the other hand, partitioning by Green doesnt help so much
true_rows, false_rows = partition(training_data,Question(0,'Green'))

# We've isolated one apple in the true rows
true_rows

[['Green', 3, 'Apple']]

In [35]:
# But, the false-rows are badly mixed up
false_rows

[['Yellow', 3, 'Apple'],
 ['Red', 1, 'Grape'],
 ['Red', 1, 'Grape'],
 ['Yellow', 3, 'Lemon']]

In [39]:
def find_best_split(rows):
    """Find the best question to ask by iterating over every feature / value
    and calculating the information gain."""
    best_gain = 0  # keep track of the best information gain
    best_question = None  # keep train of the feature / value that produced it
    current_uncertainty = gini(rows)
    n_features = len(rows[0]) - 1  # number of columns

    for col in range(n_features):  # for each feature

        values = set([row[col] for row in rows])  # unique values in the column

        for val in values:  # for each value

            question = Question(col, val)

            # try splitting the dataset
            true_rows, false_rows = partition(rows, question)

            # Skip this split if it doesn't divide the
            # dataset.
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue

            # Calculate the information gain from this split
            gain = info_gain(true_rows, false_rows, current_uncertainty)

            # You actually can use '>' instead of '>=' here
            # but I wanted the tree to look a certain way for our
            # toy dataset.
            if gain >= best_gain:
                best_gain, best_question = gain, question
        
        return best_gain, best_question


In [None]:
# Find the best question to ask first for the dataset
best_gain, best_question = find_best_split(training_data)
best_question