# State Value Functions for Frozen Lake

In [1]:
import gym
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.table import Table

matplotlib.use('Agg')

In [2]:
# generate grid of environment with state values
def draw_image(image, desc):
    fig, ax = plt.subplots()
    ax.set_axis_off()
    tb = Table(ax, bbox=[0, 0, 1, 1])

    nrows, ncols = image.shape
    width, height = 1.0 / ncols, 1.0 / nrows
    
    # Add cells
    for (i, j), val in np.ndenumerate(image):

        # add state labels
        if desc[i, j] == 'S':
            val = str(val) + " (S)"
        if desc[i, j] == 'F':
            val = str(val) + " (F')"
        if desc[i, j] == 'H':
            val = str(val) + " (H)"
        if desc[i, j] == 'G':
            val = str(val) + " (G')"
        
        tb.add_cell(i, j, width, height, text=val,
                    loc='center', facecolor='white')
        

    # Row and column labels...
    for i in range(len(image)):
        tb.add_cell(i, -1, width, height, text=i+1, loc='right',
                    edgecolor='none', facecolor='none')
        tb.add_cell(-1, i, width, height/2, text=i+1, loc='center',
                    edgecolor='none', facecolor='none')

    ax.add_table(tb)

In [3]:
DISCOUNT = 0.99
env = gym.make('FrozenLake-v0')
desc = env.desc

# make env description string
def decode(a):
    return a.decode("utf-8")
desc = np.reshape(np.array(list(map(decode, desc.flatten()))), env.desc.shape)

# turn state 0-15 into row-column position
def state_to_index(s, ncol):
    i, j = s // ncol, s % ncol
    return i, j
# turn row-column position into 0-15 state
def index_to_state(row, col, ncol):
    return row*ncol + col

# State Values Function V(s)

In [4]:
env.reset()

value = np.zeros((env.nrow, env.ncol))
itera = 0
while True:
    # keep iteration until convergence
    new_value = np.zeros_like(value)
    for i in range(env.nrow):
        for j in range(env.ncol):
            for action in range(env.env.nA):
                state = index_to_state(i,j,env.ncol)
                action_results = env.env.P[state][action]
                # average expected value for taking each action in state
                # result[0] = probability of taking the action
                # result[1] = next state
                # result[2] = reward (0 or 1)
                new_val = sum([result[0] * (result[2] + DISCOUNT * value[state_to_index(result[1], env.ncol)[0], state_to_index(result[1], env.ncol)[1]]) for result in action_results])
                new_value[i, j] += (1/env.env.nA) * new_val

    itera += 1
    if np.sum(np.abs(value - new_value)) < 1e-4:
        draw_image(np.round(new_value, decimals=2), desc)
        plt.savefig('../images/FL_figure_3_2.png')
        plt.close()
        break
    value = new_value
print(value)

[[0.01225205 0.01035079 0.01927327 0.00942753]
 [0.01472573 0.         0.03886867 0.        ]
 [0.03256587 0.08431524 0.13779117 0.        ]
 [0.         0.17032773 0.43356318 0.        ]]


# Optimal Value Function

In [5]:
env.reset()

value = np.zeros((env.nrow, env.ncol))
itera = 0
while True:
    # keep iteration until convergence
    new_value = np.zeros_like(value)
    for i in range(env.nrow):
        for j in range(env.ncol):
            values = []
            for action in range(env.env.nA):
                state = index_to_state(i,j,env.ncol)
                action_results = env.env.P[state][action]
                # average expected value for taking each action in state
                values.append(sum([result[0] * (result[2] + DISCOUNT * value[state_to_index(result[1], env.ncol)[0], state_to_index(result[1], env.ncol)[1]]) for result in action_results]))
                new_value[i, j] = np.max(values)

    itera += 1
    if np.sum(np.abs(value - new_value)) < 1e-4:
        draw_image(np.round(new_value, decimals=2), desc)
        plt.savefig('../images/FL_figure_3_5.png')
        plt.close()
        break
    value = new_value
print(value)

[[0.54172703 0.49840409 0.47022545 0.45634454]
 [0.55817381 0.         0.35813435 0.        ]
 [0.59156353 0.64290366 0.61505219 0.        ]
 [0.         0.74159535 0.86277252 0.        ]]
