In [3]:
import numpy as np

# Define the gridworld parameters
gamma = 0.9  # discount factor
reward_off_grid = -1
reward_A = 10
reward_B = 5

# Grid dimensions
grid_size = (5, 5)

# Special states
A = (0, 1)
A_prime = (4, 1)
B = (0, 3)
B_prime = (2, 3)

# Define the state space
states = [(i, j) for i in range(grid_size[0]) for j in range(grid_size[1])]

# Define the action space
actions = ['north', 'south', 'east', 'west']
action_effects = {
    'north': (-1, 0),
    'south': (1, 0),
    'east': (0, 1),
    'west': (0, -1)
}

# Initialize value function
V = {s: 0 for s in states}

# Define the transition function
def transition(state, action):
    if state == A:
        return A_prime, reward_A
    elif state == B:
        return B_prime, reward_B

    effect = action_effects[action]
    new_state = (state[0] + effect[0], state[1] + effect[1])

    if new_state[0] < 0 or new_state[0] >= grid_size[0] or new_state[1] < 0 or new_state[1] >= grid_size[1]:
        return state, reward_off_grid

    return new_state, 0

# Policy Iteration

# Initialize policy to random (e.g., always move north)
policy = {s: np.random.choice(actions) for s in states}

def policy_evaluation(policy, V, gamma=0.9, theta=1e-8):
    while True:
        delta = 0
        for s in states:
            v = V[s]
            a = policy[s]
            s_next, r = transition(s, a)
            V[s] = r + gamma * V[s_next]
            delta = max(delta, abs(v - V[s]))
        if delta < theta:
            break
    return V

def policy_improvement(policy, V, gamma=0.9):
    policy_stable = True
    for s in states:
        old_action = policy[s]
        action_values = {}
        for a in actions:
            s_next, r = transition(s, a)
            action_values[a] = r + gamma * V[s_next]
        best_action = max(action_values, key=action_values.get)
        policy[s] = best_action
        if old_action != best_action:
            policy_stable = False
    return policy, policy_stable

# Policy Iteration Algorithm
while True:
    V = policy_evaluation(policy, V)
    policy, policy_stable = policy_improvement(policy, V)
    if policy_stable:
        break

# Print the optimal policy
print("Optimal Policy (Policy Iteration):")
for i in range(grid_size[0]):
    print([policy[(i, j)] for j in range(grid_size[1])])

# Value Iteration
V = {s: 0 for s in states}

def value_iteration(V, gamma=0.9, theta=1e-8):
    while True:
        delta = 0
        for s in states:
            v = V[s]
            action_values = []
            for a in actions:
                s_next, r = transition(s, a)
                action_values.append(r + gamma * V[s_next])
            V[s] = max(action_values)
            delta = max(delta, abs(v - V[s]))
        if delta < theta:
            break
    return V

V = value_iteration(V)

# Print the value function
print("Optimal Value Function (Value Iteration):")
for i in range(grid_size[0]):
    print([round(V[(i, j)], 2) for j in range(grid_size[1])])



# Extract the policy from the value function
optimal_policy_vi = {}
for s in states:
    action_values = {}
    for a in actions:
        s_next, r = transition(s, a)
        action_values[a] = r + gamma * V[s_next]
    optimal_policy_vi[s] = max(action_values, key=action_values.get)



# Print the optimal policy from value iteration
print("Optimal Policy (Value Iteration):")
for i in range(grid_size[0]):
    print([optimal_policy_vi[(i, j)] for j in range(grid_size[1])])


Optimal Policy (Policy Iteration):
['east', 'north', 'west', 'north', 'west']
['east', 'north', 'north', 'west', 'west']
['east', 'north', 'north', 'north', 'north']
['east', 'north', 'north', 'north', 'north']
['east', 'north', 'north', 'north', 'north']
Optimal Value Function (Value Iteration):
[21.98, 24.42, 21.98, 19.42, 17.48]
[19.78, 21.98, 19.78, 17.8, 16.02]
[17.8, 19.78, 17.8, 16.02, 14.42]
[16.02, 17.8, 16.02, 14.42, 12.98]
[14.42, 16.02, 14.42, 12.98, 11.68]
Optimal Policy (Value Iteration):
['east', 'north', 'west', 'north', 'west']
['east', 'north', 'north', 'west', 'west']
['east', 'north', 'north', 'north', 'north']
['east', 'north', 'north', 'north', 'north']
['east', 'north', 'north', 'north', 'north']
