In [1]:
import numpy as np

In [2]:
def policy_evaluation(V: np.ndarray):
    while True:
        new_V = V.copy()
        for row in range(4):
            for col in range(4):
                if row == 0 and col == 0 or row == 3 and col == 3:
                    continue

                v = 0
                for row_act, col_act in ((-1, 0), (1, 0), (0, -1), (0, 1)):
                    new_row = row + row_act
                    new_col = col + col_act

                    if new_row < 0 or new_row >= 4 or new_col < 0 or new_col >= 4:
                        new_row = row
                        new_col = col
                    v += 0.25 * (-1 + V[new_row, new_col])
                
                new_V[row, col] = v
        if np.allclose(V, new_V):
            return new_V
        V = new_V

In [6]:
def get_policy(V):
    dir_table = {0: "↑", 1: "↓", 2: "←", 3: "→"}
    direction_list = []
    for row in range(4):
        direction_list_row = []
        for col in range(4):
            if row == 0 and col == 0 or row == 3 and col == 3:
                direction_list_row.append("* ")
                continue
            direction = ""
            action_value = []
            for row_act, col_act in ((-1, 0), (1, 0), (0, -1), (0, 1)):
                new_row = row + row_act
                new_col = col + col_act

                if not 0 <= new_row < 3 or not 0 <= new_col < 3:
                    action_value.append(-np.inf)
                else:
                    action_value.append(V[new_row, new_col])

            max_idx_list = np.argwhere(np.isclose(action_value, np.max(action_value))).flatten().tolist()

            for i in max_idx_list:
                direction += dir_table[i]
            direction += " " * (2-len(max_idx_list))
            direction_list_row.append(direction)
        direction_list.append(direction_list_row)
    return np.array(direction_list)


In [7]:
V = np.zeros((4, 4))
V = policy_evaluation(V)
print(V)

[[  0.         -13.99771852 -19.99661926 -21.99621676]
 [-13.99771852 -17.99702177 -19.99664188 -19.99661926]
 [-19.99661926 -19.99664188 -17.99702177 -13.99771852]
 [-21.99621676 -19.99661926 -13.99771852   0.        ]]


In [8]:
policy = get_policy(V)
print(policy)

[['* ' '← ' '← ' '← ']
 ['↑ ' '↑←' '↓←' '← ']
 ['↑ ' '↑→' '↑←' '← ']
 ['↑ ' '↑ ' '↑ ' '* ']]
