In [50]:
#!git clone -b prune-experiment https://github.com/atikul-islam-sajib/TreeBasedModel.git
!git clone -b prune-experiment https://github.com/markusloecher/TreeSandBox.git

Cloning into 'TreeSandBox'...
remote: Enumerating objects: 21, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (19/19), done.[K
remote: Total 21 (delta 3), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (21/21), 32.43 KiB | 1.16 MiB/s, done.
Resolving deltas: 100% (3/3), done.


In [52]:
%cd TreeSandBox

/Users/loecherm/Nextcloud2/SHKs/Atikul/TreeSandBox


In [None]:
!pip install -e . --verbose

In [54]:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import pandas as pd
#import numpy as np
#from sklearn.metrics import accuracy_score
from TreeModelsFromScratch.RandomForest import RandomForest
from copy import deepcopy

In [21]:
# Generate synthetic dataset
X, y = make_classification(n_samples=100, n_features=20, n_informative=10, n_redundant=10, n_clusters_per_class=2, random_state=42)
X = pd.DataFrame(X)
y = pd.Series(y)

# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

# Instantiate RandomForest
rf = RandomForest(n_trees=2, max_depth=10, min_samples_split=20, min_samples_leaf=5,
                  n_feature="sqrt", bootstrap=True, oob=True, criterion="gini",
                  treetype="classification", random_state=42)

# Fit the model
rf.fit(X_train, y_train)


25 out of 75 samples do not have OOB scores. This probably means too few trees were used to compute any reliable OOB estimates. These samples were dropped before computing the oob_score


Understand the tree structure first

In [43]:
dt =deepcopy(rf.trees[0])
rootNode = dt.node_list[0] #??
print(rootNode.samples)
print(rootNode.threshold)
print(rootNode.value)
rootNode.left.samples
rootNode.right.leaf_node
rootNode.right.samples
dt._get_y_for_node(rootNode.right)
rootNode.right.id
rootNode.left.id
dt.decision_paths


75
-0.6577327560706901
0


[(0, 1), (0, 2, 3), (0, 2, 4, 5), (0, 2, 4, 6)]

## Try a recursive function

In [55]:
def traverseTree(parentNode, numLeafs):
    if parentNode.leaf_node:
        print("leaf node", parentNode.id, "with", parentNode.samples, "samples")
        return numLeafs+1
    else:
        print("inner node", parentNode.id, "with", parentNode.samples, "samples")
        if parentNode.left != None:
            print("left turn")
            numLeafs = traverseTree(parentNode.left,numLeafs)
        if parentNode.right != None:
            print("right turn")
            numLeafs = traverseTree(parentNode.right, numLeafs)
        return numLeafs

nLeafs = traverseTree(rootNode,0)
print(nLeafs, "leaves in total")

inner node 0 with 75 samples
left turn
leaf node 1 with 20 samples
right turn
inner node 2 with 55 samples
right turn
inner node 4 with 38 samples
right turn
leaf node 6 with 33 samples
2 leaves in total


Now we use the same idea of a recursive function to prune:

In [47]:
import math

def pruneTree(parentNode, min_samples_leaf = 10):
    if parentNode.leaf_node:
        print("leaf node", parentNode.id, "with", parentNode.samples, "samples")
        return
    else:
        leftChild = parentNode.left
        rightChild = parentNode.right
        print("inner node", parentNode.id, "with", parentNode.samples, "samples and children:",leftChild.id, rightChild.id)
        if (leftChild.samples < min_samples_leaf) & (rightChild.samples < min_samples_leaf):#easiest case
            print('\033[1m' + "pruning both children", leftChild.id, "with", leftChild.samples, "samples", rightChild.id, "with", rightChild.samples, "samples" + '\033[0m')
            parentNode.leaf_node = True
            parentNode.left = None
            parentNode.right = None
        elif (leftChild.samples >= min_samples_leaf) & (rightChild.samples >= min_samples_leaf):#also easy
            print("left turn to ", leftChild.id)
            pruneTree(leftChild,min_samples_leaf)
            print("right turn to ", rightChild.id)
            pruneTree(rightChild,min_samples_leaf)
        elif leftChild.samples < min_samples_leaf:
            print('\033[1m' + "pruning left child", leftChild.id, "with", leftChild.samples, "samples" + '\033[0m')
            parentNode.left = None
            #reminder: left_idxs = np.argwhere(X_column <= split_thresh).flatten()
            parentNode.threshold = -math.inf#ideally one should remove this useless inner node
            print("right turn to ", rightChild.id)
            pruneTree(rightChild,min_samples_leaf)
        elif rightChild.samples < min_samples_leaf:
            print('\033[1m' + "pruning right child", rightChild.id, "with", rightChild.samples, "samples" + '\033[0m')
            parentNode.right = None
            parentNode.threshold = math.inf#ideally one should remove this useless inner node
            print("left turn to ", leftChild.id)
            pruneTree(leftChild,min_samples_leaf)
    
    #print("parent node", parentNode.id, "with", "children:",parentNode.left, parentNode.right)
    #if (parentNode.left == None) & (parentNode.right == None):
    #    parentNode.leaf_node = True

dt =deepcopy(rf.trees[0])
rootNode = dt.node_list[0] #??

pruneTree(rootNode,20)
print("------------ done pruning, now just traversing again:---------------")
nLeafs = traverseTree(rootNode,0)
print(nLeafs, "leaves in total")

inner node 0 with 75 samples and children: 1 2
left turn to  1
leaf node 1 with 20 samples
right turn to  2
inner node 2 with 55 samples and children: 3 4
[1mpruning left child 3 with 17 samples[0m
right turn to  4
inner node 4 with 38 samples and children: 5 6
[1mpruning left child 5 with 5 samples[0m
right turn to  6
leaf node 6 with 33 samples
------------ done pruning, now just traversing again:---------------
inner node 0 with 75 samples
left turn
leaf node 1 with 20 samples
right turn
inner node 2 with 55 samples
right turn
inner node 4 with 38 samples
right turn
leaf node 6 with 33 samples
2 leaves in total


While the routine above sort of works, it needs two major improvements:
1. The samples in the pruned nodes are discarded at the moment, whereas I think it would be better to redistribute them along all the children.
2. When only one child is pruned, the other surviving one becomes a useless inner node without a split. (I set the split threshold to infinity) So best would be to remove that node.