In [5]:
import time

import jax.lax as lax
from jax import random
import jax.numpy as jnp
import jax.nn as jnn

from make_mdp import MDP

num_states = 100
num_actions = 20
num_rewards = 10

reward_mean = 5.0
reward_std = 10.0

discount_factor = 0.9

seed = 42

key = random.key(seed)

#make the MDP
key, subkey = random.split(key)
mdp = MDP(subkey, num_states, num_actions, num_rewards, reward_mean, reward_std)
del subkey

In [6]:
eps = 10**(-8)

def vectorized_value_estimation(state):
    delta, v_pi, policy = state
    expected_r = jnp.sum(policy * mdp.expected_rewards, axis=1)
    expected_v = jnp.sum(policy[:, :, None] * mdp.transition_ps * v_pi[None, None, :], axis=(1,2)) 
    v_pi_new = expected_r + discount_factor * expected_v
    delta = jnp.max(jnp.abs(v_pi_new - v_pi))
    return (delta, v_pi_new, policy)

def cond_function(state):
    delta, v_pi, _ = state
    return delta > eps

def calculate_value_function(policy, v_pi):
    init_state = (1e13, v_pi, policy)
    delta, v_pi, _ = lax.while_loop(cond_function, vectorized_value_estimation, init_state)
    return v_pi

In [7]:
#policy matrix: num_states X num_actions with sum(policy[i,:]) == 1
key, subkey = random.split(key)
policy = random.uniform(subkey, [num_states, num_actions])
del subkey

policy_0 = policy / policy.sum(axis=-1, keepdims=True)

#random value init
key, subkey = random.split(key)
v_pi_0 = random.uniform(subkey, num_states, dtype=jnp.float32)

Improve on the random policy by using iterative policy improvement:

$$\pi'(s) = \text{argmax}_a \ q_\pi(s,a) = \text{argmax}_a \sum_{s',r} p(s',r|s,a) [r + \gamma v_\pi(s')] = \text{argmax}_a [\sum_{r} p(r|s,a) r +  \gamma \sum_{s'} p(s'|s,a) v_\pi(s')]$$

In [9]:
num_iters = 10

policy = jnp.copy(policy_0)
v_pi = jnp.copy(v_pi_0)

for i in range(num_iters):
    v_pi = calculate_value_function(policy, v_pi)
    q_pi = mdp.expected_rewards + discount_factor * jnp.sum(mdp.transition_ps * v_pi[None, None, :], axis=-1)
    best_actions = jnp.argmax(q_pi, axis=-1)
    policy = jnn.one_hot(best_actions, num_actions) 
    
    mean_value = jnp.mean(v_pi)
    print(f'{i} - mean policy value {mean_value}')

0 - mean policy value 16.9290771484375
1 - mean policy value 47.43392562866211
2 - mean policy value 47.47530746459961
3 - mean policy value 47.47545623779297
4 - mean policy value 47.47545623779297
5 - mean policy value 47.47545623779297
6 - mean policy value 47.47545623779297
7 - mean policy value 47.47545623779297
8 - mean policy value 47.47545623779297
9 - mean policy value 47.47545623779297
