In [21]:
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon']
]
# Format: [Color, Diameter, Label]

In [22]:
# Columns labels, used only to print the tree
header = ["color", "diameter", "label"]

In [23]:
def unique_vals(rows, col):
    # Find the unique values for a column in a dataset
    return set([row[col] for row in rows])

In [24]:
# Demo:
unique_vals(training_data, 0)

{'Green', 'Red', 'Yellow'}

In [25]:
def class_counts(rows):
    # Counts the number of each type of example in a dataset
    counts = {}
    for row in rows:
        # in our dataset format, the label is always the last column
        label = row[-1]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts

In [7]:
# Demo:
class_counts(training_data)

{'Apple': 2, 'Grape': 2, 'Lemon': 1}

In [8]:
def is_numeric(value):
    # Test if a value is numeric
    return isinstance(value, int) or isinstance(value, float)

In [9]:
# Demo:
is_numeric("Apple"), is_numeric(1.2), is_numeric(30)

(False, True, True)

In [10]:
class Question:
    """ A Question is used to partition a dataset
    
    This class just records a 'column number' (e.g. 0 for Color) and a 
    'column value' (e.g. Green). The 'match' method is used to compare
    the feature value is an example to the feature value stored in the question
    """
    
    def __init__(self, column, value):
        self.column = column
        self.value = value
        
    def match(self, example):
        # Compare the feature value in an example to the feature value in
        # this question
        val = example[self.column]
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value
        
    def __repr__(self):
        # This is just a helper method to print
        # the question in a readable format
        condition = "=="
        if is_numeric(self.value):
            condition = ">="
        return "Is {} {} {}?".format(header[self.column], condition,
                                    str(self.value)) 
    

In [11]:
# Demo:
Question(1, 3), Question(0, 'Green')

(Is diameter >= 3?, Is color == Green?)

In [12]:
# Let's pick an example from the training set
example = training_data[0]
print(example)
Question(0, 'Red'), Question(1, 5)

['Green', 3, 'Apple']


(Is color == Red?, Is diameter >= 5?)

In [13]:
def partition(rows, question):
    """ Partition a dataset.
    
    For each row in the dataset, check if it matches the question. If
    so, add it to 'true rows', otherwise, add it to 'false rows'.
    """
    true_rows, false_rows = [], []
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows, false_rows

In [14]:
# Demo:
# Let's partition the training data based on whether rows are Red
true_rows, false_rows = partition(training_data, Question(0, 'Red'))
print(true_rows) 
print(false_rows)

[['Red', 1, 'Grape'], ['Red', 1, 'Grape']]
[['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']]


In [15]:
def gini():
    return 0

In [16]:
def find_best_split(rows):
    """ Find the best question to ask by iterating over every feature/value 
    and calculating the information gain."""
    best_gain = 0 # keep track of the best information gain
    best_question = None # keep track of the feature/value that produced it
    current_uncertainty = gini(rows)
    n_features = len(rows[0]) - 1 #number of features
    
    for col in range(n_features):
        values = set([row[col] for row in rows]) # unique values in the column
        for val in values: 
            question = Question(col, val)
            
            # try splitting the dataset
            true_rows, false_rows = partition(rows, question)
            
            # Skip this split if it doesn't divide the dataset
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue 
        
        # Calculate the information gain from this split
        gain = info_gain(true_rows, false_rows, current_uncertainty)
        
        # You actually can use '>' insted of '>=' here,
        # but I wanted the tree to look a certain way for our dataset
        if gain >= best_gain:
            best_gain, best_question = gain, question
    
    return best_gain, question