In [2]:
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]

In [3]:
# Column labels
header = ["color", "diameter", "label"]

In [4]:
# Utility function to find unique values of a specific column
def unique(rows, column):
    return set([row[column] for row in rows])

In [5]:
# Utility function to return number of occurences of a certain class
def class_counts(rows):
    counts = {}
    for row in rows:
        label = row[-1]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts

In [6]:
# Utility function to check if a value is a number
def is_numeric(val):
    return isinstance(val, int) or isinstance(val, float)

In [7]:
# The gini function returns the gini impurity of the data 
#(This is the probability that an item can be wrongly 
# classified in a class if the label for the item is 
# randomly chosen from the data set)
def gini(rows):
    impurity = 1
    cls_counts = class_counts(rows)
    for lbl in cls_counts:
        p_lbl = (cls_counts[lbl]/float(len(rows)))
        impurity -= (p_lbl ** 2)
    return impurity

In [8]:
# The uncertainty of the starting node, minus the weighted impurity of two child nodes.
def info_gain(left, right, current_uncertainty):
    p = float(len(left)) / (len(left) + len(right))
    return current_uncertainty - p * gini(left) - (1 - p) * gini(right)

In [46]:
def print_tree(node, spacing=""):
    """World's most elegant tree printing function."""
    
    print("The Selected Decision Tree.")
    
    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        print (spacing + "Predict", node.predictions)
        return

    # Print the question at this node
    print (spacing + str(node.question))

    # Call this function recursively on the true branch
    print (spacing + '--> True:')
    print_tree(node.left_branch, spacing + "  ")

    # Call this function recursively on the false branch
    print (spacing + '--> False:')
    print_tree(node.right_branch, spacing + "  ")

In [22]:
class Question:
    '''
    Question class instance will take an init value of a column and a value
    The match function will be called to compare an example row passed to 
    the function with the value initialized with the Question instance.
    '''
    
    def __init__(self, column, value):
        self.column = column
        self.value = value
        
    def match(self, example):
        val = example[self.column]
        if is_numeric(self.value):
            return self.value >= val
        else:
            return self.value == val
    
    def __repr__(self):
        if is_numeric(self.value):
            condition = ">="
        else:
            condition = "=="
        return "Is %s %s %s?" % (header[self.column], condition, self.value)

In [11]:
class Leaf:
    """
    A Leaf node classifies data.
    This holds a dictionary of class (e.g., "Apple") -> number of times
    it appears in the rows from the training data that reach this leaf.
    """

    def __init__(self, rows):
        self.predictions = class_counts(rows)

In [31]:
class DecisionNode:
    """
    A Decision Node asks a question.
    This holds a reference to the question, and to the two child nodes.
    """

    def __init__(self, question, left_branch, right_branch):
        self.question = question
        self.left_branch = left_branch
        self.right_branch = right_branch

In [44]:
class DTClassifier:
    '''
    DTClassifier instance will take some training data 
    in the fit function then parition it accordingly using the 
    partition function then it can be used to predict a certain 
    class using the predict function. 
    '''

    def fit(self, training_data):
        self.training_data = training_data
        
    # Use a question to separate the data along a certain feature
    def partition(self, rows, question):
        true_rows, false_rows = [], []
        for row in rows:
            if question.match(row):
                true_rows.append(row)
            else:
                false_rows.append(row)
        return true_rows, false_rows

    def find_best_feature(self, rows):
        # placeholder return values 
        best_gain = 0
        best_question = None
        current_uncertainty = gini(rows)

        # for all features
        for i in range((len(rows[0]) - 1)):
            unique_values = unique(rows, i)
            
            # for every unique value
            for value in unique_values:
                question = Question(i, value)
                left, right = self.partition(rows, question)

                if len(left) == 0 or len(right) == 0:
                    continue

                gain = info_gain(left, right, current_uncertainty)

                if gain >= best_gain:
                    best_gain = gain
                    best_question = question
                          
        return best_gain, best_question
    
    def build_tree(self, rows):
        # get the feature that has the most info gain 
        # and use it to create a question
        gain, question = self.find_best_feature(rows)
        
        # No further classification necessary
        # Leaf reached
        if gain == 0:
            return Leaf(rows)
        
        left, right = self.partition(rows, question)
        
        # recursively build a tree on 
        # the newly created branches
        left_branch = self.build_tree(left)
        right_branch = self.build_tree(right)
        
        return DecisionNode(question, left_branch, right_branch)
    
    def classify(self, row, node):
        # If the node instance is already a leaf, return the predictions
        if isinstance(node, Leaf):
            return node.predictions
        
        # Do a recursive classification until a Leaf node is reached
        if node.question.match(row):
            return self.classify(row, node.left_branch)
        else:
            return self.classify(row, node.right_branch)
    
    def predict(self, test_data):
        predictions = []
        # create a tree from the training data
        tree = self.build_tree(self.training_data)
        print_tree(tree)
        for row in test_data:
            predictions.append(self.classify(row, tree))
        return predictions

In [47]:
testing_data = [
    ['Green', 3, 'Apple'],
    ['Green', 3, 'Apple'],
    ['Yellow', 4, 'Apple'],
    ['Red', 2, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]
classifier = DTClassifier()
classifier.fit(training_data)
classifier.predict(testing_data)

The Selected Decision Tree.
Is diameter >= 1?
--> True:
The Selected Decision Tree.
  Predict {'Grape': 2}
--> False:
The Selected Decision Tree.
  Is color == Yellow?
  --> True:
The Selected Decision Tree.
    Predict {'Apple': 1, 'Lemon': 1}
  --> False:
The Selected Decision Tree.
    Predict {'Apple': 1}


[{'Apple': 1},
 {'Apple': 1},
 {'Apple': 1, 'Lemon': 1},
 {'Apple': 1},
 {'Grape': 2},
 {'Apple': 1, 'Lemon': 1}]