In [1]:
import numpy as np
from collections import Counter
import pickle

In [2]:
class DecisionTree:
    
    def _calShannonEnt(self, Y):
        totalCount = len(Y)
        counter = Counter(Y)
        
        shannonEnt = 0.0
        
        for key in counter.keys():
            prob = counter[key] / totalCount
            shannonEnt -= prob * np.log2(prob)
            
        return shannonEnt
    
    def _splitDataset(self, X, Y, index, value):
        totalCount = len(X)
        
        splitX = []
        splitY = []
        
        for i in range(totalCount):
            x = X[i]
            y = Y[i]
            
            if x[index] == value:
                new_x = np.concatenate([x[:index], x[index + 1:]])
                splitX.append(new_x)
                splitY.append(y)
                
        return np.stack(splitX), np.stack(splitY)
    
    def _chooseBestFeature(self, X, Y):
        baseEnt = self._calShannonEnt(Y)
        bestInfoGain = 0.0
        bestFeature = -1
        featureCount = len(X[0])
        totalCount = len(Y)
        
        for i in range(featureCount):
            valueSet = set(X[:, i])
            featureEnt = 0.0
            
            for value in valueSet:
                _, splitY = self._splitDataset(X, Y, i, value)
                ent = self._calShannonEnt(splitY)
                
                prob = len(splitY) / totalCount
                featureEnt += prob * ent
                
            featureInfoGain = baseEnt - featureEnt
            if featureInfoGain > bestInfoGain:
                bestInfoGain = featureInfoGain
                bestFeature = i
                
        return bestFeature
    
    def _majorityCount(self, Y):
        return Counter(Y).most_common()[0][0]
    
    def _createTree(self, X, Y, featureNames):
        counter = Counter(Y)
        
        if len(counter) == 1:
            return Y[0]
        
        if len(featureNames) == 0:
            return self._majorityCount(Y)
        
        bestFeatureIndex = self._chooseBestFeature(X, Y)
        bestFeature = featureNames[bestFeatureIndex]
        
        tree = {bestFeature: {}}
        
        featureValueSet = set(X[:, bestFeatureIndex])
        for value in featureValueSet:
            splitX, splitY = self._splitDataset(X, Y, bestFeatureIndex, value)
            
            splitFeature = featureNames.copy()
            splitFeature.remove(bestFeature)
            
            tree[bestFeature][value] = self._createTree(splitX, splitY, splitFeature)
            
        return tree
    
    def _treePredict(self, x, featureNames, tree):
        bestFeature = list(tree.keys())[0]
        bestIndex = featureNames.index(bestFeature)
        
        subTreeDict = tree[bestFeature]
        subTree = subTreeDict[x[bestIndex]]
        
        if isinstance(subTree, dict):
            label = self._treePredict(x, featureNames, subTree)
        else:
            label = subTree
            
        return label
    
    def _treeDepth(self, tree):
        depth = 1
        
        subTreeDict = list(tree.values())[0]
        
        subDepths = []
        for subTree in subTreeDict.values():                
            if isinstance(subTree, dict):
                subDepth = self._treeDepth(subTree)
                subDepths.append(subDepth)
                
        if len(subDepths) > 0:
            depth += max(subDepths)
        
        return depth
    
    def _leafsNum(self, tree):
        leafsNum = 0
        
        subTreeDict = list(tree.values())[0]
        
        for subTree in subTreeDict.values():
            if isinstance(subTree, dict):
                leafsNum += self._leafsNum(subTree)
            else:
                leafsNum += 1
                
        return leafsNum
    
    def fit(self, X, Y, featureNames):
        self._tree = self._createTree(X, Y, featureNames)
    
    def predict(self, X, featureNames):
        Y = [self._treePredict(x, featureNames, self._tree) for x in X]
        
        return np.stack(Y)
    
    def save(self, path):
        with open(path, 'wb') as f:
            pickle.dump(self._tree, f)
            
    def load(self, path):
        with open(path, 'rb') as f:
            pickle.load(f)
            
    def getTreeDepth(self):
        return self._treeDepth(self._tree)
    
    def getLeafsNum(self):
        return self._leafsNum(self._tree)
        
        

In [3]:
DATASET = 'lenses.txt'
FEATURES = ['age', 'prescript', 'astigmatic', 'tearRate']

In [4]:
def createDataset(path):
    with open(path) as f:
        lines = f.readlines()
        
    dataset = [line.strip().split('\t') for line in lines]
    dataset = np.array(dataset)
    
    X = dataset[:, :-1]
    Y = dataset[:, -1]
    
    return X, Y

In [5]:
X, Y = createDataset(DATASET)

X.shape, Y.shape

((24, 4), (24,))

In [6]:
tree = DecisionTree()

In [7]:
tree.fit(X, Y, FEATURES)

In [8]:
tree.predict(X, FEATURES)

array(['no lenses', 'soft', 'no lenses', 'hard', 'no lenses', 'soft',
       'no lenses', 'hard', 'no lenses', 'soft', 'no lenses', 'hard',
       'no lenses', 'soft', 'no lenses', 'no lenses', 'no lenses',
       'no lenses', 'no lenses', 'hard', 'no lenses', 'soft', 'no lenses',
       'no lenses'], dtype='<U9')

In [9]:
tree.save('decision_tree.pkl')

In [10]:
tree.load('decision_tree.pkl')

In [11]:
tree.getTreeDepth()

4

In [12]:
tree.getLeafsNum()

9