Description
Both BARTModel and BCFModel in Python allow for random effects models to be specified without an explicit user-supplied basis, via the intercept_only and intercept_plus_treatment model_spec arguments, but the predict method assumes that group IDs provided are zero-indexed (because the group IDs are used as indices into an array of sampled random effects indices). The R implementation correctly handles this by deploying the LabelMapper class.
We should implement this same logic in Python.
Reproducing
- Generate data with non-zero-indexed random effects group IDs via
import numpy as np
from stochtree import BARTModel
rng = np.random.default_rng()
n = 100
p = 10
X = rng.uniform(0,1,(100,10))
num_groups = 3
group_ids = rng.choice(num_groups, size=n) + 2
random_intercepts = rng.uniform(0,1,num_groups)
rfx_term = random_intercepts[group_ids - 2]
y = X[:,0] + rfx_term + rng.normal(0,1,n)
- Sample and predict from a BART model
bart_model = BARTModel()
bart_model.sample(X_train = X, y_train = y, rfx_group_ids_train = group_ids,
random_effects_params = {'model_spec': 'intercept_only'})
bart_model.predict(X = X, rfx_group_ids = group_ids)
Expected behavior
The .predict() line above should work correctly, as it is run on the same group IDs that were used in sampling