## Creating the environement

In [None]:
import gymnasium as gym
import pygame
import numpy as np
from IPython.display import clear_output
from IPython import display
import matplotlib.pylab as plt
from env import SimpleCorridor

## Test Agent Function 

In [None]:
def test_agent(env: gym.Env, policy: callable, episodes: int = 10) -> None: 
    for episode in range(episodes):
        temp = env.reset()
        tagent_pos = temp[0]
        ttarget_pos = temp[1]
        done = False  
       
        env.render(mode="rgb_array")
        while not done:
            p = policy(tagent_pos)
            action = np.random.choice(2,p=p)

            next_state, _, done, _ = env.step(action)
            env.render()
            plt.axis('off')
            display.display(plt.gcf())
            display.clear_output(wait=True)
            
            tagent_pos = next_state[0]

## Initialising the environment 

In [None]:
env = SimpleCorridor()
env.reset()
env.render()


### Creating the Policy

In [None]:
policy_probs = np.full((15,2) ,0.5)
print(policy_probs)

In [None]:
def policy(state):
    return policy_probs[state]

## Value Table

In [None]:
state_values = np.zeros(shape=(15))
print(state_values)

# Implementing the policy iteration algorithm 

In [None]:
def policy_evaluation(policy_probs, state_values, theta = 1e-6, gamma = 0.99):
    delta = float("inf")

    while delta > theta: 
        delta = 0
        for state in range(15):
            old_value = state_values[state]
            new_value = 0.
            action_probablities = policy_probs[state]

            for action, prob in enumerate(action_probablities):
                next_state, reward, _, _ = env.simulate_step(state, action)
                new_value += prob * (reward + gamma * state_values[next_state])

            state_values[state] = new_value
            delta = max(delta, abs(old_value - new_value))

In [None]:
def policy_improvement(policy_probs, state_values, gamma = 0.99):
    policy_stable = True

    for state in range(15):
        old_action = policy_probs[state].argmax()

        new_action = None
        max_qsa = float("-inf")

        for action in range(2):
            next_state, reward, _, _ = env.simulate_step(state, action)
            qsa = reward + gamma * state_values[next_state]

            if qsa > max_qsa: 
                new_action = action
                max_qsa = qsa

        action_probs = np.zeros(2)
        action_probs[new_action] = 1
        policy_probs[state] = action_probs

        if new_action != old_action: 
            policy_stable = False

    return policy_stable

In [None]:
def policy_iteration(policy_probs, state_values, theta = 1e-6, gamma = 0.99):
    policy_stable = False

    while not policy_stable: 
        policy_evaluation(policy_probs, state_values, theta, gamma)
        policy_stable = policy_improvement(policy_probs, state_values, gamma)

In [None]:
policy_iteration(policy_probs, state_values)

# Printing the final values

In [None]:
print(state_values)
print(policy_probs)

# Testing the resulting agent

In [None]:
test_agent(env, policy)