# Workbench

## Load the data

In [15]:
from sklearn import datasets
from sklearn import preprocessing

#X, y = datasets.fetch_openml('mnist_784', return_X_y=True)
X, y = datasets.load_digits(return_X_y=True)
labels = sorted(set(y))
X = preprocessing.scale(X)

In [16]:
from sklearn import model_selection

X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, random_state=42)
X_fit, X_val, y_fit, y_val = model_selection.train_test_split(X_train, y_train, random_state=42)

## A first random tree

In [17]:
import myriade
from sklearn import linear_model

clf = linear_model.LogisticRegression()
rand = myriade.multiclass.RandomBalancedHierarchyClassifier(clf)
rand.fit(X_fit, y_fit)
rand.tree_

Evaluate accuracy on the test set.

In [18]:
from sklearn import metrics

y_pred = rand.predict(X_test)
print(f"{metrics.accuracy_score(y_test, y_pred):.2%}")

81.78%


Predict on the validation set, in order to build the confusion matrix.

In [19]:
y_fit

In [20]:
y_pred = rand.predict(X_val)
confusion = metrics.confusion_matrix(y_val, y_pred, labels=labels)
confusion = pd.DataFrame(confusion, index=labels, columns=labels).sort_index(axis=1).sort_index(axis=0)
confusion

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,30,0,2,0,0,3,0,0,0,0
1,0,34,1,1,4,0,0,0,0,1
2,1,5,21,1,0,0,0,0,2,0
3,0,0,0,34,0,1,0,0,3,0
4,0,3,0,0,28,0,4,2,0,0
5,0,0,0,0,0,29,0,1,5,0
6,0,0,0,0,0,0,39,0,0,0
7,0,0,0,0,1,0,0,21,0,2
8,0,4,1,1,0,3,0,0,20,3
9,0,1,1,0,0,0,0,0,1,23


In [32]:
cm = confusion.to_numpy()
cm[:, -1]

In [33]:
cm[-1, :]

In [36]:
base_model = myriade.multiclass.RandomBalancedHierarchyClassifier(
    classifier=linear_model.LogisticRegression(),
    seed=42
)
cv = model_selection.KFold(
    n_splits=2,
    shuffle=True,
    random_state=42
)
model = myriade.multiclass.BalancedHierarchyClassifier(
    classifier=linear_model.LogisticRegression(),
    base_model=base_model,
    cv=cv
)
model.fit(X_train, y_train)
print(f"{model.score(X_test, y_test):.2%}")

94.89%


## Learning a (better) tree structure from the confusion matrix

In [102]:
import warnings
import numpy as np

def make_tree_from_pairs(pairs):
    return myriade.Branch(
        make_tree_from_pairs(pairs[0]) if isinstance(pairs[0], tuple) else pairs[0],
        make_tree_from_pairs(pairs[1]) if isinstance(pairs[1], tuple) else pairs[1]
    )


def pair_labels(confusion):

    labels = confusion.index.tolist()

    errors = np.triu(1 + confusion + confusion.T, k=1)
    
    # Here we find the pairs of labels that are most confused with each other. We keep going until
    # there are no more pairs to compare.
    pairs_idx = []
    while errors.any():
        i, j = np.unravel_index(np.argmax(errors), errors.shape)
        pairs_idx.append((i, j))
        errors[i, :] = 0
        errors[:, i] = 0
        errors[j, :] = 0
        errors[:, j] = 0
    pairs = [(labels[i], labels[j]) for i, j in pairs_idx]

    # If the number of labels is odd, there will be one label left over. We call this an orphan.
    orphans_idx = list(set(range(len(labels))) - set(i for i, _ in pairs_idx) - set(j for _, j in pairs_idx))
    orphans = [labels[idx] for idx in orphans_idx]

    # We now create a new confusion matrix that combines the pairs of labels.
    confusion_arr = (
        confusion.iloc[[p[0] for p in pairs_idx] + orphans_idx, [p[0] for p in pairs_idx] + orphans_idx].to_numpy() +
        confusion.iloc[[p[0] for p in pairs_idx] + orphans_idx, [p[1] for p in pairs_idx] + orphans_idx].to_numpy() +
        confusion.iloc[[p[1] for p in pairs_idx] + orphans_idx, [p[0] for p in pairs_idx]+ orphans_idx].to_numpy() +
        confusion.iloc[[p[1] for p in pairs_idx] + orphans_idx, [p[1] for p in pairs_idx]+ orphans_idx].to_numpy()
    )

    # HACK: If there are orphans, we need to divide the corresponding rows and columns by 2 to
    # avoid double counting.
    if orphans:
        confusion_arr[-1, :] = confusion_arr[-1, :] / 2
        confusion_arr[:, -1] = confusion_arr[:, -1] / 2

    confusion = pd.DataFrame(confusion_arr, index=pairs + orphans, columns=pairs + orphans)

    # Termination condition
    if len(confusion.columns) == 1:
        return make_tree_from_pairs(pairs[0])

    return pair_labels(confusion)

confusion = metrics.confusion_matrix(y_val, y_pred, labels=labels)
confusion = pd.DataFrame(confusion, index=labels, columns=labels).sort_index(axis=1).sort_index(axis=0)
new_tree = pair_labels(confusion)
new_tree

In [103]:
better = myriade.multiclass.ManualHierarchyClassifier(clf, new_tree)
better.fit(X_fit, y_fit)
y_pred = rand.predict(X_test)
print(f"{metrics.accuracy_score(y_test, y_pred):.2%}")

94.44%


In [13]:
import myriade
from sklearn import linear_model, metrics

clf = linear_model.LogisticRegression()
balanced = myriade.multiclass.BalancedHierarchyClassifier(clf)
balanced.fit(X_fit, y_fit)
y_pred = balanced.predict(X_test)
print(f"{metrics.accuracy_score(y_test, y_pred):.2%}")

93.78%


In [14]:
confusion