# Policy Evaluation

**Author:** ZHENG Wenjie

**Last Update:** 2021-08-25

This notebook is related to Section 4.1 of the book. It studies the policy evaluation problem for the GridWorld structure. In particular, it implements three methods: directly solving the linear system, fixed-point iteration, and in-place fixed-point iteration.

The theoretic guarantee of the convergence of the algorithms is thanks to the compressing mapping theorem. The transition matrix is a stochastic matrix (i.e., the row sums equal to 1 => the largest eigenvalue is 1), and it is then multiplied by $\gamma<1$.

The implementation detail takes advantage of the specialty of the GridWorld structure. That is, each state has 4 and only 4 actions, and each action leads to a deterministic following state and a deterministic reward. Therefore, we need only 3 matrices of size $(m \times n, 4)$, for the following states, the rewards, and the policy, respectively, to fully characterize the dynamics of the MDP.

For the plot renderer, I used 'notebook_connected' to reduce the file size. For personal use, replace it with 'notebook'.

In [1]:
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
import plotly.figure_factory as ff
import plotly.io as pio
pio.renderers.default = 'notebook_connected' # or 'notebook' for personal use

In [2]:
class Maze:
    def __init__(self, m, n, exit=None, teleport=None, reward=None):
        self.m = m
        self.n = n
        self.exit = exit
        self.link = np.zeros((m*n, 4), dtype=int)
        for i in range(m*n):
            if exit is not None and i in exit:
                self.link[i, :] = -1
            elif teleport is None or i not in teleport:
                j = i - n
                j = i if j<0 else j
                self.link[i, 0] = j
                j = i - 1
                j = i if j%n==n-1 else j
                self.link[i, 1] = j
                j = i + 1
                j = i if j%n==0 else j
                self.link[i, 2] = j
                j = i + n
                j = i if j>=m*n else j
                self.link[i, 3] = j
            else:
                self.link[i, :] = teleport[i]
                
        self.reward = reward
        
    def evaluate(self, γ, π=None, direct=True, inplace=False, n_iter=60):
        m, n = self.m, self.n
        if π is None:
            π = np.full((m*n, 4), 0.25)

        if direct or not inplace:
            A = np.zeros((m*n, m*n))
            for i in range(m*n):
                if self.exit is None or i not in self.exit:
                    for j in range(4):    # must not be simplified because of repetition
                        A[i, self.link[i, j]] += π[i, j]
            b = np.sum(π * self.reward, axis=1)
            
        if direct:
            return np.linalg.solve(np.eye(m*n)-γ*A, b).reshape(m, n)
        
        if not inplace:
            trace = [np.zeros(m*n)]
            for _ in range(n_iter):
                trace.append(γ * np.dot(A, trace[-1]) + b)
            return trace
        else:
            trace = []
            v = np.zeros(m*n)
            for _ in range(n_iter):
                for i in range(m*n):
                    if self.exit is None or i not in self.exit:
                        v[i] = sum(π[i, :] * (self.reward[i, :] + γ*v[self.link[i, :]]))
                    trace.append(v[i])
            return v, trace

## Gridworld 5x5

In [3]:
reward = np.zeros((5*5, 4))
for i in range(5):
    reward[i, 0] = -1
for i in range(0, 5*5, 5):
    reward[i, 1] = -1
for i in range(4, 5*5, 5):
    reward[i, 2] = -1
for i in range(20, 5*5):
    reward[i, 3] = -1
reward[1, :] = 10
reward[3, :] = 5

maze1 = Maze(5, 5, teleport={1:21, 3:13}, reward=reward)

### Solve the linear system

In [4]:
v = maze1.evaluate(0.9)

In [5]:
ff.create_annotated_heatmap(np.flipud(np.round(v, 1)))

### Fixed-point iteration

In [6]:
v = maze1.evaluate(0.9, direct=False)

In [7]:
v[-1]

array([ 3.30919888,  8.78949409,  4.4278208 ,  5.3225686 ,  1.49237943,
        1.52179059,  2.99252004,  2.25034156,  1.90777275,  0.5476034 ,
        0.051025  ,  0.73837276,  0.67331488,  0.35838728, -0.40294041,
       -0.97338981, -0.43529326, -0.35468064, -0.585404  , -1.18287432,
       -1.85749806, -1.3450291 , -1.22906562, -1.42271704, -1.97497827])

In [8]:
fig = go.Figure(
    data=go.Heatmap(z=np.flipud(v[-1].reshape(5, 5)), zmin=-2, zmax=9, colorscale='Blues'),
    layout=go.Layout(
        title="Fixed-point iteration",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=[go.Frame(data=go.Heatmap(z=np.flipud(v[i].reshape(5, 5))))
               for i in range(10)
           ]
)

fig.show()

### In-place fixed-point iteration

In [9]:
v, trace = maze1.evaluate(0.9, direct=False, inplace=True)

In [10]:
v

array([ 3.30901641,  8.78930766,  4.42763437,  5.32238134,  1.49219316,
        1.52160654,  2.99233348,  2.25015413,  1.90758496,  0.54741581,
        0.05084012,  0.73818574,  0.67312688,  0.35819896, -0.40312873,
       -0.97357513, -0.43548062, -0.35486898, -0.58559268, -1.18306305,
       -1.85768358, -1.34521662, -1.22925413, -1.4229059 , -1.97516718])

In [11]:
v = [np.zeros(5*5)]
for i, x in enumerate(trace):
    v.append(v[-1].copy())
    v[-1][i%25] = x

In [12]:
fig = go.Figure(
    data=go.Heatmap(z=np.flipud(v[-1].reshape(5, 5)), zmin=-2, zmax=9, colorscale='Blues'),
    layout=go.Layout(
        title="In-place fixed-point iteration",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=[go.Frame(data=go.Heatmap(z=np.flipud(v[i].reshape(5, 5))))
               for i in range(250)
           ]
)

fig.show()

## Gridworld 4x4

In [13]:
reward = np.full((4*4, 4), -1, dtype=np.float_)
reward[[0, 15], :] = 0
maze2 = Maze(4, 4, exit=[0, 15], reward=reward)

### Solve the linear system¶

In [14]:
v = maze2.evaluate(1)

In [15]:
ff.create_annotated_heatmap(np.flipud(np.round(v, 0)))

### Fixed-point iteration

In [16]:
v = maze2.evaluate(1, direct=False)

In [17]:
v[-1]

array([  0.        , -13.48894025, -19.24270148, -21.15253817,
       -13.48894025, -17.33286479, -19.2477673 , -19.24270148,
       -19.24270148, -19.2477673 , -17.33286479, -13.48894025,
       -21.15253817, -19.24270148, -13.48894025,   0.        ])

In [18]:
fig = go.Figure(
    data=go.Heatmap(z=np.flipud(v[-1].reshape(4, 4)), zmin=-22, zmax=0, colorscale='Reds', reversescale=True),
    layout=go.Layout(
        title="Fixed-point iteration",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=[go.Frame(data=go.Heatmap(z=np.flipud(v[i].reshape(4, 4))))
               for i in range(60)
           ]
)

fig.show()

### In-place fixed-point iteration

In [19]:
v, trace = maze2.evaluate(1, direct=False, inplace=True)

In [20]:
v

array([  0.        , -13.92236299, -19.88838642, -21.87714856,
       -13.92236299, -17.90472921, -19.89655781, -19.89774061,
       -19.88838642, -19.89655781, -17.91271373, -13.934831  ,
       -21.87714856, -19.89774061, -13.934831  ,   0.        ])

In [21]:
v = [np.zeros(4*4)]
for i, x in enumerate(trace):
    v.append(v[-1].copy())
    v[-1][i%16] = x

In [22]:
fig = go.Figure(
    data=go.Heatmap(z=np.flipud(v[-1].reshape(4, 4)), zmin=-22, zmax=0, colorscale='Reds', reversescale=True),
    layout=go.Layout(
        title="In-place fixed-point iteration",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None])])]
    ),
    frames=[go.Frame(data=go.Heatmap(z=np.flipud(v[i].reshape(4, 4))))
               for i in range(960)
           ]
)

fig.show()