In [1]:
import numpy as np

In [2]:
# Initialize states and policy

def initialization():
    V = np.zeros([4,4])
    # L = left, R = right, U = up, D = down
    P = [[[],['L','R','D'],['L','R','D'],['L','D']],
         [['U','R','D'],['U','R','L','D'],['U','R','L','D'],['U','L','D']],
         [['U','R','D'],['U','R','L','D'],['U','R','L','D'],['U','L','D']],
         [['U','R'],['L','R','U'],['L','R','U'],[]]]
    return V,P

In [3]:
# Policy evaluation
def policy_evaluation(V):
    num_iterations = 0
    V = np.zeros([4,4])
    V1 = np.zeros([4,4])
    # Loop for convergence of V
    while num_iterations < 1000:
        # Loops to update all states
        for i in range(0,4):
            for j in range(0,4):
                if (i == 0 and j == 0) or (i == 3 and j == 3):
                    V1[i][j] = 0
                    continue
                elif i == 0 and j == 3:
                    V1[i][j] = 0.25*(-1 + V[i][j-1]) + 0.25*(-1 + V[i+1][j]) + 2*0.25*(-1 + V[i][j])
                elif i == 3 and j == 0:
                    V1[i][j] = 0.25*(-1 + V[i-1][j]) + 0.25*(-1 + V[i][j+1]) + 2*0.25*(-1 + V[i][j])
                elif i == 0 and j != 0:
                    V1[i][j] = 0.25*(-1 + V[i][j-1]) + 0.25*(-1 + V[i][j+1]) + 0.25*(-1 + V[i+1][j]) + 0.25*(-1 + V[i][j])
                elif i == 3 and j != 3:
                    V1[i][j] = 0.25*(-1 + V[i][j-1]) + 0.25*(-1 + V[i][j+1]) + 0.25*(-1 + V[i-1][j]) + 0.25*(-1 + V[i][j])
                elif j == 0 and i != 0:
                    V1[i][j] = 0.25*(-1 + V[i][j+1]) + 0.25*(-1 + V[i-1][j]) + 0.25*(-1 + V[i+1][j]) + 0.25*(-1 + V[i][j])
                elif j == 3 and i != 3:
                    V1[i][j] = 0.25*(-1 + V[i][j-1]) + 0.25*(-1 + V[i-1][j]) + 0.25*(-1 + V[i+1][j]) + 0.25*(-1 + V[i][j])
                else:
                    V1[i][j] = 0.25*(-1 + V[i-1][j]) + 0.25*(-1 + V[i+1][j]) + 0.25*(-1 + V[i][j-1]) + 0.25*(-1 + V[i][j+1])
        
        V = V1
        num_iterations += 1
    
    return V

In [4]:
# Policy improvement
def policy_improvement(V,P):
    policy_stable = True
    # Implementation of Bellman Optimality Equation for all states
    for i in range(4):
        for j in range(4):
            if (i == 0 and j == 0) or (i == 3 and j == 3):
                continue
            old_action = P[i][j]
            max_val = -1e10
            new_action = []
            # Find optimal action for the state from the given policy
            for action in P[i][j]:
                if action == 'L':
                    val_l = -1 + V[i][j-1]
                    if val_l > max_val:
                        new_action = ['L']
                        max_val = val_l
                    elif val_l == max_val:
                        new_action.append('L')
                if action == 'U':
                    val_u = -1 + V[i-1][j]
                    if val_u > max_val:
                        new_action = ['U']
                        max_val = val_u
                    elif val_u == max_val:
                        new_action.append('U')
                if action == 'R':
                    val_r = -1 + V[i][j+1]
                    if val_r > max_val:
                        new_action = ['R']
                        max_val = val_r
                    elif val_r == max_val:
                        new_action.append('R')
                if action == 'D':
                    val_d = -1 + V[i+1][j]
                    if val_d > max_val:
                        new_action = ['D']
                        max_val = val_d
                    elif val_d == max_val:
                        new_action.append('D')
                        
            # If policy is updated for atleast one state, we have to repeat the process again
            if old_action != new_action:
                P[i][j] = new_action
                policy_stable = False

    return policy_stable

In [5]:
V,P = initialization()

V = policy_evaluation(V)
while True:
    policy_stable = policy_improvement(V,P)
    if policy_stable == True:
        break
    V = policy_evaluation(V)
# Print optimal value function
print(np.round(V))
# Print optimal policy
# Each cell denotes the optimal action that can be taken from that state
for row in P:
    print(row)

[[  0. -14. -20. -22.]
 [-14. -18. -20. -20.]
 [-20. -20. -18. -14.]
 [-22. -20. -14.   0.]]
[[], ['L'], ['L'], ['L', 'D']]
[['U'], ['U', 'L'], ['L', 'D'], ['D']]
[['U'], ['U', 'R'], ['R', 'D'], ['D']]
[['U', 'R'], ['R'], ['R'], []]
