In [10]:
pip install graphviz

Defaulting to user installation because normal site-packages is not writeable
Collecting graphviz
  Downloading graphviz-0.21-py3-none-any.whl.metadata (12 kB)
Downloading graphviz-0.21-py3-none-any.whl (47 kB)
Installing collected packages: graphviz
Successfully installed graphviz-0.21
Note: you may need to restart the kernel to use updated packages.


In [107]:
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter

In [108]:
def tf_entropy(labels):
    labels = np.array(labels)
    if labels.size == 0:
        return 0.0

    vals, counts = np.unique(labels, return_counts=True) # np.unique(labels) gives us unique values and return_counts gives us how many times the unique values repeat
    probs = counts.astype(np.float64) / counts.sum() # gives an array of probabilities

    p = tf.constant(probs, dtype=tf.float64) # tf.constant creates a fixed block of data to be used for calculations later, it simply stores the array in its format to be used for math operations
    ent = -tf.reduce_sum(p * tf.math.log(p)) / tf.math.log(tf.constant(2.0, dtype=tf.float64))
    return float(ent.numpy()) # we are converting it back to a numpy array to visualize it, we do this as tensorflow doesnt give the raw values like numpy

In [109]:
def information_gain(parent_labels, subsets_labels):
    H_parent = tf_entropy(parent_labels)

    total = 0
    for s in subsets_labels:
        total += len(s)

    weighted = 0.0
    for s in subsets_labels:
        h = tf_entropy(s)
        weight = len(s) / total
        weighted += weight * h

    IG = H_parent - weighted
    return IG

In [110]:
class ID3DecisionTree:
    def __init__(self, max_depth=None):
        self.tree = None
        self.max_depth = max_depth

    def fit(self, X: pd.DataFrame, y: pd.Series):
        data = X.copy()
        data['_label'] = y
        features = list(X.columns)
        self.tree = self._build_tree(data, features, depth=0)
        return self

    def _build_tree(self, data, features, depth):
        labels = data['_label'].values

        unique_labels = np.unique(labels)
        if len(unique_labels) == 1:
            return {'type': 'leaf', 'label': labels[0]}


        if len(features) == 0 or (self.max_depth is not None and depth >= self.max_depth):
            counts = Counter(labels)
            majority = counts.most_common(1)[0][0]
            return {'type': 'leaf', 'label': majority}

        best_feat = None
        best_ig = -np.inf
        best_splits = None

        for feat in features:
            vals = data[feat].unique()

            subsets_labels = []
            for v in vals:
                subset = data[data[feat] == v]['_label'].values
                subsets_labels.append(subset)

            ig = information_gain(labels, subsets_labels)

            if ig > best_ig:
                best_ig = ig
                best_feat = feat

                best_splits = {}
                for v in vals:
                    subset_data = data[data[feat] == v].drop(columns=[feat])
                    best_splits[v] = subset_data

        if best_ig <= 1e-9:
            counts = Counter(labels)
            majority = counts.most_common(1)[0][0]
            return {'type': 'leaf', 'label': majority}

        node = {'type': 'node', 'feature': best_feat, 'children': {}}

        remaining_features = []
        for f in features:
            if f != best_feat:
                remaining_features.append(f)

        for val, subset in best_splits.items():
            if subset.empty:
                counts = Counter(labels)
                majority = counts.most_common(1)[0][0]
                node['children'][val] = {'type': 'leaf', 'label': majority}
            else:
                child_subtree = self._build_tree(subset, remaining_features, depth + 1)
                node['children'][val] = child_subtree

        return node

    def predict_single(self, x: pd.Series):
        node = self.tree

        while node['type'] != 'leaf':
            feat = node['feature']
            val = x.get(feat, None)

            if val not in node['children']:
                labels = []

                def collect_labels(n):
                    if n['type'] == 'leaf':
                        labels.append(n['label'])
                    else:
                        for key in n['children']:
                            collect_labels(n['children'][key])

                collect_labels(node)

                if len(labels) > 0:
                    counts = Counter(labels)
                    return counts.most_common(1)[0][0]
                else:
                    return None

            node = node['children'][val]

        return node['label']

    def predict(self, X: pd.DataFrame):
        predictions = []
        for index in range(len(X)):
            row = X.iloc[index]
            label = self.predict_single(row)
            predictions.append(label)
        return pd.Series(predictions, index=X.index)

    def print_tree(self, node=None, indent=""):
        if node is None:
            node = self.tree

        if node['type'] == 'leaf':
            print(indent + "Leaf:", node['label'])
            return

        print(indent + f"[Feature: {node['feature']}]")
        for val in node['children']:
            print(indent + f" -> {node['feature']} = {val}:")
            self.print_tree(node['children'][val], indent + "    ")


In [111]:
def load_and_discretize_iris():
    iris = load_iris()
    X = pd.DataFrame(iris.data, columns=iris.feature_names)
    y = pd.Series(iris.target)

    X_disc = X.copy()
    for col in X.columns:
        median = X[col].median()
        new_col = []
        for v in X[col]:
            if v <= median:
                new_col.append(f"<=_{median:.3f}")
            else:
                new_col.append(f">_{median:.3f}")
        X_disc[col] = new_col

    return X_disc, y

In [112]:
from graphviz import Digraph

def visualize_tree(tree, filename="id3_tree"):
    """
    Visualize the nested-dict decision tree using Graphviz.
    """
    dot = Digraph(comment="ID3 Decision Tree", format="png")
    
    def add_nodes_edges(node, parent=None, edge_label=""):
        node_id = str(id(node))

        if node['type'] == 'leaf':
            label = f"Leaf: {node['label']}"
            dot.node(node_id, label, shape="box", style="filled", color="lightblue")
        else:
            label = f"{node['feature']}"
            dot.node(node_id, label, shape="ellipse", color="black")
        
        if parent:
            dot.edge(parent, node_id, label=edge_label)

        if node['type'] == 'node':
            for val, child in node['children'].items():
                add_nodes_edges(child, node_id, edge_label=str(val))

    add_nodes_edges(tree)

    output_path = dot.render(filename, view=True)
    print(f"Tree visualization saved as: {output_path}")


In [113]:
if __name__ == "__main__":
    X_disc, y = load_and_discretize_iris()
    X_train, X_test, y_train, y_test = train_test_split(
        X_disc, y, test_size=0.25, random_state=42, stratify=y
    )

    tree = ID3DecisionTree(max_depth=10)
    tree.fit(X_train, y_train)

    print("Learned decision tree:")
    tree.print_tree()

    visualize_tree(tree.tree, filename="iris_id3_tree")

    preds = tree.predict(X_test)
    acc = (preds.values == y_test.values).mean()
    print(f"\nTest accuracy: {acc:.4f} ({len(y_test)} samples)")


Learned decision tree:
[Feature: petal width (cm)]
 -> petal width (cm) = >_1.300:
    [Feature: petal length (cm)]
     -> petal length (cm) = >_4.350:
        [Feature: sepal length (cm)]
         -> sepal length (cm) = >_5.800:
            [Feature: sepal width (cm)]
             -> sepal width (cm) = <=_3.000:
                Leaf: 2
             -> sepal width (cm) = >_3.000:
                Leaf: 2
         -> sepal length (cm) = <=_5.800:
            Leaf: 2
     -> petal length (cm) = <=_4.350:
        Leaf: 1
 -> petal width (cm) = <=_1.300:
    [Feature: sepal width (cm)]
     -> sepal width (cm) = <=_3.000:
        [Feature: sepal length (cm)]
         -> sepal length (cm) = <=_5.800:
            [Feature: petal length (cm)]
             -> petal length (cm) = <=_4.350:
                Leaf: 1
             -> petal length (cm) = >_4.350:
                Leaf: 1
         -> sepal length (cm) = >_5.800:
            Leaf: 1
     -> sepal width (cm) = >_3.000:
        Leaf: 0
Tr