# Import Libraries

In [5]:
import numpy as np
import plotly.express as px

# Environment Setting

In [6]:
MAZE = np.array(
    [
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 0, 0, 1, 0, 1, 0],
        [0, 1, 1, 0, 0, 1, 1, 0],
        [0, 0, 1, 1, 0, 1, 0, 0],
        [0, 1, 0, 1, 0, 1, 1, 0],
        [0, 1, 1, 1, 1, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0],
    ]
)  # 0: 벽, 1: 길
GOAL = (6, 8)

REWARD = -1
GAMMA = 0.95
ROW, COL = MAZE.shape

# 1. Policy Evaluation

In [7]:
v = np.zeros((ROW, COL))
while True:
    next_v = v.copy()
    for i in range(ROW):
        for j in range(COL):
            if MAZE[i][j] == 0:
                continue
            r = 0
            for x, y in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
                if (i + x, j + y) == GOAL:
                    r += REWARD
                elif (
                    0 <= i + x < ROW
                    and 0 <= j + y < COL
                    and MAZE[i + x][j + y] == 1
                ):
                    r += REWARD + GAMMA * v[i + x][j + y]
                else:
                    r += REWARD + GAMMA * v[i][j]
            next_v[i][j] = r / 4
    if (next_v == v).all():
        break
    v = next_v.copy()

fig = px.imshow(v, text_auto=True)
fig.update_layout(
    title="Policy Evaluation",
)
fig.show()

# 2. Value Iteration

In [8]:
v = np.zeros((ROW, COL))
while True:
    next_v = v.copy()
    for i in range(ROW):
        for j in range(COL):
            if MAZE[i][j] == 0:
                continue
            max_val = -np.inf
            for x, y in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
                if (i + x, j + y) == GOAL:
                    r = REWARD
                elif (
                    0 <= i + x < ROW
                    and 0 <= j + y < COL
                    and MAZE[i + x][j + y] == 1
                ):
                    r = REWARD + GAMMA * v[i + x][j + y]
                else:
                    r = REWARD + GAMMA * v[i][j]
                max_val = max(max_val, r)
            next_v[i][j] = max_val
    if (next_v == v).all():
        break
    v = next_v.copy()

fig = px.imshow(v, text_auto=True)
fig.update_layout(
    title="Value Iteration",
)
fig.show()