In [28]:
#git clone -b prune-experiment https://github.com/atikul-islam-sajib/TreeBasedModel.git
!git clone -b prune-experiment https://github.com/markusloecher/TreeSandBox.git

In [2]:
%cd TreeSandBox

/content/TreeBasedModel


In [28]:
!pip install -e . --verbose

In [36]:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
#from sklearn.metrics import accuracy_score
from TreeModelsFromScratch.RandomForest import RandomForest
from copy import deepcopy

In [3]:
# Generate synthetic dataset
X, y = make_classification(n_samples=100, n_features=20, n_informative=10, n_redundant=10, n_clusters_per_class=2, random_state=42)
X = pd.DataFrame(X)
y = pd.Series(y)

# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

# Instantiate RandomForest
rf = RandomForest(n_trees=2, max_depth=10, min_samples_split=20, min_samples_leaf=5,
                  n_feature="sqrt", bootstrap=True, oob=True, criterion="gini",
                  treetype="classification", random_state=42)

# Fit the model
rf.fit(X_train, y_train)


25 out of 75 samples do not have OOB scores. This probably means too few trees were used to compute any reliable OOB estimates. These samples were dropped before computing the oob_score


In [15]:
X_train.shape

(75, 20)

Understand the tree structure first

In [4]:
dt =deepcopy(rf.trees[0])
rootNode = dt.node_list[0] #??
print(rootNode.samples)
print(rootNode.threshold)
print(rootNode.value)
rootNode.left.samples
rootNode.right.leaf_node
rootNode.right.samples
dt._get_y_for_node(rootNode.right)
rootNode.right.id
rootNode.left.id
dt.decision_paths


75
-0.6577327560706901
0


[(0, 1), (0, 2, 3), (0, 2, 4, 5), (0, 2, 4, 6)]

## Try a recursive function

In [29]:
def traverseTree(parentNode, numLeafs):
    if parentNode.leaf_node:
        print("leaf node", parentNode.id, "with", parentNode.samples, "samples")
        return numLeafs+1
    else:
        print("inner node", parentNode.id, "with", parentNode.samples, "samples")
        if parentNode.left != None:
            print("left turn")
            numLeafs = traverseTree(parentNode.left,numLeafs)
        if parentNode.right != None:
            print("right turn")
            numLeafs = traverseTree(parentNode.right, numLeafs)
        return numLeafs

nLeafs = traverseTree(rootNode,0)
print(nLeafs, "leaves in total")

inner node 0 with 75 samples
left turn
leaf node 1 with 20 samples
right turn
inner node 2 with 55 samples
left turn
leaf node 3 with 17 samples
right turn
inner node 4 with 38 samples
left turn
leaf node 5 with 5 samples
right turn
leaf node 6 with 33 samples
4 leaves in total


Now we use the same idea of a recursive function to prune:

In [27]:
import math

def pruneTree(parentNode, min_samples_leaf = 10):
    if parentNode.leaf_node:
        print("leaf node", parentNode.id, "with", parentNode.samples, "samples")
        return
    else:
        leftChild = parentNode.left
        rightChild = parentNode.right
        print("inner node", parentNode.id, "with", parentNode.samples, "samples and children:",leftChild.id, rightChild.id)
        if (leftChild.samples < min_samples_leaf) & (rightChild.samples < min_samples_leaf):#easiest case
            print('\033[1m' + "pruning both children", leftChild.id, "with", leftChild.samples, "samples", rightChild.id, "with", rightChild.samples, "samples" + '\033[0m')
            parentNode.leaf_node = True
            parentNode.left = None
            parentNode.right = None
        elif (leftChild.samples >= min_samples_leaf) & (rightChild.samples >= min_samples_leaf):#also easy
            print("left turn to ", leftChild.id)
            pruneTree(leftChild,min_samples_leaf)
            print("right turn to ", rightChild.id)
            pruneTree(rightChild,min_samples_leaf)
        elif leftChild.samples < min_samples_leaf:
            print('\033[1m' + "pruning left child", leftChild.id, "with", leftChild.samples, "samples" + '\033[0m')
            parentNode.left = None
            #reminder: left_idxs = np.argwhere(X_column <= split_thresh).flatten()
            parentNode.threshold = -math.inf#ideally one should remove this useless inner node
            print("right turn to ", rightChild.id)
            pruneTree(rightChild,min_samples_leaf)
        elif rightChild.samples < min_samples_leaf:
            print('\033[1m' + "pruning right child", rightChild.id, "with", rightChild.samples, "samples" + '\033[0m')
            parentNode.right = None
            parentNode.threshold = math.inf#ideally one should remove this useless inner node
            print("left turn to ", leftChild.id)
            pruneTree(leftChild,min_samples_leaf)

    #print("parent node", parentNode.id, "with", "children:",parentNode.left, parentNode.right)
    #if (parentNode.left == None) & (parentNode.right == None):
    #    parentNode.leaf_node = True

dt =deepcopy(rf.trees[0])
rootNode = dt.node_list[0] #??

pruneTree(rootNode,20)
print("------------ done pruning, now just traversing again:---------------")
nLeafs = traverseTree(rootNode,0)
print(nLeafs, "leaves in total")

inner node 0 with 75 samples and children: 1 2
left turn to  1
leaf node 1 with 20 samples
right turn to  2
inner node 2 with 55 samples and children: 3 4
[1mpruning left child 3 with 17 samples[0m
right turn to  4
inner node 4 with 38 samples and children: 5 6
[1mpruning left child 5 with 5 samples[0m
right turn to  6
leaf node 6 with 33 samples
------------ done pruning, now just traversing again:---------------
inner node 0 with 75 samples
left turn
leaf node 1 with 20 samples
right turn
inner node 2 with 55 samples
right turn
inner node 4 with 38 samples
right turn
leaf node 6 with 33 samples
2 leaves in total


While the routine above sort of works, it needs two major improvements:
1. The samples in the pruned nodes are discarded at the moment, whereas I think it would be better to redistribute them along all the children.
2. When only one child is pruned, the other surviving one becomes a useless inner node without a split. (I set the split threshold to infinity) So best would be to remove that node.

# **Task 1 & 2**

**Task 1**: Redistribute Samples from Pruned Nodes to All Remaining Children
1. **Function Creation:** Developed a function `redistribute_samples_to_children` to distribute samples from pruned nodes to their remaining sibling nodes.

2. **Integration with Pruning Logic:** Incorporated this function into the pruning process so that whenever a node is pruned, its samples are redistributed among the remaining children.

**Task 2**: Remove Useless Inner Nodes and Promote the Surviving Child
1. **Identifying Useless Inner Nodes:** Updated the pruning logic to detect when only one child remains after pruning.

2. **Promotion of Surviving Child:** Removed the useless inner node by promoting the surviving child, transferring its properties to the parent node.

3. **Recursive Pruning:** Ensured that the pruning process continues recursively if the promoted node is not a leaf node.

In [28]:
import math

def redistribute_samples_to_children(parentNode, prunedChild):
    remaining_children = [parentNode.left, parentNode.right]
    remaining_children = [child for child in remaining_children if child and child.id != prunedChild.id]

    if prunedChild and prunedChild.leaf_node and remaining_children:
        num_remaining_children = len(remaining_children)
        samples_per_child = prunedChild.samples // num_remaining_children
        for child in remaining_children:
            print(f"Redistributing {samples_per_child} samples from pruned node {prunedChild.id} to child node {child.id}")
            child.samples += samples_per_child

def pruneTree(parentNode, min_samples_leaf=10):
    if parentNode.leaf_node:
        print("leaf node", parentNode.id, "with", parentNode.samples, "samples")
        return
    else:
        leftChild = parentNode.left
        rightChild = parentNode.right
        print("inner node", parentNode.id, "with", parentNode.samples, "samples and children:", leftChild.id if leftChild else None, rightChild.id if rightChild else None)

        if leftChild and rightChild:
            if (leftChild.samples < min_samples_leaf) and (rightChild.samples < min_samples_leaf):
                # Prune both children
                print('\033[1m' + "Pruning both children", leftChild.id, "with", leftChild.samples, "samples", rightChild.id, "with", rightChild.samples, "samples" + '\033[0m')
                parentNode.leaf_node = True
                parentNode.left = None
                parentNode.right = None
                # Distribute samples from pruned nodes
                #redistribute_samples_to_children(parentNode, leftChild)
                #redistribute_samples_to_children(parentNode, rightChild)
            elif (leftChild.samples >= min_samples_leaf) and (rightChild.samples >= min_samples_leaf):
                # Recurse on both children
                print("left turn to ", leftChild.id)
                pruneTree(leftChild, min_samples_leaf)
                print("right turn to ", rightChild.id)
                pruneTree(rightChild, min_samples_leaf)
            elif leftChild.samples < min_samples_leaf:
                # Prune left child only
                print('\033[1m' + "Pruning left child", leftChild.id, "with", leftChild.samples, "samples" + '\033[0m')
                redistribute_samples_to_children(parentNode, leftChild)
                parentNode.left = None
                print(f"Removing useless inner node {parentNode.id} and promoting right child {rightChild.id}")
                # Remove useless inner node and promote the right child
                parentNode.left = rightChild.left
                parentNode.right = rightChild.right
                parentNode.threshold = rightChild.threshold
                parentNode.leaf_node = rightChild.leaf_node
                if not parentNode.leaf_node:
                    pruneTree(parentNode, min_samples_leaf)
            elif rightChild.samples < min_samples_leaf:
                # Prune right child only
                print('\033[1m' + "Pruning right child", rightChild.id, "with", rightChild.samples, "samples" + '\033[0m')
                redistribute_samples_to_children(parentNode, rightChild)
                parentNode.right = None
                print(f"Removing useless inner node {parentNode.id} and promoting left child {leftChild.id}")
                # Remove useless inner node and promote the left child
                parentNode.left = leftChild.left
                parentNode.right = leftChild.right
                parentNode.threshold = leftChild.threshold
                parentNode.leaf_node = leftChild.leaf_node
                if not parentNode.leaf_node:
                    pruneTree(parentNode, min_samples_leaf)

        elif leftChild:
            if leftChild.samples < min_samples_leaf:
                print('\033[1m' + "Pruning left child", leftChild.id, "with", leftChild.samples, "samples" + '\033[0m')
                parentNode.left = None
                #redistribute_samples_to_children(parentNode, leftChild)
                parentNode.leaf_node = True

        elif rightChild:
            if rightChild.samples < min_samples_leaf:
                print('\033[1m' + "Pruning right child", rightChild.id, "with", rightChild.samples, "samples" + '\033[0m')
                parentNode.right = None
                #redistribute_samples_to_children(parentNode, rightChild)
                parentNode.leaf_node = True

if __name__ == "__main__":
  dt = deepcopy(rf.trees[0])
  rootNode = dt.node_list[0]

  pruneTree(rootNode, 20)
  print("------------ done pruning, now just traversing again:---------------")
  nLeafs = traverseTree(rootNode, 0)
  print(nLeafs, "leaves in total")

inner node 0 with 75 samples and children: 1 2
left turn to  1
leaf node 1 with 20 samples
right turn to  2
inner node 2 with 55 samples and children: 3 4
[1mPruning left child 3 with 17 samples[0m
Redistributing 17 samples from pruned node 3 to child node 4
Removing useless inner node 2 and promoting right child 4
inner node 2 with 55 samples and children: 5 6
[1mPruning left child 5 with 5 samples[0m
Redistributing 5 samples from pruned node 5 to child node 6
Removing useless inner node 2 and promoting right child 6
------------ done pruning, now just traversing again:---------------
inner node 0 with 75 samples
left turn
leaf node 1 with 20 samples
right turn
leaf node 2 with 55 samples
2 leaves in total


## Refining Redistribution of samples to child nodes

The current function `redistribute_samples_to_children` redistributes the samples only to the next child, not all children down the path. That works well for the given example (as the remaining children are pruned anyways) but  unfortunately it is not sufficient to only pass to the immediate child.

The case where no more children further downstream would be pruned is an example for such a use case.

We will have to write a function which updates all the children downwards from the respective node. This would be a modification of the existing function `_reestimate_node_values()` which always starts at the root node. Can we modify it so that it starts at any node and then just re-estimates the node values and sample counts going down the tree from that node require a traversal of the subtree.

The rough idea would be as follows:
When we prune the "left child 3 with 17 samples" we currently redistribute 17 samples from pruned node 3 to child node 4. But we have to go deeper.
(The code below assumes -incorrectly I think- that the data belonging to node3 are just from X_train)

In [41]:
dir(dt)
node3= dt.node_id_dict[3]["node"]
node4= dt.node_id_dict[4]["node"]
#dt.node_id_dict[3]
#dir(node3)#.samples
node3Samples = np.array(X_train.iloc[node3.sample_indices,:])
print(node3Samples.shape)
node3Samples
np.array([dt.traverse_explain_path(x, node4) for x in node3Samples], dtype="object")
#dt.traverse_explain_path(x=node3Samples,node=node4)
#dt.traverse_explain_path(x=node3Samples,node=rootNode)
#rootNode.sample_indices
#dir(rf)
#rf._bootstrap_samples()
#rf.bootstrap

(17, 20)


array([[[4, 6],
        [{'node_id': 4, 'feature': 17, 'threshold': -2.862, 'value_observation': -1.428, 'decision': '-1.428 > -2.862 --> right'},
         {'node_id': 6, 'value': 0, 'prob_distribution': array([1., 0.])}]],

       [[4, 6],
        [{'node_id': 4, 'feature': 17, 'threshold': -2.862, 'value_observation': -2.152, 'decision': '-2.152 > -2.862 --> right'},
         {'node_id': 6, 'value': 0, 'prob_distribution': array([1., 0.])}]],

       [[4, 6],
        [{'node_id': 4, 'feature': 17, 'threshold': -2.862, 'value_observation': 3.359, 'decision': '3.359 > -2.862 --> right'},
         {'node_id': 6, 'value': 0, 'prob_distribution': array([1., 0.])}]],

       [[4, 5],
        [{'node_id': 4, 'feature': 17, 'threshold': -2.862, 'value_observation': -5.523, 'decision': '-5.523 <= -2.862 --> left'},
         {'node_id': 5, 'value': 0, 'prob_distribution': array([0.8, 0.2])}]],

       [[4, 6],
        [{'node_id': 4, 'feature': 17, 'threshold': -2.862, 'value_observation': 1.8

In [47]:
traversed_nodes = dt.explain_decision_path(node3Samples)#[:, 0].copy()
traversed_nodes.shape#[0]

(17, 2)

In [33]:
dt._reestimate_node_values(node3Samples,)

False