In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Generate some example data
np.random.seed(42)
data = np.random.rand(10, 2)  # 10 samples, 2 features (color, diameter)
labels = np.random.choice([0, 1], 10)  # 0 represents apple, 1 represents orange

In [3]:
class DecisionNode:
    def __init__(self, feature_index=None, threshold=None, left=None, right=None, value=None):
        self.feature_index = feature_index  # Index of feature to split on
        self.threshold = threshold          # Threshold value for the feature
        self.left = left                    # Left subtree
        self.right = right                  # Right subtree
        self.value = value                  # Class label for leaf nodes

def gini_impurity(y):
    # Calculate Gini impurity for a set of labels
    unique_classes, counts = np.unique(y, return_counts=True)
    probabilities = counts / len(y)
    gini = 1 - np.sum(probabilities**2)
    return gini

def split_data(X, y, feature_index, threshold):
    # Split data based on a given feature and threshold
    left_mask = X[:, feature_index] <= threshold
    right_mask = ~left_mask
    return X[left_mask], y[left_mask], X[right_mask], y[right_mask]

def find_best_split(X, y):
    # Find the best feature and threshold to split the data based on Gini impurity
    best_gini = float('inf')
    best_feature = None
    best_threshold = None

    for feature_index in range(X.shape[1]):
        thresholds = np.unique(X[:, feature_index])
        for threshold in thresholds:
            X_left, y_left, X_right, y_right = split_data(X, y, feature_index, threshold)
            gini = (len(y_left) * gini_impurity(y_left) + len(y_right) * gini_impurity(y_right)) / len(y)

            if gini < best_gini:
                best_gini = gini
                best_feature = feature_index
                best_threshold = threshold

    return best_feature, best_threshold

def build_tree(X, y, depth=1, max_depth=None):
    # Recursively build the decision tree
    if depth == max_depth or len(np.unique(y)) == 1:
        # Create a leaf node
        return DecisionNode(value=np.bincount(y).argmax())

    feature_index, threshold = find_best_split(X, y)
    if feature_index is None:
        # Unable to find a split, create a leaf node
        return DecisionNode(value=np.bincount(y).argmax())

    # Split the data
    X_left, y_left, X_right, y_right = split_data(X, y, feature_index, threshold)

    # Recursively build left and right subtrees
    left_subtree = build_tree(X_left, y_left, depth + 1, max_depth)
    right_subtree = build_tree(X_right, y_right, depth + 1, max_depth)

    # Create a decision node
    return DecisionNode(feature_index=feature_index, threshold=threshold, left=left_subtree, right=right_subtree)

def plot_decision_tree(tree, feature_names=None, class_names=None, depth=0):
    # Visualize the decision tree using matplotlib
    if tree.value is not None:
        if class_names:
            print(f"{depth*'  '}Predicted class: {class_names[tree.value]}")
        else:
            print(f"{depth*'  '}Predicted class: {tree.value}")
    else:
        if feature_names:
            print(f"{depth*'  '}Split on feature {feature_names[tree.feature_index]} at threshold {tree.threshold}")
        else:
            print(f"{depth*'  '}Split on feature {tree.feature_index} at threshold {tree.threshold}")

        plot_decision_tree(tree.left, feature_names, class_names, depth + 1)
        plot_decision_tree(tree.right, feature_names, class_names, depth + 1)



In [4]:
# Build the decision tree
tree = build_tree(data, labels)

In [5]:
# Visualize the decision tree
plot_decision_tree(tree)

Split on feature 0 at threshold 0.15601864044243652
  Predicted class: 1
  Split on feature 0 at threshold 0.18182496720710062
    Predicted class: 0
    Split on feature 1 at threshold 0.5247564316322378
      Predicted class: 1
      Split on feature 0 at threshold 0.3745401188473625
        Predicted class: 0
        Split on feature 0 at threshold 0.6011150117432088
          Predicted class: 1
          Predicted class: 0
