In [1]:
import numpy as np
from libr.envs.gridworld import GridworldEnv

In [2]:
env = GridworldEnv()

In [3]:
def policyIterationJC(env, discount, tol = 10**(-4)):
    '''This is the main function for computing policy iteration method.
    Inputs - (i) env: the environment
            (ii) discount: the discount factor for infinite horizon discounted DP
            (iii) tol: tolerance for stopping the iterations; the default value is 10^(-4)
    Outputs - Value and policy  
    For example, see Figure 4.3 of Sutton and Barto for the pseudocode
    '''
    
    #Initialization
    V = np.zeros(env.nS)
    policy_new = np.zeros(env.nS) # alternatively: policy = np.zeros_like(V)
    #policy_old = np.zeros_like(policy_new)
        
    def spanNorm(V, V_old):
        '''V and V_old are arrays of same dimention'''
        
        diff_vec = V-V_old
        diff_span = np.max(diff_vec) - np.min(diff_vec)
        return diff_span


    def policyEvaluation(policy):
        #actionvalueExp = np.zeros([env.nS, env.nA]) # array of actions 
        V = np.zeros(env.nS)
        flag = 1
        while flag ==1:
            V1 = np.copy(V)
            for s in range(env.nS):
                a = policy[s]
                for prob, nextState, reward, done in env.P[s][a]:
                    #Bellman update
                     V[s] += prob*(reward+discount*V[nextState])
            err1 = spanNorm(V,V1)
            if err1 < tol:
                flag = 0
        return V
    
    while True:
        policy_old = np.copy(policy_new)
        V_old = policyEvaluation(policy_old)
                
        actionvalueExp = np.zeros([env.nS, env.nA]) # array of actions 
        for s in range(env.nS):
            for a in range(env.nA):
                for prob, nextState, reward, done in env.P[s][a]:
                    #Compute Q-values with Bellman update
                    actionvalueExp[s][a] += prob*(reward+discount*V_old[nextState])
            policy_new[s] = np.argmax(actionvalueExp[s]) # policy improvement with greedy policy
        V = policyEvaluation(policy_new)   
        err = spanNorm(V, V_old) 
        if err < tol:
            break
    return V, policy_new     

In [4]:
policyIterationJC(env,0.9)



KeyboardInterrupt: 