In [None]:
import numpy as np
from grid import grid_world, Grid, print_values
import time

In [None]:
gamma = 0.9    # discount factor
theta = 0.1   # convergence threshold for policy evaluation

actions = ["up", "down", "left", "right"]

In [None]:
grid = grid_world()
print("Grid World:")
print(grid)

Setting random policy π

In [None]:
goal = (3, 3)  # goal state
policy = Grid.generate_random_policy(grid, goal)
print("Initial Random Policy:")
print(policy)


In [None]:
grid.actions

Policy Evaluation

In [None]:
def policy_evaluation(grid, policy, gamma=1, theta=0.1):

    # value dict
    V = {state: 0.0 for state in grid.actions}
    iter = 0
    # print("Initial Value Function:")
    # print_values(V, grid)
    converged = False
    while not converged:
        # print(f"\nIteration {iter}")
        delta = 0
        for s in grid.actions:
            if grid.is_terminal(s):
                continue
            action = policy[s]
            next_state = grid.take_action(s, action)
            reward = grid.rewards.get(s, 0)
            # P(s'|s,a) = 1, so we ignore it
            new_val = reward + gamma * V.get(next_state, 0)
            delta = max(delta, abs(V[s] - new_val))
            
            # print_values(V, grid)
        if delta < theta:
            converged = True
        iter += 1
    # print("Value Function after Policy Evaluation:")
    # print_values(V, grid)
    return V

    

# policy_evaluation(grid, policy, gamma, theta)
    
    


In [None]:
actions

Policy Improvement

In [None]:
def policy_improvement(grid, V, gamma):
    improved_policy = {}
    # print("\n--- Improving Policy ---")
    for s in grid.actions:
        if grid.is_terminal(s):
            continue
        best_action = None
        best_val = float("-inf")
        # print(f"\nEvaluating actions for state {s}:")
        for a in actions:
            next_state = grid.take_action(s, a)
            reward = grid.rewards.get(s, 0)
            val = reward + gamma * V.get(next_state, 0)
            # print(f"  Action {a} → next state {next_state}, value = {val:.2f}")

            if val > best_val:
                best_val = val
                best_action = a
        improved_policy[s] = best_action
        # print(f"  → Best action: {best_action} with value {best_val:.2f}")


    print("\nImproved Policy:")
    for s in grid.actions:
        if s in improved_policy:
            print(f"  State {s}: {improved_policy[s]}")
    return improved_policy



In [None]:
V = policy_evaluation(grid, policy, gamma, theta)
new_policy = policy_improvement(grid, V, gamma)
grid.print_policy(new_policy, grid)

Policy Iteration

In [None]:
while True:
    V = policy_evaluation(grid, policy, gamma, theta)

    new_policy = policy_improvement(grid, V, gamma)

    if new_policy == policy:
        print("\nPolicy is stable. Final Policy:")
        grid.print_policy(new_policy, grid)
        break
        
        
    policy = new_policy