In [1]:
from sklearn.model_selection import cross_val_score, KFold
import numpy as np

In [2]:
from sklearn.datasets import (
    load_iris,
    load_digits,
    load_wine,
    load_breast_cancer,
    load_diabetes,
)

classification_loaders = {
    "iris": load_iris,
    "digits": load_digits,
    "wine": load_wine,
    "breast_cancer": load_breast_cancer,
}

regression_loaders = {
    "diabetes": load_diabetes,
}

In [3]:
from sklearn import tree
from sklearn import ensemble
from decision_tree.numpy import (
    DecisionTreeClassifier,
    DecisionTreeRegressor,
    RandomForestClassifier,
)
from decision_tree.jax import DecisionTreeClassifier as JaxDTC
from decision_tree.numpy.bagging import RandomForestRegressor

classification_models = {
    "sklearn DTC": tree.DecisionTreeClassifier(criterion="entropy", max_depth=4),
    "numpy DTC": DecisionTreeClassifier(max_depth=4, min_samples=1),
    "JAX DTC": JaxDTC(max_depth=4, min_samples=2),
    "sklearn RFC": ensemble.RandomForestClassifier(
        n_estimators=20, criterion="entropy", max_depth=4
    ),
    # "numpy RFC": RandomForestClassifier(n_estimators=20, max_depth=4, min_samples=2),
}

regression_models = {
    "sklearn DTR": tree.DecisionTreeRegressor(criterion="squared_error", max_depth=4),
    "our DTR": DecisionTreeRegressor(max_depth=4, min_samples=1),
    "sklearn RFR": ensemble.RandomForestRegressor(
        max_depth=4, criterion="squared_error"
    ),
    "our RFR": RandomForestRegressor(max_depth=4, min_samples=1),
}

In [4]:
def benchmark_dataset(ds_name, loader, model):
    dataset = loader()
    X, y = dataset["data"], dataset["target"]
    cv_scores = []
    for train_idx, test_idx in KFold(n_splits=5, shuffle=True).split(X):
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        model.fit(X_train, y_train)
        cv_scores.append(model.score(X_test, y_test))
    cv_score = np.mean(cv_scores)
    return cv_score

In [5]:
from prettytable import PrettyTable


def benchmark(loaders, models):
    leaderboard = PrettyTable()
    dataset_names = list(loaders.keys())
    leaderboard.add_column("Dataset", dataset_names)

    for model_name, model in models.items():
        results = [
            f"{benchmark_dataset(ds_name, loaders[ds_name], model):.3f}"
            for ds_name in dataset_names
        ]
        leaderboard.add_column(model_name, results)

    print(leaderboard)

In [6]:
benchmark(classification_loaders, classification_models)

+---------------+-------------+-----------+---------+-------------+
|    Dataset    | sklearn DTC | numpy DTC | JAX DTC | sklearn RFC |
+---------------+-------------+-----------+---------+-------------+
|      iris     |    0.967    |   0.933   |  0.947  |    0.953    |
|     digits    |    0.679    |   0.728   |  0.685  |    0.905    |
|      wine     |    0.950    |   0.944   |  0.893  |    0.983    |
| breast_cancer |    0.935    |   0.944   |  0.940  |    0.954    |
+---------------+-------------+-----------+---------+-------------+


In [6]:
benchmark(regression_loaders, regression_models)

+----------+-------------+---------+-------------+---------+
| Dataset  | sklearn DTR | our DTR | sklearn RFR | our RFR |
+----------+-------------+---------+-------------+---------+
| diabetes |    0.332    |  0.376  |    0.438    |  0.344  |
+----------+-------------+---------+-------------+---------+
