In [1]:
import numpy as np
import math
import sklearn.datasets

In [7]:


class DecisionTreeBase:
    def __init__(self, depth=0, max_depth=5, min_samples=3):
        self.l_child = None
        self.r_child = None
        self.max_depth = max_depth
        self.min_samples = min_samples
        self.depth = depth
        self.split_col = None
        self.split_val = None
        self.is_leaf = False
        self.y_vals = []
        
    def find_best_split(self, data, target):
        min_loss = math.inf
        best_col = None
        best_split = None

        for col_idx in range(len(data[0])):
            sorted_x = data[data[:, col_idx].argsort()]
            sorted_y = target[data[:, col_idx].argsort()]

            for row_idx in range(len(data) - 1):
                lesser_half = sorted_y[sorted_x[:, col_idx] <= sorted_x[row_idx][col_idx]]
                greater_half = sorted_y[sorted_x[:, col_idx] > sorted_x[row_idx][col_idx]]
                avg_loss = (self.loss(lesser_half) + self.loss(greater_half)) / 2
                if avg_loss < min_loss:
                    best_col = col_idx
                    best_split = sorted_x[row_idx][col_idx]
                    min_loss = avg_loss 
                if min_loss == 0:
                    break
            if min_loss == 0:
                break
                
        return best_col, best_split 
    
    def fit(self, X, y):
        split_col, split_val = self.find_best_split(X, y)
        self.split_col = split_col
        self.split_val = split_val

        lesser_criteria = X[:, split_col] <= split_val
        greater_criteria = X[:, split_col] > split_val
        if len(y[lesser_criteria]) < self.min_samples or len(y[greater_criteria]) < self.min_samples:
            self.is_leaf = True
            self.y_vals = y
        else:
            self.l_child = type(self)(self.depth + 1)
            self.l_child.fit(X[lesser_criteria], y[lesser_criteria])
            self.r_child = type(self)(self.depth + 1)
            self.r_child.fit(X[greater_criteria], y[greater_criteria])
            
    def predict(self, X):
        if not self.l_child:
            return self.leaf_predict(self.y_vals)
        if X[self.split_col] <= self.split_val:
            return self.l_child.predict(X)
        else:
            return self.r_child.predict(X)

    # to be implemented by child classes
    def loss(self, y):
        return None

    def leaf_predict(self, y):
        return None
    

class DecisionTreeClassifier(DecisionTreeBase):
    
    def loss(self, y):
        # get count of each unique value in the dataset
        _, counts = np.unique(y, return_counts=True)

        # gini impurity formula
        impurity = 1 - np.sum((counts / len(y))**2)
        return impurity

    def leaf_predict(self, y):
        return np.bincount(y).argmax()


class DecisionTreeRegressor(DecisionTreeBase):

    def loss(self, y):
        return np.mean((y - np.mean(y))**2)

    def leaf_predict(self, y):
        return np.mean(y)



# Testing Classification Tree

In [8]:
# note: all iris data is numeric

iris_df = sklearn.datasets.load_iris()
data = iris_df["data"]
target = iris_df["target"]

tree = DecisionTreeClassifier()
tree.fit(data, target)
a = 10
print(tree.predict(iris_df["data"][a]))
print(iris_df["target"][a])

0
0


In [9]:
correct = 0
for i in range(len(iris_df["data"])):
    pred = tree.predict(iris_df["data"][i])
    true = iris_df["target"][i]
    if pred == true:
        correct += 1        
print(f"{correct / len(iris_df['data']) * 100}% accurate")

96.0% accurate


# Testing Regression Tree

In [10]:
diabetes_df = sklearn.datasets.load_diabetes()
data = diabetes_df["data"]
target = diabetes_df["target"]
tree = DecisionTreeRegressor()
tree.fit(data, target)

a = [tree.predict(diabetes_df["data"][i]) for i in range(len(target))]
a



  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


[152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416289594,
 152.13348416