In [1]:
import time

import jax.lax as lax
from jax import random
import jax.numpy as jnp

from make_mdp import MDP


In [2]:
num_states = 100
num_actions = 20
num_rewards = 10

discount_factor = 0.9

seed = 42

In [3]:
key = random.key(seed)

In [4]:
#make the MDP
key, subkey = random.split(key)
mdp = MDP(subkey, num_states, num_actions, num_rewards)
del subkey

In [5]:
#policy matrix: num_states X num_actions with sum(policy[i,:]) == 1
key, subkey = random.split(key)
policy = random.uniform(subkey, [num_states, num_actions])
del subkey

policy = policy / policy.sum(axis=-1, keepdims=True)

### Iterative Policy Evaluation

In [6]:
eps = 10**(-5)

In [7]:
#random initialization
key, subkey = random.split(key)
v_pi_0 = random.uniform(subkey, num_states, dtype=jnp.float32)

In Sutton and Barto we have the update rule (p.74):
$$v_{k+1}(s) = \sum_a \pi(a|s) \sum_{s',r} p(s',r|s,a) [r + \gamma v_k(s')]$$

Since our MDP is not defined via the four argument probability $p(s',r|s,a)$ but rather state transitions $p(s'|s,a)$ and reward probabilities $p(r|s,a)$ we rewrite this as:

$$
\begin{aligned}
v_{k+1}(s) &= \sum_a \pi(a|s) \sum_{s',r} p(s',r|s,a) [r + \gamma v_k(s')]\\
           &= \sum_a \pi(a|s) [\sum_{s',r} p(s',r|s,a) r + \gamma \sum_{s',r} p(s',r|s,a) v_k(s')]\\
           &= \sum_a \pi(a|s) [\sum_{r} r \sum_{s'} p(s',r|s,a) + \gamma \sum_{s'} v_k(s') sum_{r} p(s',r|s,a) ]\\
           &= \sum_a \pi(a|s) [\sum_{r} r p(r|s,a) + \gamma \sum_{s'} v_k(s') p(s'|s,a) ]\\
\end{aligned}
$$

$\sum_{r} r p(r|s,a)$ is simply the expected reward when choosing action $a$ at state $s$. This is pre-computed as `mdp.expected_rewards`.
$\gamma \sum_{s'} v_k(s') p(s'|s,a)$ is the discounted expectation of the value over the next states when choosing action $a$. 


Simple while-loop implementation (synchronous updates)

In [8]:
v_pi = jnp.copy(v_pi_0)
num_iterations = 0

start = time.time()

while True:
    num_iterations += 1
    
    delta = 0
    v_pi_old = jnp.copy(v_pi)
    for s in range(num_states):
        v_pi_s = 0
        for a in range(num_actions):
            policy_s_a = policy[s, a]
            v_pi_s += policy_s_a * mdp.expected_rewards[s,a]
            v_pi_s += policy_s_a * discount_factor * jnp.dot(v_pi_old, mdp.transition_ps[s, a, :])
        v_pi = v_pi.at[s].set(v_pi_s)

        delta = max(delta, jnp.abs(v_pi_old[s] - v_pi[s]))
    
    if delta < eps:
        break

_ = v_pi.block_until_ready()

end = time.time()

elapsed_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_loop:.6f} seconds")
print(f"Value function:\n{v_pi}")


Iterations 102 - Delta 0.00001 - Time: 47.114200 seconds
Value function:
[-3.2662337 -3.3299084 -3.3180332 -3.3560042 -3.3961449 -3.3184109
 -3.2981157 -3.248808  -3.2935073 -3.3894188 -3.235799  -3.3393154
 -3.2018504 -3.2801137 -3.3524525 -3.3262892 -3.3367422 -3.232171
 -3.3092735 -3.3093662 -3.2834554 -3.3123755 -3.2620194 -3.3356605
 -3.2462163 -3.382313  -3.3461382 -3.285791  -3.3369431 -3.247228
 -3.2327912 -3.2193506 -3.3057892 -3.4036007 -3.263416  -3.30798
 -3.3137524 -3.2604678 -3.3193526 -3.1980562 -3.3333912 -3.285531
 -3.2656493 -3.3021815 -3.349827  -3.3334477 -3.3454044 -3.303187
 -3.3416786 -3.314237  -3.391325  -3.3518615 -3.293537  -3.3242981
 -3.2234306 -3.3435624 -3.3045456 -3.459414  -3.296707  -3.221883
 -3.2616482 -3.346547  -3.31943   -3.3207636 -3.3269696 -3.3348448
 -3.232609  -3.3307145 -3.302874  -3.3285286 -3.2192764 -3.2927485
 -3.3426578 -3.3128083 -3.2812192 -3.3572512 -3.2607646 -3.2218995
 -3.3352025 -3.3484054 -3.3573203 -3.2814257 -3.3321078 -3.3095

Lax-While Loop

In [9]:
def iterative_value_estimation(state):
    delta, v_pi, iteration = state
    delta = 0
    iteration += 1
    v_pi_old = jnp.copy(v_pi)
    for s in range(num_states):
        v_pi_s = 0
        for a in range(num_actions):
            policy_s_a = policy[s, a]
            v_pi_s += policy_s_a * mdp.expected_rewards[s,a]
            v_pi_s += policy_s_a * discount_factor * jnp.dot(v_pi_old, mdp.transition_ps[s, a, :])
        v_pi = v_pi.at[s].set(v_pi_s)

        delta = jnp.maximum(delta, jnp.abs(v_pi_old[s] - v_pi[s]))
    return (delta, v_pi, iteration)

def cond_function(state):
    delta, v_pi, _ = state
    return delta > eps

v_pi = jnp.copy(v_pi_0)
init_state = (1e13, v_pi, 0)

start = time.time()
delta, v_pi, num_iterations = lax.while_loop(cond_function, iterative_value_estimation, init_state)
end = time.time()

elapsed_lax_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_lax_loop:.6f} seconds")
print(f"Value function:\n{v_pi}")

Iterations 102 - Delta 0.00001 - Time: 22.735680 seconds
Value function:
[-3.2662337 -3.3299081 -3.3180335 -3.3560042 -3.3961449 -3.3184109
 -3.2981157 -3.248808  -3.2935073 -3.3894188 -3.235799  -3.3393154
 -3.2018502 -3.280114  -3.3524525 -3.3262892 -3.3367424 -3.2321708
 -3.3092732 -3.3093665 -3.2834554 -3.3123758 -3.2620196 -3.3356605
 -3.246216  -3.382313  -3.3461382 -3.285791  -3.3369431 -3.2472274
 -3.2327914 -3.2193506 -3.3057892 -3.4036007 -3.263416  -3.3079803
 -3.3137522 -3.2604678 -3.3193529 -3.1980562 -3.3333912 -3.2855313
 -3.2656496 -3.3021815 -3.349827  -3.3334477 -3.3454046 -3.303187
 -3.3416789 -3.3142374 -3.391325  -3.3518615 -3.293537  -3.324298
 -3.2234306 -3.3435624 -3.3045456 -3.4594138 -3.296707  -3.221883
 -3.2616484 -3.346547  -3.31943   -3.3207636 -3.3269696 -3.3348446
 -3.232609  -3.3307145 -3.302874  -3.3285286 -3.2192767 -3.2927487
 -3.342658  -3.3128083 -3.2812192 -3.357251  -3.2607646 -3.2218995
 -3.3352025 -3.3484054 -3.3573205 -3.2814257 -3.332108  -3.

Vectorized update in lax loop


In [10]:
def vectorized_value_estimation(state):
    delta, v_pi, iteration = state
    iteration += 1
    expected_r = jnp.sum(policy * mdp.expected_rewards, axis=1)
    expected_v = jnp.sum(policy[:, :, None] * mdp.transition_ps * v_pi[None, None, :], axis=(1,2)) 
    v_pi_new = expected_r + discount_factor * expected_v
    delta = jnp.max(jnp.abs(v_pi_new - v_pi))
    return (delta, v_pi_new, iteration)

def cond_function(state):
    delta, v_pi, _ = state
    return delta > eps

v_pi = jnp.copy(v_pi_0)
init_state = (1e13, v_pi, 0)

start = time.time()
delta, v_pi, num_iterations = lax.while_loop(cond_function, vectorized_value_estimation, init_state)
end = time.time()

elapsed_vectorized_lax_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_vectorized_lax_loop:.6f} seconds")
print(f"Value function:\n{v_pi}")

Iterations 102 - Delta 0.00001 - Time: 0.058187 seconds
Value function:
[-3.2662337 -3.3299072 -3.3180332 -3.3560035 -3.396145  -3.3184102
 -3.2981155 -3.248808  -3.2935073 -3.3894188 -3.2357981 -3.339315
 -3.2018507 -3.280114  -3.3524525 -3.326289  -3.3367424 -3.2321703
 -3.3092737 -3.3093657 -3.283455  -3.3123758 -3.2620194 -3.3356597
 -3.2462156 -3.382313  -3.346137  -3.2857904 -3.3369431 -3.2472284
 -3.2327907 -3.2193508 -3.305789  -3.4036007 -3.2634163 -3.3079803
 -3.3137517 -3.2604675 -3.3193517 -3.198056  -3.3333912 -3.2855313
 -3.2656493 -3.3021815 -3.3498266 -3.333447  -3.3454049 -3.3031862
 -3.341679  -3.3142369 -3.3913245 -3.3518612 -3.2935371 -3.3242974
 -3.2234302 -3.3435621 -3.3045452 -3.459414  -3.2967062 -3.2218835
 -3.2616484 -3.3465466 -3.3194304 -3.3207629 -3.3269691 -3.3348444
 -3.232609  -3.330715  -3.302874  -3.3285284 -3.219276  -3.2927477
 -3.3426578 -3.3128092 -3.2812183 -3.3572512 -3.2607644 -3.221899
 -3.3352032 -3.3484051 -3.3573198 -3.281426  -3.3321078 -3.

In [11]:
print(f'Simple Loop: {elapsed_loop:.6f} seconds')
print(f'Lax Loop: {elapsed_lax_loop:.6f} seconds')
print(f'Vectorized Lax Loop: {elapsed_vectorized_lax_loop:.6f} seconds')

Simple Loop: 47.114200 seconds
Lax Loop: 22.735680 seconds
Vectorized Lax Loop: 0.058187 seconds
