In [337]:
import numpy as np
from math import log
import operator
from collections import Counter

In [338]:
input_file = open("lenses.txt")

In [339]:
lenses = [inst.strip().split('\t') for inst in input_file.readlines()]

In [340]:
np.array(lenses).shape

(24, 5)

In [341]:
lensesLabels = ["age", "prescript", "astigmatic", "tearRate"]

In [342]:
def calc_ent(data):
    data_size = data.shape[0]
    labels = Counter(data[:,-1])  
    shannon_ent = 0.0  
    values = np.array(list(labels.values()))
    values_prob = values / values.sum()
    ent = -(values_prob * np.log2(values_prob)).sum()
    return ent

In [343]:
def splitDataSet(data, axis, value):
    data_f = data[data[:, axis]==value]
    data_r = np.delete(data_f, axis, 1)
    return data_r

In [344]:
def chooseBestFeatureToSplit(data):
    feat_size = data.shape[1] - 1
    baseEntropy = calc_ent(data)  
    bestInfoGain = 0.0 
    bestFeature = -1 
    for i in range(feat_size):
        uniqueVals = set(data[:, i])  
        newEntropy = 0.0  
        for value in uniqueVals:
            subDataSet = splitDataSet(data, i, value)
            prob = subDataSet.shape[0] * 1.0 / data.shape[0] 
            newEntropy += prob * calc_ent(subDataSet)  
        infoGain = baseEntropy - newEntropy  
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain  
            bestFeature = i  
    return bestFeature  

In [345]:
def majorityCnt(classList):
    return Counter(classList).most_common(1)[0][0]

In [346]:
def createTree(data, labels_ori, depth):
    labels = labels_ori[:]
    classList = data[:, -1]
    if Counter(classList)[classList[0]] == classList.shape[0]: 
        return classList[0]  
    if data.shape[0] == 1: 
        return majorityCnt(classList) 
    bestFeat = chooseBestFeatureToSplit(data)  
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}} 
    del (labels[bestFeat]) 
    featValues = data[:, bestFeat]
    uniqueVals = set(featValues) 
    for value in uniqueVals: 
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(
            splitDataSet(data, bestFeat, value), subLabels, depth+1)
    return myTree

In [347]:
lenseTree = createTree(data_l, lensesLabels, 0)

In [348]:
lenseTree

{'tearRate': {'normal': {'astigmatic': {'no': {'age': {'pre': 'soft',
      'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}},
      'young': 'soft'}},
    'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses',
        'presbyopic': 'no lenses',
        'young': 'hard'}},
      'myope': 'hard'}}}},
  'reduced': 'no lenses'}}