# Importing the libraries

In [None]:
from env import SimpleGridWorld

import numpy as np

# Initialising the environment

In [None]:
envSize = 15
env = SimpleGridWorld(size= envSize, start=0, end=14)

In [None]:
env.reset()
env.render()

# Creating the Policy

In [None]:
policy_probablities = np.full((envSize,2), 0.5)

In [None]:
def policy(state):
    return np.random.choice(2, p = policy_probablities[state])

# Value Table

In [None]:
state_values = np.zeros(shape= envSize)

# Implementing the policy iteration algorithm

In [None]:
def policy_evaluation(policy_probablities, state_values, theta = 1e-6, gamma = 0.99):
    delta = float("inf")

    while delta > theta:
        delta = 0
        for state in range(envSize):
            old_value = state_values[state]
            new_value = 0.
            action_probablities = policy_probablities[state]

            for action, prob in enumerate(action_probablities):
                next_state, reward, _ = env.simulate_step(state,action)
                new_value += prob * (reward + gamma * state_values[next_state])

            state_values[state] = new_value
            delta = max(delta, abs(old_value - new_value))

In [None]:
def policy_improvement(policy_probablities, state_values, gamma = 0.99):
    policy_stable = True

    for state in range(envSize):
        old_action = policy_probablities[state].argmax()
        new_action = None
        max_qsa = float("-inf")

        for action in range(2):
            next_state, reward, _  = env.simulate_step(state, action)
            qsa = reward + gamma * state_values[next_state]
            
            if qsa > max_qsa:
                new_action = action
                max_qsa = qsa

        action_probs = np.zeros(2)
        action_probs[new_action] = 1.
        policy_probablities[state] = action_probs

        if new_action != old_action:
            policy_stable = False
        
    return policy_stable

In [None]:
def policy_iteration(policy_probablities, state_values, theta = 1e-6, gamma = 0.99):
    policy_stable = False

    while not policy_stable:
        policy_evaluation(policy_probablities, state_values, theta, gamma)
        policy_stable = policy_improvement(policy_probablities, state_values, gamma)

In [None]:
policy_iteration(policy_probablities, state_values)

# Testing the algorithm

In [None]:
env.test_agent(policy, episodes=3)