<a href="https://colab.research.google.com/github/Riddhi-14/Assessments/blob/main/CodingAssignment3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import numpy as np
from tabulate import tabulate

class DecisionTree:
    def __init__(self):
        self.tree = None

    # Calculate entropy
    def entropy(self, y):
        unique, counts = np.unique(y, return_counts=True)
        probabilities = counts / len(y)
        entropy = -np.sum(probabilities * np.log2(probabilities))
        return entropy

    # Calculate information gain
    def information_gain(self, X, y, feature_idx, split_value):
        left_mask = X[:, feature_idx] <= split_value
        right_mask = ~left_mask

        left_entropy = self.entropy(y[left_mask])
        right_entropy = self.entropy(y[right_mask])

        n = len(y)
        left_weight = np.sum(left_mask) / n
        right_weight = np.sum(right_mask) / n

        information_gain = self.entropy(y) - (left_weight * left_entropy + right_weight * right_entropy)
        return information_gain

    # Find the best split for a given feature
    def find_best_split(self, X, y, feature_idx):
        unique_values = np.unique(X[:, feature_idx])
        splits = []

        for value in unique_values:
            information_gain = self.information_gain(X, y, feature_idx, value)
            splits.append([feature_idx, value, information_gain])

        return splits

    # Recursively build the decision tree
    def build_tree(self, X, y, depth=0, max_depth=None):
        n_samples, n_features = X.shape
        n_classes = len(np.unique(y))

        if n_classes == 1 or (max_depth is not None and depth == max_depth):
            return (None, None, None, np.argmax(np.bincount(y)))

        best_feature_idx = None
        best_split_value = None
        best_information_gain = -1

        for feature_idx in range(n_features):
            splits = self.find_best_split(X, y, feature_idx)
            for split in splits:
                if split[2] > best_information_gain:
                    best_feature_idx = split[0]
                    best_split_value = split[1]
                    best_information_gain = split[2]

        if best_information_gain == 0:
            return (None, None, None, np.argmax(np.bincount(y)))

        left_mask = X[:, best_feature_idx] <= best_split_value
        right_mask = ~left_mask

        left_subtree = self.build_tree(X[left_mask], y[left_mask], depth + 1, max_depth)
        right_subtree = self.build_tree(X[right_mask], y[right_mask], depth + 1, max_depth)

        return (best_feature_idx, best_split_value, left_subtree, right_subtree)

    # Fit the decision tree
    def fit(self, X, y, max_depth=None):
        self.tree = self.build_tree(X, y, max_depth=max_depth)

    # Predict labels for new data
    def predict(self, X):
        return np.array([self._predict(x, self.tree) for x in X])

    def _predict(self, x, tree):
        if tree[0] is None:  # Base case: leaf node
            return tree[3]  # Return the predicted class label
        feature_idx, split_value, left_subtree, right_subtree = tree
        if x[feature_idx] <= split_value:
            return self._predict(x, left_subtree)
        else:
            return self._predict(x, right_subtree)

# Mock training data
X_train = np.array([[5.1, 3.5], [4.9, 3.0], [4.7, 3.2], [4.6, 3.1], [5.0, 3.6]])
y_train = np.array([0, 0, 1, 1, 2])  # Example labels corresponding to the training data

# Initialize the decision tree
tree = DecisionTree()

# Display information gain in a table
splits = []
for feature_idx in range(X_train.shape[1]):
    feature_splits = tree.find_best_split(X_train, y_train, feature_idx)
    splits.extend(feature_splits)

print(tabulate(splits, headers=['Feature Index', 'Split Value', 'Information Gain'], tablefmt='grid'))


+-----------------+---------------+--------------------+
|   Feature Index |   Split Value |   Information Gain |
|               0 |           4.6 |           0.321928 |
+-----------------+---------------+--------------------+
|               0 |           4.7 |           0.970951 |
+-----------------+---------------+--------------------+
|               0 |           4.9 |           0.570951 |
+-----------------+---------------+--------------------+
|               0 |           5   |           0.321928 |
+-----------------+---------------+--------------------+
|               0 |           5.1 |           0        |
+-----------------+---------------+--------------------+
|               1 |           3   |           0.321928 |
+-----------------+---------------+--------------------+
|               1 |           3.1 |           0.170951 |
+-----------------+---------------+--------------------+
|               1 |           3.2 |           0.570951 |
+-----------------+------------