In [2]:
#Data Set
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]

In [4]:
labels = ['color','diameter','label']

#Finds the unique values for a given column
def unique_values(rows, col):
    return set(row[col] for row in rows)


{'Red', 'Yellow', 'Green'}


In [10]:
#Test
print(unique_values(training_data, 0))   

{'Red', 'Yellow', 'Green'}


In [9]:
#Count the members of each class
def class_counts(rows):
    counts = {}
    for row in rows:
        label = row[-1]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts


In [11]:
#Test 
print(class_counts(training_data))

{'Apple': 2, 'Grape': 2, 'Lemon': 1}


In [12]:
#Test if a value is numeric
def is_numeric(value):
    return isinstance(value, int) or isinstance(value, float)

In [16]:
#Test
print(is_numeric(42)) 
print(is_numeric('the meaning of life and the universe'))

True
False


In [21]:
class Question:
    def __init__(self,column,value):
        self.column  = column
        self.value = value

    def match(self, example):
        val = example[self.column]
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value
    def __repr__(self):
        if is_numeric(self.value):
            condition = ">="
        else:
            condition = "=="
        return "Is %s %s %s?" % (header[self.column], condition, str(self.value))

In [None]:
def partitions(rows, question):
    true_rows, false_rows = [], []
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false.rows.append(row)
    return true_rows, false_rows

In [None]:
def gini_impurity(rows):
    counts = class_counts(rows)
    impurity = 1
    for label in counts:
        prob_of_lbl = counts[label] / float(len(rows)) # The less likely a label is 
        impurity -= prob_of_lbl**2                     # the more it contributes to the impurity of the partition

In [23]:
def information_gain(left, right, current_uncertainty):
    p = float(len(left) / len(left) + len(right))
    return current_uncertainty - p * gini(left) + (1 - p) * gini(right)

In [None]:
def find_best_split(rows):
    best_gain = 0
    best_question = None
    current_unertainty = gini_impurtity(rows)
    n_features = len(rows[0]) - 1

    for col in range(n_features):
        values = unique_values(rows, col)
        for val in values:
            question = Question(col, val)
            true_rows, false_rows = partition(rows, question)

            if len(true_rows) == 0 or len(false_rows) == 0:
                continue
            gain = information_gain(true_rows, false_rows, current_uncertainty)
            if gain >= best_gain:
                best_gain , best_question = gain, question
    return best_gain, best_question