In [1]:
%pip install bartz

Collecting bartz
  Downloading bartz-0.6.0-py3-none-any.whl.metadata (2.8 kB)
Collecting equinox>=0.12.2 (from bartz)
  Downloading equinox-0.12.2-py3-none-any.whl.metadata (18 kB)
Collecting jaxtyping>=0.3.2 (from bartz)
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting wadler-lindig>=0.1.0 (from equinox>=0.12.2->bartz)
  Downloading wadler_lindig-0.1.6-py3-none-any.whl.metadata (17 kB)
Downloading bartz-0.6.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading equinox-0.12.2-py3-none-any.whl (177 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.2/177.2 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxtyping-0.3.2-py3-none-any.whl (55 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.4/55.4 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading wadler_lindig-0.1.6-py3-none-any.whl (20 k

The next cell tells JAX to use all and only 95% of the GPU memory:

In [2]:
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'

The next cell defines the data generating process (DGP) $$y_i=\frac 1{\text{norm.}} \sum_{j=1}^p X_{ij} \beta_j + \frac 1{\text{norm.}} \sum_{j=1}^p \sum_{k=1}^p A_{jk} X_{ij} X_{ik} + \varepsilon_i,$$ where the matrix $A$ is sparse.

In [3]:
import functools

import jax
from jax import numpy as jnp
from jax import random

@functools.partial(jax.jit, static_argnums=(1, 2))
def dgp(key, n, p, max_interactions, error_sdev):
    """ DGP. Uses data-based standardization, so you have to generate train &
    test at once. """

    # split random key
    keys = list(random.split(key, 4))

    # generate matrices
    X = random.uniform(keys.pop(), (p, n))
    beta = random.normal(keys.pop(), (p,))
    A = random.normal(keys.pop(), (p, p))
    error = random.normal(keys.pop(), (n,))

    # make A banded to limit the number of interactions
    num_nonzero = 1 + (max_interactions - 1) // 2
    num_nonzero = jnp.clip(num_nonzero, 0, p)
    interaction_pattern = jnp.arange(p) < num_nonzero
    multi_roll = jax.vmap(jnp.roll, in_axes=(None, 0))
    nonzero = multi_roll(interaction_pattern, jnp.arange(p))
    A *= nonzero

    # compute terms
    linear = beta @ X
    quadratic = jnp.einsum('ai,bi,ab->i', X, X, A)
    error *= error_sdev

    # equalize the terms
    linear /= jnp.std(linear)
    quadratic /= jnp.std(quadratic)

    # compute response
    y = linear + quadratic + error

    return X, y

The next cell defines a convenience function that generates the data and splits it in train/test sets.

In [4]:
import collections

Data = collections.namedtuple('Data', 'X_train y_train X_test y_test')

def make_synthetic_dataset(key, n_train, n_test, p, sigma):
    X, y = dgp(key, n_train + n_test, p, 5, sigma)
    X_train, y_train = X[:, :n_train], y[:n_train]
    X_test, y_test = X[:, n_train:], y[n_train:]
    return Data(X_train, y_train, X_test, y_test)

The next generates the data and runs BART.

In [5]:
import time

import bartz

n_train = 100_000  # number of training points
p = 1000           # number of predictors/features
sigma = 0.1        # error standard deviation

n_test = 1000      # number of test points
n_tree = 10_000    # number of trees used by bartz

# seeds for random sampling
keys = list(random.split(random.key(202404161853), 2))

# generate the data on CPU to avoid running out of GPU memory
cpu = jax.devices('cpu')[0]
key = jax.device_put(keys.pop(), cpu) # the random key is the only jax-array input, so it determines the device used
data = make_synthetic_dataset(key, n_train, n_test, p, sigma)

# move the data to GPU (if there is a GPU)
device = jax.devices()[0] # the default jax device is gpu if there is one
data = jax.device_put(data, device)

# run bartz
start = time.perf_counter()
bart = bartz.BART.gbart(data.X_train, data.y_train, ntree=n_tree, printevery=10, seed=keys.pop())
end = time.perf_counter()

..........
It 10/1100 grow P=54% A=27%, prune P=46% A=30%, fill=6% (burnin)
..........
It 20/1100 grow P=54% A=23%, prune P=46% A=28%, fill=6% (burnin)
..........
It 30/1100 grow P=55% A=23%, prune P=45% A=26%, fill=6% (burnin)
..........
It 40/1100 grow P=53% A=21%, prune P=47% A=25%, fill=6% (burnin)
..........
It 50/1100 grow P=54% A=21%, prune P=46% A=25%, fill=6% (burnin)
..........
It 60/1100 grow P=53% A=20%, prune P=47% A=24%, fill=6% (burnin)
..........
It 70/1100 grow P=53% A=21%, prune P=47% A=24%, fill=6% (burnin)
..........
It 80/1100 grow P=53% A=20%, prune P=47% A=24%, fill=6% (burnin)
..........
It 90/1100 grow P=53% A=19%, prune P=47% A=23%, fill=6% (burnin)
..........
It 100/1100 grow P=53% A=19%, prune P=47% A=23%, fill=6% (burnin)
..........
It 110/1100 grow P=54% A=20%, prune P=46% A=23%, fill=6%
..........
It 120/1100 grow P=54% A=20%, prune P=46% A=22%, fill=6%
..........
It 130/1100 grow P=53% A=18%, prune P=47% A=21%, fill=6%
..........
It 140/1100 grow P=54% A

Interpretation of the printout:
* grow P = fraction of trees where a GROW move was proposed
* grow A = GROW acceptance: fraction of proposed GROW moves that were accepted
* prune P, A = the same for the PRUNE move

The fractions refer to the state of the trees at a single point in time, they are not averaged over multiple iterations.

A low acceptance means that the trees are changing very slowly.

The next cell computes the predictions.

In [6]:
# compute predictions
yhat_test = bart.predict(data.X_test) # posterior samples, n_samples x n_test
yhat_test_mean = jnp.mean(yhat_test, axis=0) # posterior mean point-by-point
yhat_test_var = jnp.var(yhat_test, axis=0) # posterior variance point-by-point

# RMSE
rmse = jnp.sqrt(jnp.mean(jnp.square(yhat_test_mean - data.y_test)))
expected_error_variance = jnp.mean(jnp.square(bart.sigma))
expected_rmse = jnp.sqrt(jnp.mean(yhat_test_var + expected_error_variance))
avg_sigma = jnp.sqrt(expected_error_variance)

print(f'total sdev: {jnp.std(data.y_train):#.2g}')
print(f'error sdev: {sigma:#.2g}')
print(f'RMSE: {rmse:#.2g}')
print(f'expected RMSE: {expected_rmse:#.2g}')
print(f'model error sdev: {avg_sigma:#.2g}')
print(f'time: {(end - start) / 60:#.2g} min')

total sdev: 1.4
error sdev: 0.10
RMSE: 0.29
expected RMSE: 0.28
model error sdev: 0.23
time: 6.5 min


The RMSE can at best be as low as the error standard deviation used to generate the data.