In [73]:
from math import log

In [87]:
def parse(path):
    data = []
    with open(path) as infile:
        for line in infile:
            if line.isspace():
                continue

            sample = line.split(', ')
            data.append(sample)

    return data

def attr_order(data):
    order = []
    
    for attr_index in range(len(data[0])):
        order.append({})

        pos = 0
        for sample in data:
            value = sample[attr_index]
            if value not in order[attr_index]:
                order[attr_index][value] = pos
                pos += 1

    return order

In [75]:
def entropy(data, class_index):
    class_sizes = {}
    for sample in data:
        class_label = sample[class_index]
        if class_label == '?':
            continue

        if class_label not in class_sizes:
            class_sizes[class_label] = 1
        else:
            class_sizes[class_label] += 1

    num_samples = sum(class_sizes.values())
    h = 0
    for class_label in class_sizes:
        p = class_sizes[class_label] / num_samples
        h += -p * log(p)

    return h

In [76]:
def information_gain(data, attr_index, class_index):
    subsets = {}
    for sample in data:
        value = sample[attr_index]

        if value == '?' or sample[class_index] == '?':
            continue

        if not value in subsets:
            subsets[value] = [sample]
        else:
            subsets[value].append(sample)

    num_samples = sum([len(s) for s in subsets.values()])

    rem = 0
    for value in subsets:
        subset = subsets[value]
        rem += len(subset) / num_samples * entropy(subset, class_index)
    
    return entropy(data, class_index) - rem    

In [77]:
def majority(data, attr_index):
    occurences = {}
    for sample in data:
        attr = sample[attr_index]
        if attr == '?':
            continue
            
        if attr not in occurences:
            occurences[attr] = 1
        else:
            occurences[attr] += 1

    return max(occurences, key=occurences.get)

In [78]:
def best_attr(data, attr_avail, class_index):
    best = 0
    max_ig= 0

    for attr_index in range(len(data[0])):
        if not attr_avail[attr_index]:
            continue
            
        ig = information_gain(data, attr_index, class_index)
        if ig > max_ig:
            best = attr_index
            max_ig = ig

    return best

In [79]:
def dtree_learn(data, attr_avail, default, class_index):
    if not data:
        return (None, default)
    
    classifications = set()
    for sample in data:
        class_label = sample[class_index]
        if class_label == '?':
            continue
            
        classifications.add(class_label)
        if len(classifications) == 1:
            return (None, list(classifications)[0])
        
    if True not in attr_avail:
        return (None, majority(data, class_index))
    
    attr_index = best_attr(data, attr_avail, class_index)
    subtrees = []
    
    subsets = {}
    for sample in data:
        value = sample[attr_index]
        if value == '?':
            continue

        if value not in subsets:
            subsets[value] = [sample]
        else:
            subsets[value].append(sample)
            
    attr_avail_subtree = attr_avail
    attr_avail_subtree[attr_index] = False
    for value in subsets:
        subset = subsets[value]
        subtree = tree_learn(subset, attr_avail_subtree, majority(subset, class_index))
        subtrees.append(subtree)
        
    return (attr_index, subtrees)

In [80]:
def dtree_classify(dtree, sample):
    global ATTR_ORDER
    
    node = dtree   
    while node[0] is not None:
        subtrees = node[1]
        node = subtrees[ATTR_ORDER[sample[attr_index]]]
        
    return node[1]

In [81]:
def dtree_test(dtree, data, class_index):
    correct = 0
    for sample in data:
        if dtree_classify(dtree, sample) == class_index:
            correct += 1
            
    print("{} out of {} correct".format(correct, len(data)))

In [89]:
INDICES = {
    'workclass' : 0,
    'final_weight' : 1,
    'education' : 2,
    'education_num' : 3,
    'marital_status' : 4,
    'occupation' : 5,
    'relationship' : 6,
    'race' : 7,
    'sex' : 8,
    'capital_gain' : 9,
    'capital_loss' : 10,
    'hours_per_week' : 11,
    'native_country' : 12,
    'income_class' : 13
}

ATTR_ORDER = attr_order(data)

training_data = parse('training_data.txt')
test_data = parse('test_data.txt')

dtree = dtree_learn(training_data, [True] * len(training_data[0]), None, INDICES['income_class'])
dtree_test(dtree, test_data, INDICES['income_class'])

0 out of 16282 correct
