# PPL Example 1: Activity Choice (without Context)

In [None]:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS, Predictive

In [None]:
def model(data=None):
    alpha = jnp.repeat(1, 4)
    theta = numpyro.sample('theta', dist.Dirichlet(alpha))
    obs = numpyro.sample('obs', dist.MultinomialProbs(theta), obs=data)

In [None]:
# cycling, picnic, climbing, movie = 20, 20, 20, 40
data = jnp.array([20, 20, 20, 40])

In [None]:
# Parameter estimation
random_key = random.PRNGKey(42)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, thinning=1)
mcmc.run(random_key, data=data)
mcmc.print_summary()

In [None]:
# Store the posterior samples
import pickle

samples = mcmc.get_samples()
with open('samples.pkl', 'wb') as f:
    pickle.dump(samples, f)

In [None]:
import arviz as az
import matplotlib.pyplot as plt

plt.style.use("seaborn-v0_8")

trace = az.from_numpyro(mcmc)
axes = az.plot_trace(trace)
plt.tight_layout()

plt.show()

In [None]:
import pickle

with open('samples.pkl', 'rb') as f:
    samples = pickle.load(f)

random_key = random.PRNGKey(1)
predictive = Predictive(model, samples, return_sites=['obs'])
pred = predictive(random_key)

def activity_name(activity_array: jnp.ndarray) -> str:
    idx = jnp.argmax(activity_array)
    return ["cycling", "picnic", "climbing", "movie"][idx]

for i in range(10):
    activity = activity_name(pred["obs"][i])
    print(f"Activity: {activity}")

# PPL Example 2 Activity Choice with Context

In [None]:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS, Predictive

In [None]:
data = jnp.array([[0, 0, 15, 10],
                  [15, 5, 0, 5],
                  [0, 0, 5, 20],
                  [5, 15, 0, 5]])

mood = jnp.array([0, 0, 1, 1])
weather = jnp.array([0, 1, 0, 1])

In [None]:
def model(mood, weather, obs=None):
    lam = numpyro.sample("lam", dist.Exponential(rate=1.0)) 
    
    beta_mood = numpyro.sample(
        "beta_mood", 
        dist.Exponential(rate=lam).expand((2, 4))
        )
    beta_weather = numpyro.sample(
        "beta_weather", 
        dist.Exponential(rate=lam).expand((2, 4))
        )

    lam_0 = numpyro.sample("lam_0", dist.Exponential(rate=1.0))
    beta_0 = numpyro.sample(
        "beta_0", 
        dist.Exponential(rate=lam_0).expand((4,))
        )
    
    with numpyro.plate("context", 4):
        concentration = numpyro.deterministic(
            "concentration", 
            beta_0 + beta_mood[mood, :] + beta_weather[weather, :]
            )
        theta = numpyro.sample(
            "theta", 
            dist.Dirichlet(concentration=concentration)
            )
        numpyro.sample("obs", dist.Multinomial(probs=theta), obs=obs)

In [None]:
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, thinning=1, num_chains=2)

mcmc.run(random.PRNGKey(12345), obs=data, mood=mood, weather=weather)
mcmc.print_summary()

In [None]:
import arviz as az
import matplotlib.pyplot as plt

trace = az.from_numpyro(mcmc)
az.plot_trace(trace)
plt.tight_layout()
plt.show()

In [None]:
samples = mcmc.get_samples()
predictive = Predictive(model, samples, return_sites=['obs', 'theta'])
pred = predictive(random.PRNGKey(1), mood=mood, weather=weather)
thetas = pred["theta"] 
print(jnp.mean(thetas, axis=0))

#   cycling,   picnic,    climbing,  movie
# [[0.02324187 0.02164141 0.552792   0.40232483]  # active & rainy
#  [0.5500262  0.20776746 0.02417052 0.218036  ]  # active & sunny
#  [0.02163169 0.02243164 0.2116474  0.74428934]  # chill & rainy
#  [0.21048392 0.5475201  0.0210917  0.22090434]] # chill & sunny

In [None]:
def ctx2idx(mood: str, weather: str) -> int:
    m = 2 if mood == "chill" else 0
    w = 1 if weather == "sunny" else 0
    return m + w
    
def activity_name(activity_array: jnp.ndarray) -> str:
    idx = jnp.argmax(activity_array)
    return ["cycling", "picnic", "climbing", "movie"][idx]


activity_arrays = pred["obs"][:,ctx2idx("active", "rainy"),:]
for i in range(10):    
    print(f"Activity: {activity_name(activity_arrays[i])}", activity_arrays[i])


In [None]:
mcmc.print_summary()
#                mean       std    median      5.0%     95.0%     n_eff     r_hat
# beta_0[0]      0.22      0.23      0.15      0.00      0.50   3446.23      1.00
# beta_0[1]      0.22      0.22      0.15      0.00      0.50   3081.75      1.00
# ...