# Running Multiple Chains (Sequentially or in Parallel) in StochTree

Mixing of an MCMC sampler is a perennial concern for complex Bayesian models. BART and BCF are no exception. On common way to address such concerns is to run multiple independent "chains" of an MCMC sampler, so that if each chain gets stuck in a different region of the posterior, their combined samples attain better coverage of the full posterior.

This idea works with the classic "root-initialized" MCMC sampler of Chipman et al (2010), but a key insight of He and Hahn (2023) and Krantsevich et al (2023) is that the GFR algorithm may be used to warm-start initialize multiple chains of the BART / BCF MCMC sampler.

Operationally, the above two approaches have the same implementation (setting `num_gfr > 0` if warm-start initialization is desired), so this vignette will demonstrate how to run a multi-chain sampler sequentially.

To begin, load `stochtree` and other relevant libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
from sklearn.model_selection import train_test_split
from stochtree import BARTModel, BCFModel

# Demo 1: Supervised Learning

## Data Simulation

Simulate a simple partitioned linear model

In [None]:
# Generate the data
random_seed = 1111
rng = np.random.default_rng(random_seed)
n = 500
p_x = 10
p_w = 1
snr = 3
X = rng.uniform(size=(n, p_x))
leaf_basis = rng.uniform(size=(n, p_w))
f_XW = (((0 <= X[:, 0]) & (0.25 > X[:, 0])) *
         (-7.5 * leaf_basis[:, 0]) +
         ((0.25 <= X[:, 0]) & (0.5 > X[:, 0])) * (-2.5 * leaf_basis[:, 0]) +
         ((0.5 <= X[:, 0]) & (0.75 > X[:, 0])) * (2.5 * leaf_basis[:, 0]) +
         ((0.75 <= X[:, 0]) & (1 > X[:, 0])) * (7.5 * leaf_basis[:, 0]))
noise_sd = np.std(f_XW) / snr
y = f_XW + rng.normal(0, noise_sd, size=n)

# Split data into test and train sets
test_set_pct = 0.2
train_inds, test_inds = train_test_split(np.arange(n), test_size=test_set_pct, random_state=random_seed)
n_train = len(train_inds)
n_test = len(test_inds)
X_train = X[train_inds]
X_test = X[test_inds]
leaf_basis_train = leaf_basis[train_inds]
leaf_basis_test = leaf_basis[test_inds]
y_train = y[train_inds]
y_test = y[test_inds]

## Sampling Multiple Chains Sequentially

The simplest way to sample multiple chains of a stochtree model is to do so "sequentially," that is, after chain 1 is sampled, chain 2 is sampled from a different starting state, and similarly for each of the requested chains. This is supported internally in both the `bart()` and `bcf()` functions, with the `num_chains` parameter in the `general_params` list.

Define some high-level parameters, including number of chains to run and number of samples per chain. Here we run 4 independent chains with 5000 MCMC iterations, each of which is initialized by a different "grow-from-root" sample (the last 4 of 5 GFR samples) and  burned in for 2000 iterations after warm-start.

In [None]:
num_chains = 4
num_gfr = 5
num_burnin = 2000
num_mcmc = 5000

Run the sampler

In [None]:
bart_model = BARTModel()
bart_model.sample(
  X_train = X_train,
  leaf_basis_train = leaf_basis_train,
  y_train = y_train,
  num_gfr = num_gfr,
  num_burnin = num_burnin,
  num_mcmc = num_mcmc,
  general_params = {'num_chains' : num_chains}
)

Now we have a `BARTModel` object with `num_chains * num_mcmc` samples stored internally. These samples are arranged sequentially, with the first `num_mcmc` samples corresponding to chain 1, the next `num_mcmc` samples to chain 2, etc...

Since each chain is a set of samples of the same model, we can analyze the samples collectively, for example, by looking at out-of-sample predictions.

In [None]:
y_hat_test = bart_model.predict(
  covariates = X_test,
  basis = leaf_basis_test, 
  type = "mean", 
  terms = "y_hat"
)
plt.scatter(y_hat_test, y_test)
plt.xlabel("Estimated conditional mean")
plt.ylabel("Actual outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))

Now, suppose we want to analyze each of the chains separately to assess mixing / convergence.

We can use our knowledge of the internal arrangement of the chain samples to construct a an `mcmc.list` in the `coda` package, from which we can perform various diagnostics.

In [None]:
sigma2_samples = bart_model.global_var_samples
sigma2_samples_by_chain = {"sigma2": np.reshape(sigma2_samples, (num_chains, num_mcmc))}
az.plot_trace(sigma2_samples_by_chain)

In [None]:
az.ess(sigma2_samples_by_chain)

In [None]:
az.rhat(sigma2_samples_by_chain)

In [None]:
az.plot_autocorr(sigma2_samples_by_chain)

In [None]:
az.plot_violin(sigma2_samples_by_chain)