# Supervised Learning (Classification)

Load necessary libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc

from stochtree import BARTModel

Generate sample data

In [None]:
# RNG
rng = np.random.default_rng()

# Generate covariates
n = 1000
p_X = 10
X = rng.uniform(0, 1, (n, p_X))


# Define the outcome mean function
def outcome_mean(X):
    return np.where(
        (X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
        -7.5 * X[:, 1],
        np.where(
            (X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
            -2.5 * X[:, 1],
            np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * X[:, 1], 7.5 * X[:, 1]),
        ),
    )


# Generate outcome
epsilon = rng.normal(0, 1, n)
z = outcome_mean(X) + epsilon
y = np.where(z >= 0, 1, 0)

Test-train split

In [None]:
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
z_train = z[train_inds]
z_test = z[test_inds]
y_train = y[train_inds]
y_test = y[test_inds]

Run BART

In [None]:
num_gfr = 10
num_mcmc = 100
bart_model = BARTModel()
general_params = {"num_chains": 1, "probit_outcome_model": True}
bart_model.sample(
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    num_gfr=num_gfr,
    num_mcmc=num_mcmc,
    general_params=general_params
)

Since we've simulated this data, we can compare the true latent continuous outcome variable to the probit-scale predictions for a test set.

In [None]:
plt.scatter(x=np.mean(bart_model.y_hat_test,axis=1), y=z_test)
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

On non-simulated datasets, the first thing we would evaluate is the prediction accuracy.

In [None]:
preds_test = np.mean(bart_model.y_hat_test,axis=1) > 0
print(f"Test set accuracy: {np.mean(y_test == preds_test):.3f}")

We can also compute the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) for every posterior sample, as well as the ROC of the posterior mean.

In [None]:
num_gfr = 10
num_mcmc = 100
fpr_list = list()
tpr_list = list()
threshold_list = list()
for i in range(num_mcmc):
    fpr, tpr, thresholds = roc_curve(y_test, bart_model.y_hat_test[:,i], pos_label=1)
    fpr_list.append(fpr)
    tpr_list.append(tpr)
    threshold_list.append(thresholds)
probit_preds_test_mean = np.mean(bart_model.y_hat_test,axis=1)
fpr_mean, tpr_mean, thresholds_mean = roc_curve(y_test, probit_preds_test_mean, pos_label=1)
for i in range(num_mcmc):
    plt.plot(fpr_list[i], tpr_list[i], color = 'blue', linestyle='solid', linewidth = 0.9)
plt.plot(fpr_mean, tpr_mean, color = 'black', linestyle='dashed', linewidth = 1.75)
plt.axline((0, 0), slope=1, color="red", linestyle='dashed', linewidth=1.5)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.show()