In [None]:
!pip install proglearn
!pip install git+https://github.com/neurodata/treeple.git

In [9]:
import numpy as np
import time
import matplotlib.pyplot as plt
import seaborn as sns
from treeple import ObliqueRandomForestClassifier
from proglearn.sims import generate_gaussian_parity
from sklearn.metrics import accuracy_score
from proglearn.sims import generate_spirals
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.datasets import mnist, fashion_mnist

In [10]:
def MultiTaskClf(task0_train, task1_train, task0_test, task1_test,
                     clf_type="SPORF", **kwargs):
    """
    A  multi-task classifier with different forest choices.
    Parameters:
        task0_train, task1_train: input
        task0_test, task1_test: input
        clf_type: "SPORF" (will add MORF and Honest Forest Later)
        **kwargs: Tunable hyperparameters depends on tree (e.g. SPORF will be n_estimators and feature_combinations)
    Returns:
        Dictionary with accuracies for each task
    """
    if clf_type == "SPORF":
        Clf = ObliqueRandomForestClassifier
        default_params = {
            "n_estimators": 200,
            "feature_combinations": 2.0
        }
        params = {**default_params, **kwargs}
        X0_train, y0_train = task0_train
        X1_train, y1_train = task1_train
        X0_test, y0_test = task0_test
        X1_test, y1_test = task1_test

        # Add label for task
        task_train_labels = np.concatenate([np.zeros(len(y0_train)), np.ones(len(y1_train))])
        X_train = np.vstack([X0_train, X1_train])
        X_train = np.column_stack((X_train, task_train_labels))
        y_train = np.concatenate([y0_train, y1_train])

        X0_test = np.column_stack((X0_test, np.zeros(len(y0_test))))
        X1_test = np.column_stack((X1_test, np.ones(len(y1_test))))

        # Train and predict
        clf = Clf(**params, random_state=42)
        clf.fit(X_train, y_train)

        y0_pred = clf.predict(X0_test)
        y1_pred = clf.predict(X1_test)

    elif clf_type == "MORF":
        raise NotImplementedError("MORF will be available soon.")
    elif clf_type == "HonestForest":
        raise NotImplementedError("HonestForest will be available soon.")
    else:
        raise NotImplementedError(f"{clf_type} not supported.")


    return {
        "task0_accuracy": accuracy_score(y0_test, y0_pred),
        "task1_accuracy": accuracy_score(y1_test, y1_pred)
    }