-
-
Notifications
You must be signed in to change notification settings - Fork 454
Closed
Description
Tell us about it
When creating InferenceData
using
az.from_numpyro(...)
the resulting autogenerated coordinates are note very telling:
Now consider the NumPyro model below that produced these samples.
These coords with autogenerated names are in fact plate dimensions. And plates have names.
import numpy as np
import pandas as pd
# INFO: PPL specific imports
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.infer import SVI, Trace_ELBO, Predictive
from numpyro.infer import MCMC, NUTS
from numpyro.infer.autoguide import AutoLaplaceApproximation, AutoNormal
from jax import lax, random
from jax.scipy.special import expit
import arviz as az
# %%
data_uri = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/NWOGrants.csv"
df_dev = pd.read_csv(data_uri, sep=";")
df_dev.head()
df_dev["gender"] = df_dev["gender"] == "m"
df_dev["gender"] = df_dev["gender"].astype(int)
df_dev["discipline"] = df_dev["discipline"].astype("category").cat.codes
# %%
def model(data: pd.DataFrame, observed=True):
applications = data["applications"].values
awards = data["awards"].values
discipline = data["discipline"].values
discipline_card = np.unique(discipline).shape[0]
gender = data["gender"].values
gender_card = np.unique(gender).shape[0]
observations_card = data.shape[0]
# INFO: good plate version
with numpyro.plate("plate_gender", gender_card):
with numpyro.plate("plate_discipline", discipline_card):
alpha_gender_discipline = numpyro.sample("alpha_gender_discipline", dist.Normal(-1, 1))
assert alpha_gender_discipline.shape == (9, 2)
link_p = numpyro.deterministic("link_p", alpha_gender_discipline[discipline, gender])
with numpyro.plate("plate_observations", observations_card):
numpyro.sample(
"awards", dist.Binomial(total_count=applications, logits=link_p), obs=awards if observed else None
)
kernel = NUTS(model)
mcmc = MCMC(
kernel,
num_warmup=1000,
num_samples=5000,
num_chains=1,
progress_bar=True,
)
mcmc.run(random.PRNGKey(0), df_dev)
samples = mcmc.get_samples()
az.from_numpyro(mcmc)
Thoughts on implementation
It would be handy if from_numpyro
could extract those sites from numpyro model.
I'll note that providing custom coords like so az.from_numpyro(mcmc, coords={"gender": np.array([0, 1])})
produces no effect (coords don't get renamed or anything). For now I stick to .rename({"alpha_gender_dim_0": "gender"})
.
Metadata
Metadata
Assignees
Labels
No labels