In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class Node:
    
    def __init__(self,  predicted_class):
     
        self.predicted_class = predicted_class
        self.feature_index = 0
        self.threshold = 0
        self.left = None
        self.right = None

class DecisionTree:
    
    def __init__(self, X, y, max_depth):
        self.max_depth = max_depth
        self.n_classes = len(set(y))
        self.n_features = X.shape[1]
    
    def find_split(self, X, y):
        
        n_samples = X.shape[0]
    
        if n_samples <= 1:
            return None, None
        
        feature_ix, threshold = None, None
        
        sample_per_class_parent = [np.sum(y == c) for c in range(self.n_classes)]
        
        best_gini = 1.0 - sum((n / n_samples) ** 2 for n in sample_per_class_parent)
        
        for feature in range(self.n_features):
            
            sample_sorted = sorted(X[:, feature])
            sort_idx = np.argsort(X[:, feature])
            y_sorted = y[sort_idx]
            
            sample_per_class_left = [0] * self.n_classes
            
            sample_per_class_right = sample_per_class_parent.copy()
            
            for i in range(1, n_samples):
                
                c = y_sorted[i - 1]
                
                sample_per_class_left[c] += 1
                
                sample_per_class_right[c] -= 1
                
                gini_left = 1.0 - sum((sample_per_class_left[x] / i) ** 2 for x in range(self.n_classes))
                
                gini_right = 1.0 - sum((sample_per_class_right[x] / (n_samples - i)) ** 2 for x in range(self.n_classes))
                
                weighted_gini = ((i / n_samples) * gini_left) + ((n_samples - i)/ n_samples) * gini_right
                
                if sample_sorted[i] == sample_sorted[i - 1]:
                    continue
                    
                if weighted_gini < best_gini:
                    best_gini = weighted_gini
                    feature_ix = feature
                    threshold = (sample_sorted[i] + sample_sorted[i - 1]) / 2
        
        return feature_ix, threshold
    
    def fit(self, X, y,depth=0):
        
        num_samples_per_class = [np.sum(y == i) for i in range(self.n_classes)]
        
        predicted_class = np.argmax(num_samples_per_class)

        node = Node(predicted_class = predicted_class)
        
        if depth < self.max_depth:
            feature, threshold = self.find_split(X, y)
            
            if feature is not None:
                indices_left = X[:, feature] < threshold
                X_left, y_left = X[indices_left], y[indices_left]
                X_right, y_right = X[~indices_left], y[~indices_left]
                node.feature_index = feature
                node.threshold = threshold
                node.left = self.fit(X_left, y_left, depth + 1)
                node.right = self.fit(X_right, y_right, depth + 1)
        return node
    
    def _predict(self, sample, tree):
        while tree.left:
            if sample[tree.feature_index] < tree.threshold:
                tree = tree.left
            else:
                tree = tree.right
        
        return tree.predicted_class
    
    def predict(self, X, tree):
        return [self._predict(sample , tree) for sample in X]

In [3]:
import sys
from sklearn.datasets import load_iris

dataset = load_iris()

X, y = dataset.data, dataset.target
clf = DecisionTree(X, y, max_depth=10)
tree = clf.fit(X,y)
print(clf.predict([[0,0,5,1.5]], tree))

[2]


In [4]:

X = np.array([[2],[3],[10],[19]])
y = np.array([0, 0, 1, 1])

model = DecisionTree(X, y, max_depth=1)
feature, threshold = model.find_split(X, y)

#will print 0, 6.5
print("Best feature used for split: ", feature)
print("Best threshold used for split: ", threshold)

Best feature used for split:  0
Best threshold used for split:  6.5


In [5]:
Xtrain = np.array([[2, 5],[3, 5],[10, 5],[19, 5]])
ytrain = np.array([0, 0, 1, 1])
Xtest = np.array(([[4, 6],[6, 9],[9, 2],[12, 8]]))
ytest = np.array([0, 0, 1, 1])

model = DecisionTree(X, y, max_depth=3)
tree = model.fit(Xtrain, ytrain)
pred = model.predict(X, tree)

print("Tree feature ind: ", tree.feature_index)
print("Tree threshold: ", tree.threshold)
print("Pred: ", np.array(pred))
print("ytest: ", ytest)

Tree feature ind:  0
Tree threshold:  6.5
Pred:  [0 0 1 1]
ytest:  [0 0 1 1]
