# Summary and Plotting Utilities

This notebook demonstrates the summary and plotting utilities available for `stochtree` models in Python.

We begin by loading all necessary libraries.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from stochtree import (
    BARTModel, 
    BCFModel,
    plot_parameter_trace
)

And set a random seed for reproducibility.

In [None]:
random_seed = 1234
rng = np.random.default_rng(random_seed)

## Supervised Learning

We begin with the supervised learning use case served by the `BARTModel` class.

Below we simulate a simple regression dataset.

In [None]:
# Generate covariates and basis
n = 1000
p_X = 10
p_W = 1
X = rng.uniform(0, 1, (n, p_X))
W = rng.uniform(0, 1, (n, p_W))

# Define the outcome mean function
def outcome_mean(X, W):
    return np.where(
        (X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
        -7.5 * W[:, 0],
        np.where(
            (X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
            -2.5 * W[:, 0],
            np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 
                2.5 * W[:, 0], 
                7.5 * W[:, 0]),
        ),
    )


# Generate outcome
epsilon = rng.normal(0, 1, n)
y = outcome_mean(X, W) + epsilon

Now we fit a simple BART model to the data

In [None]:
bart_model = BARTModel()
general_params = {"num_chains": 3}
bart_model.sample(
    X_train=X,
    y_train=y,
    leaf_basis_train=W,
    num_gfr=10,
    num_mcmc=1000,
    general_params=general_params,
)

We obtain a high level summary of the BART model by running `print()`

In [None]:
print(bart_model)

For a more detailed summary (including the information above), we use the `summary()` method of `BARTModel`.

In [None]:
print(bart_model.summary())

Finally, we can use the `plot_parameter_trace` utility function to make quick traceplots of any parametric terms, which in this case involves the global error scale $\sigma^2$ and the leaf scale $\sigma^2_{\ell}$

In [None]:
ax = plot_parameter_trace(bart_model, term="global_error_scale")
plt.show()

In [None]:
ax = plot_parameter_trace(bart_model, term="leaf_scale")
plt.show()

## Causal Inference

We begin with the causal inference use case served by the `BCFModel` class.

Below we simulate a simple dataset for a causal inference problem with binary treatment and continuous outcome.

In [None]:
# Generate covariates and basis
n = 1000
p_X = 5
X = rng.uniform(0, 1, (n, p_X))
pi_X = 0.25 + 0.5 * X[:, 0]
Z = rng.binomial(1, pi_X, n).astype(float)

# Define the outcome mean functions (prognostic and treatment effects)
mu_X = pi_X * 5 + 2 * X[:, 2]
tau_X = X[:, 1] * 2 - 1

# Generate outcome
epsilon = rng.normal(0, 1, n)
y = mu_X + tau_X * Z + epsilon

Now we fit a simple BCF model to the data

In [None]:
bcf_model = BCFModel()
general_params = {"num_chains": 3}
bcf_model.sample(
    X_train=X,
    Z_train=Z,
    y_train=y,
    propensity_train=pi_X,
    num_gfr=10,
    num_mcmc=1000,
    general_params=general_params,
)

As above, we can `print()` this model for a quick overview

In [None]:
print(bcf_model)

And we can use the `summary()` method for a more detailed look at sampled model terms

In [None]:
print(bcf_model.summary())

Finally, we can also plot parametric terms with `plot_parameter_trace()`

In [None]:
ax = plot_parameter_trace(bcf_model, term="global_error_scale")
plt.show()

In [None]:
ax = plot_parameter_trace(bcf_model, term="adaptive_coding")
plt.show()