In [17]:
import time

import jax.lax as lax
from jax import random
import jax.numpy as jnp

from make_mdp import MDP


In [18]:
num_states = 100
num_actions = 20
num_rewards = 10

discount_factor = 0.9

seed = 42

In [19]:
key = random.key(seed)

In [20]:
#make the MDP
key, subkey = random.split(key)
mdp = MDP(subkey, num_states, num_actions, num_rewards)
del subkey

In [21]:
#policy matrix: num_states X num_actions with sum(policy[i,:]) == 1
key, subkey = random.split(key)
policy = random.uniform(subkey, [num_states, num_actions])
del subkey

policy = policy / policy.sum(axis=-1, keepdims=True)

### Iterative Policy Evaluation

In [22]:
eps = 10**(-8)

In [23]:
#random initialization
key, subkey = random.split(key)
v_pi_0 = random.uniform(subkey, num_states, dtype=jnp.float32)

In Sutton and Barto we have the update rule (p.74):
$$v_{k+1}(s) = \sum_a \pi(a|s) \sum_{s',r} p(s',r|s,a) [r + \gamma v_k(s')]$$

Since our MDP is not defined via the four argument probability $p(s',r|s,a)$ but rather state transitions $p(s'|s,a)$ and reward probabilities $p(r|s,a)$ we rewrite this as:

$$
\begin{aligned}
v_{k+1}(s) &= \sum_a \pi(a|s) \sum_{s',r} p(s',r|s,a) [r + \gamma v_k(s')]\\
           &= \sum_a \pi(a|s) [\sum_{s',r} p(s',r|s,a) r + \gamma \sum_{s',r} p(s',r|s,a) v_k(s')]\\
           &= \sum_a \pi(a|s) [\sum_{r} r \sum_{s'} p(s',r|s,a) + \gamma \sum_{s'} v_k(s') \sum_{r} p(s',r|s,a) ]\\
           &= \sum_a \pi(a|s) [\sum_{r} r p(r|s,a) + \gamma \sum_{s'} v_k(s') p(s'|s,a) ]\\
\end{aligned}
$$

$\sum_{r} r p(r|s,a)$ is simply the expected reward when choosing action $a$ at state $s$. This is pre-computed as `mdp.expected_rewards`.
$\gamma \sum_{s'} v_k(s') p(s'|s,a)$ is the discounted expectation of the value over the next states when choosing action $a$. 


Simple while-loop implementation (synchronous updates)

In [24]:
v_pi = jnp.copy(v_pi_0)
num_iterations = 0

start = time.time()

while True:
    num_iterations += 1
    
    delta = 0
    v_pi_old = jnp.copy(v_pi)
    for s in range(num_states):
        v_pi_s = 0
        for a in range(num_actions):
            policy_s_a = policy[s, a]
            v_pi_s += policy_s_a * mdp.expected_rewards[s,a]
            v_pi_s += policy_s_a * discount_factor * jnp.dot(v_pi_old, mdp.transition_ps[s, a, :])
        v_pi = v_pi.at[s].set(v_pi_s)

        delta = max(delta, jnp.abs(v_pi_old[s] - v_pi[s]))
    
    if delta < eps:
        break

_ = v_pi.block_until_ready()

end = time.time()

elapsed_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_loop:.6f} seconds")
print(f"Value function:\n{v_pi}")


Iterations 167 - Delta 0.00000 - Time: 79.704119 seconds
Value function:
[-3.2663145 -3.329989  -3.3181143 -3.3560846 -3.396226  -3.3184915
 -3.2981966 -3.2488885 -3.2935884 -3.3894997 -3.2358794 -3.3393958
 -3.2019312 -3.2801948 -3.352533  -3.3263695 -3.3368227 -3.2322514
 -3.3093536 -3.309447  -3.283536  -3.3124564 -3.2621002 -3.3357408
 -3.2462966 -3.3823938 -3.3462186 -3.285872  -3.3370242 -3.247308
 -3.2328722 -3.219432  -3.3058698 -3.4036813 -3.2634966 -3.308061
 -3.3138328 -3.260548  -3.319434  -3.1981363 -3.3334718 -3.2856119
 -3.2657304 -3.302262  -3.349907  -3.3335283 -3.3454852 -3.3032675
 -3.3417597 -3.3143177 -3.3914058 -3.3519425 -3.293618  -3.324379
 -3.223511  -3.3436432 -3.304626  -3.4594944 -3.2967875 -3.2219634
 -3.261729  -3.3466275 -3.3195114 -3.3208437 -3.3270507 -3.334924
 -3.2326894 -3.3307955 -3.3029542 -3.32861   -3.2193573 -3.2928288
 -3.3427389 -3.3128898 -3.2813003 -3.3573318 -3.2608452 -3.2219803
 -3.3352835 -3.348486  -3.357401  -3.2815063 -3.3321888 -3.3

Lax-While Loop

In [25]:
def iterative_value_estimation(state):
    delta, v_pi, iteration = state
    delta = 0
    iteration += 1
    v_pi_old = jnp.copy(v_pi)
    for s in range(num_states):
        v_pi_s = 0
        for a in range(num_actions):
            policy_s_a = policy[s, a]
            v_pi_s += policy_s_a * mdp.expected_rewards[s,a]
            v_pi_s += policy_s_a * discount_factor * jnp.dot(v_pi_old, mdp.transition_ps[s, a, :])
        v_pi = v_pi.at[s].set(v_pi_s)

        delta = jnp.maximum(delta, jnp.abs(v_pi_old[s] - v_pi[s]))
    return (delta, v_pi, iteration)

def cond_function(state):
    delta, v_pi, _ = state
    return delta > eps

v_pi = jnp.copy(v_pi_0)
init_state = (1e13, v_pi, 0)

start = time.time()
delta, v_pi, num_iterations = lax.while_loop(cond_function, iterative_value_estimation, init_state)
end = time.time()

elapsed_lax_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_lax_loop:.6f} seconds")
print(f"Value function:\n{v_pi}")

Iterations 164 - Delta 0.00000 - Time: 24.065244 seconds
Value function:
[-3.2663147 -3.3299887 -3.318114  -3.3560846 -3.396226  -3.3184915
 -3.2981966 -3.2488885 -3.2935884 -3.3894997 -3.2358794 -3.3393958
 -3.2019312 -3.2801945 -3.352533  -3.3263695 -3.3368227 -3.2322514
 -3.3093536 -3.309447  -3.2835362 -3.3124564 -3.2621    -3.3357408
 -3.2462966 -3.3823936 -3.3462186 -3.285872  -3.3370242 -3.247308
 -3.2328722 -3.2194319 -3.3058698 -3.4036813 -3.2634964 -3.308061
 -3.3138328 -3.260548  -3.3194335 -3.198136  -3.333472  -3.2856119
 -3.2657304 -3.302262  -3.3499074 -3.3335283 -3.3454852 -3.3032675
 -3.3417594 -3.3143177 -3.3914058 -3.3519425 -3.2936177 -3.3243787
 -3.2235112 -3.3436432 -3.304626  -3.459494  -3.2967875 -3.2219636
 -3.261729  -3.3466275 -3.3195114 -3.3208437 -3.3270507 -3.334924
 -3.2326896 -3.3307953 -3.3029542 -3.3286095 -3.2193573 -3.2928288
 -3.3427389 -3.3128898 -3.2812996 -3.3573318 -3.2608452 -3.22198
 -3.3352835 -3.348486  -3.3574011 -3.2815063 -3.3321886 -3.30

Vectorized update in lax loop


In [26]:
def vectorized_value_estimation(state):
    delta, v_pi, iteration = state
    iteration += 1
    expected_r = jnp.sum(policy * mdp.expected_rewards, axis=1)
    expected_v = jnp.sum(policy[:, :, None] * mdp.transition_ps * v_pi[None, None, :], axis=(1,2)) 
    v_pi_new = expected_r + discount_factor * expected_v
    delta = jnp.max(jnp.abs(v_pi_new - v_pi))
    return (delta, v_pi_new, iteration)

def cond_function(state):
    delta, v_pi, _ = state
    return delta > eps

v_pi = jnp.copy(v_pi_0)
init_state = (1e13, v_pi, 0)

start = time.time()
delta, v_pi, num_iterations = lax.while_loop(cond_function, vectorized_value_estimation, init_state)
end = time.time()

elapsed_vectorized_lax_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_vectorized_lax_loop:.6f} seconds")
print(f"Value function:\n{v_pi}")

Iterations 166 - Delta 0.00000 - Time: 0.047804 seconds
Value function:
[-3.2663147 -3.329988  -3.3181143 -3.3560848 -3.3962255 -3.3184915
 -3.298196  -3.2488887 -3.293588  -3.3895004 -3.2358797 -3.3393962
 -3.2019315 -3.2801945 -3.352533  -3.3263698 -3.3368225 -3.232251
 -3.309354  -3.3094468 -3.2835362 -3.3124566 -3.2621007 -3.335741
 -3.2462964 -3.382394  -3.3462183 -3.2858715 -3.3370245 -3.2473085
 -3.2328722 -3.219432  -3.3058703 -3.4036813 -3.2634969 -3.3080614
 -3.3138323 -3.2605484 -3.3194337 -3.1981366 -3.3334718 -3.285612
 -3.2657304 -3.3022628 -3.3499076 -3.3335285 -3.3454866 -3.3032672
 -3.3417604 -3.3143175 -3.3914058 -3.3519423 -3.2936182 -3.3243787
 -3.2235107 -3.3436432 -3.3046265 -3.4594948 -3.2967873 -3.221964
 -3.2617295 -3.3466275 -3.319511  -3.320844  -3.3270502 -3.3349245
 -3.2326899 -3.3307958 -3.3029554 -3.3286097 -3.2193568 -3.2928293
 -3.3427386 -3.31289   -3.2812994 -3.3573322 -3.2608454 -3.2219803
 -3.3352838 -3.3484857 -3.357401  -3.2815065 -3.3321886 -3.30

Alternatively, we can directly solve the underlying linear system.
The original Bellman equation was:

$$v_\pi(s) = \sum_a \pi(a|s) [\sum_r r p(r|s,a) + \gamma \sum_{s'} v_\pi(s') p(s'|s,a)]$$

Which we can rewrite as:

$$\sum_a \pi(a|s) \sum_r r p(r|s,a) = v_\pi(s) - \gamma \sum_{s'} v_\pi(s') \sum_a \pi(a|s) p(s'|s,a)$$

For a finite MDP, this can be rewritten as the linear system:

$$\mathbf{r} = (\mathbf{I} - \gamma \mathbf{P}) \mathbf{v}$$

where:

$$
\begin{align}
    \mathbf{r}_i = \sum_a \pi(a|s_i) \sum_r r p(r|s_i,a)\\
    \mathbf{P}_{i,j} = \sum_a \pi(a|s_i) p(s_j|s,a) \\
\end{align}
$$

 (we use $s_i$ to denote the $i$-th state).

In [27]:
start = time.time()
r = jnp.sum(policy * mdp.expected_rewards, axis=1)
P = jnp.sum( policy[:, :, None] * mdp.transition_ps, axis=1  )
v_pi = jnp.linalg.solve(jnp.eye(num_states) - discount_factor * P, r)
end = time.time()
elapsed_linear_system = end - start
print(f"Time: {elapsed_linear_system:.6f} seconds")
print(f"Value function:\n{v_pi}")


Time: 0.002164 seconds
Value function:
[-3.266314  -3.3299894 -3.3181143 -3.3560843 -3.396226  -3.318492
 -3.2981966 -3.2488885 -3.2935886 -3.3895006 -3.235879  -3.3393962
 -3.2019315 -3.280194  -3.3525338 -3.3263705 -3.3368232 -3.2322514
 -3.309354  -3.309448  -3.2835362 -3.3124561 -3.262101  -3.3357413
 -3.246297  -3.3823943 -3.3462183 -3.2858713 -3.337024  -3.2473087
 -3.2328722 -3.2194319 -3.3058708 -3.403682  -3.2634974 -3.3080616
 -3.3138325 -3.2605488 -3.3194335 -3.1981368 -3.3334723 -3.2856119
 -3.2657301 -3.3022628 -3.3499079 -3.3335283 -3.3454857 -3.3032675
 -3.3417597 -3.3143177 -3.3914056 -3.3519423 -3.2936182 -3.3243787
 -3.223511  -3.3436434 -3.3046255 -3.4594948 -3.2967877 -3.2219641
 -3.2617297 -3.3466284 -3.3195107 -3.3208444 -3.3270504 -3.3349252
 -3.23269   -3.3307958 -3.3029552 -3.3286092 -3.2193573 -3.2928288
 -3.3427396 -3.31289   -3.2813    -3.3573322 -3.2608452 -3.22198
 -3.3352842 -3.348486  -3.3574007 -3.2815065 -3.3321886 -3.3096225
 -3.384559  -3.2686405 -3.

In [28]:
print(f'Simple Loop: {elapsed_loop:.6f} seconds')
print(f'Lax Loop: {elapsed_lax_loop:.6f} seconds')
print(f'Vectorized Lax Loop: {elapsed_vectorized_lax_loop:.6f} seconds')
print(f'Linear system: {elapsed_linear_system:.6f} seconds')

Simple Loop: 79.704119 seconds
Lax Loop: 24.065244 seconds
Vectorized Lax Loop: 0.047804 seconds
Linear system: 0.002164 seconds
