In [1]:
import numpy as np
from libr.envs.gridworld import GridworldEnv

In [2]:
env = GridworldEnv()

In [3]:
def valueIterationJC(env, discount, tol = 10**(-4)):
    '''This is the main function for computing policy iteration method.
    Inputs - (i) env: the environment
            (ii) discount: the discount factor for infinite horizon discounted DP
            (iii) tol: tolerance for stopping the iterations; the default value is 10^(-4)
    Outputs - Value and policy  
    For example, see Figure 4.5 of Sutton and Barto for the pseudocode
    '''
    
    
    V = np.zeros(env.nS) #initialize Value vector for all states (number of states = env.nS)
    V_old = np.copy(V) # keep a copy of V to compare the error in iterations
    
    policy = np.zeros(env.nS) # policy is an array with env.nS number of rows and env.nA number of columns
    
        
    def spanNorm(V, V_old):
        '''V and V_old are arrays of same dimention'''
        
        diff_vec = V-V_old
        diff_span = np.max(diff_vec) - np.min(diff_vec)
        return diff_span

    def oneStepValueExp(V_old):
        actionvalueExp = np.zeros([env.nS, env.nA]) # array of actions 
        V = np.zeros(env.nS)
        
        for s in range(env.nS):
            for a in range(env.nA):
                for prob, nextState, reward, done in env.P[s][a]:
                    #Bellman update
                    actionvalueExp[s][a] += prob*(reward+discount*V_old[nextState])
            V[s] = np.max(actionvalueExp[s])
            policy[s] = np.argmax(actionvalueExp[s])                                         
        return V, policy
                                                
   
    while True:
        V_old = np.copy(V)                                    
        V, policy = oneStepValueExp(V_old)
        err = spanNorm(V,V_old)
        if err < tol:
            break
                                                  
    return V, policy                                               

In [4]:
valueIterationJC(env, 0.9)

(array([ 0.  , -1.  , -1.9 , -2.71, -1.  , -1.9 , -2.71, -1.9 , -1.9 ,
        -2.71, -1.9 , -1.  , -2.71, -1.9 , -1.  ,  0.  ]),
 array([ 0.,  3.,  3.,  2.,  0.,  0.,  0.,  2.,  0.,  0.,  1.,  2.,  0.,
         1.,  1.,  0.]))