# Decision Tree

## Implementation Details
- DecisionNode class represents a node in the decision tree
- DecisionTreeClassifier class implements the decision tree algorithm
- The `fit` method trains the model by recursively growing the tree
- The `predict` method makes predictions by traversing the tree
- The `_best_split` method finds the best split for the data
- The `_grow_tree` method grows the tree by recursively splitting the data
- The `_entropy` method calculates the entropy of the data
- Gini impurity is used as the splitting criterion


Todo:
- Regression
- Multi-Class classification
- Entropy
- Pruning
- Feature Importance
- Handling Missing Values
- Handling Categorical Variables
- Handling Imbalanced Data
- Handling High-Dimensional Data

In [1]:
%pip install torch numpy

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import numpy as np


In [3]:
class DecisionNode:
    def __init__(self, feature_index=None, threshold=None, left=None, right=None, value=None):
        self.feature_index = feature_index
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value

class DecisionTreeClassifier:
    def __init__(self, max_depth=None):
        self.max_depth = max_depth
        self.root = None
        self.feature_names = None

    def fit(self, X, y, feature_names=None):
        self.n_classes = len(torch.unique(y))
        self.n_features = X.shape[1]
        self.feature_names = feature_names if feature_names is not None else [f"feature_{i}" for i in range(self.n_features)]
        self.root = self._grow_tree(X, y)
        if self.root is None:
            raise ValueError("Failed to build the tree. Check your data and parameters.")

    def _grow_tree(self, X, y, depth=0):
        n_samples, n_features = X.shape
        n_labels = len(torch.unique(y))

        if (self.max_depth is not None and depth >= self.max_depth) or n_labels == 1 or n_samples < 2:
            return DecisionNode(value=self._leaf_value(y))

        best_feature, best_threshold = self._best_split(X, y)
        
        if best_feature is None:
            return DecisionNode(value=self._leaf_value(y))

        left_idxs = X[:, best_feature] < best_threshold
        right_idxs = ~left_idxs

        left = self._grow_tree(X[left_idxs], y[left_idxs], depth+1)
        right = self._grow_tree(X[right_idxs], y[right_idxs], depth+1)

        return DecisionNode(best_feature, best_threshold, left, right)

    def _best_split(self, X, y):
        best_gain = -float('inf')
        best_feature, best_threshold = None, None

        for feature in range(self.n_features):
            thresholds = torch.unique(X[:, feature])
            for threshold in thresholds:
                gain = self._information_gain(X[:, feature], y, threshold)
                if gain > best_gain:
                    best_gain = gain
                    best_feature = feature
                    best_threshold = threshold

        return best_feature, best_threshold

    def _information_gain(self, X_column, y, threshold):
        parent_entropy = self._entropy(y)

        left_idxs = X_column < threshold
        right_idxs = ~left_idxs

        n = len(y)
        n_l, n_r = left_idxs.sum().item(), right_idxs.sum().item()

        if n_l == 0 or n_r == 0:
            return 0

        e_l, e_r = self._entropy(y[left_idxs]), self._entropy(y[right_idxs])
        child_entropy = (n_l / n) * e_l + (n_r / n) * e_r

        return parent_entropy - child_entropy

    def _entropy(self, y):
        _, counts = torch.unique(y, return_counts=True)
        probabilities = counts.float() / len(y)
        return -torch.sum(probabilities * torch.log2(probabilities + 1e-9))

    def _leaf_value(self, y):
        return torch.argmax(torch.bincount(y)).item()

    def predict(self, X):
        return torch.tensor([self._traverse_tree(x, self.root) for x in X])

    def _traverse_tree(self, x, node):
        if node.value is not None:
            return node.value
        if x[node.feature_index] < node.threshold:
            return self._traverse_tree(x, node.left)
        return self._traverse_tree(x, node.right)


In [4]:
# generate sample data
X = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.float32)
y = torch.tensor([0, 0, 1, 1])

model = DecisionTreeClassifier(max_depth=2)
model.fit(X, y, feature_names=['Feature A', 'Feature B'])

# predict
X_test = torch.tensor([[2, 3], [6, 7]], dtype=torch.float32)
predictions = model.predict(X_test)
print("\nPredictions:", predictions)



Predictions: tensor([0, 1])
