In [3]:
import numpy as np
import numpyro
import numpyro.distributions as dist
import jax
import jax.numpy as jnp
from numpyro.infer import MCMC, NUTS
from itertools import product
from typing import Tuple
from numpy.typing import NDArray

rng_key = jax.random.PRNGKey(0)

In [4]:
num_states = 2
num_actions = 2
states = [0, 1]
actions = [0, 1]


def mdp(state: int, action: int) -> tuple[int, float]:
    assert state in [0, 1], "Invalid state"
    assert action in [0, 1], "Invalid action"

    if state == 0:
        if action == 0:
            next_state = 0 if np.random.rand() < 0.9 else 1
            reward = np.random.normal(2.0, 1.0)
            return next_state, reward
        elif action == 1:
            next_state = 0 if np.random.rand() < 0.1 else 1
            reward = np.random.normal(1.5, 1.0)
            return next_state, reward
    elif state == 1:
        if action == 0:
            next_state = 0 if np.random.rand() < 0.9 else 1
            reward = np.random.normal(0.0, 1.0)
            return next_state, reward
        elif action == 1:
            next_state = 0 if np.random.rand() < 0.2 else 1
            reward = np.random.normal(3.0, 1.0)
            return next_state, reward
    assert False, "Should not reach here"

In [5]:
# Sample a trajectory
s = np.random.choice(states)
a = np.random.choice(actions)
state_traj = [s]
action_traj = [a]
reward_traj = []

for _ in range(300):
    s_, reward = mdp(s, a)
    a_ = np.random.choice(actions)

    state_traj.append(s_)
    action_traj.append(a_)
    reward_traj.append(reward)
    s = s_
    a = a_

for s, a, r in list(zip(state_traj, action_traj, reward_traj))[:10]:
    print(f"State: {s}, Action: {a}, Reward: {r}")
print("...")

State: 1, Action: 1, Reward: 3.5626485439432516
State: 1, Action: 0, Reward: 0.08674049946635698
State: 0, Action: 1, Reward: 1.5155929367283145
State: 1, Action: 1, Reward: 2.642991390603359
State: 1, Action: 1, Reward: 4.050187765651558
State: 1, Action: 1, Reward: 4.735093173729242
State: 1, Action: 1, Reward: 2.3329309020785756
State: 1, Action: 0, Reward: 0.5269864672580259
State: 0, Action: 0, Reward: 0.5124401925578577
State: 0, Action: 0, Reward: 0.38451923945558564
...


In [6]:
def mdp_model(states, actions, rewards, next_states):
    # Prior for transition probabilities for each state-action pair
    alpha_prior = jnp.array([
        [[1.0, 1.0], [1.0, 1.0]],
        [[1.0, 1.0], [1.0, 1.0]]
    ])

    trans_probs = numpyro.sample(
        "trans_probs",
        dist.Dirichlet(alpha_prior).expand((num_states, num_actions))
    )

    # Prior for reward means
    reward_means = numpyro.sample(
        "reward_means",
        dist.Normal(0, 5.0).expand((num_states, num_actions))
    )

    reward_stds = numpyro.sample(
        "reward_stds",
        dist.HalfNormal(5.0).expand((num_states, num_actions))
    )

    with numpyro.plate("record", len(states)):
        numpyro.sample(
            "trans",
            dist.Categorical(probs=trans_probs[states, actions]),
            obs=next_states
        )

        # Gaussian likelihood for rewards
        numpyro.sample(
            "reward",
            dist.Normal(
                reward_means[states, actions],
                reward_stds[states, actions]
            ),
            obs=rewards
        )

In [7]:
next_state_traj = jnp.array(state_traj[1:])
state_traj = jnp.array(state_traj[:300])
action_traj = jnp.array(action_traj[:300])
reward_traj = jnp.array(reward_traj)

In [8]:
rng_key, subkey = jax.random.split(rng_key)
kernel = NUTS(mdp_model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000)
mcmc.run(subkey, state_traj, action_traj, reward_traj, next_state_traj)
samples = mcmc.get_samples()

sample: 100%|██████████| 3000/3000 [00:03<00:00, 834.83it/s, 7 steps of size 5.52e-01. acc. prob=0.92] 


In [9]:
trans_probs_mean = jnp.mean(samples['trans_probs'], axis=0)
reward_means_mean = jnp.mean(samples['reward_means'], axis=0)

# Print the learned transition probabilities
print("Transition probabilities:")
for s, a in product(states, actions):
    print(f"P(s' | s={s}, a={a}) = {trans_probs_mean[s, a]}")

# Print the learned reward means
print("Reward means:")
for s, a in product(states, actions):
    print(f"E[r | s={s}, a={a}] = {reward_means_mean[s, a]}")

# Transition probabilities:
# P(s' | s=0, a=0) = [0.87283117 0.12716895]
# P(s' | s=0, a=1) = [0.08410239 0.9158976 ]
# P(s' | s=1, a=0) = [0.804974   0.19502601]
# P(s' | s=1, a=1) = [0.14632195 0.85367817]
# Reward means:
# E[r | s=0, a=0] = 2.0038230419158936
# E[r | s=0, a=1] = 1.5468661785125732
# E[r | s=1, a=0] = -0.12734664976596832
# E[r | s=1, a=1] = 2.970562696456909


Transition probabilities:
P(s' | s=0, a=0) = [0.92480195 0.07519802]
P(s' | s=0, a=1) = [0.0983525 0.9016475]
P(s' | s=1, a=0) = [0.9293374  0.07066277]
P(s' | s=1, a=1) = [0.2576761 0.7423239]
Reward means:
E[r | s=0, a=0] = 1.7359533309936523
E[r | s=0, a=1] = 1.5451616048812866
E[r | s=1, a=0] = -0.21869699656963348
E[r | s=1, a=1] = 2.837351083755493


In [10]:
def policy_evaluation(
        policy: NDArray[np.float64],
        P: NDArray[np.float64],
        R: NDArray[np.float64],
        gamma: float = 0.95,
        theta: float = 0.01,
        max_iter: int = 1000
) -> NDArray[np.float64]:
    """Perform policy evaluation to estimate the state value function with a given policy.
    """
    num_states = P.shape[0]
    num_actions = P.shape[1]
    assert policy.shape == (num_states, num_actions)

    V = np.zeros(num_states)

    for i in range(max_iter):
        delta = 0
        V_new = np.zeros(num_states)
        for s in range(num_states):
            for a in range(num_actions):
                for s_next in range(num_states):
                    V_new[s] += policy[s, a] * P[s, a, s_next] * (R[s, a] + gamma * V[s_next])
            delta = max(delta, np.abs(V_new[s] - V[s]))

        if delta < theta:
            break
        V = V_new

    return V

In [11]:
def policy_improvement(
        V: NDArray[np.float64],
        P: NDArray[np.float64],
        R: NDArray[np.float64],
        gamma=0.99
) -> NDArray[np.float64]:
    """Perform policy improvement to find the optimal policy given the value function."""
    num_states = P.shape[0]
    num_actions = P.shape[1]
    policy = np.zeros((num_states, num_actions))

    for s in range(num_states):
        Q = np.zeros(num_actions)
        for a in range(num_actions):
            for s_next in range(num_states):
                Q[a] += P[s, a, s_next] * (R[s, a] + gamma * V[s_next])

        ai = np.argmax(Q)
        policy[s, ai] = 1.0

    assert np.allclose(np.sum(policy, axis=1), 1.0)
    return policy

In [12]:
def policy_iteration(
        P: NDArray[np.float64],
        R: NDArray[np.float64],
        gamma: float = 0.99
) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
    """Perform policy iteration to find the optimal policy and value function."""
    num_states = P.shape[0]
    num_actions = P.shape[1]
    policy = np.zeros((num_states, num_actions))
    policy[:, 0] = 1.0  # Start with a policy that always takes action 0

    while True:
        V = policy_evaluation(policy, P, R, gamma)
        new_policy = policy_improvement(V, P, R, gamma)

        if np.array_equal(policy, new_policy):
            break

        policy = new_policy

    return policy, V

In [None]:
# Run policy iteration
P = np.zeros((2, 2, 2))  # [state, action, next_state]

# State 0. Using the true transition probabilities for simplicity
P[0, 0, 0] = 0.9  # Action 0: 90% chance to stay in state 0
P[0, 0, 1] = 0.1  # Action 0: 10% chance to go to state 1
P[0, 1, 0] = 0.1  # Action 1: 10% chance to stay in state 0
P[0, 1, 1] = 0.9  # Action 1: 90% chance to go to state 1

# State 1
P[1, 0, 0] = 0.9  # Action 0: 90% chance to go to state 0
P[1, 0, 1] = 0.1  # Action 0: 10% chance to stay in state 1
P[1, 1, 0] = 0.2  # Action 1: 20% chance to go to state 0
P[1, 1, 1] = 0.8  # Action 1: 80% chance to stay in state 1

R = np.zeros((2, 2))  # [state, action]

# State 0
R[0, 0] = 2.0  # Expected reward for action 0
R[0, 1] = 1.5  # Expected reward for action 1

# State 1
R[1, 0] = 0.0  # Expected reward for action 0
R[1, 1] = 3.0  # Expected reward for action 1
optimal_policy, optimal_value = policy_iteration(P, R)
print("Optimal Policy:\n", optimal_policy)
print("Optimal Value Function:", optimal_value)


Optimal Policy:
 [[0. 1.]
 [0. 1.]]
Optimal Value Function: [270.62017394 271.9850511 ]
