In [1]:
import pathlib
import arviz_base as azb
import numpy as np
import xarray as xr
import bambi as bmb
import pandas as pd
seed = 123

In [2]:
cats = pd.read_csv(pathlib.Path("cats.csv")) 

In [3]:
model_1 = bmb.Model(
    "censored(days_to_event / 31, adopt) ~ 1", 
    data=cats,
    family="exponential",
    link="log"
)
model_1.set_alias({"censored(days_to_event / 31, adopt)": "months"})

idata_1 = model_1.fit(
    tune=500,
    draws=500,
    random_seed=seed, 
    chains=4, 
)
model_1.predict(idata_1, kind="response", inplace=True, random_seed=seed)
idata_1.extend(model_1.prior_predictive(random_seed=seed))


Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [Intercept]


Output()

Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 1 seconds.
Sampling: [Intercept, months]


In [4]:
status = np.where(cats["adopt"] == "none", 0, 1)
idata_1.add_groups({"constant_data": {"months": xr.DataArray(status)}})
del idata_1.prior
dt = azb.convert_to_datatree(idata_1)
dt.to_netcdf(pathlib.Path("..", "..", "data", "censored_cats.nc"), engine="netcdf4")

