In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from numpyro import sample, plate
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
from numpy.random import default_rng

In [15]:
K = 4   # {Train, Clean, EDA, Release}
d = 6   # Feature dimensions
N = 1  # Number of scenes
rng = default_rng(42)

In [16]:
def generate_scene_vector():
    x_india = rng.integers(0, 2)
    imbs = rng.uniform(0, 1)
    divs = rng.uniform(0, 1)
    clouds = rng.uniform(0, 100) / 100
    months = rng.integers(1, 13) / 12
    aligns = rng.uniform(0, 1)
    return np.array([x_india, imbs, divs, clouds, months, aligns])

X = np.stack([generate_scene_vector() for _ in range(N)])

In [22]:
true_beta = rng.normal(0, 1, size=(K, d))
true_pi = rng.dirichlet([1.0]*K)

In [23]:
logits = X @ true_beta.T
probs = (true_pi * np.exp(logits.T)).T
probs = probs / probs.sum(axis=1, keepdims=True)
y = np.array([rng.choice(K, p=p) for p in probs])

In [24]:
def routing_model(X, y=None):
    pi = sample("pi", dist.Dirichlet(jnp.ones(K)))
    sigma = sample("sigma", dist.HalfCauchy(1.0))
    beta = sample("beta", dist.Normal(0, sigma).expand([K, d]))

    logits = jnp.dot(X, beta.T) + jnp.log(pi)
    probs = jax.nn.softmax(logits)

    with plate("data", X.shape[0]):
        sample("obs", dist.Categorical(probs=probs), obs=y)

In [25]:
kernel = NUTS(routing_model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000, num_chains=1)
mcmc.run(jax.random.PRNGKey(0), X=X, y=y)
posterior = mcmc.get_samples()

sample: 100%|██████████| 1500/1500 [00:02<00:00, 572.09it/s, 63 steps of size 1.17e-01. acc. prob=0.88]


In [30]:
def predict_routing_probs(x_new, posterior):
    beta_samples = posterior['beta']         # shape: (S, K, d)
    pi_samples = posterior['pi']             # shape: (S, K)
    
    S = beta_samples.shape[0]                # number of posterior samples

    # Repeat x_new for all samples: shape (S, K, d)
    x_rep = jnp.repeat(x_new[None, :, :], S, axis=0)  # (S, 1, d)
    x_rep = jnp.tile(x_rep, (1, K, 1))                # (S, K, d)

    # Compute logits for each sample
    logits = jnp.sum(beta_samples * x_rep, axis=-1) + jnp.log(pi_samples)  # shape: (S, K)
    probs = jax.nn.softmax(logits, axis=-1)

    return probs.mean(axis=0) 

In [31]:
x_test = jnp.array(generate_scene_vector()).reshape(1, -1)
routing_probs = predict_routing_probs(x_test, posterior)

In [32]:
modes = ["Train", "Clean", "EDA", "Release"]
for mode, p in zip(modes, routing_probs):
    print(f"{mode}: {p:.3f}")

print(f"Predicted mode: {modes[jnp.argmax(routing_probs)]}")

Train: 0.259
Clean: 0.372
EDA: 0.248
Release: 0.121
Predicted mode: Clean
