# Vizualize a shallow single tree

In [1]:
import warnings
warnings.filterwarnings("ignore")


import numpy as np
from scipy import io
from scipy import stats
from graphviz import Digraph
import math

eps = 1e-5  # a small number

In [2]:
data = io.loadmat("data/spam_data.mat")
print("\nloaded data!")
fields = "test_data", "training_data", "training_labels"
for field in fields:
    print(field, data[field].shape)


loaded data!
test_data (5857, 32)
training_data (5172, 32)
training_labels (1, 5172)


In [3]:
X_train = data['training_data']
y_train = data['training_labels'][0]

In [4]:
class DecisionTree:
    def __init__(self, max_depth=3, feature_labels=None, m=0, count=0, count_list=[]):
        self.max_depth = max_depth
        self.features = feature_labels
        self.left, self.right = None, None  # for non-leaf nodes
        self.split_idx, self.thresh = None, None  # for non-leaf nodes
        self.data, self.pred = None, None  # for leaf nodes
        self.m = m
        
        #visualization
        f = Digraph('graph', filename='graph')
        f.attr('node', shape='circle')
        
    def reduce_feature(self, X, y):        
        if self.m == 0:
            self.X = X
            self.y = y            
        else:
            random.seed(42)
            random_idx_m = [random.randint(0,len(X[0])-1) for i in range(self.m)]
            self.X = X.T[random_idx_m].T
            self.y = y
        
        return self.X, self.y
    
    
    @staticmethod
    def information_gain(X, y, thresh):
        n = len(y)
        n_class0 = list(y).count(0)
        n_class1 = list(y).count(1)

        H = -n_class0/n * np.log2(n_class0/n) - n_class1/n * np.log2(n_class1/n) # initial entropy

        left, right = [], []
        for idx in range(n): 
            if X[idx] < thresh:
                left.append(y[idx])
            else: right.append(y[idx])

        left_n_class0 = list(left).count(0)
        left_n_class1 = list(left).count(1)
        right_n_class0 = list(right).count(0)
        right_n_class1 = list(right).count(1)

        left_H = -left_n_class0/n * np.log2(left_n_class0/n) - left_n_class1/n * np.log2(left_n_class1/n)
        right_H = -right_n_class0/n * np.log2(right_n_class0/n) - right_n_class1/n * np.log2(right_n_class1/n)
        H_after = (len(left)*left_H + len(right)*right_H)/n

        gain = H - H_after
        
        return gain

    @staticmethod
    def gini_impurity(X, y, thresh):
        # TODO implement gini_impurity function
        pass

    def split(self, X, y, idx, thresh):
        X0, idx0, X1, idx1 = self.split_test(X, idx=idx, thresh=thresh)
        y0, y1 = y[idx0], y[idx1]

        return X0, y0, X1, y1

    def split_test(self, X, idx, thresh):
        idx0 = np.where(X[:, idx] < thresh)[0]
        idx1 = np.where(X[:, idx] >= thresh)[0]
        X0, X1 = X[idx0, :], X[idx1, :]

        return X0, idx0, X1, idx1

    def fit(self, X_in, y_in, count=0, count_list=[]):
        
        
        # reduce column space for random forrest
        X, y = self.reduce_feature(X_in, y_in)

        
        if self.max_depth > 0:

            
            # compute entropy gain for all single-dimension splits,
            # thresholding with a linear interpolation of 10 values
            gains = []
            # The following logic prevents thresholding on exactly the minimum
            # or maximum values, which may not lead to any meaningful node
            # splits.
            thresh = np.array([
                np.linspace(np.min(X[:, i]) + eps, np.max(X[:, i]) - eps, num=10)
                for i in range(X.shape[1])
            ])
            for i in range(X.shape[1]):
                gains.append([self.information_gain(X[:, i], y, t) for t in thresh[i, :]])

            gains = np.nan_to_num(np.array(gains))
            self.split_idx, thresh_idx = np.unravel_index(np.argmax(gains), gains.shape)            
            self.thresh = thresh[self.split_idx, thresh_idx]
            X0, y0, X1, y1 = self.split(X, y, idx=self.split_idx, thresh=self.thresh)           
            
            print('\n' + '"' + str(self.features[self.split_idx]) + '" (frequence > ' + str(math.trunc(self.thresh)) + ')')
            print('# left: ' + str(len(X0)) + ' -- # right: ' + str(len(X1)))
            
            
            
            if X0.size > 0 and X1.size > 0:
                count+=1
                count_list.append(count)
                count_list.append(str(self.features[self.split_idx]) + '_' + str(count) + '" (frequence > ' + str(math.trunc(self.thresh)) + ')')
                #print(count)
                #print(count_list)
                
                self.left = DecisionTree(
                    max_depth=self.max_depth - 1, feature_labels=self.features)
                self.left.fit(X0, y0, count, count_list)
                self.right = DecisionTree(
                    max_depth=self.max_depth - 1, feature_labels=self.features)
                self.right.fit(X1, y1, count, count_list) 
                
            else:
                self.max_depth = 0
                self.data, self.labels = X, y
                self.pred = stats.mode(y).mode[0]
            
        else:
            
            self.data, self.labels = X, y
            self.pred = stats.mode(y).mode[0]
 
        #f.view()
        count_list = np.sort(np.array(count_list).reshape(-1,2), axis=1)
        
        count_list = np.array((sorted(count_list, key = lambda x: float(x[0]))))
        count_list = count_list.reshape(-1,2)
        #print(np.sort(count_list, axis=0))
        #print(count)
        
        
        return self, count, count_list
    
    def predict(self, X):
        if self.max_depth == 0:
            return self.pred * np.ones(X.shape[0])
        else:
            X0, idx0, X1, idx1 = self.split_test(X, idx=self.split_idx, thresh=self.thresh)
            yhat = np.zeros(X.shape[0])
            yhat[idx0] = self.left.predict(X0)
            yhat[idx1] = self.right.predict(X1)
            return yhat


In [5]:
label = ['pain','private','bank','money','drug','spam','prescription','creative','height','featured','differ','width','other','energy','business','message','volumes','revision','path','meter','memo','planning','pleased','record','out','semicolon','dollar','sharp','exclamation','para','bracket','and']
dt = DecisionTree(max_depth=3, feature_labels=label)
count_list = dt.fit(X_train, y_train)[2]


"exclamation" (frequence > 0)
# left: 3917 -- # right: 1255

"pain" (frequence > 0)
# left: 3878 -- # right: 39

"money" (frequence > 0)
# left: 3788 -- # right: 90

"pain" (frequence > 1)
# left: 33 -- # right: 6

"and" (frequence > 0)
# left: 1100 -- # right: 155

"message" (frequence > 1)
# left: 995 -- # right: 105

"money" (frequence > 0)
# left: 146 -- # right: 9


In [6]:
f = Digraph('graph', filename='graph')
f.attr('node', shape='circle')
for i in range(3):
    j = 2*(i+1)-1
    k = 2*(i+1)
    f.edge(count_list[i,1], count_list[j,1], label='true')
    f.edge(count_list[i,1], count_list[k,1], label='flase')
f.view()

'graph.pdf'