# Ordinal Regression in StochTree

This notebook demonstrates how to use BART to model ordinal outcomes with a complementary log-log (cloglog) link function (Alam and Linero (2025)).

We begin by loading the requisite libraries.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

from stochtree import BARTModel, OutcomeModel

## Introduction to Ordinal BART with Cloglog Link

Ordinal data refers to outcomes that have a natural ordering but undefined distances between categories. Examples include survey responses (strongly disagree, disagree, neutral, agree, strongly agree), severity ratings (mild, moderate, severe), or educational levels (elementary, high school, college, graduate).

The cloglog link function is:
$$\text{cloglog}(p) = \log(-\log(1-p))$$

In the BART framework with cloglog ordinal regression, we model:
$$P(Y = k \mid Y \geq k, X = x) = 1 - \exp\left(-e^{\gamma_k + \lambda(x)}\right)$$

where $\lambda(x)$ is represented by a stochastic tree ensemble and $\gamma_k$ are cutpoints for the ordinal categories. This link function is asymmetric and particularly appropriate when the probability of being in higher categories changes rapidly at certain thresholds, making it different from the symmetric probit or logit links commonly used in ordinal regression.

## Data Simulation

We begin by simulating from a dataset with an ordinal outcome with three categories, $y_i \in \left\{1,2,3\right\}$ whose probabilities depend on covariates, $X$.

In [None]:
# RNG
random_seed = 2026
rng = np.random.default_rng(random_seed)

# Sample size and number of predictors
n = 2000
p = 5

# Design matrix and true lambda function
X = rng.standard_normal((n, p))
beta = np.ones(p) / np.sqrt(p)
true_lambda = X @ beta

# Set cutpoints for ordinal categories (3 categories: 1, 2, 3)
n_categories = 3
gamma_true = np.array([-2.0, 1.0])

# True ordinal class probabilities
true_probs = np.zeros((n, n_categories))
true_probs[:, 0] = 1 - np.exp(-np.exp(gamma_true[0] + true_lambda))
for j in range(1, n_categories - 1):
    true_probs[:, j] = (
        np.exp(-np.exp(gamma_true[j - 1] + true_lambda))
        * (1 - np.exp(-np.exp(gamma_true[j] + true_lambda)))
    )
true_probs[:, n_categories - 1] = 1 - true_probs[:, :-1].sum(axis=1)

# Generate ordinal outcomes (1-indexed integers)
y = np.array(
    [rng.choice(np.arange(1, n_categories + 1), p=true_probs[i]) for i in range(n)],
    dtype=float,
)

# Print outcome distribution
unique, counts = np.unique(y, return_counts=True)
print("Outcome distribution:", dict(zip(unique.astype(int), counts)))

# Train-test split
sample_inds = np.arange(n)
train_inds, test_inds = train_test_split(sample_inds, test_size=0.2, random_state=random_seed)
X_train = X[train_inds, :]
X_test = X[test_inds, :]
y_train = y[train_inds]
y_test = y[test_inds]

## Model Fitting

We specify the cloglog link function for modeling an ordinal outcome by setting `outcome_model=OutcomeModel(outcome="ordinal", link="cloglog")` in the `general_params` argument. Since ordinal outcomes are incompatible with the Gaussian global error variance model, we also set `sample_sigma2_global=False`.

We also override the default `num_trees` for the mean forest (200) in favor of greater regularization for the ordinal model and set `sample_sigma2_leaf=False`.

In [None]:
bart_model = BARTModel()
bart_model.sample(
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    num_gfr=0,
    num_burnin=1000,
    num_mcmc=1000,
    general_params={
        "cutpoint_grid_size": 100,
        "sample_sigma2_global": False,
        "keep_every": 1,
        "num_chains": 1,
        "random_seed": random_seed,
        "outcome_model": OutcomeModel(outcome="ordinal", link="cloglog"),
    },
    mean_forest_params={"num_trees": 50, "sample_sigma2_leaf": False},
)

## Prediction

As with any other BART model in `stochtree`, we can use the `predict` function on our ordinal model. Specifying `scale = "linear"` and `terms = "y_hat"` will simply return predictions from the estimated $\lambda(x)$ function, but users can estimate class probabilities via `scale = "probability"`, which by default will return an array of dimension (`num_observations`, `num_categories`, `num_samples`), where `num_observations = nrow(X)`, `num_categories` is the number of unique ordinal labels that the outcome takes, and `num_samples` is the number of draws of the model. Specifying `type = "mean"` collapses the output to a `num_observations` x `num_categories` matrix, with the average posterior class probability for each observation. Users can also specify `type = "class"` for the maximum a posteriori (MAP) class label estimate for each draw of each observation.

Below we compute the posterior class probabilities for the train and test sets.

In [None]:
est_probs_train = bart_model.predict(
  X_train, scale="probability", terms="y_hat", type="mean"
)
est_probs_test = bart_model.predict(
  X_test, scale="probability", terms="y_hat", type="mean"
)

## Model Results and Interpretation

Since one of the "cutpoints" is fixed for identifiability, we plot the posterior distributions of the other two cutpoints and compare them to their true simulated values (blue dotted lines).

The cutpoint samples are accessed via `bart_model.cloglog_cutpoint_samples` (shape: `(n_categories - 1, num_samples)`) and are shifted by the per-sample mean of the training predictions to account for the non-identifiable intercept.

In [None]:
gamma1 = bart_model.cloglog_cutpoint_samples[0, :] + bart_model.y_hat_train.mean(axis=0)
gamma2 = bart_model.cloglog_cutpoint_samples[1, :] + bart_model.y_hat_train.mean(axis=0)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.hist(gamma1, density=True, bins=40)
ax1.axvline(gamma_true[0], color="blue", linestyle="dotted", linewidth=2)
ax1.set_title("Posterior Distribution of Cutpoint 1")
ax1.set_xlabel("Cutpoint 1")
ax1.set_ylabel("Density")

ax2.hist(gamma2, density=True, bins=40)
ax2.axvline(gamma_true[1], color="blue", linestyle="dotted", linewidth=2)
ax2.set_title("Posterior Distribution of Cutpoint 2")
ax2.set_xlabel("Cutpoint 2")
ax2.set_ylabel("Density")

plt.tight_layout()
plt.show()

Similarly, we can compare the true latent "utility function" $\lambda(x)$ to the (mean-shifted) BART forest predictions.

In [None]:
lambda_pred_train = bart_model.y_hat_train.mean(axis=1) - bart_model.y_hat_train.mean()
lambda_pred_test = bart_model.y_hat_test.mean(axis=1) - bart_model.y_hat_test.mean()
corr_train = np.corrcoef(true_lambda[train_inds], lambda_pred_train)[0, 1]
corr_test = np.corrcoef(true_lambda[test_inds], lambda_pred_test)[0, 1]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.scatter(lambda_pred_train, true_lambda[train_inds], alpha=0.3, s=10)
ax1.axline((0, 0), slope=1, color="blue", linewidth=2)
ax1.set_title("Train Set: Predicted vs Actual")
ax1.set_xlabel("Predicted")
ax1.set_ylabel("Actual")
ax1.text(0.05, 0.95, f"Correlation: {corr_train:.3f}", transform=ax1.transAxes,
         color="red", verticalalignment="top")

ax2.scatter(lambda_pred_test, true_lambda[test_inds], alpha=0.3, s=10)
ax2.axline((0, 0), slope=1, color="blue", linewidth=2)
ax2.set_title("Test Set: Predicted vs Actual")
ax2.set_xlabel("Predicted")
ax2.set_ylabel("Actual")
ax2.text(0.05, 0.95, f"Correlation: {corr_test:.3f}", transform=ax2.transAxes,
         color="red", verticalalignment="top")

plt.tight_layout()
plt.show()

Finally, we compare the estimated posterior mean class probabilities with the true simulated value for each class on the training set.

In [None]:
fig, axes = plt.subplots(1, n_categories, figsize=(15, 5))
for j in range(n_categories):
    corr = np.corrcoef(true_probs[train_inds, j], est_probs_train[:, j])[0, 1]
    axes[j].scatter(true_probs[train_inds, j], est_probs_train[:, j], alpha=0.3, s=10)
    axes[j].axline((0, 0), slope=1, color="blue", linewidth=2)
    axes[j].set_title(f"Training Set: True vs Estimated Probability, Class {j + 1}")
    axes[j].set_xlabel("True Class Probability")
    axes[j].set_ylabel("Estimated Class Probability")
    axes[j].text(0.05, 0.95, f"Correlation: {corr:.3f}", transform=axes[j].transAxes,
                 color="red", verticalalignment="top")
plt.tight_layout()
plt.show()

And we run the same comparison on the test set.

In [None]:
fig, axes = plt.subplots(1, n_categories, figsize=(15, 5))
for j in range(n_categories):
    corr = np.corrcoef(true_probs[test_inds, j], est_probs_test[:, j])[0, 1]
    axes[j].scatter(true_probs[test_inds, j], est_probs_test[:, j], alpha=0.3, s=10)
    axes[j].axline((0, 0), slope=1, color="blue", linewidth=2)
    axes[j].set_title(f"Test Set: True vs Estimated Probability, Class {j + 1}")
    axes[j].set_xlabel("True Class Probability")
    axes[j].set_ylabel("Estimated Class Probability")
    axes[j].text(0.05, 0.95, f"Correlation: {corr:.3f}", transform=axes[j].transAxes,
                 color="red", verticalalignment="top")
plt.tight_layout()
plt.show()

# References

Alam, Entejar, and Antonio R Linero. 2025. “A Unified Bayesian Nonparametric Framework for Ordinal, Survival, and Density Regression Using the Complementary Log-Log Link.” *arXiv Preprint arXiv:2502.00606*.