In [1]:
from collections import defaultdict

In [2]:
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]

header = ['color', 'diameter', 'label']



### Utility functions

In [9]:
def unique_vals(rows, col):
    """Find the unique values for a column in a dataset"""
    return set(row[col] for row in rows)

def class_counts(rows):
    """Count the number of each type of example in a dataset"""
    counts = defaultdict(lambda: 0)
    for row in rows:
        label = row[-1]
        counts[label] += 1
    return counts

def is_numeric(value):
    """Test if a value is numeric"""
    return isinstance(value, int) or isinstance(value, float)

def gini_impurity(rows):
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts:
        lbl_proba = counts[lbl] / float(len(rows))
        impurity -= lbl_proba ** 2
    return impurity

def info_gain(left, right, current_impurity):
    p = float(len(left)) / (len(left) + len(right))
    child_impurity = p * gini_impurity(left) + (1-p) * gini_impurity(right)
    return current_impurity - child_impurity

def counts_to_freq(counts):
    """Transform counts into format of frequency"""
    total = sum(counts.values()) * 1.0
    probs = {}
    for lbl in counts.keys():
        probs[lbl] = counts[lbl] / total
    return probs

def print_tree(node, spacing=""):
    # Case 1
    if isinstance(node, Leaf):
        print(spacing, 'Predict', node.predictions)
        return
    
    # Case 2
    print(spacing, node.question)
    spacing = spacing + "    "
    # Call the function recursively on true and false branch
    print(spacing, '-->True:')
    print_tree(node.true_branch, spacing)
    print(spacing, '-->False:')
    print_tree(node.false_branch, spacing)


### Main functions for buiding a tree and make prediction

In [10]:
def partition(rows, question):
    """Partition a dataset based on a question"""
    true_rows, false_rows = [], []
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows, false_rows

def find_best_split(rows):
    """Find the best question and information gain for partitioning"""
    best_gain = 0
    best_question = None
    current_impurity = gini_impurity(rows)
    n_features = len(rows[0]) - 1
    # 
    for col in range(n_features):
        values = set([row[col] for row in rows])
        for val in values:
            question = Question(col, val)
            true_rows, false_rows = partition(rows, question)
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue
            gain = info_gain(true_rows, false_rows, current_impurity)
            if gain >= best_gain:
                best_gain, best_question = gain, question
    return best_question, best_gain

def build_tree(rows):
    """
    Build tree recursively. There are two possible cases:
    1. If info gain is zero (Base case):
        Terminate partition process and return Leaf Node
    2. Otherwise:
        Continue building the tree recursively, return Decision Node
    """
    question, gain = find_best_split(rows)
    # 1. If info gain is zero
    if gain == 0:
        return Leaf(rows)
    
    # 2. Otherwise:
    true_rows, false_rows = partition(rows, question)
    # Recursively build the true and false branch
    true_branch = build_tree(true_rows)
    false_branch = build_tree(false_rows)
    # Return a Decision node
    return Decision_Node(question, true_branch, false_branch)

def predict_proba(node, row):
    """Classify a row recursively"""
    if isinstance(node, Leaf):
        return node.predictions
    else:
        if node.question.match(row):
            return predict_proba(node.true_branch, row)
        else:
            return predict_proba(node.false_branch, row)

### Main objects in a tree

In [11]:
class Question:
    """A Question used to partition a dataset"""
    def __init__(self, column, value):
        self.column = column
        self.value = value
        
    def match(self, example):
        """Compare feature value in an example 
        to the feature value in this question"""
        val = example[self.column]
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value
    
    def __repr__(self):
        """Print this question in a readable format"""
        condition = "=="
        if is_numeric(self.value):
            condition = ">="
        return "Is {} {} {} ?".format(
            header[self.column], condition, str(self.value))

    
class Leaf:
    """A leaf node classifies data"""
    
    def __init__(self, rows):
        self.counts = dict(class_counts(rows))
        self.predictions = counts_to_freq(self.counts)

        
class Decision_Node:
    """
    A decision node ask question and split data.
    This holds a reference to the question, and to| the two child nodes.
    """
    
    def __init__(self, question, true_branch, false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch
        
        
class Decision_Tree_Classifier():
    def __init__(self):
        self.tree = None
    
    def train(self, rows):
        self.tree = build_tree(rows)
        
    def print_tree(self):
        return print_tree(self.tree, "")

    def predict_proba(self, row):
        """Classify a row recursively"""
        return predict_proba(self.tree, row)

### Evaluate the model

In [12]:
clf = Decision_Tree_Classifier()
clf.train(training_data)
clf.print_tree()

 Is diameter >= 3 ?
     -->True:
     Is color == Yellow ?
         -->True:
         Predict {'Apple': 0.5, 'Lemon': 0.5}
         -->False:
         Predict {'Apple': 1.0}
     -->False:
     Predict {'Grape': 1.0}


In [13]:
clf.predict_proba(training_data[1])

{'Apple': 0.5, 'Lemon': 0.5}