In [2]:
from collections import namedtuple
from copy import deepcopy

Cell = namedtuple("Cell", ["utility", "direction"])
Grid = list[list[Cell]]

step_cost = -0.04
discount_factor = 0.95
initial_grid: Grid = [
    [Cell(0, None), Cell(-1, None), Cell(1, None)],
    [Cell(0, None), Cell(0, None), Cell(0, None)],
    [Cell(0, None), Cell(None, None), Cell(0, None)],
    [Cell(0, None), Cell(0, None), Cell(0, None)],
]
penalty_pos = (0, 1)
reward_pos = (0, 2)

In [3]:
def iterate(grid: Grid, p: float) -> Grid:
    new_grid = deepcopy(grid)

    def is_valid(x: int, y: int) -> bool:
        if x < 0 or x >= len(grid) or y < 0 or y >= len(grid[0]):
            return False
        return grid[x][y].utility is not None

    def perpendicular(dx: int, dy: int) -> list[tuple[int, int]]:
        if abs(dx) == 1 and dy == 0:
            return [(0, 1), (0, -1)]
        elif dx == 0 and abs(dy) == 1:
            return [(1, 0), (-1, 0)]
        else:
            raise ValueError("Invalid Direction")

    directions: list[tuple[int, int]] = [(0, 1), (1, 0), (0, -1), (-1, 0)]

    for x, row in enumerate(grid):
        for y, cell in enumerate(row):
            if not is_valid(x, y) or penalty_pos == (x, y) or reward_pos == (x, y):
                continue

            def find_utility(nx: int, ny: int) -> float:
                return grid[nx][ny].utility if is_valid(nx, ny) else cell.utility

            max_util = -99999
            direction = None
            for dx, dy in directions:
                expected_utility = p * find_utility(x + dx, y + dy) + sum(
                    (1 - p) / 2 * find_utility(x + ddx, y + ddy)
                    for ddx, ddy in perpendicular(dx, dy)
                )
                if (new_max_util := step_cost + discount_factor * expected_utility) > max_util:
                    max_util = new_max_util
                    match (dx, dy):
                        case (0, 1):
                            direction="right"
                        case (0, -1):
                            direction="left"
                        case (1, 0):
                            direction="down"
                        case (-1, 0):
                            direction="up"
                        
                max_util = max(max_util, step_cost + discount_factor * expected_utility)
            new_grid[x][y] = Cell(max_util, direction)

    return new_grid

In [29]:
def print_cell(grid: Grid, directions: bool = False) -> None:
    for row in grid:
        for cell in row:
            print(cell.direction if directions else round(cell.utility, 3) if cell.utility is not None else cell.utility, end="\t")
        print()
    print()

def converge(prev: Grid, cur: Grid) -> bool:
    for row1, row2 in zip(prev, cur):
        for cell1, cell2 in zip(row1, row2):
            if cell1.utility is None and cell2.utility is None:
                continue
            if abs(cell1.utility - cell2.utility) > 0.0001:
                return False
    return True

In [30]:
prev = initial_grid
cur = iterate(initial_grid, 0.7)
while not converge(prev, cur):
    prev = cur
    cur = iterate(prev, 0.7)
    print_cell(prev)

-0.04	-1	1	
-0.04	-0.04	0.625	
-0.04	None	-0.04	
-0.04	-0.04	-0.04	

-0.078	-1	1	
-0.078	0.227	0.708	
-0.078	None	0.364	
-0.078	-0.078	-0.078	

-0.114	-1	1	
0.089	0.321	0.758	
-0.114	None	0.535	
-0.114	-0.114	0.18	

-0.119	-1	1	
0.141	0.368	0.779	
-0.013	None	0.617	
-0.148	0.047	0.325	

-0.106	-1	1	
0.185	0.388	0.788	
0.05	None	0.654	
-0.032	0.19	0.423	

-0.074	-1	1	
0.21	0.397	0.793	
0.098	None	0.671	
0.089	0.295	0.482	

-0.053	-1	1	
0.227	0.401	0.795	
0.127	None	0.678	
0.183	0.365	0.517	

-0.039	-1	1	
0.237	0.403	0.795	
0.147	None	0.682	
0.247	0.408	0.537	

-0.03	-1	1	
0.243	0.404	0.796	
0.166	None	0.683	
0.287	0.433	0.548	

-0.025	-1	1	
0.248	0.404	0.796	
0.198	None	0.684	
0.313	0.448	0.554	

-0.021	-1	1	
0.254	0.404	0.796	
0.224	None	0.684	
0.331	0.456	0.558	

-0.017	-1	1	
0.258	0.404	0.796	
0.244	None	0.684	
0.342	0.461	0.559	

-0.013	-1	1	
0.261	0.405	0.796	
0.257	None	0.684	
0.35	0.463	0.56	

-0.011	-1	1	
0.264	0.405	0.796	
0.266	None	0.684	
0.355	0.465	0.561	

-0.009	-1	1	
0.26