# Policy Iteration on Grid World

In [108]:
import gym
import numpy as np
from gym import spaces

## A Gridworld Example
### if on top-left or botton-right corner than reward is 0, else -1


In [109]:
# custom 2d grid world enviroment which extends gym.Env
class GridWorld(gym.Env):
    """
        A grid world env follows the book Example 4.1
    """
    metadata = {'render.modes': ['console']}
    
    # actions available 
    UP   = 0
    LEFT = 1
    DOWN = 2
    RIGHT= 3
    
    def __init__(self, size):
        super(GridWorld, self).__init__()
        
        self.size = size # size of the grid world
        self.end_state = [[0,0],[size-1,size-1]] # top left and bottom right
        
        # randomly assign the inital location of agent
        self.agent_position = [np.random.randint(0,self.size),np.random.randint(0,self.size)]
        
        # respective actions of agents : up, down, left and right
        self.action_space = spaces.Discrete(4)
        
        # set the observation space to (1,) to represent agent position in the grid world 
        # staring from [0,size*size)
        self.observation_space = spaces.Box(low=0, high=4, shape=(2,), dtype=np.uint8)
        
        self.states = [[i,j] for i in range(size) for j in range(size)]
    
    def bound_position(self):
        self.agent_position[0] = 0 if self.agent_position[0]<0 else self.agent_position[0]
        self.agent_position[0] = self.size-1 if self.agent_position[0]>self.size-1 else self.agent_position[0]
        self.agent_position[1] = 0 if self.agent_position[1]<0 else self.agent_position[1]
        self.agent_position[1] = self.size-1 if self.agent_position[1]>self.size-1 else self.agent_position[1]
        
    def step(self,action):
        info = {} # additional information
        
        if self.agent_position in self.end_state:
            return self.agent_position, 0, True, info
        
        
        if action == self.UP:
            self.agent_position[0] -=1

        elif action == self.LEFT:
            self.agent_position[1] -=1

        elif action == self.DOWN:
            self.agent_position[0] +=1

        elif action == self.RIGHT:
            self.agent_position[1] +=1

        else:
            raise ValueError("Received invalid action={} which is not part of the action space".format(action))
        
        self.bound_position()
        done = bool(self.agent_position in self.end_state)
        
        # reward agent when it is in the terminal cell, else reward = 0
        reward = 0 if done else -1
        
        return self.agent_position, reward, done, info
    
    def render(self, mode='console'):
        '''
            render the state
        '''
        if mode != 'console':
          raise NotImplementedError()
        
        row  = self.agent_position[0]
        col  = self.agent_position[1]
        
        for r in range(self.size):
            for c in range(self.size):
                if r == row and c == col:
                    print("X",end='')
                else:
                    print('.',end='')
            print('')

    def reset(self,position):
        
        self.agent_position = position
    
    def close(self):
        pass

## Policy Evaluation Function

In [118]:
def policy_evaluation(theta,V,Pi,env,max_iter = 1):
    
    delta = 0
    for i in range(max_iter):
        V_new = {}
        
        #for all non-terminal states
        for state in env.states:
            if state in env.end_state:
                V_new[tuple(state)] = 0
                continue
                
            #old state value
            v = V[tuple(state)]
            
            #action under current policy
            action = Pi[tuple(state)]
            
            #take the action and get reward
            env.reset(state.copy())
            next_state, reward, done, info = env.step(action)
            
            #calculate new state value
            new_v = reward + V[tuple(next_state)]
            
            V_new[tuple(state)] = new_v
            
            #update delta
            delta = max(abs(new_v-v), delta)
            
        V = V_new
        
        if delta < theta:
            break
            
    return V

## Policy Improvement Function

In [119]:
def policy_improvement(V,Pi,env):
    policy_stable = True
    Pi_new = {}
    
    #for all non-terminal states
    for state in env.states:
        if state in env.end_state:
            continue
        #old action
        action = Pi[tuple(state)]
        
        #find new optimal action under current state-value function
        action_values = []
        
        #take all actions and find the largest value
        for a in range(env.action_space.n):
            env.reset(state.copy())
            next_state, reward, done, info = env.step(a)
            next_state_value = V[tuple(next_state)]
            action_values.append(reward+next_state_value)
        best_action = np.argmax(action_values)
        
        #update policy
        Pi_new[tuple(state)] = best_action
        
        if action != best_action:
            policy_stable = False

    return policy_stable, Pi_new

In [120]:
#rendering function
def plot(V,Pi,size):
    for r in range(size):
        for c in range(size):
            print(V[(r,c)],end = '\t|')
        print('')
        
    
    print('===========================')
    action_map = {0:'up',1:'left',2:'down',3:'right'}
    for r in range(size):
        for c in range(size):
            try:
                print(action_map[Pi[(r,c)]],end = '\t|')
            except:
                print('X',end = '\t|')
        print('')

In [127]:
size = 4
V = dict.fromkeys([(i,j) for i in range(size) for j in range(size)], 0) # values as 0
Pi = dict.fromkeys([(i,j) for i in range(size) for j in range(size) if i+j != 0 and i+j != 6], 0) # inital actions as UP
env = GridWorld(size)
theta = 0.1

plot(V,Pi,size)

0	|0	|0	|0	|
0	|0	|0	|0	|
0	|0	|0	|0	|
0	|0	|0	|0	|
X	|up	|up	|up	|
up	|up	|up	|up	|
up	|up	|up	|up	|
up	|up	|up	|X	|


### Round#1

In [128]:
V = policy_evaluation(theta,V,Pi,env)
policy_stable, Pi = policy_improvement(V,Pi,env)
print("Policy is stable: ", policy_stable)
plot(V,Pi,size)


Policy is stable:  False
0	|-1	|-1	|-1	|
0	|-1	|-1	|-1	|
-1	|-1	|-1	|-1	|
-1	|-1	|-1	|0	|
X	|left	|up	|up	|
up	|left	|up	|up	|
up	|up	|up	|down	|
up	|up	|right	|X	|


### Round#2

In [129]:
V = policy_evaluation(theta,V,Pi,env)
policy_stable, Pi = policy_improvement(V,Pi,env)
print("Policy is stable: ", policy_stable)
plot(V,Pi,size)


Policy is stable:  False
0	|0	|-2	|-2	|
0	|-1	|-2	|-2	|
-1	|-2	|-2	|0	|
-2	|-2	|0	|0	|
X	|left	|left	|up	|
up	|up	|left	|down	|
up	|up	|down	|down	|
up	|right	|right	|X	|


In [130]:
V = policy_evaluation(theta,V,Pi,env)
policy_stable, Pi = policy_improvement(V,Pi,env)
print("Policy is stable: ", policy_stable)
plot(V,Pi,size)


Policy is stable:  False
0	|0	|-1	|-3	|
0	|-1	|-2	|-1	|
-1	|-2	|-1	|0	|
-2	|-1	|0	|0	|
X	|left	|left	|left	|
up	|up	|up	|down	|
up	|up	|down	|down	|
up	|right	|right	|X	|


In [131]:
V = policy_evaluation(theta,V,Pi,env)
policy_stable, Pi = policy_improvement(V,Pi,env)
print("Policy is stable: ", policy_stable)
plot(V,Pi,size)

Policy is stable:  True
0	|0	|-1	|-2	|
0	|-1	|-2	|-1	|
-1	|-2	|-1	|0	|
-2	|-1	|0	|0	|
X	|left	|left	|left	|
up	|up	|up	|down	|
up	|up	|down	|down	|
up	|right	|right	|X	|
