In [1]:
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]

In [2]:
header = ["color", "diameter", "label"]

In [8]:
def unique_vals(rows, col):
    """Find the unique values for a column in a dataset."""
    return set([row[col] for row in rows])

print("COL 1", unique_vals(training_data, 0))
print("COL 2", unique_vals(training_data, 1))

COL 1 {'Red', 'Yellow', 'Green'}
COL 2 {1, 3}


In [9]:
def class_counts(rows):
    """Counts the number of each class/label in a dataset."""
    counts = {}  # a dictionary of label -> count.
    for row in rows:
        # in our dataset format, the label is always the last column
        label = row[-1]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts

print("CLASS", class_counts(training_data))

CLASS {'Apple': 2, 'Grape': 2, 'Lemon': 1}


In [10]:
def is_numeric(value):
    """Test if a value is numeric."""
    return isinstance(value, int) or isinstance(value, float)

In [11]:
class Question:
    """
    A Question is used to partition a dataset.

    This class just records a 'column number' (e.g., 0 for Color) and a
    'column value' (e.g., Green). The 'match' method is used to compare
    the feature value in an example to the feature value stored in the
    question. See the demo below.
    """

    def __init__(self, column, value):
        self.column = column
        self.value = value

    # determines if point is above or equal to the threshold 
    def match(self, example):
        # Compare the feature value in an example to the
        # feature value in this question.
        val = example[self.column]
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value

    def __repr__(self):
        # This is just a helper method to print
        # the question in a readable format.
        condition = "=="
        if is_numeric(self.value):
            condition = ">="
        return "Is %s %s %s?" % (
            header[self.column], condition, str(self.value))

In [18]:
print(Question(1, 3))

is_green = Question(0, "Green")
print(is_green)

Is diameter >= 3?
Is color == Green?


In [19]:
match_ex = training_data[0] # first row
print(match_ex)

q.match(match_ex) # passing in whole row

['Green', 3, 'Apple']


True

In [16]:
def partition(rows, question):
    """
    Partitions a dataset.

    For each row in the dataset, check if it matches the question. If
    so, add it to 'true rows', otherwise, add it to 'false rows'.
    """
    true_rows, false_rows = [], []
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows, false_rows

In [25]:
true, false = partition(training_data, is_green)
print("Which rows are green?")
print("TRUE", true)
print("FALSE", false)

Which rows are green?
TRUE [['Green', 3, 'Apple']]
FALSE [['Yellow', 3, 'Apple'], ['Red', 1, 'Grape'], ['Red', 1, 'Grape'], ['Yellow', 3, 'Lemon']]


In [29]:
is_diameter_less_than_3 = Question(1, 3)
print(is_diameter_less_than_3)

true, false = partition(training_data, is_diameter_less_than_3)
print("TRUE", true)
print("FALSE", false)

Is diameter >= 3?
TRUE [['Green', 3, 'Apple'], ['Yellow', 3, 'Apple'], ['Yellow', 3, 'Lemon']]
FALSE [['Red', 1, 'Grape'], ['Red', 1, 'Grape']]


In [37]:
def gini(predicted_classes):
    """
    Calculate the Gini Impurity for a list of rows.

    There are a few different ways to do this, I thought this one was
    the most concise. See:
    https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity
    """
    counts = class_counts(predicted_classes)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(predicted_classes))
        impurity -= prob_of_lbl**2
    return impurity

In [34]:
overall_g = gini(training_data)
g

0.6399999999999999

In [36]:
g_d_3 = gini(true)
g_d_3

0.4444444444444445