In [None]:
import numpy as np
import numpyro
import numpyro.distributions as dist
import jax
import jax.numpy as jnp
from numpyro.infer import MCMC, NUTS
from itertools import product

rng_key = jax.random.PRNGKey(0)

In [2]:
num_states = 2
num_actions = 2
states = [0, 1]
actions = [0, 1]


def mdp(state: int, action: int) -> tuple[int, float]:
    assert state in [0, 1], "Invalid state"
    assert action in [0, 1], "Invalid action"

    if state == 0:
        if action == 0:
            next_state = 0 if np.random.rand() < 0.9 else 1
            reward = np.random.normal(2.0, 1.0)
            return next_state, reward
        elif action == 1:
            next_state = 0 if np.random.rand() < 0.1 else 1
            reward = np.random.normal(1.5, 1.0)
            return next_state, reward
    elif state == 1:
        if action == 0:
            next_state = 0 if np.random.rand() < 0.9 else 1
            reward = np.random.normal(0.0, 1.0)
            return next_state, reward
        elif action == 1:
            next_state = 0 if np.random.rand() < 0.2 else 1
            reward = np.random.normal(3.0, 1.0)
            return next_state, reward
    assert False, "Should not reach here"

In [15]:
# Sample a trajectory
s = np.random.choice(states)
a = np.random.choice(actions)
state_traj = [s]
action_traj = [a]
reward_traj = []

for _ in range(300):
    s_, reward = mdp(s, a)
    a_ = np.random.choice(actions)

    state_traj.append(s_)
    action_traj.append(a_)
    reward_traj.append(reward)
    s = s_
    a = a_

for s, a, r in list(zip(state_traj, action_traj, reward_traj))[:10]:
    print(f"State: {s}, Action: {a}, Reward: {r}")
print("...")

State: 1, Action: 1, Reward: 1.7644137334598737
State: 1, Action: 0, Reward: 0.5918058443861409
State: 1, Action: 0, Reward: -1.172098637360545
State: 0, Action: 1, Reward: 1.9151350957581499
State: 1, Action: 0, Reward: -0.04030941450177798
State: 0, Action: 1, Reward: 1.081824572708424
State: 1, Action: 1, Reward: 3.3165145224564534
State: 1, Action: 1, Reward: 0.9866620231437442
State: 0, Action: 0, Reward: 1.394226635556238
State: 0, Action: 0, Reward: 2.307613019086643
...


In [13]:
def mdp_model(states, actions, rewards, next_states):
    # Prior for transition probabilities for each state-action pair
    alpha_prior = jnp.array([
        [[1.0, 1.0], [1.0, 1.0]],
        [[1.0, 1.0], [1.0, 1.0]]
    ])

    trans_probs = numpyro.sample(
        "trans_probs",
        dist.Dirichlet(alpha_prior).expand((num_states, num_actions))
    )

    # Prior for reward means
    reward_means = numpyro.sample(
        "reward_means",
        dist.Normal(0, 5.0).expand((num_states, num_actions))
    )

    reward_stds = numpyro.sample(
        "reward_stds",
        dist.HalfNormal(5.0).expand((num_states, num_actions))
    )

    with numpyro.plate("record", len(states)):
        numpyro.sample(
            "trans",
            dist.Categorical(probs=trans_probs[states, actions]),
            obs=next_states
        )

        # Gaussian likelihood for rewards
        numpyro.sample(
            "reward",
            dist.Normal(
                reward_means[states, actions],
                reward_stds[states, actions]
            ),
            obs=rewards
        )

In [16]:
next_state_traj = jnp.array(state_traj[1:])
state_traj = jnp.array(state_traj[:300])
action_traj = jnp.array(action_traj[:300])
reward_traj = jnp.array(reward_traj)

In [None]:
rng_key, subkey = jax.random.split(rng_key)
kernel = NUTS(mdp_model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000)
mcmc.run(subkey, state_traj, action_traj, reward_traj, next_state_traj)
samples = mcmc.get_samples()

sample: 100%|██████████| 3000/3000 [00:03<00:00, 914.88it/s, 7 steps of size 5.69e-01. acc. prob=0.92] 


In [None]:
trans_probs_mean = jnp.mean(samples['trans_probs'], axis=0)
reward_means_mean = jnp.mean(samples['reward_means'], axis=0)

# Print the learned transition probabilities
print("Transition probabilities:")
for s, a in product(states, actions):
    print(f"P(s' | s={s}, a={a}) = {trans_probs_mean[s, a]}")

# Print the learned reward means
print("Reward means:")
for s, a in product(states, actions):
    print(f"E[r | s={s}, a={a}] = {reward_means_mean[s, a]}")

# Transition probabilities:
# P(s' | s=0, a=0) = [0.87283117 0.12716895]
# P(s' | s=0, a=1) = [0.08410239 0.9158976 ]
# P(s' | s=1, a=0) = [0.804974   0.19502601]
# P(s' | s=1, a=1) = [0.14632195 0.85367817]
# Reward means:
# E[r | s=0, a=0] = 2.0038230419158936
# E[r | s=0, a=1] = 1.5468661785125732
# E[r | s=1, a=0] = -0.12734664976596832
# E[r | s=1, a=1] = 2.970562696456909


Transition probabilities:
P(s' | s=0, a=0) = [0.87283117 0.12716895]
P(s' | s=0, a=1) = [0.08410239 0.9158976 ]
P(s' | s=1, a=0) = [0.804974   0.19502601]
P(s' | s=1, a=1) = [0.14632195 0.85367817]
Reward means:
E[r | s=0, a=0] = 2.0038230419158936
E[r | s=0, a=1] = 1.5468661785125732
E[r | s=1, a=0] = -0.12734664976596832
E[r | s=1, a=1] = 2.970562696456909


In [19]:
trans_probs_mean

Array([[[0.87283117, 0.12716895],
        [0.08410239, 0.9158976 ]],

       [[0.804974  , 0.19502601],
        [0.14632195, 0.85367817]]], dtype=float32)