Skip to content

InferenceData coords for NumPyro plates #2022

@ColdTeapot273K

Description

@ColdTeapot273K

Tell us about it

When creating InferenceData using

az.from_numpyro(...)

the resulting autogenerated coordinates are note very telling:
image

Now consider the NumPyro model below that produced these samples.
These coords with autogenerated names are in fact plate dimensions. And plates have names.

import numpy as np
import pandas as pd

# INFO: PPL specific imports

import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.infer import SVI, Trace_ELBO, Predictive
from numpyro.infer import MCMC, NUTS
from numpyro.infer.autoguide import AutoLaplaceApproximation, AutoNormal

from jax import lax, random
from jax.scipy.special import expit

import arviz as az

# %%

data_uri = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/NWOGrants.csv"
df_dev = pd.read_csv(data_uri, sep=";")
df_dev.head()

df_dev["gender"] = df_dev["gender"] == "m"
df_dev["gender"] = df_dev["gender"].astype(int)
df_dev["discipline"] = df_dev["discipline"].astype("category").cat.codes

# %%


def model(data: pd.DataFrame, observed=True):
    applications = data["applications"].values
    awards = data["awards"].values

    discipline = data["discipline"].values
    discipline_card = np.unique(discipline).shape[0]
    gender = data["gender"].values
    gender_card = np.unique(gender).shape[0]

    observations_card = data.shape[0]

    # INFO: good plate version
    with numpyro.plate("plate_gender", gender_card):
        with numpyro.plate("plate_discipline", discipline_card):
            alpha_gender_discipline = numpyro.sample("alpha_gender_discipline", dist.Normal(-1, 1))

    assert alpha_gender_discipline.shape == (9, 2)

    link_p = numpyro.deterministic("link_p", alpha_gender_discipline[discipline, gender])

    with numpyro.plate("plate_observations", observations_card):
        numpyro.sample(
            "awards", dist.Binomial(total_count=applications, logits=link_p), obs=awards if observed else None
        )


kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=1000,
    num_samples=5000,
    num_chains=1,
    progress_bar=True,
)
mcmc.run(random.PRNGKey(0), df_dev)
samples = mcmc.get_samples()


az.from_numpyro(mcmc)

Thoughts on implementation

It would be handy if from_numpyro could extract those sites from numpyro model.

I'll note that providing custom coords like so az.from_numpyro(mcmc, coords={"gender": np.array([0, 1])}) produces no effect (coords don't get renamed or anything). For now I stick to .rename({"alpha_gender_dim_0": "gender"}).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions