In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

%matplotlib inline

In [2]:
# Load dataset of Iris
iris = datasets.load_iris()

X = iris.data
y = np.expand_dims(iris.target, axis=1)
samples = np.concatenate((X, y), axis=1)
np.random.shuffle(samples)
#samples = np.c_[X, y]

train_data = samples[:-15]
print(train_data.shape)
test_data = samples[-15:]

header = iris.feature_names

(135, 5)


In [3]:
class Question(object):
    
    def __init__(self, column, value):
        self.column = column
        self.value = value
    
    def match(self, sample):
        val = sample[self.column]
        return val >= self.value
    
    def __repr__(self):
        condition = ">="
        return "Is %s %s %s?"%(header[self.column], condition, str(self.value))

In [4]:
def partition(samples, question):
    true_part, false_part = [], []
    for one in samples:
        if question.match(one):
            true_part.append(one)
        else:
            false_part.append(one)
    return np.array(true_part), np.array(false_part)

In [5]:
def class_counts(data):
    counts = {}
    for one in data[:, -1]:
        try:
            counts[one] += 1
        except KeyError:
            counts[one] = 1
    return counts

def gini(data):
    counts = class_counts(data)
    size = len(data)
    impurity = 1
    for key in counts:
        prob = counts[key] / float(size)
        impurity -= prob ** 2
    return impurity

def entropy(data):
    entropy = 0
    size = len(data)
    count = class_counts(data)
    for label in count:
        p = count[label] / float(size)
        entropy -= p * np.log2(p)
    return entropy

def info_gain(left, right, current):
    p = float(len(left)) / (len(left) + len(right))
    return current - p * gini(left) - (1-p) * gini(right)
    

In [6]:
def find_best_split(data):
    best_gain = 0
    best_question = None
    # get gini index before split
    current = gini(data)
    for col in range(len(data[0]) - 1):
        values = set(data[:, col])
        for val in values:
            question = Question(col, val)
            true_part, false_part = partition(data, question)
            # No split, just ahead to next question iteration
            if len(true_part) == 0 or len(false_part) == 0:
                continue
            gain = info_gain(true_part, false_part, current)
            if gain >= best_gain:
                best_gain, best_question = gain, question
    # print(best_gain, best_question)
    return best_gain, best_question

In [7]:
gain, question = find_best_split(train_data[:])

In [8]:
class DecisionNode(object):
    def __init__(self, question, true_branch, false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

class Leaf(object):
    def __init__(self, samples):
        self.predictions = class_counts(samples)
    
    def predict_prob(self):
        """A nicer way to print the predictions at a leaf."""
        total = sum(self.predictions.values()) * 1.0
        probs = {}
        for lbl in self.predictions.keys():
            probs[lbl] = str(int(self.predictions[lbl] / total * 100)) + "%"
        return probs

In [9]:
def build_tree(samples):
    gain, question = find_best_split(samples)
    if gain == 0:
        return Leaf(samples)
    true_samples, false_samples = partition(samples, question)
    true_branch = build_tree(true_samples)
    false_branch = build_tree(false_samples)
    return DecisionNode(question, true_branch, false_branch)

def print_tree(node, spacing=""):
    if isinstance(node, Leaf):
        print(spacing + "Predict", node.predictions)
        return
    print(spacing + str(node.question))
    
    print(spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")
    
    print(spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")

In [10]:
my_tree = build_tree(train_data)

In [11]:
print_tree(my_tree)

Is petal width (cm) >= 1.0?
--> True:
  Is petal width (cm) >= 1.8?
  --> True:
    Is petal length (cm) >= 4.9?
    --> True:
      Predict {2.0: 39}
    --> False:
      Is sepal width (cm) >= 3.2?
      --> True:
        Predict {1.0: 1}
      --> False:
        Predict {2.0: 2}
  --> False:
    Is petal length (cm) >= 5.1?
    --> True:
      Is sepal width (cm) >= 2.8?
      --> True:
        Predict {2.0: 2}
      --> False:
        Predict {1.0: 1}
    --> False:
      Is sepal length (cm) >= 5.0?
      --> True:
        Predict {1.0: 41}
      --> False:
        Is petal width (cm) >= 1.7?
        --> True:
          Predict {2.0: 1}
        --> False:
          Predict {1.0: 1}
--> False:
  Predict {0.0: 47}


In [12]:
def classify(sample, node):
    if isinstance(node, Leaf):
        return node.predict_prob()
    if node.question.match(sample):
        return classify(sample, node.true_branch)
    else:
        return classify(sample, node.false_branch)

In [13]:
for sample in test_data:
    print("Ground Truth: {}. Prediction: {}".format(sample[-1], classify(sample, my_tree)))

Ground Truth: 0.0. Prediction: {0.0: '100%'}
Ground Truth: 1.0. Prediction: {1.0: '100%'}
Ground Truth: 2.0. Prediction: {1.0: '100%'}
Ground Truth: 0.0. Prediction: {0.0: '100%'}
Ground Truth: 1.0. Prediction: {1.0: '100%'}
Ground Truth: 2.0. Prediction: {2.0: '100%'}
Ground Truth: 2.0. Prediction: {1.0: '100%'}
Ground Truth: 1.0. Prediction: {1.0: '100%'}
Ground Truth: 1.0. Prediction: {1.0: '100%'}
Ground Truth: 2.0. Prediction: {2.0: '100%'}
Ground Truth: 2.0. Prediction: {2.0: '100%'}
Ground Truth: 2.0. Prediction: {2.0: '100%'}
Ground Truth: 0.0. Prediction: {0.0: '100%'}
Ground Truth: 1.0. Prediction: {1.0: '100%'}
Ground Truth: 1.0. Prediction: {1.0: '100%'}
