In [1]:
import numpy as np

In [2]:
def initialization():
    '''
    Returns:
            V - a 2d array initialized as 0
            R - array containing rewards for each state
            P - array denoting equiprobable random policy
            states - array containing tuples of states in the gridworld
            terminal_states - array containing terminal states of the gridworld
    '''
    num_rows = 4
    num_cols = 4
    states = []
    for i in range(num_rows):
        for j in range(num_cols):
            states.append((i,j))
    terminal_states = [(0,0),(3,3)]
    V = np.zeros([4,4])
    R = {}
    P = {}
    for state in states:
        if state in terminal_states:
            R[state] = 0
            P[state] = []
        else:
            R[state] = -1
            P[state] = ['L','R','D','U']
    return V,R,P,states,terminal_states

In [3]:
def policy_evaluation(V,R,P,states,terminal_states):
    '''
    Arguments:
            V is a 2d array initialized as 0
            R is the array containing rewards for each state
            P is the policy taken by the agent
            states - array containing tuples of states in the gridworld
            terminal_states - array containing terminal states of the gridworld
    Returns:
            V - The value function calculated for the policy P
    '''
    num_iterations = 0
    V = {}
    for state in states:
        V[state] = 0
    while num_iterations < 1000:
        V1 = {}
        for state in states:
            V1[state] = 0
        
        for state in states:
            if state in terminal_states:
                continue
            moves = len(P[state])
            if 'L' in P[state]:
                V1[state] += (R[state] + V[(state[0],max(0,state[1]-1))])/moves
            if 'R' in P[state]:
                V1[state] += (R[state] + V[(state[0],min(3,state[1]+1))])/moves
            if 'U' in P[state]:
                V1[state] += (R[state] + V[(max(0,state[0]-1),state[1])])/moves
            if 'D' in P[state]:
                V1[state] += (R[state] + V[(min(3,state[0]+1),state[1])])/moves    
                    
        V = V1
        num_iterations += 1
    return V

In [4]:
def policy_improvement(V,R,P,P1,states,terminal_states):
    '''
    Arguments: 
            V is the value function
            R is the array containing rewards for each state
            P is the equiprobable random policy
            P1 is the previous optimal policy
            states - array containing tuples of states in the gridworld
            terminal_states - array containing terminal states of the gridworld
    Returns:
            P - Optimal policy after performing policy improvement
            policy_stable - bool variable denoting if P = P1
    '''
    policy_stable = True
    # Iterate over all states to find the optimal policy
    for state in states:
        if state in terminal_states:
            continue
        optimal_policy = []
        old_policy = P1[state]
        max_val = -1e10
        if 'L' in P[state]:
            if state[1] - 1 >= 0:
                val = R[state] + V[(state[0],max(0,state[1]-1))]
                if val > max_val:
                    max_val = val
                    optimal_policy = ['L']
                elif val == max_val:
                    optimal_policy.append('L')

        if 'R' in P[state]:
            if state[1] + 1 < 4:
                val = R[state] + V[(state[0],min(3,state[1]+1))]
                if val > max_val:
                    max_val = val
                    optimal_policy = ['R']
                elif val == max_val:
                    optimal_policy.append('R')
        
        if 'U' in P[state]:
            if state[0] - 1 >= 0:
                val = R[state] + V[(max(state[0]-1,0),state[1])]
                if val > max_val:
                    max_val = val
                    optimal_policy = ['U']
                elif val == max_val:
                    optimal_policy.append('U')
                
        if 'R' in P[state]:
            if state[0] + 1 < 4:
                val = R[state] + V[(min(3,state[0]+1),state[1])]
                if val > max_val:
                    max_val = val
                    optimal_policy = ['D']
                elif val == max_val:
                    optimal_policy.append('D')
        
           # Check if policy for state V[state] has changed
            if old_policy != optimal_policy:
                policy_stable = False
                
            P[state] = optimal_policy
            
    return P,policy_stable

In [5]:
V,R,P1,states,terminal_states = initialization()
V = policy_evaluation(V,R,P1,states,terminal_states)

# Initial value function
V1 = np.zeros((4,4))
for state in states:
    V1[state[0]][state[1]] = V[state]
print("Initial value function is: ")
print(V1)

# Perform policy iteration until the policy doesn't change for any state in an iteration
while True:
    # Equiprobable random policy
    P = {}
    for state in states:
        if not (state in terminal_states):
            P[state] = ['L','R','U','D']
        else:
            P[state] = []

    P1,policy_stable = policy_improvement(V,R,P,P1,states,terminal_states)
    # If policy stable is true, the policy hasn't changed for any state in an iteration
    if policy_stable:
        break
    V = policy_evaluation(V,R,P1,states,terminal_states)

# Print optimal value function
print("\nOptimal value function is: ")
V1 = np.zeros((4,4))
for state in states:
    V1[state[0]][state[1]] = V[state]
print(np.round(V1))


# Print optimal policy
# Each cell denotes the optimal action that needs to be taken from that state
P = [[[],[],[],[]],
     [[],[],[],[]],
     [[],[],[],[]],
     [[],[],[],[]]]
for v,a in P1.items():
    P[v[0]][v[1]] = a
print("\nThe optimal policy is:")
for row in P:
    print(row)


Initial value function is: 
[[  0. -14. -20. -22.]
 [-14. -18. -20. -20.]
 [-20. -20. -18. -14.]
 [-22. -20. -14.   0.]]

Optimal value function is: 
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]

The optimal policy is:
[[], ['L'], ['L'], ['L', 'D']]
[['U'], ['L', 'U'], ['L', 'R', 'U', 'D'], ['D']]
[['U'], ['L', 'R', 'U', 'D'], ['R', 'D'], ['D']]
[['R', 'U'], ['R'], ['R'], []]
