# Decision Tree

Predict how likely a user is going to pay for premium access.

In [49]:
from collections import Counter, defaultdict, namedtuple
import math

Node = namedtuple('Node', 'criteria value results true false')

In [21]:
# Referrer, Location, ReadFAQ, Pages Viewed, Service Chosen.
data = [['slashdot', 'USA', 'yes', 18, 'None'],
        ['google', 'France', 'yes', 23, 'Premium'],
        ['digg', 'USA', 'yes', 24, 'Basic'],
        ['kiwitobes', 'France', 'yes', 23, 'Basic'],
        ['google', 'UK', 'no', 21, 'Premium'],
        ['(direct)', 'New Zealand', 'no', 12, 'None'],
        ['(direct)', 'UK', 'no', 21, 'Basic'],
        ['google', 'USA', 'no', 24, 'Premium'],
        ['slashdot', 'France', 'yes', 19, 'None'],
        ['digg', 'USA', 'no', 18, 'None'],
        ['google', 'UK', 'no', 18, 'None'],
        ['kiwitobes', 'UK', 'no', 19, 'None'],
        ['digg', 'New Zealand', 'yes', 12, 'Basic'],
        ['slashdot', 'UK', 'no', 21, 'None'],
        ['google', 'UK', 'yes', 18, 'Basic'], 
        ['kiwitobes', 'France', 'yes', 19, 'Basic']]

In [23]:
def partition_by(data, col, value):
    split_fn = None
    if isinstance(value, int) or isinstance(value, float):
        split_fn = lambda row: row[col] >= value
    else:
        split_fn = lambda row: row[col] == value
    
    # Divide the rows into two sets and return them.
    set1 = [row for row in data if split_fn(row)]
    set2 = [row for row in data if not split_fn(row)]
    
    return set1, set2

In [25]:
partition_by(data, col=2, value='yes')

([['slashdot', 'USA', 'yes', 18, 'None'],
  ['google', 'France', 'yes', 23, 'Premium'],
  ['digg', 'USA', 'yes', 24, 'Basic'],
  ['kiwitobes', 'France', 'yes', 23, 'Basic'],
  ['slashdot', 'France', 'yes', 19, 'None'],
  ['digg', 'New Zealand', 'yes', 12, 'Basic'],
  ['google', 'UK', 'yes', 18, 'Basic'],
  ['kiwitobes', 'France', 'yes', 19, 'Basic']],
 [['google', 'UK', 'no', 21, 'Premium'],
  ['(direct)', 'New Zealand', 'no', 12, 'None'],
  ['(direct)', 'UK', 'no', 21, 'Basic'],
  ['google', 'USA', 'no', 24, 'Premium'],
  ['digg', 'USA', 'no', 18, 'None'],
  ['google', 'UK', 'no', 18, 'None'],
  ['kiwitobes', 'UK', 'no', 19, 'None'],
  ['slashdot', 'UK', 'no', 21, 'None']])

In [34]:
def counter(data):
    labels = [row[-1] for row in data]
    return list(Counter(labels).items())

In [33]:
list(counter(data))

[('None', 7), ('Premium', 3), ('Basic', 6)]

In [47]:
def gini_impurity(data):
    n = len(data)
    counts = counter(data)
    p = 0
    for label, count in counts:
        p += (count / n) ** 2
    return 1 - p

In [48]:
gini_impurity(data) # 0.6328125

0.6328125

In [54]:
def entropy(data):
    """The amount of disorder in a set - basically how mixed a set is. Prefer lower score."""
    log2 = lambda x: math.log(x) / math.log(2)
    n = len(data)
    counts = counter(data)
    ent = 0.0
    for _, count in counts:
        p = count / n
        ent -= p * log2(p)
    return ent

In [53]:
entropy(data) # 1.5052408149441479

1.5052408149441479

In [56]:
set1, set2 = partition_by(data, 2, 'yes')
entropy(set1), gini_impurity(set1)

(1.2987949406953985, 0.53125)

In [None]:
# def build_tree(data, score_fn=entropy):
#     if len(data) == 0: return None
#     current_score = score_fn(data)
    
#     # Set up some variables to track the best criteria.
#     best_gain = 0.0
#     best_criteria = None
#     best_sets = None
    
#     # The last column is the target.
#     for col in range(0, len(data[0]) - 1):