# Importing libraries

In [None]:
import numpy as np 
import matplotlib.pyplot as plt 

from env import Grid, test_agent

# Initialising the Environment

In [None]:
envsize = 4
env = Grid(size=envsize, type= "random")

In [None]:
info = env.reset()
print(f"The starting position is {info[0]}")
print(f"The ending position is {info[1]}")

# Defining the policy

In [None]:
policy_probs = np.full((envsize, envsize, 4), 0.25)

In [None]:
def policy(state):
    return policy_probs[state]

In [None]:
action_probabilities = policy((0,0))
for action, prob in zip(range(4), action_probabilities):
    print(f"Probability of taking action {action}: {prob}")

# Defining the value table

In [None]:
state_values = np.zeros(shape = (envsize, envsize))

In [None]:
print(state_values)

# Implementing the policy iteration algorithm

In [None]:
def policy_evaluation(policy_probs, state_values, theta=1e-6, gamma=0.99):
    delta = float("inf")

    while delta > theta:
        delta = 0
        for row in range(envsize):
            for col in range(envsize):
                old_value = state_values[(row, col)]
                new_value = 0.
                action_probabilities = policy_probs[(row, col)]

                for action, prob in enumerate(action_probabilities):
                    next_state, reward, _, _ = env.simulate_step((row, col), action)
                    new_value += prob *(reward + gamma * state_values[next_state])

                state_values[(row, col)] = new_value

                delta = max(delta, abs(old_value - new_value))


In [None]:
def policy_improvement(policy_probs, state_values, gamma=0.99):
    policy_stable = True

    for row in range(envsize):
        for col in range(envsize):
            old_action = policy_probs[(row, col)].argmax()

            new_action = None
            max_qsa = float("-inf")

            for action in range(4):
                next_state, reward, _, _ = env.simulate_step(
                    (row, col), action)
                qsa = reward + gamma * state_values[next_state]

                if qsa > max_qsa:
                    new_action = action
                    max_qsa = qsa

            action_probs = np.zeros(4)
            action_probs[new_action] = 1.
            policy_probs[(row, col)] = action_probs

            if new_action != old_action:
                policy_stable = False

    return policy_stable

In [None]:
def policy_iteration(policy_probs, state_values, theta = 1e-6, gamma = 0.99):
    policy_stable = False
    
    while not policy_stable: 
        policy_evaluation(policy_probs, state_values, theta, gamma)
        policy_stable = policy_improvement(policy_probs, state_values, gamma)
        

In [None]:
policy_iteration(policy_probs, state_values)

In [None]:
print(policy_probs)

In [None]:
print(state_values)