Code to accompany Machine Learning Recipes #8. We'll write a Decision Tree Classifier, in pure Python. Below each of the methods, I've written a little demo to help explain what it does.

In [194]:
# For Python 2 / 3 compatability
from __future__ import print_function
import random

In [195]:
# Toy dataset.
# Format: each row is an example.
# The last column is the label.
# The first two columns are features.
# Feel free to play with it by adding more features & examples.
# Interesting note: I've written this so the 2nd and 5th examples
# have the same features, but different labels - so we can see how the
# tree handles this case.
training_data = [
    [54.4, 14.4, 'Kylrum'],    
    [45.4, 12.4, 'Kylrum'],   
    [89.4, 19.5, 'Klassrum'],
    [57.4, 18.1, 'Lärarrum'],    
    [22.4, 8.6, 'Kylrum'],   
    [24.4, 11.24, 'Kylrum'],   
    [84.4, 24.4,'Klassrum'],    
    [95.4, 22.4, 'Klassrum'],   
    [81.4, 20.1, 'Lärarrum'],  
    [70.5, 19.7, 'Lärarrum'],
    [70.5, 19.8,'Klassrum'], 
]

In [196]:
def generate_data2(times):
    while times != 0:
        randomRow = training_data[random.randint(0, len(training_data) - 1)]
        new_row = [randomRow[0]+ round(random.uniform(-5, 5)), randomRow[1] + round(random.uniform(-3, 3)), randomRow[2]]
        training_data.append(new_row)
        
        times += -1

In [197]:
generate_data2(2000)

In [198]:
# Column labels.
# These are used only to print the tree.
header = ["size", "temp", "label"]

In [199]:
def unique_vals(rows, col):
    """Find the unique values for a column in a dataset."""
    return set([row[col] for row in rows])

In [200]:
#######
# Demo:
unique_vals(training_data, 0)
# unique_vals(training_data, 1)
#######

{3.3999999999999986,
 4.399999999999999,
 5.399999999999999,
 6.399999999999999,
 7.399999999999999,
 8.399999999999999,
 9.399999999999999,
 10.399999999999999,
 11.399999999999999,
 12.399999999999999,
 13.399999999999999,
 14.399999999999999,
 15.399999999999999,
 16.4,
 17.4,
 18.4,
 19.4,
 20.4,
 21.4,
 22.4,
 23.4,
 24.4,
 25.4,
 26.4,
 27.4,
 28.4,
 29.4,
 30.4,
 31.4,
 32.4,
 33.4,
 34.4,
 35.4,
 36.4,
 37.4,
 38.4,
 39.4,
 40.4,
 41.4,
 42.4,
 43.4,
 44.4,
 45.4,
 46.4,
 47.4,
 48.4,
 49.4,
 50.4,
 51.4,
 52.4,
 53.4,
 54.4,
 55.4,
 55.5,
 56.4,
 57.4,
 57.5,
 58.4,
 58.5,
 59.4,
 59.5,
 60.4,
 60.5,
 61.4,
 61.5,
 62.4,
 62.5,
 63.4,
 63.5,
 64.4,
 64.5,
 65.5,
 66.5,
 67.5,
 68.4,
 68.5,
 69.5,
 70.4,
 70.5,
 71.4,
 71.5,
 72.4,
 72.5,
 73.4,
 73.5,
 74.4,
 74.5,
 75.4,
 75.5,
 76.4,
 76.5,
 77.4,
 77.5,
 78.4,
 78.5,
 79.4,
 79.5,
 80.4,
 80.5,
 81.4,
 81.5,
 82.4,
 82.5,
 83.4,
 84.4,
 84.5,
 85.4,
 86.4,
 87.4,
 88.4,
 89.4,
 90.4,
 91.4,
 92.4,
 93.4,
 94.4,
 95.4,
 96.4

In [201]:
def class_counts(rows):
    """Counts the number of each type of example in a dataset."""
    counts = {}  # a dictionary of label -> count.
    for row in rows:
        # in our dataset format, the label is always the last column
        label = row[-1]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts

In [202]:
#######
# Demo:
class_counts(training_data)
#######

{'Kylrum': 913, 'Klassrum': 616, 'Lärarrum': 482}

In [203]:
def is_numeric(value):
    """Test if a value is numeric."""
    return isinstance(value, int) or isinstance(value, float)

In [204]:
#######
# Demo:
is_numeric(7)
# is_numeric("Red")
#######

True

In [205]:
class Question:
    """A Question is used to partition a dataset.

    This class just records a 'column number' (e.g., 0 for Color) and a
    'column value' (e.g., Green). The 'match' method is used to compare
    the feature value in an example to the feature value stored in 
    question. See the demo below.
    """

    def __init__(self, column, value):
        self.column = column
        self.value = value

    def match(self, example):
        # Compare the feature value in an example to the
        # feature value in this question.
        val = example[self.column]
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value

    def __repr__(self):
        # This is just a helper method to print
        # the question in a readable format.
        condition = "=="
        if is_numeric(self.value):
            condition = ">="
        return "Is %s %s %s?" % (
            header[self.column], condition, str(self.value))

In [206]:
#######
# Demo:
# Let's write a question for a numeric attribute
Question(0, 60)

Is size >= 60?

In [207]:
# How about one for a categorical attribute
q = Question(2, 'kylrum')
q

Is label == kylrum?

In [208]:
# Let's pick an example from the training set...
example = training_data[0]
# ... and see if it matches the question
q.match(example) # this will be true, since the first example is Green.
#######

False

In [209]:
def partition(rows, question):
    """Partitions a dataset.

    For each row in the dataset, check if it matches the question. If
    so, add it to 'true rows', otherwise, add it to 'false rows'.
    """
    true_rows, false_rows = [], []
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows, false_rows

In [210]:
#######
# Demo:
# Let's partition the training data based on whether rows are Red.
true_rows, false_rows = partition(training_data, Question(0, 60))
# This will contain all the rows bigger than 60
true_rows

[[89.4, 19.5, 'Klassrum'],
 [84.4, 24.4, 'Klassrum'],
 [95.4, 22.4, 'Klassrum'],
 [81.4, 20.1, 'Lärarrum'],
 [70.5, 19.7, 'Lärarrum'],
 [70.5, 19.8, 'Klassrum'],
 [93.4, 22.4, 'Klassrum'],
 [88.4, 22.5, 'Klassrum'],
 [90.4, 20.5, 'Klassrum'],
 [84.4, 18.1, 'Lärarrum'],
 [82.4, 18.1, 'Lärarrum'],
 [67.5, 18.7, 'Lärarrum'],
 [78.4, 19.1, 'Lärarrum'],
 [74.5, 19.8, 'Klassrum'],
 [91.4, 21.5, 'Klassrum'],
 [93.4, 21.5, 'Klassrum'],
 [88.4, 20.1, 'Lärarrum'],
 [84.4, 21.5, 'Klassrum'],
 [94.4, 22.5, 'Klassrum'],
 [89.4, 20.5, 'Klassrum'],
 [63.5, 17.7, 'Lärarrum'],
 [76.5, 18.8, 'Klassrum'],
 [70.5, 21.7, 'Lärarrum'],
 [84.4, 20.1, 'Lärarrum'],
 [88.4, 24.4, 'Klassrum'],
 [75.5, 22.8, 'Klassrum'],
 [72.5, 21.8, 'Klassrum'],
 [85.4, 18.1, 'Lärarrum'],
 [83.4, 18.1, 'Lärarrum'],
 [84.4, 22.1, 'Lärarrum'],
 [81.4, 16.1, 'Lärarrum'],
 [67.5, 15.7, 'Lärarrum'],
 [70.5, 17.7, 'Lärarrum'],
 [86.4, 22.4, 'Klassrum'],
 [92.4, 24.4, 'Klassrum'],
 [78.4, 22.1, 'Lärarrum'],
 [87.4, 20.1, 'Lärarrum'],
 

In [211]:
# This will contain everything else.
false_rows
#######

[[54.4, 14.4, 'Kylrum'],
 [45.4, 12.4, 'Kylrum'],
 [57.4, 18.1, 'Lärarrum'],
 [22.4, 8.6, 'Kylrum'],
 [24.4, 11.24, 'Kylrum'],
 [23.4, 5.6, 'Kylrum'],
 [18.4, 4.6, 'Kylrum'],
 [26.4, 11.6, 'Kylrum'],
 [41.4, 11.4, 'Kylrum'],
 [20.4, 3.5999999999999996, 'Kylrum'],
 [16.4, 5.6, 'Kylrum'],
 [43.4, 13.4, 'Kylrum'],
 [20.4, 2.5999999999999996, 'Kylrum'],
 [43.4, 12.4, 'Kylrum'],
 [53.4, 13.4, 'Kylrum'],
 [18.4, 0.5999999999999996, 'Kylrum'],
 [16.4, 6.6, 'Kylrum'],
 [17.4, -2.4000000000000004, 'Kylrum'],
 [44.4, 13.4, 'Kylrum'],
 [38.4, 13.4, 'Kylrum'],
 [13.399999999999999, 6.6, 'Kylrum'],
 [19.4, 11.6, 'Kylrum'],
 [19.4, -1.4000000000000004, 'Kylrum'],
 [41.4, 9.4, 'Kylrum'],
 [22.4, 5.6, 'Kylrum'],
 [38.4, 11.4, 'Kylrum'],
 [19.4, 1.5999999999999996, 'Kylrum'],
 [51.4, 13.4, 'Kylrum'],
 [16.4, 13.6, 'Kylrum'],
 [27.4, 9.6, 'Kylrum'],
 [23.4, 0.5999999999999996, 'Kylrum'],
 [39.4, 13.4, 'Kylrum'],
 [47.4, 10.4, 'Kylrum'],
 [17.4, 6.6, 'Kylrum'],
 [15.399999999999999, 1.5999999999999996, '

In [212]:
def gini(rows):
    """Calculate the Gini Impurity for a list of rows.

    There are a few different ways to do this, I thought this one was
    the most concise. See:
    https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity
    """
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(rows))
        impurity -= prob_of_lbl**2
    return impurity

In [213]:
#######
# Demo:
# Let's look at some example to understand how Gini Impurity works.
#
# First, we'll look at a dataset with no mixing.
no_mixing = [['Kylrum'],
              ['Kylrum']]
# this will return 0
gini(no_mixing)

0.0

In [214]:
# Now, we'll look at dataset with a 50:50 apples:oranges ratio
some_mixing = [['Kylrum'],
               ['Klassrum']]
# this will return 0.5 - meaning, there's a 50% chance of misclassifying
# a random example we draw from the dataset.
gini(some_mixing)

0.5

In [215]:
# Now, we'll look at a dataset with many different labels
lots_of_mixing = [['Kylrum'],
                  ['Klassrum'],
                  ['Lärarrum']]
                  
# This will return 0.8
gini(lots_of_mixing)


0.6666666666666665

In [216]:
def info_gain(left, right, current_uncertainty):
    """Information Gain.

    The uncertainty of the starting node, minus the weighted impurity of
    two child nodes.
    """
    p = float(len(left)) / (len(left) + len(right))
    return current_uncertainty - p * gini(left) - (1 - p) * gini(right)

In [217]:
#######
# Demo:
# Calculate the uncertainy of our training data.
current_uncertainty = gini(training_data)
current_uncertainty

0.6426049072221133

In [218]:
# How much information do we gain by partioning on 'Green'?
true_rows, false_rows = partition(training_data, Question(0, 57))
info_gain(true_rows, false_rows, current_uncertainty)

0.34226232548924396

In [219]:
# What about if we partioned on 'Red' instead?
true_rows, false_rows = partition(training_data, Question(0, 70))
info_gain(true_rows, false_rows, current_uncertainty)

0.2741462542175305

In [220]:
# It looks like we learned more using 'Red' (0.37), than 'Green' (0.14).
# Why? Look at the different splits that result, and see which one
# looks more 'unmixed' to you.
true_rows, false_rows = partition(training_data, Question(0,57))

# Here, the true_rows contain only 'Grapes'.
true_rows

[[89.4, 19.5, 'Klassrum'],
 [57.4, 18.1, 'Lärarrum'],
 [84.4, 24.4, 'Klassrum'],
 [95.4, 22.4, 'Klassrum'],
 [81.4, 20.1, 'Lärarrum'],
 [70.5, 19.7, 'Lärarrum'],
 [70.5, 19.8, 'Klassrum'],
 [93.4, 22.4, 'Klassrum'],
 [88.4, 22.5, 'Klassrum'],
 [90.4, 20.5, 'Klassrum'],
 [84.4, 18.1, 'Lärarrum'],
 [82.4, 18.1, 'Lärarrum'],
 [67.5, 18.7, 'Lärarrum'],
 [78.4, 19.1, 'Lärarrum'],
 [74.5, 19.8, 'Klassrum'],
 [91.4, 21.5, 'Klassrum'],
 [93.4, 21.5, 'Klassrum'],
 [88.4, 20.1, 'Lärarrum'],
 [84.4, 21.5, 'Klassrum'],
 [94.4, 22.5, 'Klassrum'],
 [89.4, 20.5, 'Klassrum'],
 [63.5, 17.7, 'Lärarrum'],
 [76.5, 18.8, 'Klassrum'],
 [70.5, 21.7, 'Lärarrum'],
 [84.4, 20.1, 'Lärarrum'],
 [88.4, 24.4, 'Klassrum'],
 [75.5, 22.8, 'Klassrum'],
 [72.5, 21.8, 'Klassrum'],
 [85.4, 18.1, 'Lärarrum'],
 [83.4, 18.1, 'Lärarrum'],
 [84.4, 22.1, 'Lärarrum'],
 [81.4, 16.1, 'Lärarrum'],
 [67.5, 15.7, 'Lärarrum'],
 [70.5, 17.7, 'Lärarrum'],
 [86.4, 22.4, 'Klassrum'],
 [92.4, 24.4, 'Klassrum'],
 [78.4, 22.1, 'Lärarrum'],
 

In [221]:
# And the false rows contain two types of fruit. Not too bad.
false_rows

[[54.4, 14.4, 'Kylrum'],
 [45.4, 12.4, 'Kylrum'],
 [22.4, 8.6, 'Kylrum'],
 [24.4, 11.24, 'Kylrum'],
 [23.4, 5.6, 'Kylrum'],
 [18.4, 4.6, 'Kylrum'],
 [26.4, 11.6, 'Kylrum'],
 [41.4, 11.4, 'Kylrum'],
 [20.4, 3.5999999999999996, 'Kylrum'],
 [16.4, 5.6, 'Kylrum'],
 [43.4, 13.4, 'Kylrum'],
 [20.4, 2.5999999999999996, 'Kylrum'],
 [43.4, 12.4, 'Kylrum'],
 [53.4, 13.4, 'Kylrum'],
 [18.4, 0.5999999999999996, 'Kylrum'],
 [16.4, 6.6, 'Kylrum'],
 [17.4, -2.4000000000000004, 'Kylrum'],
 [44.4, 13.4, 'Kylrum'],
 [38.4, 13.4, 'Kylrum'],
 [13.399999999999999, 6.6, 'Kylrum'],
 [19.4, 11.6, 'Kylrum'],
 [19.4, -1.4000000000000004, 'Kylrum'],
 [41.4, 9.4, 'Kylrum'],
 [22.4, 5.6, 'Kylrum'],
 [38.4, 11.4, 'Kylrum'],
 [19.4, 1.5999999999999996, 'Kylrum'],
 [51.4, 13.4, 'Kylrum'],
 [16.4, 13.6, 'Kylrum'],
 [27.4, 9.6, 'Kylrum'],
 [23.4, 0.5999999999999996, 'Kylrum'],
 [39.4, 13.4, 'Kylrum'],
 [47.4, 10.4, 'Kylrum'],
 [17.4, 6.6, 'Kylrum'],
 [15.399999999999999, 1.5999999999999996, 'Kylrum'],
 [21.4, -1.400000

In [222]:
# On the other hand, partitioning by Green doesn't help so much.
true_rows, false_rows = partition(training_data, Question(0, 30))

# We've isolated one apple in the true rows.
true_rows

[[54.4, 14.4, 'Kylrum'],
 [45.4, 12.4, 'Kylrum'],
 [89.4, 19.5, 'Klassrum'],
 [57.4, 18.1, 'Lärarrum'],
 [84.4, 24.4, 'Klassrum'],
 [95.4, 22.4, 'Klassrum'],
 [81.4, 20.1, 'Lärarrum'],
 [70.5, 19.7, 'Lärarrum'],
 [70.5, 19.8, 'Klassrum'],
 [93.4, 22.4, 'Klassrum'],
 [88.4, 22.5, 'Klassrum'],
 [90.4, 20.5, 'Klassrum'],
 [84.4, 18.1, 'Lärarrum'],
 [82.4, 18.1, 'Lärarrum'],
 [67.5, 18.7, 'Lärarrum'],
 [78.4, 19.1, 'Lärarrum'],
 [41.4, 11.4, 'Kylrum'],
 [74.5, 19.8, 'Klassrum'],
 [91.4, 21.5, 'Klassrum'],
 [93.4, 21.5, 'Klassrum'],
 [88.4, 20.1, 'Lärarrum'],
 [84.4, 21.5, 'Klassrum'],
 [43.4, 13.4, 'Kylrum'],
 [94.4, 22.5, 'Klassrum'],
 [43.4, 12.4, 'Kylrum'],
 [53.4, 13.4, 'Kylrum'],
 [89.4, 20.5, 'Klassrum'],
 [63.5, 17.7, 'Lärarrum'],
 [76.5, 18.8, 'Klassrum'],
 [70.5, 21.7, 'Lärarrum'],
 [84.4, 20.1, 'Lärarrum'],
 [44.4, 13.4, 'Kylrum'],
 [88.4, 24.4, 'Klassrum'],
 [38.4, 13.4, 'Kylrum'],
 [75.5, 22.8, 'Klassrum'],
 [72.5, 21.8, 'Klassrum'],
 [85.4, 18.1, 'Lärarrum'],
 [83.4, 18.1, 'Lä

In [223]:
# But, the false-rows are badly mixed up.
false_rows
#######

[[22.4, 8.6, 'Kylrum'],
 [24.4, 11.24, 'Kylrum'],
 [23.4, 5.6, 'Kylrum'],
 [18.4, 4.6, 'Kylrum'],
 [26.4, 11.6, 'Kylrum'],
 [20.4, 3.5999999999999996, 'Kylrum'],
 [16.4, 5.6, 'Kylrum'],
 [20.4, 2.5999999999999996, 'Kylrum'],
 [18.4, 0.5999999999999996, 'Kylrum'],
 [16.4, 6.6, 'Kylrum'],
 [17.4, -2.4000000000000004, 'Kylrum'],
 [13.399999999999999, 6.6, 'Kylrum'],
 [19.4, 11.6, 'Kylrum'],
 [19.4, -1.4000000000000004, 'Kylrum'],
 [22.4, 5.6, 'Kylrum'],
 [19.4, 1.5999999999999996, 'Kylrum'],
 [16.4, 13.6, 'Kylrum'],
 [27.4, 9.6, 'Kylrum'],
 [23.4, 0.5999999999999996, 'Kylrum'],
 [17.4, 6.6, 'Kylrum'],
 [15.399999999999999, 1.5999999999999996, 'Kylrum'],
 [21.4, -1.4000000000000004, 'Kylrum'],
 [20.4, -0.40000000000000036, 'Kylrum'],
 [18.4, 8.6, 'Kylrum'],
 [20.4, 1.5999999999999996, 'Kylrum'],
 [18.4, 3.5999999999999996, 'Kylrum'],
 [24.4, 2.5999999999999996, 'Kylrum'],
 [24.4, 9.6, 'Kylrum'],
 [15.399999999999999, 5.6, 'Kylrum'],
 [14.399999999999999, -2.4000000000000004, 'Kylrum'],
 [2

In [224]:
def find_best_split(rows):
    """Find the best question to ask by iterating over every feature / value
    and calculating the information gain."""
    best_gain = 0  # keep track of the best information gain
    best_question = None  # keep train of the feature / value that produced it
    current_uncertainty = gini(rows)
    n_features = len(rows[0]) - 1  # number of columns

    for col in range(n_features):  # for each feature

        values = set([row[col] for row in rows])  # unique values in the column

        for val in values:  # for each value

            question = Question(col, val)

            # try splitting the dataset
            true_rows, false_rows = partition(rows, question)

            # Skip this split if it doesn't divide the
            # dataset.
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue

            # Calculate the information gain from this split
            gain = info_gain(true_rows, false_rows, current_uncertainty)

            # You actually can use '>' instead of '>=' here
            # but I wanted the tree to look a certain way for our
            # toy dataset.
            if gain >= best_gain:
                best_gain, best_question = gain, question

    return best_gain, best_question

In [225]:
#######
# Demo:
# Find the best question to ask first for our toy dataset.
best_gain, best_question = find_best_split(training_data)
best_question
# FYI: is color == Red is just as good. See the note in the code above
# where I used '>='.
#######

Is size >= 60.5?

In [226]:
class Leaf:
    """A Leaf node classifies data.

    This holds a dictionary of class (e.g., "Apple") -> number of times
    it appears in the rows from the training data that reach this leaf.
    """

    def __init__(self, rows):
        self.predictions = class_counts(rows)

In [227]:
class Decision_Node:
    """A Decision Node asks a question.

    This holds a reference to the question, and to the two child nodes.
    """

    def __init__(self,
                 question,
                 true_branch,
                 false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

In [228]:
def build_tree(rows):
    """Builds the tree.

    Rules of recursion: 1) Believe that it works. 2) Start by checking
    for the base case (no further information gain). 3) Prepare for
    giant stack traces.
    """

    # try partitioning the dataset on each of the unique attribute,
    # calculate the information gain,
    # and return the question that produces the highest gain.
    gain, question = find_best_split(rows)

    # Base case: no further info gain
    # Since we can ask no further questions,
    # we'll return a leaf.
    if gain == 0:
        return Leaf(rows)

    # If we reach here, we have found a useful feature / value
    # to partition on.
    true_rows, false_rows = partition(rows, question)

    # Recursively build the true branch.
    true_branch = build_tree(true_rows)

    # Recursively build the false branch.
    false_branch = build_tree(false_rows)

    # Return a Question node.
    # This records the best feature / value to ask at this point,
    # as well as the branches to follow
    # dependingo on the answer.
    return Decision_Node(question, true_branch, false_branch)

In [229]:
def print_tree(node, spacing=""):
    """World's most elegant tree printing function."""

    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        print (spacing + "Predict", node.predictions)
        return

    # Print the question at this node
    print (spacing + str(node.question))

    # Call this function recursively on the true branch
    print (spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")

    # Call this function recursively on the false branch
    print (spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")

In [230]:
my_tree = build_tree(training_data)

In [231]:
print_tree(my_tree)

Is size >= 60.5?
--> True:
  Is size >= 90.4?
  --> True:
    Is temp >= 20.4?
    --> True:
      Is size >= 91.4?
      --> True:
        Is size >= 94.4?
        --> True:
          Predict {'Klassrum': 86}
        --> False:
          Is temp >= 23.4?
          --> True:
            Predict {'Klassrum': 26}
          --> False:
            Is temp >= 23.1?
            --> True:
              Predict {'Lärarrum': 1}
            --> False:
              Is temp >= 21.4?
              --> True:
                Is size >= 93.4?
                --> True:
                  Is temp >= 22.4?
                  --> True:
                    Predict {'Klassrum': 4}
                  --> False:
                    Is temp >= 22.1?
                    --> True:
                      Predict {'Lärarrum': 1}
                    --> False:
                      Predict {'Klassrum': 3}
                --> False:
                  Predict {'Klassrum': 18}
              --> False:
                Is 

In [232]:
def classify(row, node):
    """See the 'rules of recursion' above."""

    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        return node.predictions

    # Decide whether to follow the true-branch or the false-branch.
    # Compare the feature / value stored in the node,
    # to the example we're considering.
    if node.question.match(row):
        return classify(row, node.true_branch)
    else:
        return classify(row, node.false_branch)

In [233]:
#######
# Demo:
# The tree predicts the 1st row of our
# training data is an apple with confidence 1.
classify(training_data[0], my_tree)
#######

{'Kylrum': 869}

In [234]:
def print_leaf(counts):
    """A nicer way to print the predictions at a leaf."""
    total = sum(counts.values()) * 1.0
    probs = {}
    for lbl in counts.keys():
        probs[lbl] = str(int(counts[lbl] / total * 100)) + "%"
    return probs

In [235]:
#######
# Demo:
# Printing that a bit nicer
print_leaf(classify(training_data[0], my_tree))
#######

{'Kylrum': '100%'}

In [236]:
#######
# Demo:
# On the second example, the confidence is lower
print_leaf(classify(training_data[1], my_tree))
#######

{'Kylrum': '100%'}

In [237]:
# Evaluate
testing_data = [   
    [30, 15.6, 'Kylrum'],
    [50, 16, 'Klassrum'],
    [89, 21, 'Klassrum'],
    [81.7, 19,'Lärarrum'],
    [87, 19.8, 'Lärarrum'],
]

In [239]:
def generate_testingdata(times):
    while times != 0:
        randomRow = testing_data[random.randint(0, len(testing_data) - 1)]
        new_row = [randomRow[0]+ round(random.uniform(-5, 5)), randomRow[1] + round(random.uniform(-3, 3)), randomRow[2]]
        testing_data.append(new_row)
        times += -1

In [240]:
generate_testingdata(20)

In [241]:
for row in testing_data:
    print ("Actual: %s. Predicted: %s" %
           (row[-1], print_leaf(classify(row, my_tree))))

Actual: Kylrum. Predicted: {'Kylrum': '100%'}
Actual: Klassrum. Predicted: {'Kylrum': '100%'}
Actual: Klassrum. Predicted: {'Klassrum': '100%'}
Actual: Lärarrum. Predicted: {'Klassrum': '100%'}
Actual: Lärarrum. Predicted: {'Klassrum': '100%'}
Actual: Kylrum. Predicted: {'Kylrum': '100%'}
Actual: Lärarrum. Predicted: {'Klassrum': '100%'}
Actual: Lärarrum. Predicted: {'Klassrum': '100%'}
Actual: Kylrum. Predicted: {'Kylrum': '100%'}
Actual: Kylrum. Predicted: {'Kylrum': '100%'}
Actual: Lärarrum. Predicted: {'Klassrum': '100%'}
Actual: Klassrum. Predicted: {'Lärarrum': '100%'}
Actual: Lärarrum. Predicted: {'Klassrum': '100%'}
Actual: Lärarrum. Predicted: {'Lärarrum': '100%'}
Actual: Lärarrum. Predicted: {'Klassrum': '100%'}
Actual: Lärarrum. Predicted: {'Klassrum': '100%'}
Actual: Lärarrum. Predicted: {'Klassrum': '100%'}
Actual: Kylrum. Predicted: {'Kylrum': '100%'}
Actual: Kylrum. Predicted: {'Kylrum': '100%'}
Actual: Klassrum. Predicted: {'Kylrum': '100%'}
Actual: Kylrum. Predicted: {