In [90]:
import numpy as np
from collections import Counter

class Node:
    def __init__(self,feature = None,threshold = None,left = None,right = None,*,value= None,n_value=None) -> None:
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value
        
    def is_leaf_node(self):
        return not self.value is None
    
def mean_square_error(y_pred,y):
    return np.average((y_pred-y)**2)

In [91]:
class DecisionTree:
    
    def __init__(self,min_samples_split =15 ,max_depth = 4 ,n_features = 0 ,root = None) -> None:
        
        self.min_samples_split = min_samples_split
        self.max_depth = max_depth
        self.n_features = n_features
        self.root = root
        self.features_importance = {}
    
    def __calulate_variance(self,y):
        var = np.var(y.to_numpy().flatten())
        return var
    
    def __split_data(self,X,threshold):
        #น้อยกว่าอยู่ซ้าย
        
        left_idxs = X[X<=threshold].index
        right_idxs = X[X>threshold].index
        
        return left_idxs,right_idxs
        
    def __information_gain(self,X,threshold,y):
        #parant entropy
        parent_entropy = self.__calulate_variance(y)

        left_idxs,right_idxs = self.__split_data(X,threshold)
        
        if len(left_idxs) == 0 or len(right_idxs) == 0:
            return 0
        
        n = len(y)
        n_l , n_r =  len(left_idxs), len(right_idxs)
        e_l , e_r = self.__calulate_variance(y.loc[list(left_idxs)]), self.__calulate_variance(y.loc[list(right_idxs)])
        
        child_entropy = (n_l/n)*e_l + (n_r/n)*e_r
        
        # calculate the information gain
        
        information_gain = parent_entropy - child_entropy
            
        return information_gain
    
    
    def __update_features_importance(self,best_feature,best_threshold,depth):
        
        if not self.features_importance.get(depth):
            self.features_importance[depth] = [[best_feature,best_threshold]]
        else:
            self.features_importance[depth].append([best_feature,best_threshold])
            
        
        
    def __best_split(self,X,y):
        stop_split = False
        best_ig = 0
        features_name = X.columns
        best_feature = 'Not found'
        best_threshold = 'Not found'
        for feature_name in features_name:
            for threshold in np.unique(X[feature_name]):
                    
                ig = self.__information_gain(X[feature_name],threshold,y)
                
                if ig > best_ig:
                    best_ig = ig
                    best_threshold = threshold
                    best_feature = feature_name
                    print(f'best_feature best_threshold {best_feature} {best_threshold}') if self.verbose == 2 else ''
                    if self.verbose == 2: print(f'current information gain = {best_ig}')
                    
        if self.verbose == 2: print(f'best information gain = {best_ig}')
        
        if best_ig == 0:
            stop_split = True
            print('early stop : best_ig = 0')
            return None,None,stop_split
        print(f'best_feature,best_threshold {best_feature,best_threshold}\n')if self.verbose else ''
        
        return best_feature,best_threshold,stop_split
        
    @staticmethod
    def __most_common_label(y):
        av = np.average(y.to_numpy().flatten())
        # print(av)
        return av
    
    
    def __grow_tree(self,X,y,current_depth=0):
        
        if (current_depth>self.max_depth) or (len(X.columns) == 0) or (len(np.unique(y)) == 1):
            if current_depth>self.max_depth:print("max depth reach")
            return Node(value=self.__most_common_label(y),n_value=len(y))
        
        print(f'current depth = {current_depth}') if self.verbose else ''
        
        best_feature,best_threshold,stop_grow =self.__best_split(X,y)
        
        if stop_grow:return Node(value=self.__most_common_label(y),n_value=len(y))
        
        left_idxs,right_idxs = self.__split_data(X[best_feature],best_threshold)
        
        if len(left_idxs) < self.min_samples_split or len(right_idxs) < self.min_samples_split:
            print("min samples split reach")
            return Node(value=self.__most_common_label(y),n_value=len(y))
       
        self.__update_features_importance(best_feature,best_threshold,current_depth) 
        
        left = self.__grow_tree(X.drop(best_feature,axis=1).loc[left_idxs],y.loc[left_idxs],current_depth+1)

        right = self.__grow_tree(X.drop(best_feature,axis=1).loc[right_idxs],y.loc[right_idxs],current_depth+1)
        
        
        return Node(best_feature,best_threshold,left,right)
    
    def fit(self,X,y,verbose=True):
        self.verbose = verbose
        self.n_features = X.shape[1]
        self.features_name = X.columns
        self.root = self.__grow_tree(X,y)
        
        predictions_train = self.predict(X)
        print('mean_square_error : ',mean_square_error(y.to_numpy().flatten(),self.predict(X)))
        
    def __traverse_tree(self,X,node:Node):
        if node.is_leaf_node():
            return node.value
        
        if X[node.feature] <= node.threshold:
            return self.__traverse_tree(X,node.left)
        
        return self.__traverse_tree(X,node.right)
              
    def predict(self,X):
        return np.array([self.__traverse_tree(x[1],self.root) for x in X.iterrows()])
        
    
        

In [92]:
from sklearn import datasets
from sklearn.model_selection import train_test_split
import pandas as pd

data = datasets.load_diabetes()


x = pd.DataFrame(data.data,columns=data.feature_names)
y = pd.DataFrame(data.target,columns=['target'])

In [93]:
x

Unnamed: 0,age,sex,bmi,bp,s1,s2,s3,s4,s5,s6
0,0.038076,0.050680,0.061696,0.021872,-0.044223,-0.034821,-0.043401,-0.002592,0.019908,-0.017646
1,-0.001882,-0.044642,-0.051474,-0.026328,-0.008449,-0.019163,0.074412,-0.039493,-0.068330,-0.092204
2,0.085299,0.050680,0.044451,-0.005671,-0.045599,-0.034194,-0.032356,-0.002592,0.002864,-0.025930
3,-0.089063,-0.044642,-0.011595,-0.036656,0.012191,0.024991,-0.036038,0.034309,0.022692,-0.009362
4,0.005383,-0.044642,-0.036385,0.021872,0.003935,0.015596,0.008142,-0.002592,-0.031991,-0.046641
...,...,...,...,...,...,...,...,...,...,...
437,0.041708,0.050680,0.019662,0.059744,-0.005697,-0.002566,-0.028674,-0.002592,0.031193,0.007207
438,-0.005515,0.050680,-0.015906,-0.067642,0.049341,0.079165,-0.028674,0.034309,-0.018118,0.044485
439,0.041708,0.050680,-0.015906,0.017282,-0.037344,-0.013840,-0.024993,-0.011080,-0.046879,0.015491
440,-0.045472,-0.044642,0.039062,0.001215,0.016318,0.015283,-0.028674,0.026560,0.044528,-0.025930


In [94]:

X_train,X_test,y_train,y_test = train_test_split(x,y,test_size=0.2 ,random_state=42)

clf = DecisionTree(min_samples_split=5,max_depth=10)
clf.fit(X_train,y_train,verbose=False)


min samples split reach
min samples split reach
min samples split reach
min samples split reach
min samples split reach
mean_square_error :  3283.6870038055476


In [95]:
predictions_train = clf.predict(X_train)
predictions_test = clf.predict(X_test)

print('train_mean_square_error : ',mean_square_error(y_train.to_numpy().flatten(),predictions_train))
print('test_mean_square_error : ',mean_square_error(y_test.to_numpy().flatten(),predictions_test))

train_mean_square_error :  3283.6870038055476
test_mean_square_error :  3259.7975547678784


In [96]:
clf.features_importance

{0: [['bmi', 0.00457216660300077]],
 1: [['s5', 0.00538436996854573], ['s6', 0.0320591578182113]],
 2: [['bp', 0.0597439326260547]]}

In [97]:
predictions_test

array([164.66666667, 175.475     , 164.66666667, 243.        ,
       100.55921053, 100.55921053, 175.475     , 243.        ,
       164.66666667, 175.475     , 100.55921053, 164.66666667,
       100.55921053, 243.        , 100.55921053, 175.475     ,
       243.        , 243.        , 243.57142857, 243.57142857,
       175.475     , 100.55921053, 100.55921053, 175.475     ,
       175.475     , 175.475     , 175.475     , 100.55921053,
       100.55921053, 100.55921053, 175.475     , 100.55921053,
       175.475     , 164.66666667, 175.475     , 175.475     ,
       164.66666667, 164.66666667, 164.66666667, 100.55921053,
       100.55921053, 100.55921053, 100.55921053, 164.66666667,
       164.66666667, 100.55921053, 100.55921053, 100.55921053,
       100.55921053, 164.66666667, 175.475     , 100.55921053,
       175.475     , 100.55921053, 164.66666667, 175.475     ,
       100.55921053, 175.475     , 100.55921053, 100.55921053,
       175.475     , 175.475     , 175.475     , 100.55