In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt


def dataframe_entropy(d: pd.DataFrame, column: str):
    return entropy(d[column].value_counts()[0] / len(d))


def entropy(p):
    if p == 0 or p == 1:
        return 0
    ap = 1-p
    return -(p*np.log2(p) + ap*np.log2(ap))


raw_dataset = pd.read_csv("./dogscats-categorical.csv")
raw_dataset.tail()


Unnamed: 0,Ears shape,Face shape,Whiskers,Cat
5,Pointy,Round,Absent,Yes
6,Floppy,Not round,Absent,No
7,Pointy,Round,Absent,Yes
8,Floppy,Round,Absent,No
9,Floppy,Round,Absent,No


In [2]:
dataset = raw_dataset.copy()
dataset = pd.get_dummies(dataset, columns=['Ears shape'], prefix='', prefix_sep='')
dataset.tail()


Unnamed: 0,Face shape,Whiskers,Cat,Floppy,Oval,Pointy
5,Round,Absent,Yes,False,False,True
6,Not round,Absent,No,True,False,False
7,Round,Absent,Yes,False,False,True
8,Round,Absent,No,True,False,False
9,Round,Absent,No,True,False,False


In [3]:
import pprint
import typing


def dt_split(dataframe: pd.DataFrame, labels_column: str, ignore_columns: list[str] = []):
    most_informative_feature: typing.Any = ""
    best_info_gain = -1
    BASE_ENTROPY = dataframe_entropy(dataframe, labels_column)
    subsets: typing.Any = None
    for column in dataframe.columns:
        if column in ignore_columns or column == labels_column:
            continue
        splits = list(dataframe.groupby(column))
        (value_left, subset_left) = splits[0]
        (value_right, subset_right) = splits[1] if len(splits) > 1 else ('UNKNOWN', None)
        entropy_left = dataframe_entropy(subset_left, labels_column)
        entropy_right = dataframe_entropy(subset_right, labels_column) if subset_right is not None else 0
        weight_left = len(subset_left) / len(dataframe)
        weight_right = len(subset_right) / len(dataframe) if subset_right is not None else 0
        weighted_entropy = weight_left * entropy_left + weight_right * entropy_right
        info_gain = BASE_ENTROPY - weighted_entropy
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            most_informative_feature = column
            subsets = (value_left, entropy_left, subset_left), (value_right, entropy_right, subset_right)
    return most_informative_feature, subsets


def most_common_value(dataframe: pd.DataFrame, column: str):
    return dataframe[column].mode()[0]


def recursive_split(dataframe: pd.DataFrame, labels_column: str, ignore_columns: list[str] = [], depth=0, max_depth=3):
    most_informative_feature, subsets = dt_split(dataframe=dataframe, labels_column=labels_column, ignore_columns=ignore_columns)
    (value_left, entropy_left, subset_left), (value_right, entropy_right, subset_right) = subsets
    options = {
        value_left: subset_left[labels_column].iloc[0],
        value_right: subset_right[labels_column].iloc[0] if subset_right is not None else None
    }
    if entropy_left > 0:
        if depth < max_depth:
            options[value_left] = recursive_split(dataframe=subset_left, labels_column=labels_column, ignore_columns=[*ignore_columns, most_informative_feature], depth=depth+1, max_depth=max_depth)
        else:
            options[value_left] = most_common_value(dataframe=subset_left, column=labels_column)

    if entropy_right > 0:
        if depth < max_depth:
            options[value_right] = recursive_split(dataframe=subset_right, labels_column=labels_column, ignore_columns=[*ignore_columns, most_informative_feature], depth=depth+1, max_depth=max_depth)
        else:
            options[value_right] = most_common_value(dataframe=subset_right, column=labels_column)

    return {most_informative_feature: options}


tree = recursive_split(dataset, "Cat", max_depth=3)
pprint.pprint(tree)


{'Floppy': {False: {'Face shape': {'Not round': {'Oval': {False: 'No',
                                                          True: 'Yes'}},
                                   'Round': {'Oval': {False: 'Yes',
                                                      True: {'Whiskers': {'Absent': 'No',
                                                                          'Present': 'Yes'}}}}}},
            True: 'No'}}
