In [36]:
import pandas as pd
import numpy as np
from sklearn.datasets import make_regression
import random
from sklearn.model_selection import train_test_split

In [37]:
X, y = make_regression(n_samples=50, n_features=20, n_informative=2, noise=5, random_state=42)
X = pd.DataFrame(X)
y = pd.Series(y)
X.columns = [f'col_{col}' for col in X.columns]

In [58]:
def mse(y):
    n = len(y)
    y_mean = y.mean()
    return ((y-y_mean)**2).mean()

In [59]:
def msep(y, X, Q):
    n = len(y)
    y_right = y.loc[X[X>Q].index]
    y_left = y.loc[X[X<=Q].index]

    mse_left = mse(y_left)
    mse_right = mse(y_right)
    
    n_left = len(y_left)
    n_right = len(y_right)
    
    return mse(y) - (n_left*mse_left/n + n_right*mse_right/n) 

In [50]:
class TreeNode:
    def __init__(self, split_column, split_value, predicted_classes):
        self.split_column = split_column
        self.split_value = split_value
        self.predicted_classes = predicted_classes
        self.left = None
        self.right = None

In [85]:
class MyTreeReg:
    def __init__(self, max_depth=5, min_samples_split=2, max_leafs=20, bins=None):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.max_leafs = max_leafs if max_leafs > 1 else 2
        self.bins = bins
        self.fi = None
        
    def __str__(self):
        return f"MyTreeClf class: max_depth={self.max_depth}, min_samples_split={self.min_samples_split}, max_leafs={self.max_leafs}"
    
    def fit(self,X, y):
        self.n = len(y)
        self.leafs_cnt = 1
        self.fi = {key: 0 for key in X.columns}
        if self.bins != None and self.bins < X.shape[0] - 2:
            self.bins = self.get_bins(X)
        else:
            self.bins = None
        self.node = self.build_tree(X, y,0)
        
    def predict(self, X):
        self.y_predict = pd.Series(index=X.index)
        self.prediction(self.node, X)
        return self.y_predict
    
    def get_bins(self, X):
        bins = pd.DataFrame()
        for column in X.columns:
            sample = np.histogram(X[column], bins=self.bins)[1][1:-1]
            bins[column] = sample    
        return bins
    
    def prediction(self, node, X):
        if type(node) == np.float64:
            self.y_predict[X.index] = node
            return
        
        left_indexes = X[node.split_column] <= node.split_value
        right_indexes = X[node.split_column] > node.split_value
        self.prediction(node.left, X[left_indexes])
        self.prediction(node.right, X[right_indexes])     
        
    def build_tree(self, X, y, current_depth):
        if self.leafs_cnt >= self.max_leafs:
            return self.build_leaf(X,y)
        
        if X.shape[0] <= 1 or len(np.unique(y)) <= 1:
            return self.build_leaf(X,y)
        
        split_column, Q, ig = self.get_best_split(X, y)
        node = TreeNode(split_column, Q, -1)
        
        
        if current_depth < self.max_depth and len(y) >= self.min_samples_split:

            left_indices = X[split_column] <= Q
            right_indices = X[split_column] > Q

            X_left = X[left_indices]
            y_left = y[left_indices]
            X_right = X[right_indices]
            y_right = y[right_indices]
    
            self.leafs_cnt += 1
            node.left = self.build_tree(X_left, y_left, current_depth + 1)
            node.right = self.build_tree(X_right, y_right, current_depth + 1)
            
            self.fi[split_column] += X.shape[0]/self.n*(
                mse(y) - 
                len(y_left)/len(y)*mse(y_left) -
                len(y_right)/len(y)*mse(y_right)
            )
            
        else:
            return self.build_leaf(X,y)
        return node
    
    def build_leaf(self, X, y):
        return np.sum(y[X.index]) / len(y[X.index])        
    
    def get_best_split(self, X, y):
        best_ig = 0
        best_Q = 0
        best_column_name = None
        prev_value = None
       
        if self.bins is None:
            for column in X.columns:
                column = X[column].sort_values()
                for index, value in column.items():
                    if prev_value != None:
                        Q = (prev_value + value)/2
                        ig = msep(y, column, Q)
                        if ig > best_ig:
                            best_ig = ig
                            best_Q = Q
                            best_column_name = column.name
                    prev_value = value
        else:
            for column in X.columns:
                sample = self.bins[column]
                column = X[column]
                for Q in sample:
                    ig = msep(y, column, Q)
                    if ig > best_ig:
                        best_ig = ig
                        best_Q = Q
                        best_column_name = column.name
          
        return best_column_name, best_Q, best_ig  
    
    def print_tree(self, node):
        if type(node) == np.float64:
            print(node)
            return
        print(node.split_column, node.split_value)
        self.print_tree(node.left)
        self.print_tree(node.right)

In [86]:
tree = MyTreeReg(5, 5, 10, bins=10)

In [87]:
tree.fit(X, y)

In [88]:
tree.print_tree(tree.node)

col_12 -0.5305322341700651
col_3 -1.0787319943443108
-51.53357138412902
col_12 -1.4040935656111186
-26.193740693290327
col_4 -0.6398858261592428
-2.8460830070482714
col_0 1.0875910915842835
-15.798094157876667
-4.09681180271582
col_12 1.2165904287120415
col_3 0.10406780302684537
col_12 0.3430290972709882
col_3 -1.4729985934680294
-15.750145202415375
-6.55925550638459
6.501698112514691
17.697987445361417
35.38642492999285


In [89]:
tree.predict(X)

0      6.501698
1    -15.798094
2    -26.193741
3    -15.798094
4    -15.798094
5     17.697987
6     17.697987
7     35.386425
8     -2.846083
9    -15.798094
10    35.386425
11     6.501698
12    17.697987
13     6.501698
14    17.697987
15   -15.750145
16    17.697987
17    -4.096812
18    -6.559256
19    35.386425
20     6.501698
21     6.501698
22    35.386425
23   -51.533571
24    17.697987
25    17.697987
26     6.501698
27   -15.798094
28   -26.193741
29    17.697987
30   -26.193741
31     6.501698
32   -15.798094
33   -15.798094
34    -6.559256
35    -6.559256
36    17.697987
37   -15.750145
38    35.386425
39     6.501698
40   -15.798094
41    17.697987
42     6.501698
43    -6.559256
44    17.697987
45    17.697987
46    17.697987
47     6.501698
48    17.697987
49    17.697987
dtype: float64

In [90]:
tree.fi

{'col_0': 2.4341334889813093,
 'col_1': 0,
 'col_2': 0,
 'col_3': 71.34321292256898,
 'col_4': 2.443788781761423,
 'col_5': 0,
 'col_6': 0,
 'col_7': 0,
 'col_8': 0,
 'col_9': 0,
 'col_10': 0,
 'col_11': 0,
 'col_12': 285.62074290086684,
 'col_13': 0,
 'col_14': 0,
 'col_15': 0,
 'col_16': 0,
 'col_17': 0,
 'col_18': 0,
 'col_19': 0}