Import necessary modules/libraries

In [1]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import datasets

Define node class, along with necessary attributes and functions

In [3]:
class TreeNode:
  def __init__(self,X,y,thresh,feature):
    self.X= X #storing the input data at a node
    self.y= y #storing the class labels
    self.thresh=thresh #storing the threshold for splitting
    self.feature=feature #storing the feature used for splitting
    self.children= [] #list of children nodes
    self.parent= None #storing the parent node

  def add_child(self, child): #adding a child node
    child.parent = self #setting the node as child's parent
    self.children.append(child) #storing the child node in the list of children nodes

  def get_level(self): #get the depth of a node
    level = 0
    p = self.parent
    while p:
      level += 1
      p = p.parent

    return level

  def print_tree(self): #displaying the tree formed hierarchically
    spaces = ' ' * self.get_level() * 3
    prefix = spaces + "|__" if self.parent else "" #using indentation and |_ as representation of depth & parent-child relation
    print(prefix + "feature_idx: "+ str(self.feature))
    if self.children:
      for child in self.children:
        child.print_tree() #recursively use the method for displaying the subtrees

Creating a binary classification dataset with 100 data-points and 10 features

In [4]:
num_feats=10
X,y= datasets.make_classification(n_samples=100, n_features=num_feats, n_classes=2, random_state=25)

In [5]:
root= TreeNode(X,y,-1,-1) #root node

Method to build a tree using the above dataset

In [6]:
def build_tree(node,depth):
    sel_feature= np.random.randint(0,num_feats) #select a random feature index to perform splitting

    depth-=1 #update the depth
    #print(len(node.X),depth)

    if(depth>0 and len(node.X)>1): #check if the depth criterion is satisfied and node is not a leaf node (>1 samples)
      thresh= (np.min(X[:,sel_feature])+np.max(X[:,sel_feature]))/2
      #print(thresh,len(X[X[:,sel_feature]<thresh]))
      left_child=TreeNode(X[X[:,sel_feature]<thresh],y[X[:,sel_feature]<thresh],thresh,sel_feature)
      right_child=TreeNode(X[X[:,sel_feature]>=thresh],y[X[:,sel_feature]>=thresh],thresh,sel_feature)
      node.add_child(left_child)
      node.add_child(right_child)
      build_tree(left_child,depth)
      build_tree(right_child,depth)

    else:
      return


In [7]:
build_tree(root,5) #build a tree of depth 5 max

In [8]:
root.print_tree()

feature_idx: -1
   |__feature_idx: 2
      |__feature_idx: 3
         |__feature_idx: 3
            |__feature_idx: 7
            |__feature_idx: 7
         |__feature_idx: 3
            |__feature_idx: 8
            |__feature_idx: 8
      |__feature_idx: 3
         |__feature_idx: 2
            |__feature_idx: 5
            |__feature_idx: 5
         |__feature_idx: 2
            |__feature_idx: 7
            |__feature_idx: 7
   |__feature_idx: 2
      |__feature_idx: 0
         |__feature_idx: 8
            |__feature_idx: 9
            |__feature_idx: 9
         |__feature_idx: 8
            |__feature_idx: 8
            |__feature_idx: 8
      |__feature_idx: 0
         |__feature_idx: 8
            |__feature_idx: 2
            |__feature_idx: 2
         |__feature_idx: 8
            |__feature_idx: 6
            |__feature_idx: 6
