In [1]:
#decison trees for sklearn

import numpy  as np
from collections import Counter

from sklearn import datasets
from sklearn.model_selection import train_test_split

def entropy(y):
    hist=np.bincount(y)
    ps = hist /len(y)
    return   -np.sum([p*np.log2(p) for p in ps if p > 0])




class node:
    def __init__(self,feature=None, threshold=None,left=None,right=None,*,value=None):
        self.feature=feature
        self.threshold=threshold
        self.left=left
        self.right=right
        self.value=value
        
    def is_leaf_node(self):
      return self.value is not None

class DecisonTree:
    def __init__(self,min_samples_split=2,max_depth=100,n_feats=None):
        self.min_samples_split=min_samples_split
        self.max_depth=max_depth
        self.n_feats=n_feats
        self.root=None


    def fit(self,X,y):
      #grow tree
      self.n_feats= X.shape[1] if not self.n_feats else min (self.n_feats,X.shape[1])
      self.root= self._grow_tree(X,y)
      
    def _grow_tree(self,X,y,depth=0):
        n_samples,n_features=X.shape
        n_labels=len(np.unique(y))
        
        #stoping criteria
        if (depth >=self.max_depth
            or n_labels==1
            or n_samples < self.min_samples_split):
            
            leaf_value=self._most_common_label(y)
            return node(value=leaf_value)
        
       #no stopping criteria
        feat_idxs=np.random.choice(n_features,self.n_feats,replace=False)
       
       #greedy search
        best_feat,best_thresh=self._best_criteria(X,y,feat_idxs)
        left_idx, right_idx=self._split(X[:,best_feat],best_thresh)
        left=self._grow_tree(X[left_idx,:], y[left_idx],depth+1)
        right=self._grow_tree(X[right_idx,:], y[right_idx],depth+1)
        
        return node(best_feat,best_thresh,left,right)
    
    def predict(self,X):
          return np.array([self._traverse_tree(x, self.root) for x in X])
    
    def _traverse_tree(self,x,node):
      #  print("predicting")
       # print("thresold",node.threshold)
       # print("feature",x[node.feature])
        
        if node.is_leaf_node():
            return node.value
        
        if x[node.feature] <= node.threshold:
            return self._traverse_tree(x,node.left)
        return self._traverse_tree(x, node.right)
    
    
    def _best_criteria(self,X,y,feat_idxs):
        best_gain=-1
        split_idx,split_thresh=None,None
        
        for feat_idx in feat_idxs:
            X_column=X[:,feat_idx]
            thresholds=np.unique(X_column)
            
            for threshold in thresholds:
                gain=self._information_gain(y,X_column,threshold)
                
                if gain > best_gain:
                    best_gain=gain
                    split_idx=feat_idx
                    split_thresh=threshold
        return split_idx,split_thresh
   
    def _information_gain(self,y,X_column,split_thresh):
        #parent entropy
        parent_entropy=entropy(y)
        
        
        #generate split
        left_idx,right_idx=self._split(X_column,split_thresh)
        
        if len(left_idx)==0 or len(right_idx)==0:
            return 0
        n=len(y)
        n_l,n_r=len(left_idx),len(right_idx)
        e_l,e_r=entropy(y[left_idx]),entropy(y[right_idx])
        child_entropy=(n_l/n)*e_l + (n_r/n)*e_r

        ig=parent_entropy-child_entropy
        return ig
    
    
    def _most_common_label(self,y):
        counter=Counter(y)
        most_common=counter.most_common(1)[0][0]
        return most_common
    
        
    
    def _split(self,X_column,split_thresh):
      left_idxs = np.argwhere(X_column <=split_thresh).flatten()
      right_idxs = np.argwhere(X_column >split_thresh).flatten() 
      return left_idxs,right_idxs
  
def accuracy(y_true, y_pred):
     
     accuracy = np.sum(y_true == y_pred) / len(y_true)
     return accuracy    

if __name__ == "__main__" :
    
    data=datasets.load_breast_cancer()
    X=data.data
    y=data.target
    
    X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2,random_state=12)
    
    d=DecisonTree(max_depth=10)
    
    d.fit(X_train,y_train)
    print(d.root.__dict__)
    y_pred=d.predict(X_test)
   
       
    print("checking accuracy")
    acc = accuracy(y_test, y_pred)
    print ("Accuracy:", acc)

{'feature': 20, 'threshold': 16.77, 'left': <__main__.node object at 0x7f17c86232d0>, 'right': <__main__.node object at 0x7f17c86236d0>, 'value': None}
checking accuracy
Accuracy: 0.9473684210526315
