# Importing all the libraries

In [None]:
from env import CliffWalking
import gymnasium as gym 
import numpy as np 
import pygame
import matplotlib.pyplot as plt

# Initialising the environment

In [None]:
env = CliffWalking()
env.reset()
env.render()

# Creating the Policy

In [None]:
policy_probs = np.full((48,4), 0.25)
print(policy_probs)

In [None]:
def policy(state):
    return np.random.choice(4, p=policy_probs[state])

# Value Table

In [None]:
state_values = np.zeros(shape=(48))
print(state_values)

# Implementing the policy iteration algorithm

In [None]:
def policy_evaluation(policy_probs, state_values, theta = 1e-6, gamma = 0.99) -> None:
    delta = float("inf")

    while delta > theta: 
        delta = 0
        for state in range(48):
            old_value = state_values[state]
            new_value = 0.
            action_probablities = policy_probs[state]

            for action, prob in enumerate(action_probablities):
                next_state, reward,_,_ = env.simulate_step(state, action)
                new_value += prob * (reward + gamma * state_values[next_state])

            state_values[state] = new_value
            delta = max(delta, abs(old_value - new_value))

In [None]:
def policy_improvement(policy_probs, state_values, gamma = 0.99):
    policy_stable = True

    for state in range(48):
        old_action = policy_probs[state].argmax()

        new_action = None
        max_qsa = float("-inf")

        for action in range(4):
            next_state, reward, _, _ = env.simulate_step(state, action)
            qsa = reward + gamma * state_values[next_state]

            if qsa > max_qsa:
                new_action = action
                max_qsa = qsa
        
        action_probs = np.zeros(4)
        action_probs[new_action] = 1.
        policy_probs[state] = action_probs

        if new_action != old_action:
            policy_stable = False

    return policy_stable

In [None]:
def policy_iteration(policy_probs, state_values, theta = 1e-6, gamma = 0.99):
    policy_stable = False

    while not policy_stable:
        policy_evaluation(policy_probs, state_values, theta, gamma)
        policy_stable = policy_improvement(policy_probs, state_values, gamma)

In [None]:
policy_iteration(policy_probs, state_values)

# Printing the final values 

In [None]:
print(state_values)

In [None]:
print(policy_probs)

# Testing the resulting agent 

In [None]:
def test_agent(policy, episodes=1):
    env.pygame_init()
    for episode in range(episodes):
        state = env.reset()
        done, terminated = False, False
        while not (done or terminated):
            action = policy(state)
            next_state, reward, done, terminated = env.step(action)
            frame = env.render()
            state = next_state
        print(episode+1)

In [None]:
test_agent(policy)

In [None]:
env.close()