In [96]:
import numpy as np

num_samples = 200
item_idx = np.arange(num_samples)
item = np.random.uniform(0, 1, num_samples)

num_reads = 10
param = np.random.choice([800, 850, 900], (num_samples, num_reads), replace=True)
A, B, C = 100, 0.5, 45
response = item[:, None] * A  + param * B + C + np.random.normal(0, 10, (num_samples, num_reads))


In [97]:
import pandas as pd

reads = pd.DataFrame(
    {
        "response": response.flatten(),
        "param": param.flatten(),
        "item": np.repeat(item, num_reads),
        "item_idx": np.repeat(item_idx, num_reads),
    }
)

train = reads.query("item_idx < 198")
test = reads.query("item_idx >= 198")

In [98]:
import numpyro
import numpyro.distributions as dist

def model(param, response, item):
    A = numpyro.sample("A", dist.Normal(0, 200))
    B = numpyro.sample("B", dist.Normal(0, 1))
    C = numpyro.sample("C", dist.Normal(0, 100))

    numpyro.sample("response", dist.Normal(item * A + param * B + C, 1), obs=response)


In [99]:
from numpyro.infer import NUTS, MCMC
from jax import random

# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(model)
num_samples = 6000
mcmc = MCMC(kernel, num_warmup=2000, num_samples=num_samples)
mcmc.run(
    rng_key_, param=train.param.values, response=train.response.values, item=train.item.values
)
mcmc.print_summary()
posterior = mcmc.get_samples()

sample: 100%|██████████| 8000/8000 [00:12<00:00, 650.33it/s, 7 steps of size 1.68e-03. acc. prob=0.92]   



                mean       std    median      5.0%     95.0%     n_eff     r_hat
         A     99.41      0.08     99.41     99.26     99.53    840.58      1.00
         B      0.49      0.00      0.49      0.49      0.49   2731.59      1.00
         C     54.93      0.47     54.92     54.16     55.69   2697.88      1.00

Number of divergences: 0


In [102]:
import altair as alt

base = alt.Chart(pd.DataFrame(posterior).head(5000)).mark_bar().encode(y="count()")
base.encode(
    x=alt.X("A:O").bin(maxbins=20),
) | base.encode(
    x=alt.X("B:O").bin(maxbins=20),
) | base.encode(
    x=alt.X("C:O").bin(maxbins=20),
)

# Using coefficients as priors

In [103]:
from scipy.stats import norm

mu_A, sigma_A = norm.fit(posterior["A"])
mu_B, sigma_B = norm.fit(posterior["B"])
mu_C, sigma_C = norm.fit(posterior["C"])

In [104]:
import numpyro
import numpyro.distributions as dist

def item_model(param, response, mu_A, sigma_A, mu_B, sigma_B, mu_C, sigma_C):
    A = numpyro.sample("A", dist.Normal(mu_A, sigma_A))
    B = numpyro.sample("B", dist.Normal(mu_B, sigma_B))
    C = numpyro.sample("C", dist.Normal(mu_C, sigma_C))

    item = numpyro.sample("item", dist.Uniform(0, 1))

    numpyro.sample("response", dist.Normal(item * A + param * B + C, 0.2), obs=response)


In [105]:

from numpyro.infer import NUTS, MCMC
from jax import random

# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(item_model)
num_samples = 4000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
for item_idx, test_item in test.groupby("item_idx"):
    print(item_idx)
    print(test_item.item)
    mcmc.run(
        rng_key_,
        param=test_item.param.values,
        response=test_item.response.values,
        mu_A=mu_A,
        sigma_A=sigma_A,
        mu_B=mu_B,
        sigma_B=sigma_B,
        mu_C=mu_C,
        sigma_C=sigma_C,

    )
    mcmc.print_summary()

198
1980    0.862496
1981    0.862496
1982    0.862496
1983    0.862496
1984    0.862496
1985    0.862496
1986    0.862496
1987    0.862496
1988    0.862496
1989    0.862496
Name: item, dtype: float64


  0%|          | 0/5000 [00:00<?, ?it/s]

sample: 100%|██████████| 5000/5000 [00:05<00:00, 846.43it/s, 87 steps of size 2.16e-03. acc. prob=0.94]  



                mean       std    median      5.0%     95.0%     n_eff     r_hat
         A     99.41      0.09     99.41     99.27     99.55    432.51      1.01
         B      0.50      0.00      0.50      0.50      0.50   2241.60      1.00
         C     54.93      0.48     54.92     54.18     55.75   3007.19      1.00
      item      0.81      0.01      0.81      0.79      0.82   2068.55      1.00

Number of divergences: 0
199
1990    0.947654
1991    0.947654
1992    0.947654
1993    0.947654
1994    0.947654
1995    0.947654
1996    0.947654
1997    0.947654
1998    0.947654
1999    0.947654
Name: item, dtype: float64


sample: 100%|██████████| 5000/5000 [00:04<00:00, 1029.00it/s, 39 steps of size 1.99e-03. acc. prob=0.95] 



                mean       std    median      5.0%     95.0%     n_eff     r_hat
         A     99.41      0.08     99.42     99.28     99.55    524.48      1.00
         B      0.49      0.00      0.49      0.49      0.49   2129.52      1.00
         C     54.93      0.47     54.94     54.18     55.70   1969.36      1.00
      item      0.94      0.01      0.94      0.93      0.95   1591.15      1.00

Number of divergences: 0


# Sample static coefficient distribution

In [118]:
from functools import partial 

from numpyro.infer import MCMC, NUTS, HMCGibbs


def item_model(param, response, mu_A, sigma_A, mu_B, sigma_B, mu_C, sigma_C):
    A = numpyro.sample("A", dist.Normal(mu_A, sigma_A))
    B = numpyro.sample("B", dist.Normal(mu_B, sigma_B))
    C = numpyro.sample("C", dist.Normal(mu_C, sigma_C))

    item = numpyro.sample("item", dist.Uniform(0, 1))

    numpyro.sample("response", dist.Normal(item * A + param * B + C, 0.2), obs=response)

# def gibbs_fn(rng_key, gibbs_sites, hmc_sites, mu_A, sigma_A, mu_B, sigma_B, mu_C, sigma_C):
#     return {
#         'A': dist.Normal(mu_A, sigma_A).sample(rng_key),
#         'B': dist.Normal(mu_B, sigma_B).sample(rng_key),
#         'C': dist.Normal(mu_C, sigma_C).sample(rng_key),
#     }
def gibbs_fn(rng_key, gibbs_sites, hmc_sites, mu_A, sigma_A, mu_B, sigma_B, mu_C, sigma_C):
    return {
        'A': np.random.choice(posterior["A"], gibbs_sites["A"].shape),
        'B': np.random.choice(posterior["B"], gibbs_sites["B"].shape),
        'C': np.random.choice(posterior["C"], gibbs_sites["C"].shape),
    }

hmc_kernel = NUTS(item_model)
kernel = HMCGibbs(
    hmc_kernel,
    gibbs_fn=partial(
        gibbs_fn,
        mu_A=mu_A,
        sigma_A=sigma_A,
        mu_B=mu_B,
        sigma_B=sigma_B,
        mu_C=mu_C,
        sigma_C=sigma_C,
    ),
    gibbs_sites=['A', 'B', 'C']
)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, progress_bar=False)
for item_idx, test_item in test.groupby("item_idx"):
    print(item_idx)
    print(test_item.item)
    mcmc.run(
        rng_key_,
        param=test_item.param.values,
        response=test_item.response.values,
        mu_A=mu_A,
        sigma_A=sigma_A,
        mu_B=mu_B,
        sigma_B=sigma_B,
        mu_C=mu_C,
        sigma_C=sigma_C,

    )
    mcmc.print_summary()

198
1980    0.862496
1981    0.862496
1982    0.862496
1983    0.862496
1984    0.862496
1985    0.862496
1986    0.862496
1987    0.862496
1988    0.862496
1989    0.862496
Name: item, dtype: float64

                mean       std    median      5.0%     95.0%     n_eff     r_hat
         A     99.59      0.00     99.59     99.59     99.59      0.50      1.00
         B      0.49      0.00      0.49      0.49      0.49      0.50      1.00
         C     55.60      0.00     55.60     55.60     55.60      0.50      1.00
      item      0.87      0.00      0.87      0.87      0.87    687.88      1.00

199
1990    0.947654
1991    0.947654
1992    0.947654
1993    0.947654
1994    0.947654
1995    0.947654
1996    0.947654
1997    0.947654
1998    0.947654
1999    0.947654
Name: item, dtype: float64

                mean       std    median      5.0%     95.0%     n_eff     r_hat
         A     99.46      0.00     99.46     99.46     99.46      0.50      1.00
         B      0.49      0.

# Inference all together

In [None]:
import numpyro
import numpyro.distributions as dist

def model(param, response, item):
    A = numpyro.sample("A", dist.Normal(0, 200))
    B = numpyro.sample("B", dist.Normal(0, 1))
    C = numpyro.sample("C", dist.Normal(0, 100))

    for i in range(len(item)):

    numpyro.sample("response", dist.Normal(item * A + param * B + C, 1), obs=response)


In [None]:
from numpyro.infer import NUTS, MCMC
from jax import random

# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(model)
num_samples = 6000
mcmc = MCMC(kernel, num_warmup=2000, num_samples=num_samples)
mcmc.run(
    rng_key_, param=train.param.values, response=train.response.values, item=train.item.values
)
mcmc.print_summary()
posterior = mcmc.get_samples()

sample: 100%|██████████| 8000/8000 [00:12<00:00, 650.33it/s, 7 steps of size 1.68e-03. acc. prob=0.92]   



                mean       std    median      5.0%     95.0%     n_eff     r_hat
         A     99.41      0.08     99.41     99.26     99.53    840.58      1.00
         B      0.49      0.00      0.49      0.49      0.49   2731.59      1.00
         C     54.93      0.47     54.92     54.16     55.69   2697.88      1.00

Number of divergences: 0
