# MDP

In [1]:
import random
from typing import Callable, Dict, List, Tuple
from enum import IntEnum


# Define the state space, action space, and other MDP components
class State(IntEnum):
    S0 = 0
    W = 1
    M = 2
    WM = 3
    SUCCESS = 4
    ABANDON = 5


class Action(IntEnum):
    ASK_WEATHER = 0
    ASK_MOOD = 1
    ASK_BOTH = 2
    RECOMMEND = 3


TERMINAL_STATES = {State.SUCCESS, State.ABANDON}

AVAILABLE_ACTIONS = {
    State.S0: [Action.ASK_WEATHER, Action.ASK_MOOD, Action.ASK_BOTH],
    State.W: [Action.ASK_MOOD],
    State.M: [Action.ASK_WEATHER],
    State.WM: [Action.RECOMMEND],
    State.SUCCESS: [],
    State.ABANDON: [],
}

In [2]:
TRANSITIONS = {
    (State.S0, Action.ASK_WEATHER): {
        State.W: 0.75,
        State.S0: 0.15,
        State.ABANDON: 0.10,
    },
    (State.S0, Action.ASK_MOOD): {State.M: 0.70, State.S0: 0.20, State.ABANDON: 0.10},
    (State.S0, Action.ASK_BOTH): {
        State.WM: 0.50,
        State.W: 0.20,
        State.M: 0.15,
        State.ABANDON: 0.15,
    },
    (State.W, Action.ASK_MOOD): {State.WM: 0.80, State.W: 0.10, State.ABANDON: 0.10},
    (State.M, Action.ASK_WEATHER): {State.WM: 0.80, State.M: 0.10, State.ABANDON: 0.10},
    (State.WM, Action.RECOMMEND): {State.SUCCESS: 0.75, State.ABANDON: 0.25},
}

TransitionModel = Callable[[State, Action], Dict[State, float]]


def transition_model(state: State, action: Action) -> Dict[State, float]:
    """Return the transition probabilities for a given state and action."""
    if (state, action) not in TRANSITIONS:
        raise ValueError(f"Invalid state-action pair: ({state}, {action})")
    return TRANSITIONS[(state, action)]


In [3]:
RewardModel = Callable[[State, Action, State], float]


def reward_model(state: State, action: Action, next_state: State) -> float:
    """Return the reward for a given state, action, and next state."""
    return 10.0 if next_state == State.SUCCESS else 0.0

In [4]:
def mdp_step(
    state: State, action: Action, TM: TransitionModel, RM: RewardModel
) -> Tuple[State, float]:
    """Sample next state and compute reward."""
    P = TM(state, action)
    next_state = random.choices(list(P.keys()), weights=list(P.values()))[0]
    return next_state, RM(state, action, next_state)

In [5]:
def random_policy(state: State) -> Action:
    """Select a random action from the available actions for the given state."""
    return random.choice(AVAILABLE_ACTIONS[state])


def sample_trajectory(
    policy: Callable[[State], Action] = random_policy,
    num_steps: int = 1000,
) -> List[List[Tuple[State, Action, float, State]]]:
    """Sample a trajectory, resetting to S0 when reaching terminal states."""
    trajectories = []
    state = State.S0

    trajectory = []
    for _ in range(num_steps):
        if state in TERMINAL_STATES:
            # Reset to initial state and append to trajectory
            trajectories.append(trajectory)
            trajectory = []
            state = State.S0

        action = random_policy(state)
        next_state, reward = mdp_step(state, action, transition_model, reward_model)
        trajectory.append((state, action, reward, next_state))
        state = next_state

    return trajectories


trajectories = sample_trajectory(num_steps=1000)

for trajectory in trajectories[:3]:
    for state, action, reward, next_state in trajectory:
        print(
            f"State: {state.name}, Action: {action.name}, Reward: {reward}, Next: {next_state.name}"
        )
    print()

State: S0, Action: ASK_WEATHER, Reward: 0.0, Next: W
State: W, Action: ASK_MOOD, Reward: 0.0, Next: WM
State: WM, Action: RECOMMEND, Reward: 10.0, Next: SUCCESS

State: S0, Action: ASK_MOOD, Reward: 0.0, Next: S0
State: S0, Action: ASK_MOOD, Reward: 0.0, Next: S0
State: S0, Action: ASK_BOTH, Reward: 0.0, Next: W
State: W, Action: ASK_MOOD, Reward: 0.0, Next: WM
State: WM, Action: RECOMMEND, Reward: 10.0, Next: SUCCESS

State: S0, Action: ASK_MOOD, Reward: 0.0, Next: M
State: M, Action: ASK_WEATHER, Reward: 0.0, Next: WM
State: WM, Action: RECOMMEND, Reward: 10.0, Next: SUCCESS



# Estimating the Transition Model and the Reward Model

In [6]:
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO, autoguide
import optax

# Define valid state-action pairs and their possible next states
VALID_SA_PAIRS = list(TRANSITIONS.keys())
SA_TO_IDX = {sa: i for i, sa in enumerate(VALID_SA_PAIRS)}

# For each valid (s, a), store the possible next states
SA_NEXT_STATES = {sa: list(TRANSITIONS[sa].keys()) for sa in VALID_SA_PAIRS}

# Build mapping from (s, a, s') -> unique index for rewards
SAS_TRIPLES = []
SAS_TO_IDX = {}
for sa in VALID_SA_PAIRS:
    s, a = sa
    for ns in SA_NEXT_STATES[sa]:
        SAS_TO_IDX[(s, a, ns)] = len(SAS_TRIPLES)
        SAS_TRIPLES.append((s, a, ns))

# Number of states (for transition matrix columns)
NUM_STATES = len(State)

In [7]:
def preprocess_data(trajectories):
    """
    Preprocess trajectories into vectorized form for efficient inference.

    Returns:
        sa_indices: Index of the (s,a) pair for each transition
        next_state_indices: State index of s' (direct state value)
        sas_indices: Index of the (s,a,s') triple for each transition
        rewards: Reward values
    """
    sa_indices = []
    next_state_indices = []
    sas_indices = []
    rewards = []

    for trajectory in trajectories:
        for state, action, reward, next_state in trajectory:
            sa = (state, action)
            sa_idx = SA_TO_IDX[sa]
            sas_idx = SAS_TO_IDX[(state, action, next_state)]

            sa_indices.append(sa_idx)
            next_state_indices.append(int(next_state))
            sas_indices.append(sas_idx)
            rewards.append(float(reward))

    return (
        jnp.array(sa_indices),
        jnp.array(next_state_indices),
        jnp.array(sas_indices),
        jnp.array(rewards),
    )


sa_indices, next_state_indices, sas_indices, rewards_data = preprocess_data(
    trajectories
)
print(f"Number of transitions: {len(sa_indices)}")

Number of transitions: 999


In [8]:
SA_SIZES = [len(SA_NEXT_STATES[sa]) for sa in VALID_SA_PAIRS]


def mdp_model(sa_indices, next_state_indices, sas_indices, rewards):
    """
    Probabilistic model for MDP learning.

    Transition model:
        θ_{s,a} ~ Dirichlet(1)
        s' | s, a ~ Categorical(θ_{s,a})

    Reward model:
        μ_{s,a,s'} ~ Normal(0, 10)
        σ ~ HalfNormal(1)
        r | s, a, s' ~ Normal(μ_{s,a,s'}, σ)
    """
    num_sas = len(SAS_TRIPLES)

    # Transition probabilities
    trans_probs_matrix = jnp.zeros((len(VALID_SA_PAIRS), NUM_STATES))

    for i, sa in enumerate(VALID_SA_PAIRS):
        num_next = SA_SIZES[i]
        probs = numpyro.sample(f"trans_prob_{i}", dist.Dirichlet(jnp.ones(num_next)))

        # Place probabilities into correct state columns
        next_state_ids = jnp.array([int(s) for s in SA_NEXT_STATES[sa]])
        trans_probs_matrix = trans_probs_matrix.at[i, next_state_ids].set(probs)

    # Reward model parameters
    reward_means = numpyro.sample("reward_means", dist.Normal(0, 10).expand([num_sas]))
    reward_noise = numpyro.sample("reward_noise", dist.HalfNormal(1.0))

    # Likelihoods
    with numpyro.plate("data", len(sa_indices)):
        # Transition likelihood
        obs_probs = trans_probs_matrix[sa_indices]
        numpyro.sample(
            "obs_next_state", dist.Categorical(probs=obs_probs), obs=next_state_indices
        )

        # Reward likelihood
        obs_reward_means = reward_means[sas_indices]
        numpyro.sample(
            "obs_reward", dist.Normal(obs_reward_means, reward_noise), obs=rewards
        )


In [None]:
# Run inference using SVI (faster than MCMC for this model)
rng_key = jax.random.PRNGKey(0)

print("Running SVI inference...")
guide = autoguide.AutoNormal(mdp_model)
optimizer = optax.adam(0.01)
svi = SVI(mdp_model, guide, optimizer, loss=Trace_ELBO())

# Run optimization
svi_result = svi.run(
    rng_key,
    num_steps=2000,
    sa_indices=sa_indices,
    next_state_indices=next_state_indices,
    sas_indices=sas_indices,
    rewards=rewards_data,
    progress_bar=True,
)

# Get posterior samples from the guide
rng_key, sample_key = jax.random.split(rng_key)
predictive = numpyro.infer.Predictive(guide, params=svi_result.params, num_samples=1000)
samples = predictive(
    sample_key, sa_indices, next_state_indices, sas_indices, rewards_data
)


Running SVI inference...


100%|██████████| 2000/2000 [00:02<00:00, 919.25it/s, init loss: 4405.4521, avg. loss [1901-2000]: -4401.1343] 


In [None]:
print("\n" + "=" * 60)
print("LEARNED TRANSITION PROBABILITIES")
print("=" * 60)

for i, sa in enumerate(VALID_SA_PAIRS):
    s, a = sa
    possible_next = SA_NEXT_STATES[sa]

    probs_samples = samples[f"trans_prob_{i}"]
    probs_mean = jnp.mean(probs_samples, axis=0)
    probs_std = jnp.std(probs_samples, axis=0)

    print(f"\nP(s' | s={State(s).name}, a={Action(a).name}):")
    print(f"  {'True':>12}  {'Learned':>12}  {'Std':>8}")

    true_probs = TRANSITIONS[sa]
    for idx, ns in enumerate(possible_next):
        true_p = true_probs.get(ns, 0.0)
        learned_p = float(probs_mean[idx])
        std_p = float(probs_std[idx])
        print(
            f"  {State(ns).name:>8}: {true_p:>8.3f}  {learned_p:>12.3f}  {std_p:>8.4f}"
        )


LEARNED TRANSITION PROBABILITIES

P(s' | s=S0, a=ASK_WEATHER):
          True       Learned       Std
         W:    0.750         0.735    0.0375
        S0:    0.150         0.170    0.0339
   ABANDON:    0.100         0.095    0.0268

P(s' | s=S0, a=ASK_MOOD):
          True       Learned       Std
         M:    0.700         0.708    0.0371
        S0:    0.200         0.186    0.0310
   ABANDON:    0.100         0.106    0.0241

P(s' | s=S0, a=ASK_BOTH):
          True       Learned       Std
        WM:    0.500         0.475    0.0433
         W:    0.200         0.232    0.0343
         M:    0.150         0.149    0.0284
   ABANDON:    0.150         0.144    0.0280

P(s' | s=W, a=ASK_MOOD):
          True       Learned       Std
        WM:    0.800         0.755    0.0349
         W:    0.100         0.125    0.0263
   ABANDON:    0.100         0.120    0.0261

P(s' | s=M, a=ASK_WEATHER):
          True       Learned       Std
        WM:    0.800         0.800    0.0292
  

In [None]:
print("\n" + "=" * 60)
print("LEARNED REWARD PARAMETERS")
print("=" * 60)

reward_means_samples = samples["reward_means"]
reward_noise_samples = samples["reward_noise"]

print(f"Shared reward noise σ: {float(jnp.mean(reward_noise_samples)):.4f}")
print()

for sa in VALID_SA_PAIRS:
    s, a = sa
    for ns in SA_NEXT_STATES[sa]:
        sas_idx = SAS_TO_IDX[(s, a, ns)]

        learned_mean = float(jnp.mean(reward_means_samples[:, sas_idx]))
        mean_std = float(jnp.std(reward_means_samples[:, sas_idx]))

        true_reward = reward_model(State(s), Action(a), State(ns))

        print(
            f"R(s={State(s).name}, a={Action(a).name}, s'={State(ns).name}): "
            f"True={true_reward:.1f}, Learned μ={learned_mean:.3f} (±{mean_std:.3f})"
        )


LEARNED REWARD PARAMETERS
Shared reward noise σ: 0.0012

R(s=S0, a=ASK_WEATHER, s'=W): True=0.0, Learned μ=-0.001 (±0.000)
R(s=S0, a=ASK_WEATHER, s'=S0): True=0.0, Learned μ=0.002 (±0.000)
R(s=S0, a=ASK_WEATHER, s'=ABANDON): True=0.0, Learned μ=0.001 (±0.001)
R(s=S0, a=ASK_MOOD, s'=M): True=0.0, Learned μ=0.000 (±0.000)
R(s=S0, a=ASK_MOOD, s'=S0): True=0.0, Learned μ=0.001 (±0.000)
R(s=S0, a=ASK_MOOD, s'=ABANDON): True=0.0, Learned μ=-0.001 (±0.000)
R(s=S0, a=ASK_BOTH, s'=WM): True=0.0, Learned μ=-0.001 (±0.000)
R(s=S0, a=ASK_BOTH, s'=W): True=0.0, Learned μ=0.001 (±0.000)
R(s=S0, a=ASK_BOTH, s'=M): True=0.0, Learned μ=-0.001 (±0.000)
R(s=S0, a=ASK_BOTH, s'=ABANDON): True=0.0, Learned μ=-0.001 (±0.000)
R(s=W, a=ASK_MOOD, s'=WM): True=0.0, Learned μ=0.000 (±0.000)
R(s=W, a=ASK_MOOD, s'=W): True=0.0, Learned μ=0.002 (±0.000)
R(s=W, a=ASK_MOOD, s'=ABANDON): True=0.0, Learned μ=0.000 (±0.000)
R(s=M, a=ASK_WEATHER, s'=WM): True=0.0, Learned μ=0.001 (±0.000)
R(s=M, a=ASK_WEATHER, s'=M): Tru