# Policy Iteration

In [1]:
# Import necessary libraries

import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Arguments for default/terminating cases

MAX_ERROR = 10**(-6)

In [3]:
# Read text files

mdp1 = open("mdp-10-5.txt", 'rt')
mdp2 = open("mdp-2-2.txt", 'rt')

In [4]:
# Extract each line from the text files

def extract_info(mdp):
    # To store each line of the file
    lines = []
    
    # To store each line of the file
    for line in mdp: 
        lines.append(line)
    
    # To store number of states and actions of the mdp
    n_states = int(lines[0].split()[1])
    n_actions = int(lines[1].split()[1])
    
    # To define values of states and actions
    state = [i for i in range(n_states)]
    action = [i for i in range(n_actions)]
    
    # To store the reward and transition probability for a particular set of initial state, action and final state 
    trans_model = [[[(0,0) for i in range(n_states)] for j in range(n_actions)] for k in range(n_states)]
    
    for line in lines[2:len(lines)-1]:
        splitline = line.split()
        
        # Accessing each 'token' of the line to store the information in the right place
        trans_model[eval(splitline[1])][eval(splitline[2])][eval(splitline[3])] = (eval(splitline[4]),eval(splitline[5]))
    
    gamma = float(lines[len(lines)-1].split()[1])
    
    return n_states, n_actions, state, action, trans_model, gamma

In [5]:
def policy_iter(num_s, num_a, state, action, trans_model, gamma):
    
    # Initialise State Values and policies

    values = np.zeros(num_s)
    policy = np.random.randint(0, num_a, num_s)

    while True:
        while True:
            # Policy Evaluation
            delta = 0
            
            for s in state:
                V0 = values[s]
                for a in action:
                    v = 0
                    
                    for s_next in state:
                        # Update the state value 
                        v += trans_model[s][a][s_next][1]*(trans_model[s][a][s_next][0] + gamma * values[s_next])
                    
                    values[s] = max(values[s],v)
                    
                delta = max(delta, abs(V0 - values[s]))
            if delta < MAX_ERROR: 
                break

        # 3. Policy Improvement
        policy_stable = True
        
        for s in state:
            old_action = policy[s]
            
            # To determine the optimal policy
            action_max = 0
            v_list = []
            
            for a in action:
                v = 0
                
                for s_next in state:
                    # To determine the state value
                    v += trans_model[s][a][s_next][1] * (trans_model[s][a][s_next][0] + gamma * values[s_next])
                v_list.append(v)
                
            # To determine the optimal policy based on the maximum state-value obtained    
            policy[s] = np.argmax(v_list)
            if old_action != policy[s]: 
                policy_stable = False
        if policy_stable: 
            break        
    
    return values, policy

In [6]:
# For mdp1

num_s, num_a, state, action, trans_model, gamma = extract_info(mdp1)

In [7]:
mdp1.close()

In [8]:
val1, pol1 = policy_iter(num_s, num_a, state, action, trans_model, gamma)

In [9]:
# Write the results into a text file

file = open('sol-PI-mdp-10-5.txt','w')
lines = []
for s in state:
    lines.append(f"{'{:.6f}'.format(round(val1[s],6))} {pol1[s]}\n")
file.writelines(lines)
file.close()

In [10]:
# For mdp2

num_s, num_a, state, action, trans_model, gamma = extract_info(mdp2)

In [11]:
mdp2.close()

In [12]:
val2, pol2 = policy_iter(num_s, num_a, state, action, trans_model, gamma)

In [13]:
# Write the results into a text file

file = open('sol-PI-mdp-2-2.txt','w')
lines = []
for s in state:
    lines.append(f"{'{:.6f}'.format(round(val2[s],6))} {pol2[s]}\n")
file.writelines(lines)
file.close()

#### Upon comparison with the provided solutions, the obtained results agree uptil the least significant digit (6th and 4th decimal place). This could further be improved by decreasing the margin of error (MAX_ERROR) allowed. For example, the values match upto the 6th decimal place when MAX_ERROR = 10^-8 and lower.