In [35]:
import time

import jax.lax as lax
from jax import random
import jax.numpy as jnp

from make_mdp import MDP


In [36]:
num_states = 100
num_actions = 20
num_rewards = 10

discount_factor = 0.9

seed = 42

In [37]:
key = random.key(seed)

In [38]:
#make the MDP
key, subkey = random.split(key)
mdp = MDP(subkey, num_states, num_actions, num_rewards)
del subkey

In [39]:
#policy matrix: num_states X num_actions with sum(policy[i,:]) == 1
key, subkey = random.split(key)
policy = random.uniform(subkey, [num_states, num_actions])
del subkey

policy = policy / policy.sum(axis=-1, keepdims=True)

### Iterative Policy Evaluation

In [40]:
eps = 10**(-5)

In [41]:
#random initialization
key, subkey = random.split(key)
v_pi_0 = random.uniform(subkey, num_states, dtype=jnp.float32)

In Sutton and Barto we have the update rule (p.74):
$$v_{k+1}(s) = \sum_a \pi(a|s) \sum_{s',r} p(s',r|s,a) [r + \gamma v_k(s')]$$

Since our MDP is not defined via the four argument probability $p(s',r|s,a)$ but rather state transitions $p(s'|s,a)$ and reward probabilities $p(r|s,a)$ we rewrite this as:

$$
\begin{aligned}
v_{k+1}(s) &= \sum_a \pi(a|s) \sum_{s',r} p(s',r|s,a) [r + \gamma v_k(s')]\\
           &= \sum_a \pi(a|s) [\sum_{s',r} p(s',r|s,a) r + \gamma \sum_{s',r} p(s',r|s,a) v_k(s')]\\
           &= \sum_a \pi(a|s) [\sum_{r} r \sum_{s'} p(s',r|s,a) + \gamma \sum_{s'} v_k(s') sum_{r} p(s',r|s,a) ]\\
           &= \sum_a \pi(a|s) [\sum_{r} r p(r|s,a) + \gamma \sum_{s'} v_k(s') p(s'|s,a) ]\\
\end{aligned}
$$

$\sum_{r} r p(r|s,a)$ is simply the expected reward when choosing action $a$ at state $s$. This is pre-computed as `mdp.expected_rewards`.
$\gamma \sum_{s'} v_k(s') p(s'|s,a)$ is the discounted expectation of the value over the next states when choosing action $a$. 


Simple while-loop implementation (synchronous updates)

In [42]:
v_pi = jnp.copy(v_pi_0)
num_iterations = 0

start = time.time()

while True:
    num_iterations += 1
    
    delta = 0
    v_pi_old = jnp.copy(v_pi)
    for s in range(num_states):
        v_pi_s = 0
        for a in range(num_actions):
            policy_s_a = policy[s, a]
            v_pi_s += policy_s_a * mdp.expected_rewards[a, s]
            v_pi_s += policy_s_a * discount_factor * jnp.dot(v_pi_old, mdp.transition_ps[s, a, :])
        v_pi = v_pi.at[s].set(v_pi_s)

        delta = max(delta, jnp.abs(v_pi_old[s] - v_pi[s]))
    
    if delta < eps:
        break

_ = v_pi.block_until_ready()

end = time.time()

elapsed_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_loop:.6f} seconds")

Iterations 102 - Delta 0.00001 - Time: 72.472999 seconds


Lax-While Loop

In [43]:
def iterative_value_estimation(state):
    delta, v_pi = state
    delta = 0
    v_pi_old = jnp.copy(v_pi)
    for s in range(num_states):
        v_pi_s = 0
        for a in range(num_actions):
            policy_s_a = policy[s, a]
            v_pi_s += policy_s_a * mdp.expected_rewards[a, s]
            v_pi_s += policy_s_a * discount_factor * jnp.dot(v_pi_old, mdp.transition_ps[s, a, :])
        v_pi = v_pi.at[s].set(v_pi_s)

        delta = jnp.maximum(delta, jnp.abs(v_pi_old[s] - v_pi[s]))
    return (delta, v_pi)

def cond_function(state):
    delta, v_pi = state
    return delta > eps

v_pi = jnp.copy(v_pi_0)
init_state = (1e13, v_pi)

start = time.time()
lax.while_loop(cond_function, iterative_value_estimation, init_state)
end = time.time()
elapsed_lax_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_lax_loop:.6f} seconds")

Iterations 102 - Delta 0.00001 - Time: 21.894257 seconds


Vectorized update in lax loop


In [45]:
def vectorized_value_estimation(state):
    delta, v_pi = state
    expected_r = jnp.sum(policy * mdp.expected_rewards, axis=1)
    expected_v = jnp.sum(policy[:, :, None] * mdp.transition_ps * v_pi[None, None, :], axis=(1,2)) 
    v_pi_new = expected_r + discount_factor * expected_v
    delta = jnp.max(jnp.abs(v_pi_new - v_pi))
    return (delta, v_pi_new)

def cond_function(state):
    delta, v_pi = state
    return delta > eps


v_pi = jnp.copy(v_pi_0)
init_state = (1e13, v_pi)

start = time.time()
lax.while_loop(cond_function, vectorized_value_estimation, init_state)
end = time.time()
elapsed_vectorized_lax_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_vectorized_lax_loop:.6f} seconds")

Iterations 102 - Delta 0.00001 - Time: 0.133021 seconds


In [46]:
print(f'Simple Loop: {elapsed_loop:.6f} seconds')
print(f'Lax Loop: {elapsed_lax_loop:.6f} seconds')
print(f'Vectorized Lax Loop: {elapsed_vectorized_lax_loop:.6f} seconds')

Simple Loop: 72.472999 seconds
Lax Loop: 21.894257 seconds
Vectorized Lax Loop: 0.133021 seconds
