In [None]:
%pip install bartz



The next cell tells JAX to use all and only 95% of the GPU memory:

In [None]:
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'

The next cell runs bartz on simulated data. The DGP is $y_i=X_{ij}\beta_j + A_{jk}X_{ij}X_{ik} + \varepsilon_i$, where the matrix $A$ is sparse.

In [None]:
import functools
import time

import jax
from jax import numpy as jnp
from jax import random
import bartz

n_train = 100_000
n_test = 1000
p = 1000
n_tree = 10_000

@functools.partial(jax.jit, static_argnums=(1, 2))
def simulate_data(key, n, p, max_interactions):

    # split random key
    keys = list(random.split(key, 4))

    # generate matrices
    X = random.normal(keys.pop(), (p, n))
    beta = random.normal(keys.pop(), (p,))
    A = random.normal(keys.pop(), (p, p))
    error = random.normal(keys.pop(), (n,))

    # make A banded to limit the number of interactions
    num_nonzero = 1 + (max_interactions - 1) // 2
    num_nonzero = jnp.clip(num_nonzero, 0, p)
    interaction_pattern = jnp.arange(p) < num_nonzero
    multi_roll = jax.vmap(jnp.roll, in_axes=(None, 0))
    nonzero = multi_roll(interaction_pattern, jnp.arange(p))
    A *= nonzero
    A /= jnp.sqrt(num_nonzero)

    # compute response
    y = beta @ X + jnp.einsum('ai,bi,ab->i', X, X, A) + error

    return X, y

# seeds for random sampling
keys = list(random.split(random.key(202404161853), 2))

# generate the data on CPU
cpu = jax.devices('cpu')[0]
key = jax.device_put(keys.pop(), cpu)
X, y = simulate_data(key, n_train + n_test, p, 5)
X_train, y_train = X[:, :n_train], y[:n_train]
X_test, y_test = X[:, n_train:], y[n_train:]

# move the data to GPU (if present)
device = jax.devices()[0]
X_train, y_train, X_test, y_test = jax.device_put((X_train, y_train, X_test, y_test), device)

start = time.perf_counter()
bart = bartz.BART.gbart(X_train, y_train, ntree=n_tree, printevery=10, seed=keys.pop())
end = time.perf_counter()

yhat_test_mean = bart.predict(X_test).mean(axis=0)
rmse = jnp.sqrt(jnp.mean(jnp.square(yhat_test_mean - y_test)))
print(f'RMSE: {rmse:#.2g}')
print(f'sigma: {jnp.sqrt(jnp.mean(bart.sigma ** 2)):#.2g}')
print(f'time: {(end - start) / 60:#.2g} min')

Iteration   10/1100 P_grow=0.53 P_prune=0.47 A_grow=0.32 A_prune=0.35 (burnin)
Iteration   20/1100 P_grow=0.52 P_prune=0.48 A_grow=0.32 A_prune=0.35 (burnin)
Iteration   30/1100 P_grow=0.53 P_prune=0.47 A_grow=0.32 A_prune=0.36 (burnin)
Iteration   40/1100 P_grow=0.53 P_prune=0.47 A_grow=0.31 A_prune=0.34 (burnin)
Iteration   50/1100 P_grow=0.52 P_prune=0.48 A_grow=0.31 A_prune=0.35 (burnin)
Iteration   60/1100 P_grow=0.54 P_prune=0.46 A_grow=0.31 A_prune=0.33 (burnin)
Iteration   70/1100 P_grow=0.53 P_prune=0.47 A_grow=0.31 A_prune=0.34 (burnin)
Iteration   80/1100 P_grow=0.53 P_prune=0.47 A_grow=0.31 A_prune=0.34 (burnin)
Iteration   90/1100 P_grow=0.54 P_prune=0.46 A_grow=0.31 A_prune=0.34 (burnin)
Iteration  100/1100 P_grow=0.53 P_prune=0.47 A_grow=0.29 A_prune=0.34 (burnin)
Iteration  110/1100 P_grow=0.52 P_prune=0.48 A_grow=0.31 A_prune=0.33
Iteration  120/1100 P_grow=0.53 P_prune=0.47 A_grow=0.30 A_prune=0.33
Iteration  130/1100 P_grow=0.52 P_prune=0.48 A_grow=0.31 A_prune=0.33
