In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
iris_data = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)
iris_data.columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'class']

X = iris_data.iloc[:, :-1].values
y = iris_data.iloc[:, -1].values

In [3]:
def entropy(labels):
    _, counts = np.unique(labels, return_counts=True)
    probabilities = counts / counts.sum()
    return -probabilities.dot(np.log2(probabilities))

In [4]:
def split(X, y):
    best_feature, best_threshold, best_gain = None, None, 0
    for feature in range(X.shape[1]):
        thresholds = np.unique(X[:, feature])
        for threshold in thresholds:
            y_left = y[X[:, feature] < threshold]
            y_right = y[X[:, feature] >= threshold]
            if len(y_left) == 0 or len(y_right) == 0:
                continue
            gain = entropy(y) - (len(y_left) / len(y)) * entropy(y_left) - (len(y_right) / len(y)) * entropy(y_right)
            if gain > best_gain:
                best_feature = feature
                best_threshold = threshold
                best_gain = gain
    return best_feature, best_threshold

In [5]:
def build_tree(X, y, max_depth=10, min_samples_split=2, depth=0):
    if len(y) < min_samples_split or depth >= max_depth:
        return pd.Series(y).value_counts().index[0]
    feature, threshold = split(X, y)
    if feature is None:
        return pd.Series(y).value_counts().index[0]
    indices_left = X[:, feature] < threshold
    X_left, y_left = X[indices_left], y[indices_left]
    X_right, y_right = X[~indices_left], y[~indices_left]
    left = build_tree(X_left, y_left, max_depth, min_samples_split, depth+1)
    right = build_tree(X_right, y_right, max_depth, min_samples_split, depth+1)
    return {'feature': feature, 'threshold': threshold, 'left': left, 'right': right}

In [6]:
def predict(x, tree):
    if isinstance(tree, dict):
        feature, threshold, left, right = tree.values()
        if x[feature] < threshold:
            return predict(x, left)
        else:
            return predict(x, right)
    else:
        return tree

In [23]:
def print_tree(node, depth=0, is_left=None):
    if isinstance(node, dict):
        feature = node['feature']
        threshold = node['threshold']
        left = node['left']
        right = node['right']
        if is_left:
            arrow = '↙'
        else:
            arrow = '↘'
        print("  " * depth + arrow + f" {feature} < {threshold}")
        print_tree(left, depth + 1, is_left=True)
        print_tree(right, depth + 1, is_left=False)
    else:
        print("  " * depth + "➤ " + str(node))


In [28]:
def printtree(node, depth=0, edge=''):
    if isinstance(node, dict):
        feature = node['feature']
        threshold = node['threshold']
        left = node['left']
        right = node['right']
        print("{}{}|-- X[{}] < {} Depth ={} ".format(depth * "  ", edge, feature, threshold,depth))
        printtree(left, depth, edge='L')
        printtree(right, depth + 1, edge='R')
    else:
        print("{}{}|-- {} Depth ={}".format(depth * "  ", edge, node,depth))

In [8]:
def accuracy(y_true, y_pred):
    return np.mean(y_true == y_pred)

In [9]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

In [10]:
tree = build_tree(X_train, y_train)
y_pred = np.array([predict(x, tree) for x in X_test])
acc = accuracy(y_test, y_pred)

In [24]:
print_tree(tree)

↘ 2 < 3.0
  ➤ Iris-setosa
  ↘ 2 < 4.9
    ↙ 3 < 1.8
      ➤ Iris-versicolor
      ↘ 0 < 6.0
        ➤ Iris-versicolor
        ➤ Iris-virginica
    ↘ 3 < 1.8
      ↙ 0 < 6.3
        ➤ Iris-virginica
        ↘ 0 < 7.2
          ➤ Iris-versicolor
          ➤ Iris-virginica
      ➤ Iris-virginica


In [29]:
printtree(tree)

|-- X[2] < 3.0 Depth =0 
L|-- Iris-setosa Depth =0
  R|-- X[2] < 4.9 Depth =1 
  L|-- X[3] < 1.8 Depth =1 
  L|-- Iris-versicolor Depth =1
    R|-- X[0] < 6.0 Depth =2 
    L|-- Iris-versicolor Depth =2
      R|-- Iris-virginica Depth =3
    R|-- X[3] < 1.8 Depth =2 
    L|-- X[0] < 6.3 Depth =2 
    L|-- Iris-virginica Depth =2
      R|-- X[0] < 7.2 Depth =3 
      L|-- Iris-versicolor Depth =3
        R|-- Iris-virginica Depth =4
      R|-- Iris-virginica Depth =3


In [12]:
print("Accuracy:", acc)

Accuracy: 0.9
