From the Python interface, BARTModel fails when the data is float32. Easy to sidestep by casting, weird though. Pops up for me because I want to replace BART3 with stochtree as reference implementation in my unit tests of bartz, and all my arrays are float32 by default because jax.
import numpy as np
import stochtree
rng = np.random.default_rng(0)
X = rng.normal(size=(100, 3)).astype(np.float32)
y = rng.normal(size=100).astype(np.float32)
m = stochtree.BARTModel()
m.sample(X_train=X, y_train=y)
Traceback (most recent call last):
File "<python-input-0>", line 9, in <module>
m.sample(X_train=X, y_train=y)
~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
File "/Users/giacomo/Library/Caches/uv/archive-v0/ewzn4Tz9IlUrLkmx/lib/python3.14/site-packages/stochtree/bart.py", line 1421, in sample
global_model_config = GlobalModelConfig(global_error_variance=current_sigma2)
File "/Users/giacomo/Library/Caches/uv/archive-v0/ewzn4Tz9IlUrLkmx/lib/python3.14/site-packages/stochtree/config.py", line 677, in __init__
raise ValueError("`global_error_variance` must be a positive scalar")
ValueError: `global_error_variance` must be a positive scalar
Versions: stochtree 0.4.2, numpy 2.4.6, python 3.14, macOS 26.4.1
From the Python interface,
BARTModelfails when the data is float32. Easy to sidestep by casting, weird though. Pops up for me because I want to replace BART3 with stochtree as reference implementation in my unit tests of bartz, and all my arrays are float32 by default because jax.Versions: stochtree 0.4.2, numpy 2.4.6, python 3.14, macOS 26.4.1