# Frozen Lake: Bellman Equations

In reinforcement learning, an agent is acting in an environment, and learning by trial-and-error to optimize its performance in order to gain maximal cumulative reward. The model of the environment can be formalized by a Markov Decision Process (MDP). To solve the MDP, value functions are defined, which give the expected cumulative reward in a given state. These value functions can be expressed mathematically using the Bellman equations, and solving the MDP comes down to solving these recursive equations. Because of the large number of equations, solving this system is done iteratively applying dynamic programming (DP). In this notebook the Bellman equation applied in the Value Iteration algorithm is discussed and implemented using numpy. Different implementations using loops, list comprehensions, and numpy array operations are tested and compared to see which of these implementations is computationally most efficient.

$ V _k(s) = \max _{a}\left( \pi(a|s) \sum \limits _{s'} P _{sas'}[R _{sas'}+\gamma V _{k-1}(s')] \right) $


$ V _k(s) = \max _{a}\left( \pi(a|s) \left[ \sum \limits _{s'} P _{sas'}R _{sas'} +\sum \limits _{s'} \gamma P _{sas'} V _{k-1}(s') \right] \right) $


In [1]:
from ReinforcementLearning import *
from timeit import time

In [2]:
env = FrozenLake.make()
mdp = GymMDP(env)
policy = UniformRandomPolicy(env)

In [3]:
nstates = env.nstates
nactions = env.nactions
niter = 10000

## Loops only

In [4]:
time1 = time.time()

Vs = np.zeros(nstates)

for i in range(niter):
    for s in range(nstates):
        Qsa = []
        for a in range(nactions):
            qsum = 0.0
            for n in range(nstates):
                 qsum += mdp.Psas[s, a, n] * (mdp.Rsas[s, a, n] + mdp.gamma * Vs[n])
            Qsa.append(policy.prob[s, a] * qsum)
        Vs[s] = max(Qsa)

time2 = time.time()
print(time2-time1)
print(Vs)

15.728909969329834
[1.55787288e-06 5.61039750e-06 6.01564996e-05 6.01564996e-06
 1.15262041e-05 0.00000000e+00 6.55705846e-04 0.00000000e+00
 1.25230373e-04 1.36600790e-03 7.80831365e-03 0.00000000e+00
 0.00000000e+00 8.45855072e-03 9.16780501e-02 0.00000000e+00]


## Outer loop and list comprehensions 

In [5]:
time1 = time.time()

Vs = np.zeros(nstates)

for i in range(niter):
    for s in range(nstates):
        Vs[s] = max([policy.prob[s, a] *
                     sum([mdp.Psas[s, a, n] *
                          (mdp.Rsas[s, a, n] + mdp.gamma * Vs[n])
                          for n in range(nstates)])
                     for a in range(nactions)])

time2 = time.time()
print(time2-time1)
print(Vs)

15.292079448699951
[1.55787288e-06 5.61039750e-06 6.01564996e-05 6.01564996e-06
 1.15262041e-05 0.00000000e+00 6.55705846e-04 0.00000000e+00
 1.25230373e-04 1.36600790e-03 7.80831365e-03 0.00000000e+00
 0.00000000e+00 8.45855072e-03 9.16780501e-02 0.00000000e+00]


## Outer loop and array operations 

In [6]:
time1 = time.time()

Vs = np.zeros(nstates)
PR = np.sum(mdp.Psas * mdp.Rsas, axis=2)
gPsas = mdp.gamma * mdp.Psas

for i in range(niter):
    for s in range(nstates):
        Vs[s] = np.max(policy.prob[s, :] *
                       (PR[s, :] + np.squeeze(np.dot(gPsas[s, :, :], Vs))))

time2 = time.time()
print(time2-time1)
print(Vs)

1.7982220649719238
[1.55787288e-06 5.61039750e-06 6.01564996e-05 6.01564996e-06
 1.15262041e-05 0.00000000e+00 6.55705846e-04 0.00000000e+00
 1.25230373e-04 1.36600790e-03 7.80831365e-03 0.00000000e+00
 0.00000000e+00 8.45855072e-03 9.16780501e-02 0.00000000e+00]


## Array operations only 

In [7]:
time1 = time.time()

nsa = nstates * nactions
Vs = np.zeros(nstates)
prob = np.reshape(policy.prob, (nsa,), order="c")
PR = np.sum(mdp.Psas * mdp.Rsas, axis=2)
PR = np.reshape(PR, (nsa, ), order="c")
gPsa = mdp.gamma * np.reshape(mdp.Psas, (nsa, nstates), order="c")

for i in range(niter):
    Q = prob * (PR + np.dot(gPsa, Vs))
    Q = np.reshape(Q, (nstates, nactions), order="c")
    Vs = np.max(Q, axis=1)

time2 = time.time()
print(time2-time1)
print(Vs)

0.14261531829833984
[1.55787288e-06 5.61039750e-06 6.01564996e-05 6.01564996e-06
 1.15262041e-05 0.00000000e+00 6.55705846e-04 0.00000000e+00
 1.25230373e-04 1.36600790e-03 7.80831365e-03 0.00000000e+00
 0.00000000e+00 8.45855072e-03 9.16780501e-02 0.00000000e+00]
