# Policy Iteration on Grid World

In [1]:
import gym
import numpy as np
from gym import spaces

## A Gridworld Example
### if on top-left or botton-right corner than reward is 0, else -1


In [2]:
# custom 2d grid world enviroment which extends gym.Env
class GridWorld(gym.Env):
    """
        A grid world env follows the book Example 4.1
    """
    metadata = {'render.modes': ['console']}
    
    # actions available 
    UP   = 0
    LEFT = 1
    DOWN = 2
    RIGHT= 3
    
    def __init__(self, size):
        super(GridWorld, self).__init__()
        
        self.size = size # size of the grid world
        self.end_state = [[0,0],[size-1,size-1]] # top left and bottom right
        
        # randomly assign the inital location of agent
        self.agent_position = [np.random.randint(0,self.size),np.random.randint(0,self.size)]
        
        # respective actions of agents : up, down, left and right
        self.action_space = spaces.Discrete(4)
        
        # set the observation space to (1,) to represent agent position in the grid world 
        # staring from [0,size*size)
        self.observation_space = spaces.Box(low=0, high=4, shape=(2,), dtype=np.uint8)
        
        self.states = [[i,j] for i in range(size) for j in range(size)]
    
    def bound_position(self):
        self.agent_position[0] = 0 if self.agent_position[0]<0 else self.agent_position[0]
        self.agent_position[0] = self.size-1 if self.agent_position[0]>self.size-1 else self.agent_position[0]
        self.agent_position[1] = 0 if self.agent_position[1]<0 else self.agent_position[1]
        self.agent_position[1] = self.size-1 if self.agent_position[1]>self.size-1 else self.agent_position[1]
        
    def step(self,action):
        info = {} # additional information
        
        if self.agent_position in self.end_state:
            return self.agent_position, 0, True, info
        
        
        if action == self.UP:
            self.agent_position[0] -=1

        elif action == self.LEFT:
            self.agent_position[1] -=1

        elif action == self.DOWN:
            self.agent_position[0] +=1

        elif action == self.RIGHT:
            self.agent_position[1] +=1

        else:
            raise ValueError("Received invalid action={} which is not part of the action space".format(action))
        
        self.bound_position()
        done = bool(self.agent_position in self.end_state)
        
        # reward agent when it is in the terminal cell, else reward = 0
        reward = 0 if done else -1
        
        return self.agent_position, reward, done, info
    
    def render(self, mode='console'):
        '''
            render the state
        '''
        if mode != 'console':
          raise NotImplementedError()
        
        row  = self.agent_position[0]
        col  = self.agent_position[1]
        
        for r in range(self.size):
            for c in range(self.size):
                if r == row and c == col:
                    print("X",end='')
                else:
                    print('.',end='')
            print('')

    def reset(self,position):
        
        self.agent_position = position
    
    def close(self):
        pass

## Policy Evaluation Function

In [3]:
def get_expected_value(state,Pi,V):
    '''
    Update the state-value for a single state under policy pi
    '''
    #pi(s)
    action_prob = Pi[tuple(state)]
    
    
    expected_value = 0
    #for each action and its probability
    #calculate the reward and value of next state
    #sum up to get expectation
    for action in range(len(action_prob)):
        prob = action_prob[action]
        env.reset(state.copy())
        next_state, reward, done, info = env.step(action)
        expected_value += prob*(reward + V[tuple(next_state)])
    
    return round(expected_value,1)

In [4]:
def policy_evaluation_single_round(V,Pi,env):
    """
    Update the value of all states under a current policy Pi
    NOT INPLACE
    SINGLE ROUND
    """
    V_new = {}
    for state in env.states:
        if state in env.end_state:
            V_new[tuple(state)] = 0
            continue      
        V_new[tuple(state)] = get_expected_value(state,Pi,V)
        
    return V_new

## Policy Improvement Function

In [5]:
def policy_improvement(V,Pi,env):
    """
    Update the policy given a state-value functions
    when multiple actions result in the same action_value, the p will be divided equally
    """
    
    policy_stable = True
    Pi_new = {}
    
    #for all non-terminal states
    for state in env.states:
        if state in env.end_state:
            continue
        #old action
        action = Pi[tuple(state)]
        
        #find new optimal action under current state-value function
        action_values = []
        
        #take all actions and find the largest value
        for a in range(env.action_space.n):
            env.reset(state.copy())
            next_state, reward, done, info = env.step(a)
            next_state_value = V[tuple(next_state)]
            action_values.append(reward+next_state_value)
        best_actions = np.argwhere(action_values == np.amax(action_values)).flatten()

        
        #update policy
        prob = 1/len(best_actions)
        Pi_new[tuple(state)] = [prob if i in best_actions else 0 for i in range(env.action_space.n)]
        
        if not np.array_equal(action,Pi_new[tuple(state)]):
            policy_stable = False

    return policy_stable, Pi_new

In [15]:
#rendering function
def plot(V,Pi,size):
    for r in range(size):
        for c in range(size):
            print(V[(r,c)],end = '\t|')
        print('')
        
    
    print('-------------------------------')
    action_map = {0:'\u2191',1:u'\u2190',2:u'\u2193',3:u'\u2192'}
    for r in range(size):
        for c in range(size):
            try:
                best_actions = np.argwhere(Pi[(r,c)] == np.amax(Pi[(r,c)])).flatten()
                for action in best_actions:
                    print(action_map[action],end = '')
                print('\t|',end = '')
                
            except:
                print('X',end = '\t|')
        print('')
    print(' ')

## Initalize a 4\*4 gridworld

In [16]:
size = 4
V = dict.fromkeys([(i,j) for i in range(size) for j in range(size)], 0) # values as 0
Pi = dict.fromkeys([(i,j) for i in range(size) for j in range(size) if i+j != 0 and i+j != 6], [0.25]*4)# inital actions as UP
env = GridWorld(4)
plot(V,Pi,4)

0	|0	|0	|0	|
0	|0	|0	|0	|
0	|0	|0	|0	|
0	|0	|0	|0	|
-------------------------------
X	|↑←↓→	|↑←↓→	|↑←↓→	|
↑←↓→	|↑←↓→	|↑←↓→	|↑←↓→	|
↑←↓→	|↑←↓→	|↑←↓→	|↑←↓→	|
↑←↓→	|↑←↓→	|↑←↓→	|X	|
 


In [17]:
for i in range(5):
    
    print('Round#: ',i)
    #policy evaluation
    V = policy_evaluation_single_round(V,Pi,env)
    
    #policy improvement
    policy_stable,Pi = policy_improvement(V,Pi,env)
    
    #plot
    plot(V,Pi,4)
    print("Policy is stable: ", policy_stable)
    print('=====================================')
    

Round#:  0
0	|-0.8	|-1.0	|-1.0	|
-0.8	|-1.0	|-1.0	|-1.0	|
-1.0	|-1.0	|-1.0	|-0.8	|
-1.0	|-1.0	|-0.8	|0	|
-------------------------------
X	|←	|←	|↑←↓→	|
↑	|↑←	|↑←↓→	|↓	|
↑	|↑←↓→	|↓→	|↓	|
↑←↓→	|→	|→	|X	|
 
Policy is stable:  False
Round#:  1
0	|0.0	|-1.8	|-2.0	|
0.0	|-1.8	|-2.0	|-1.8	|
-1.8	|-2.0	|-1.8	|0.0	|
-2.0	|-1.8	|0.0	|0	|
-------------------------------
X	|←	|←	|←↓	|
↑	|↑←	|↑←↓→	|↓	|
↑	|↑←↓→	|↓→	|↓	|
↑→	|→	|→	|X	|
 
Policy is stable:  False
Round#:  2
0	|0.0	|-1.0	|-2.8	|
0.0	|-1.0	|-2.8	|-1.0	|
-1.0	|-2.8	|-1.0	|0.0	|
-2.8	|-1.0	|0.0	|0	|
-------------------------------
X	|←	|←	|←↓	|
↑	|↑←	|↑←↓→	|↓	|
↑	|↑←↓→	|↓→	|↓	|
↑→	|→	|→	|X	|
 
Policy is stable:  True
Round#:  3
0	|0.0	|-1.0	|-2.0	|
0.0	|-1.0	|-2.0	|-1.0	|
-1.0	|-2.0	|-1.0	|0.0	|
-2.0	|-1.0	|0.0	|0	|
-------------------------------
X	|←	|←	|←↓	|
↑	|↑←	|↑←↓→	|↓	|
↑	|↑←↓→	|↓→	|↓	|
↑→	|→	|→	|X	|
 
Policy is stable:  True
Round#:  4
0	|0.0	|-1.0	|-2.0	|
0.0	|-1.0	|-2.0	|-1.0	|
-1.0	|-2.0	|-1.0	|0.0	|
-2.0	|-1.0	|0.0	|0	|
-----