In [136]:
import numpy as np
import sys
from gym.envs.toy_text import discrete

from contextlib import closing
from io import StringIO

In [137]:
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3

class GRID_WORLD(discrete.DiscreteEnv):
    def __init__(self):
        metadata = {'render.modes': ['human', 'ansi']}
        self.number_actions=4
        self.shape=(4,4)
        self.number_states=np.prod(self.shape)
        P = {}
        for s in range(self.number_states):
          position = np.unravel_index(s, self.shape)
          P[s] = {a: [] for a in range( self.number_actions)}
          P[s][UP] = self.tarnsition_p(position, [-1, 0])
          P[s][RIGHT] = self.tarnsition_p(position, [0, 1])
          P[s][DOWN] = self.tarnsition_p(position, [1, 0])
          P[s][LEFT] = self.tarnsition_p(position, [0, -1])
        self.P = P
        isd = np.ones(self.number_states) /self.number_states
        super(GRID_WORLD, self).__init__(self.number_states, self.number_actions, P,isd)
        
    def tarnsition_p(self,pose_array,action):
        sing_state=np.ravel_multi_index(tuple(pose_array),self.shape)
        if sing_state==0 or sing_state==self.number_states-1:
            return [(1.0, sing_state, 0, True)]
        else:
            new_pose=np.array(pose_array)+np.array(action)
            new_pose=self.limitations(new_pose).astype(int)
            new_sing_state=np.ravel_multi_index(tuple(new_pose),self.shape)
            return [(1.0, new_sing_state, -1, False)]  
        
    def limitations(self,array):
        array[0] = min(array[0], self.shape[0] - 1)
        array[0] = max(array[0], 0)
        array[1] = min(array[1], self.shape[1] - 1)
        array[1] = max(array[1], 0)
        return array
    
    def render(self, mode='human'):
        outfile = StringIO() if mode == 'ansi' else sys.stdout
        for s in range(self.number_states):
          position = np.unravel_index(s, self.shape)
          if s == self.number_states:
            output = " x "
          elif s == 0 or s == self.number_states - 1:
            output = " T "
          else:
            output = " o "
          if position[1] == 0:
            output = output.lstrip()
          if position[1] == self.shape[1] - 1:
            output = output.rstrip()
            output += '\n'
          outfile.write(output)
        outfile.write('\n')
        if mode != 'human':
          with closing(outfile):
            return outfile.getvalue()

In [242]:
def policy_evaluation(policy,env,discount_factor=1.0, theta=0.00001):
    V=np.zeros(env.number_states)
    V_up=np.copy(V)
    while True:
        d=0
        for s in range(env.number_states):
            v=0
            for a, pi in enumerate(policy[s]):
                for prob,updated_state,reward,done in env.P[s][a]:
                    v += pi * prob *(reward + discount_factor * V[updated_state])
            V_up[s] = v
            d = max(d, np.abs(V_up[s] - V[s]))
                    
        V = np.copy(V_up)
        if d < theta:
            break
    return V
                    
                
            

In [243]:
def q_greedify_policy(env, V, pi, s, gamma):
    G = np.zeros(env.number_actions, dtype=float)
    for a in range(env.number_actions):
        
        for prob, next_state, reward, done in env.P[s][a]:
            G[a] += prob * (reward + gamma * V[next_state])
            
    greed_actions = np.argwhere(G == np.amax(G))
    for a in range(env.number_actions):
        if a in greed_actions:
            pi[s, a] = 1 / len(greed_actions)
        else:
            pi[s, a] = 0
  

        

In [244]:
def improve_policy(env, V, pi, gamma):
    policy_stable = True
    for s in range(env.number_states):
        old = pi[s].copy()
        q_greedify_policy(env, V, pi, s, gamma)
        if not np.array_equal(pi[s], old):
            policy_stable = False
    return pi, policy_stable


In [245]:
def policy_iteration(env, gamma, theta):
    V = np.zeros(env.number_states)
    pi = np.ones([env.number_states, env.number_actions])/ env.number_actions
    policy_stable = False
    while not policy_stable:
        V =policy_evaluation(pi,env,gamma, theta)
        pi, policy_stable = improve_policy(env, V, pi, gamma)
    return V, pi
            

In [256]:
env=GRID_WORLD() 
pi = np.ones([env.number_states, env.number_actions])/ env.number_actions
v,pi=policy_iteration(env,gamma = 1, theta = 0.00001)
print(v.reshape(4,4))
print(pi)

[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]
[[0.25 0.25 0.25 0.25]
 [0.   0.   0.   1.  ]
 [0.   0.   0.   1.  ]
 [0.   0.   0.5  0.5 ]
 [1.   0.   0.   0.  ]
 [0.5  0.   0.   0.5 ]
 [0.25 0.25 0.25 0.25]
 [0.   0.   1.   0.  ]
 [1.   0.   0.   0.  ]
 [0.25 0.25 0.25 0.25]
 [0.   0.5  0.5  0.  ]
 [0.   0.   1.   0.  ]
 [0.5  0.5  0.   0.  ]
 [0.   1.   0.   0.  ]
 [0.   1.   0.   0.  ]
 [0.25 0.25 0.25 0.25]]
