In [1]:
import numpy as np

In [2]:
theta = 1e-5
n = 10
gamma = 0.9

In [3]:
def gridworld_1D_policy_eval(n, theta, gamma, pi):
    state_vals = np.zeros(n)
    Delta = 50              #Choose initial value of delta high enough so Delta > theta holds
    
    for i in range(1, n-1):
        state_vals[i] = 1 #Initialize all non-terminal state values arbitrarily
    while Delta > theta:
        Delta = 0
        for s in range(1, n-1):  #only iterate over non-terminal states
            v = state_vals[s] 
            if s == 1: #left action takes us to terminal state where we get reward of 10
                state_vals[s] = pi[s][0]*(10 + gamma*state_vals[s-1]) + pi[s][1]*gamma*state_vals[s+1]
            elif s == n-2:  #right action takes us to terminal state where we get a reward of -5
                state_vals[s] = pi[s][0]*gamma*state_vals[s-1] + pi[s][1]*(-5 + gamma*state_vals[s+1])
            else:
                state_vals[s] = pi[s][0]*gamma*state_vals[s-1] + pi[s][1]*gamma*state_vals[s+1]
            Delta = max(Delta, abs(v - state_vals[s]))
    
    return state_vals

In [4]:
pi = np.zeros((n,2))+0.5 #equiprobable policy, action left is 0 and action right is 1
print(gridworld_1D_policy_eval(n,theta,gamma,pi))

[ 0.          6.88156669  4.18125175  2.41009525  1.1745072   0.19991375
 -0.73025963 -1.82271621 -3.32022229  0.        ]


In [5]:
def gridworld_1D(n, theta, gamma):
    pi = np.zeros((n,2))+0.5 #we start with equiprobable policy
    policy_stable = True
    
    while policy_stable: # In this implementation policy stable = false means actually the opposite
        state_vals = gridworld_1D_policy_eval(n,theta,gamma, pi) #evaluation step
        for s in range(1, n-1):
            v = state_vals[s]
            old_action = pi[s][0] #if we change one probability we also have to change the other one accordingly
            if s == 1:            #so we only have to check if one changed
                if (10 + gamma*state_vals[s-1] >= gamma*state_vals[s+1]):
                    pi[s][0] = 1
                    pi[s][1] = 0
                else:
                    pi[s][0] = 0
                    pi[s][1] = 1
                    
            elif s == n-2:
                if (gamma*state_vals[s-1] >= gamma*state_vals[s+1]- 5):
                    pi[s][0] = 1
                    pi[s][1] = 0
                else:
                    pi[s][0] = 0
                    pi[s][1] = 1
                    
            else:
                if (gamma*state_vals[s-1] >= gamma*state_vals[s+1]):
                    pi[s][0] = 1
                    pi[s][1] = 0
                else:
                    pi[s][0] = 0
                    pi[s][1] = 1
            if old_action == pi[s][0]:
                policy_stable = False
                
    return pi

In [6]:
print(gridworld_1D(n, theta, gamma)[1:n-1])

[[1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]]
