In [2]:
import numpy as np
from random import shuffle
from sklearn.datasets import load_iris
from typing import List
from enum import Enum

In [3]:
class FeatureType(Enum):
    NUMERICAL = 0
    CATEGORICAL = 1

def entropy(y: np.ndarray):
    _, counts = np.unique(y, return_counts=True)
    p = counts / counts.sum()
    return -(p * np.log2(p + 1e-9)).sum()  # Add small value to prevent log(0)

def information_gain(y: np.ndarray, y_splits: List[np.ndarray]):
    ig = entropy(y)
    total = len(y)
    for ys in y_splits:
        if len(ys) == 0:  # Avoid empty splits
            continue
        ig -= (len(ys) / total) * entropy(ys)
    return ig

def split_at_feature(feature: np.ndarray, feature_type: FeatureType, y: np.ndarray):
    if feature_type == FeatureType.CATEGORICAL:
        unique_values = np.unique(feature)
        best_ig, best_value = -1, None
        for value in unique_values:
            left, right = y[feature == value], y[feature != value]
            ig = information_gain(y, [left, right])
            if ig > best_ig:
                best_ig, best_value = ig, value
        return best_value, best_ig

    elif feature_type == FeatureType.NUMERICAL:
        thresholds = np.linspace(feature.min(), feature.max(), 11)[1:]
        best_ig, best_threshold = -1, None
        for threshold in thresholds:
            left, right = y[feature < threshold], y[feature >= threshold]
            ig = information_gain(y, [left, right])
            if ig > best_ig:
                best_ig, best_threshold = ig, threshold
        return best_threshold, best_ig
        

class DecisionTree:
    def __init__(self, num_classes: int, feature_types: List[FeatureType], max_depth: int=100, gain_threshold: float=1e-3):
        self.num_classes = num_classes
        self.feature_types = feature_types
        self.max_depth = max_depth
        self.gain_threshold = gain_threshold
        
        self._distribution: np.ndarray = np.zeros(num_classes)
        self._feature_idx: int = None
        self._information_gain: float = 0
        self._split_value: float = None
        self._children: List[DecisionTree] = []
        
    def fit(self, X: np.ndarray, y: np.ndarray):
        classes, counts = np.unique(y, return_counts=True)
        self._distribution[classes] = counts / len(y)

        if self.max_depth == 1 or len(np.unique(y)) == 1:
            return
        
        best_ig = 0
        for i, (feature, feature_type) in enumerate(zip(X.T, self.feature_types)):
            split_value, ig = split_at_feature(feature, feature_type, y)
            if ig > best_ig:
                best_ig, self._feature_idx, self._split_value = ig, i, split_value

        if best_ig < self.gain_threshold or self._feature_idx is None:
            return  # Stop splitting

        # Split the data
        feature = X[:, self._feature_idx]
        if self.feature_types[self._feature_idx] == FeatureType.CATEGORICAL:
            left_mask, right_mask = feature == self._split_value, feature != self._split_value
        else:
            left_mask, right_mask = feature < self._split_value, feature >= self._split_value

        if left_mask.sum() == 0 or right_mask.sum() == 0:  # Prevent empty splits
            return
        
        # Create children
        self._children = [
            DecisionTree(self.num_classes, self.feature_types, self.max_depth - 1, self.gain_threshold),
            DecisionTree(self.num_classes, self.feature_types, self.max_depth - 1, self.gain_threshold)
        ]
        self._children[0].fit(X[left_mask], y[left_mask])
        self._children[1].fit(X[right_mask], y[right_mask])
    
    def predict(self, x: np.ndarray) -> int:
        if not self._children:
            return self._distribution.argmax()
        
        if self.feature_types[self._feature_idx] == FeatureType.CATEGORICAL:
            branch = 0 if x[self._feature_idx] == self._split_value else 1
        else:
            branch = 0 if x[self._feature_idx] < self._split_value else 1

        return self._children[branch].predict(x)
    
    def __repr__(self):
        if not self._children:
            majority = self._distribution.argmax()
            return f"Leaf(class={majority}, prob={self._distribution[majority]:.2f})"
        
        s = f"Node(feature_idx={self._feature_idx}, split_value={self._split_value})"
        for i, child in enumerate(self._children):
            s += "\n\t" + f"{'Left' if i == 0 else 'Right'}: " + repr(child).replace("\n", "\n\t")
        
        return s

In [4]:
data = load_iris(as_frame=True)
X = data["data"].values
y = data["target"].values

idx = list(range(len(X)))
shuffle(idx)
X = X[idx]
y = y[idx]

val_size = int(len(X)*0.1)
train_X, val_X = X[val_size:], X[:val_size]
train_y, val_y = y[val_size:], y[:val_size]

In [5]:
classifier = DecisionTree(len(np.unique(y)), [FeatureType.NUMERICAL]*train_X.shape[1])
classifier.fit(train_X, train_y)

In [6]:
preds = []
for x in val_X:
    preds.append(classifier.predict(x))
    
preds = np.array(preds)

(preds == val_y).mean().item()

1.0

In [7]:
preds = []
for x in train_X:
    preds.append(classifier.predict(x))
    
preds = np.array(preds)

(preds == train_y).mean().item()

1.0

In [8]:
classifier

Node(feature_idx=2, split_value=2.18)
	Left: Leaf(class=0, prob=1.00)
	Right: Node(feature_idx=3, split_value=1.75)
		Left: Node(feature_idx=2, split_value=4.96)
			Left: Node(feature_idx=3, split_value=1.63)
				Left: Leaf(class=1, prob=1.00)
				Right: Leaf(class=2, prob=1.00)
			Right: Node(feature_idx=3, split_value=1.52)
				Left: Leaf(class=2, prob=1.00)
				Right: Node(feature_idx=0, split_value=6.720000000000001)
					Left: Leaf(class=1, prob=1.00)
					Right: Leaf(class=2, prob=1.00)
		Right: Node(feature_idx=0, split_value=6.06)
			Left: Node(feature_idx=1, split_value=3.0500000000000003)
				Left: Leaf(class=2, prob=1.00)
				Right: Leaf(class=1, prob=1.00)
			Right: Leaf(class=2, prob=1.00)