In [30]:
import time

import jax.lax as lax
from jax import random
import jax.numpy as jnp

from make_mdp import MDP


In [31]:
num_states = 30
num_actions = 15
num_rewards = 10

reward_mean = 0.0
reward_std = 10.0

discount_factor = 0.9

seed = 42

In [32]:
key = random.key(seed)

In [33]:
#make the MDP
key, subkey = random.split(key)
mdp = MDP(subkey, num_states, num_actions, num_rewards, reward_mean, reward_std)
del subkey

In [34]:
#policy matrix: num_states X num_actions with sum(policy[i,:]) == 1
key, subkey = random.split(key)
policy = random.uniform(subkey, [num_states, num_actions])
del subkey

policy = policy / policy.sum(axis=-1, keepdims=True)

### Iterative Policy Evaluation

In [35]:
eps = 10**(-8)

In [36]:
#random initialization
key, subkey = random.split(key)
v_pi_0 = random.uniform(subkey, num_states, dtype=jnp.float32)

In Sutton and Barto we have the update rule (p.74):
$$v_{k+1}(s) = \sum_a \pi(a|s) \sum_{s',r} p(s',r|s,a) [r + \gamma v_k(s')]$$

Since our MDP is not defined via the four argument probability $p(s',r|s,a)$ but rather state transitions $p(s'|s,a)$ and reward probabilities $p(r|s,a)$ we rewrite this as:

$$
\begin{aligned}
v_{k+1}(s) &= \sum_a \pi(a|s) \sum_{s',r} p(s',r|s,a) [r + \gamma v_k(s')]\\
           &= \sum_a \pi(a|s) [\sum_{s',r} p(s',r|s,a) r + \gamma \sum_{s',r} p(s',r|s,a) v_k(s')]\\
           &= \sum_a \pi(a|s) [\sum_{r} r \sum_{s'} p(s',r|s,a) + \gamma \sum_{s'} v_k(s') \sum_{r} p(s',r|s,a) ]\\
           &= \sum_a \pi(a|s) [\sum_{r} r p(r|s,a) + \gamma \sum_{s'} v_k(s') p(s'|s,a) ]\\
\end{aligned}
$$

$\sum_{r} r p(r|s,a)$ is simply the expected reward when choosing action $a$ at state $s$. This is pre-computed as `mdp.expected_rewards`.
$\gamma \sum_{s'} v_k(s') p(s'|s,a)$ is the discounted expectation of the value over the next states when choosing action $a$. 


Simple while-loop implementation (synchronous updates)

In [37]:
v_pi = jnp.copy(v_pi_0)
num_iterations = 0

start = time.time()

while True:
    num_iterations += 1
    
    delta = 0
    v_pi_old = jnp.copy(v_pi)
    for s in range(num_states):
        v_pi_s = 0
        for a in range(num_actions):
            policy_s_a = policy[s, a]
            v_pi_s += policy_s_a * mdp.expected_rewards[s,a]
            v_pi_s += policy_s_a * discount_factor * jnp.dot(v_pi_old, mdp.transition_ps[s, a, :])
        v_pi = v_pi.at[s].set(v_pi_s)

        delta = max(delta, jnp.abs(v_pi_old[s] - v_pi[s]))
    
    if delta < eps:
        break

_ = v_pi.block_until_ready()

end = time.time()

elapsed_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_loop:.6f} seconds")
print(f"Value function:\n{v_pi}")


Iterations 151 - Delta 0.00000 - Time: 20.032519 seconds
Value function:
[-3.35251   -3.3723989 -3.3082206 -3.3765604 -3.371306  -3.4147203
 -3.476385  -3.41075   -3.4429617 -3.3209078 -3.4015646 -3.35638
 -3.410884  -3.3508618 -3.3867407 -3.3969636 -3.3065453 -3.3682504
 -3.3791513 -3.4573994 -3.4402492 -3.4397793 -3.4604473 -3.289907
 -3.3306336 -3.2623203 -3.434849  -3.3700235 -3.4616327 -3.424651 ]


Lax-While Loop

In [38]:
def iterative_value_estimation(state):
    delta, v_pi, iteration = state
    delta = 0
    iteration += 1
    v_pi_old = jnp.copy(v_pi)
    for s in range(num_states):
        v_pi_s = 0
        for a in range(num_actions):
            policy_s_a = policy[s, a]
            v_pi_s += policy_s_a * mdp.expected_rewards[s,a]
            v_pi_s += policy_s_a * discount_factor * jnp.dot(v_pi_old, mdp.transition_ps[s, a, :])
        v_pi = v_pi.at[s].set(v_pi_s)

        delta = jnp.maximum(delta, jnp.abs(v_pi_old[s] - v_pi[s]))
    return (delta, v_pi, iteration)

def cond_function(state):
    delta, v_pi, _ = state
    return delta > eps

v_pi = jnp.copy(v_pi_0)
init_state = (1e13, v_pi, 0)

start = time.time()
delta, v_pi, num_iterations = lax.while_loop(cond_function, iterative_value_estimation, init_state)
end = time.time()

elapsed_lax_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_lax_loop:.6f} seconds")
print(f"Value function:\n{v_pi}")

Iterations 161 - Delta 0.00000 - Time: 4.081857 seconds
Value function:
[-3.3525105 -3.372399  -3.3082213 -3.3765607 -3.3713064 -3.4147208
 -3.4763856 -3.4107502 -3.4429622 -3.3209085 -3.4015653 -3.3563802
 -3.4108846 -3.350862  -3.3867414 -3.3969638 -3.3065462 -3.3682508
 -3.3791518 -3.4573998 -3.4402494 -3.4397798 -3.4604475 -3.2899072
 -3.3306339 -3.2623205 -3.4348495 -3.370024  -3.4616337 -3.424651 ]


Vectorized update in lax loop


In [39]:
def vectorized_value_estimation(state):
    delta, v_pi, iteration = state
    iteration += 1
    expected_r = jnp.sum(policy * mdp.expected_rewards, axis=1)
    expected_v = jnp.sum(policy[:, :, None] * mdp.transition_ps * v_pi[None, None, :], axis=(1,2)) 
    v_pi_new = expected_r + discount_factor * expected_v
    delta = jnp.max(jnp.abs(v_pi_new - v_pi))
    return (delta, v_pi_new, iteration)

def cond_function(state):
    delta, v_pi, _ = state
    return delta > eps

v_pi = jnp.copy(v_pi_0)
init_state = (1e13, v_pi, 0)

start = time.time()
delta, v_pi, num_iterations = lax.while_loop(cond_function, vectorized_value_estimation, init_state)
end = time.time()

elapsed_vectorized_lax_loop = end - start
print(f"Iterations {num_iterations} - Delta {delta:.5f} - Time: {elapsed_vectorized_lax_loop:.6f} seconds")
print(f"Value function:\n{v_pi}")

Iterations 153 - Delta 0.00000 - Time: 0.045203 seconds
Value function:
[-3.3525121 -3.3724003 -3.3082228 -3.3765624 -3.3713074 -3.4147222
 -3.4763868 -3.4107516 -3.4429636 -3.320909  -3.401567  -3.3563805
 -3.410886  -3.350864  -3.3867414 -3.396964  -3.3065472 -3.36825
 -3.3791535 -3.4574013 -3.4402509 -3.4397807 -3.4604495 -3.2899094
 -3.3306346 -3.2623215 -3.4348507 -3.370025  -3.461634  -3.424654 ]


Alternatively, we can directly solve the underlying linear system.
The original Bellman equation was:

$$v_\pi(s) = \sum_a \pi(a|s) [\sum_r r p(r|s,a) + \gamma \sum_{s'} v_\pi(s') p(s'|s,a)]$$

Which we can rewrite as:

$$\sum_a \pi(a|s) \sum_r r p(r|s,a) = v_\pi(s) - \gamma \sum_{s'} v_\pi(s') \sum_a \pi(a|s) p(s'|s,a)$$

For a finite MDP, this can be rewritten as the linear system:

$$\mathbf{r} = (\mathbf{I} - \gamma \mathbf{P}) \mathbf{v}$$

where:

$$
\begin{align}
    \mathbf{r}_i = \sum_a \pi(a|s_i) \sum_r r p(r|s_i,a)\\
    \mathbf{P}_{i,j} = \sum_a \pi(a|s_i) p(s_j|s,a) \\
\end{align}
$$

 (we use $s_i$ to denote the $i$-th state).

In [40]:
start = time.time()
r = jnp.sum(policy * mdp.expected_rewards, axis=1)
P = jnp.sum( policy[:, :, None] * mdp.transition_ps, axis=1  )
v_pi = jnp.linalg.solve(jnp.eye(num_states) - discount_factor * P, r)
end = time.time()
elapsed_linear_system = end - start
print(f"Time: {elapsed_linear_system:.6f} seconds")
print(f"Value function:\n{v_pi}")


Time: 0.360672 seconds
Value function:
[-3.35251   -3.3723993 -3.308222  -3.376561  -3.3713067 -3.4147213
 -3.4763858 -3.4107506 -3.4429624 -3.3209093 -3.4015653 -3.3563802
 -3.4108849 -3.350863  -3.3867414 -3.3969636 -3.3065467 -3.3682516
 -3.379152  -3.4574    -3.4402497 -3.43978   -3.4604478 -3.2899075
 -3.3306339 -3.2623205 -3.4348502 -3.3700242 -3.461634  -3.4246514]


In [41]:
print(f'Simple Loop: {elapsed_loop:.6f} seconds')
print(f'Lax Loop: {elapsed_lax_loop:.6f} seconds')
print(f'Vectorized Lax Loop: {elapsed_vectorized_lax_loop:.6f} seconds')
print(f'Linear system: {elapsed_linear_system:.6f} seconds')

Simple Loop: 20.032519 seconds
Lax Loop: 4.081857 seconds
Vectorized Lax Loop: 0.045203 seconds
Linear system: 0.360672 seconds
