# COURSE:   PGP [AI&ML]

## Learner :  Chaitanya Kumar Battula
## Module  : RNN
## Topic   : Policy Improvement in GridWorld

## **Problem Statement:**

Company Robo.ai is building a robot that can traverse unassisted, through the environment, and reach the food counter. Instead of creating their own environment, they have planned to use a prebuilt 4x4 grid world. You are a researcher who has to identify the policy and value iteration methods to tackle this task. You have decided to go with the policy iteration method.
You have already performed the first step to get new policy. Now, improve the new policy using policy improvement.

## **Environment**

This environment possesses two terminal states present at:<br>
* Top left corner
* Bottom right corner

<br>
The 4x4 grid looks as follows:<br>
T  o  o  o<br>
o  x  o  o<br>
o  o  o  o<br>
o  o  o  T<br>
Where x is the position of the agent and T are the two terminal states.<br>

<br>
The allowed actions are as follows:
* UP = 0 
* RIGHT = 1 
* DOWN = 2 
* LEFT = 3 <br>


    Note: The agent will move back to current states if it performs an action that leads it to go off the edge.

Rewards:
The agent is granted a reward of -1 at each step until it reaches a terminal state.

Environment courtesy: Sutton's Reinforcement Learning book, chapter 4.


### **Dependencies**
* Discrete
* Gridworld

    Note: The steps for policy evaluation are present in this document.

## **Import libraries and environment**

In [1]:
import numpy as nump
import sys
from gridworld import GridworldEnv

In [2]:
environment = GridworldEnv()

## **Evaluate the policy**

Arguments:
    
* policy = [S, A] shaped matrix
* environment.P = Transition probabilities
* environment.P[s][a] = Transition tuple (prob, next_state, reward, done)
* environment.nS = Number of states 
* environment.nA = Number of actions
* theta = Stopping the evaluation once the value function changes is less than theta for all the states
* discount_factor = Gamma discount factor
* Returns = Value function in form of a vector of length environment.nS
        

In [3]:
def policy_eval(policy, environment, discount_factor=1.0, theta=0.00001):
    
    # Start with a random value function where the value is 0 for all the states.
    Val_function = nump.zeros(environment.nS)
    while True:
      
        delta = 0
        # Perform a "full backup" for each state
        for s in range(environment.nS):
            v = 0
            
            # Look at all the possible next actions
            for a, action_prob in enumerate(policy[s]):
              
                # Look at the possible next states in accordance to all the 4 types of actions
                for  prob, next_state, reward, done in environment.P[s][a]:
                  
                    # Calculate the expected value
                    v += action_prob * prob * (reward + discount_factor * Val_function[next_state])
                    
            # Register the change in value function across any state
            delta = max(delta, nump.abs(v - Val_function[s]))
            Val_function[s] = v
              
        # Cease the evaluation once the value function change is below a threshold i.e, theta
        if delta < theta:
            break
    return nump.array(Val_function)

## **Improve the Policy**

Arguments:

* policy_eval_fn: Policy Evaluation function that takes 3 arguments:
  * policy
  * environment
  * discount_factor
* Returns: It is a tuple of policy, Val_function 
* Returns under one-step lookahead: It is a vector of length environment.nA<br> that contains the expected value of each action

In [4]:
#Borrowing the evaluated policy from policy evaluation
def policy_improvement(environment, policy_eval_fn=policy_eval, discount_factor=1.0):

    #Defining one step lookahead to find the value function
    def one_step_lookahead(state, Val_function):
        
        A = nump.zeros(environment.nA)
        for a in range(environment.nA):
            for prob, next_state, reward, done in environment.P[state][a]:
                A[a] += prob * (reward + discount_factor * Val_function[next_state])
        return A
      
    # Start with a random policy
    policy = nump.ones([environment.nS, environment.nA]) / environment.nA
    
    while True:
        # Evaluate the current policy
        Val_function = policy_eval_fn(policy, environment, discount_factor)
        
        # Any changes to the policy will set it to False:
        policy_stable = True
        
        # For each state
        for s in range(environment.nS):
            # The best action taken under the currect policy
            chosen_a = nump.argmax(policy[s])
            
            # One-step lookahead finds the best action 
            # Arbitarily resolving the ties
            action_values = one_step_lookahead(s, Val_function)
            best_a = nump.argmax(action_values)
            
            # Greedy update of the policy
            if chosen_a != best_a:
                policy_stable = False
            policy[s] = nump.eye(environment.nA)[best_a]
        
        # Return the best stable optimal policy
        if policy_stable:
            return policy, Val_function

In [5]:
policy, v = policy_improvement(environment)
print("Policy Probability Distribution:")
print(policy)
print("")

print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):")
print(nump.reshape(nump.argmax(policy, axis=1), environment.shape))
print("")

print("Value Function:")
print(v)
print("")

print("Reshaped Grid Value Function:")
print(v.reshape(environment.shape))
print("")

Policy Probability Distribution:
[[1. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 0. 1.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]]

Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):
[[0 3 3 2]
 [0 0 0 2]
 [0 0 1 2]
 [0 1 1 0]]

Value Function:
[ 0. -1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1.  0.]

Reshaped Grid Value Function:
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]

