<a href="https://colab.research.google.com/github/anirbansen3027/Practice/blob/main/DecisionTreeFromScratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [52]:
import numpy as np
from collections import Counter

In [53]:
def entropy(y):
  hist = np.bincount(y)
  hist = hist / len(y)
  return -np.sum([p*np.log2(p) for p in hist if p > 0])

class Node:
  def __init__(self, feature = None, threshold = None, left = None, right = None, *, value = None):
    self.feature = feature
    self.threshold = threshold
    self.left = left
    self.right = right
    self.value = value
  
  def is_leaf_node(self):
    return self.value is not None

class DecisionTree:
  def __init__(self, min_samples_split = 5,  max_depth = 50, n_feats = None):
    self.min_samples_split = min_samples_split
    self.max_depth = max_depth
    self.n_feats = n_feats
    self.root = None
  
  def _most_common_label(self, y):
    counter = Counter(y)
    return counter.most_common(1)[0][0]

  def _information_gain(self, y, X_column, thres):
    parent_entropy = entropy(y)
    left_idx, right_idx = self._split(X_column, thres)
    if len(left_idx) == 0 or len(right_idx) == 0:
      return 0
    n = len(y)
    n_l = len(left_idx)
    n_r = len(right_idx)
    e_l, e_r = entropy(y[left_idx]), entropy(y[right_idx])
    child_entropy = ((n_l/n) * (e_l)) + ((n_r/n) * (e_r))
    return parent_entropy - child_entropy

  def _split(self, X_column, thres):
    left_idx = np.argwhere(X_column <= thres).flatten()
    right_idx = np.argwhere(X_column > thres).flatten()
    return left_idx, right_idx

  def _best_criteria(self, X, y, feat_idx):
    best_gain = -1,
    split_idx, split_thres = None, None
    for feat_id in feat_idx:
      X_column = X[:, feat_id]
      thres_vals = np.unique(X_column)
      for thres in thres_vals:
        gain = self._information_gain(y, X_column, thres)
        if gain > best_gain:
          best_gain = gain
          split_idx = feat_id
          split_thres = thres
    
    return split_idx, split_thres

  def _grow_tree(self, X, y, depth = 0):
    n_samples, n_features = X.shape
    n_labels = len(np.unique(y))
    
    #stopping criteria
    if depth >= self.max_depth or n_labels == 1 or n_samples <= self.min_samples_split:
      leaf_value = self._most_common_label(y)
      return Node(value = leaf_value)

    feat_idx = np.random.choice(n_features, self.n_feats, replace = False)
    
    best_feat, best_thres = self._best_criteria(X, y, feat_idx)
    left_idx, right_idx = self._split(X[:, best_feat], best_thres)
    left = self._grow_tree(X[left_idx,:], y[left_idx], depth+1)
    right = self._grow_tree(X[right_idx,:], y[right_idx], depth+1)
    return Node(best_feat, best_thres, left, right)

  def fit(self, X, y):
    self.n_feats = X.shape[1] if not self.n_feats else min(X.shape[1], self.n_feats)
    self.root = self._grow_tree(X,y)

  def predict(self, X):
    return np.array([self._traverse_tree(x, self.root) for x in X])

  def _traverse_tree(self, x, node):
    if node.is_leaf_node():
      return node.value
    if x[node.feature] <= node.threshold:
      return self._traverse_tree(x, node.left)
    return self._traverse_tree(x, node.right)

In [54]:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

X, y = make_classification()
X.shape,  y.shape

((100, 20), (100,))

In [55]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, stratify= y, shuffle = True, random_state = 123)
dt = DecisionTree()
dt.fit(X_train, y_train)
y_pred = dt.predict(X_test)
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       0.89      0.80      0.84        10
           1       0.82      0.90      0.86        10

    accuracy                           0.85        20
   macro avg       0.85      0.85      0.85        20
weighted avg       0.85      0.85      0.85        20

