### Assignment : Week 2
## Finding best policies in simple MDPs

Great work making the MDPs in Week 1!

In this assignment, we'll use the simplest RL techniques - Policy and Value iteration to find the best policies (which maximize the discounted total reward) in our MDPs from last week.

Feel free to use your own MDPs, or import them from the OpenAI Gym library.

You can start this assignment during/after reading Grokking Ch-3.

For this you have to install gymnasium, which is an API standard for reinforcement learning with a diverse collection of reference environments. This can be easily done by running:

    pip install gymnasium

## Frozen Lake

Let's now try to solve the Frozen Lake environment for some cases

In [47]:
# Step 0 is to import stuff

import gymnasium as gym
import numpy as np
from gymnasium.envs.toy_text.frozen_lake import generate_random_map
import pprint
import random

In [48]:
# Step 1 is to get the MDP
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=True)
env = env.unwrapped
mdp_transitions = env.P
init_state = env.reset()
goal_state = 15

pprint.pprint(mdp_transitions)

{0: {0: [(0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 4, 0.0, False)],
     1: [(0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 4, 0.0, False),
         (0.3333333333333333, 1, 0.0, False)],
     2: [(0.3333333333333333, 4, 0.0, False),
         (0.3333333333333333, 1, 0.0, False),
         (0.3333333333333333, 0, 0.0, False)],
     3: [(0.3333333333333333, 1, 0.0, False),
         (0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 0, 0.0, False)]},
 1: {0: [(0.3333333333333333, 1, 0.0, False),
         (0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 5, 0.0, True)],
     1: [(0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 5, 0.0, True),
         (0.3333333333333333, 2, 0.0, False)],
     2: [(0.3333333333333333, 5, 0.0, True),
         (0.3333333333333333, 2, 0.0, False),
         (0.3333333333333333, 1, 0.0, False)],
     3: [(0.3333333333333333,

In [49]:
# Step 2 is to write the policy

# This is according to the convention of gymnasium
LEFT, DOWN, RIGHT, UP = range(4)

pi = {
    0:RIGHT, 1:RIGHT, 2:DOWN, 3:LEFT,
    4:DOWN, 5:LEFT, 6:DOWN, 7:LEFT,
    8:RIGHT, 9:RIGHT, 10:DOWN, 11:LEFT,
    12:LEFT, 13:RIGHT, 14:RIGHT, 15:LEFT
}

# Or you can do it randomly
# pi = dict()
# for state in mdp:
#     pi[state] = np.random.choice(mdp[state].keys())

In [50]:
# Step 3 is computing the value function for this envi and policy

# Let us start with a random value function

val = dict()
for state in mdp_transitions:
    val[state] = 0

Holes = [5,7,11,12,15]
# Since 5, 7, 11, 12 and 15 are terminal states, we know their values are 0

#Or you could do it randomly, remember to set the terminal states to 0. You can also implement this while evaluating the value function using 
# val = dict()
# for state in mdp:
#     val[state] = np.random.random()
#     if mdp[state][0][0][0] == 0: # if the first action in the first outcome of the first state is 0, then it is a terminal state
#         val[state] = 0

#instead of doing this you can simply intialize the value function to 0 for all states 
# for state in swf_mdp:
#   val[state] = 0

In [51]:
def get_new_value_fn(val, mdp, pi, gamma = 1.0):
    new_val = dict()
    for state in range(16):
        if state in Holes:
            new_val[state]=0
        else:
            state_new_val = 0
            action = pi[state]
            for my_tupple in mdp[state][action]:
                prob,next_state,new_reward,done = my_tupple
                state_new_val += prob*(new_reward + gamma * val[next_state])
            new_val[state] = state_new_val          
    return new_val

In [52]:
def policy_evaluation(val, mdp, pi, epsilon=1e-10, gamma=1.0):
    count = 0
    start_v = val
    max_difference = 1
    while max_difference>epsilon:
        prev_v = start_v
        next_v = get_new_value_fn(prev_v,mdp,pi,gamma)
        keys = sorted(prev_v.keys())
        prev_values = [prev_v[key] for key in keys]
        next_values = [next_v[key] for key in keys]
        differences = [abs(next_val - prev_val) for next_val, prev_val in zip(next_values, prev_values)]
        max_difference = max(differences)
        count +=1
        start_v = next_v 
    return next_v,count

In [53]:
def policy_improvement(val, mdp, pi, gamma=1.0):
    new_pi = dict()
    q = dict()
    for state in range(16):
        empty_q_1 = {0:0, 1:0, 2:0, 3:0}
        q[state] = empty_q_1
    for state in range(len(mdp)):
        actions_dict = mdp[state]
        for action in range(len(mdp[0])):
            action_list = actions_dict[action]
            q_value = 0
            for _ in action_list:
                prob,next_state,new_reward,done = _
                q_value += prob*(new_reward + gamma * val[next_state]*(not done))
            q[state][action] = q_value  
    list_actions = [0,1,2,3]    
    for state in q:
        temp_dict = q[state]
        max_key = max(temp_dict, key=temp_dict.get)
        new_pi[state] = list_actions[max_key]
    return new_pi, q

In [54]:
def policy_iteration(mdp, epsilon=1e-10, gamma=1.0):
    pi = dict()
    val = {s: 0 for s in mdp}
    count = 0
    old_pi = {s: random.randint(0,3) for s in mdp} #starting with a random policy
    temp_val,_ = policy_evaluation(val,mdp,old_pi,epsilon,gamma)
    while True:
        new_val,_ = policy_evaluation(temp_val, mdp, old_pi, epsilon, gamma)
        new_pi,_ = policy_improvement(new_val, mdp, old_pi, gamma)
        count += 1
        if new_pi == old_pi:
            pi = new_pi
            val = policy_evaluation(new_val, mdp, new_pi, epsilon, gamma)
            break
        else:
            old_pi = new_pi
    return pi, val, count

policy_iteration(mdp_transitions,1e-10,1)

({0: 0,
  1: 3,
  2: 3,
  3: 3,
  4: 0,
  5: 0,
  6: 0,
  7: 0,
  8: 3,
  9: 1,
  10: 0,
  11: 0,
  12: 0,
  13: 2,
  14: 1,
  15: 0},
 ({0: 0.8235294095045046,
   1: 0.8235294087468735,
   2: 0.823529408208908,
   3: 0.8235294079297658,
   4: 0.8235294096690228,
   5: 0,
   6: 0.5294117630897942,
   7: 0,
   8: 0.823529409986084,
   9: 0.8235294104326095,
   10: 0.7647058811781084,
   11: 0,
   12: 0,
   13: 0.8823529402305985,
   14: 0.9411764700974368,
   15: 0},
  1),
 4)

In [55]:
#Now perform value iteration, note that the value function is a dictionary and not a list, also return the number of iterations it took to converge
def value_iteration(mdp, gamma=1.0, epsilon=1e-10):
    val = {s: 0 for s in mdp}  # Initialize value function to zero for all states
    count = 0
    while True:
        delta = 0
        for state in mdp:
            v = val[state]
            q_values = []
            for action in mdp[state]:
                q_value = 0
                for prob, next_state, reward, done in mdp[state][action]:
                    q_value += prob * (reward + gamma * val[next_state])
                q_values.append(q_value)
            val[state] = max(q_values)
            delta = max(delta, abs(v - val[state]))
        count += 1
        if delta < epsilon:
            break

    pi = {state: max(mdp[state], key=lambda action: sum(prob * (reward + gamma * val[next_state])
                                                        for prob, next_state, reward, done in mdp[state][action]))
          for state in mdp}

    return pi, val, count
    

In [56]:
#Function to print the policy you got after running the policy iteration or value iteration on the 4x4 FrozenLake environment
def print_policy(policy, env):
    """
    Prints the policy for the 4x4 FrozenLake environment in a grid layout.
    """
    action_symbols = {0: '←', 1: '↓', 2: '→', 3: '↑'}  #action symbols
    grid_size = env.desc.shape  #get the grid dimensions (e.g., 4x4)
    
    policy_symbols = np.array([action_symbols[action] for cell,action in policy.items()])
    policy_grid = policy_symbols.reshape(grid_size)  #reshape into a grid

    print("Policy Grid:")
    for row in policy_grid:
        print(" ".join(row))
        


In [57]:
pi1, val1, count1 = policy_iteration(mdp_transitions,1e-10,1)
pi2, val2, count2 = value_iteration(mdp_transitions,1,1e-10)

print(pi1,val1,count1)
print(pi2,val2,count2)

print_policy(pi1,env)

{0: 0, 1: 3, 2: 3, 3: 3, 4: 0, 5: 0, 6: 0, 7: 0, 8: 3, 9: 1, 10: 0, 11: 0, 12: 0, 13: 2, 14: 1, 15: 0} ({0: 0.8235294094743778, 1: 0.8235294087066481, 2: 0.8235294081615119, 3: 0.823529407878649, 4: 0.8235294096410889, 5: 0, 6: 0.529411763068253, 7: 0, 8: 0.8235294099623763, 9: 0.8235294104148537, 10: 0.7647058811624488, 11: 0, 12: 0, 13: 0.8823529402179906, 14: 0.9411764700908947, 15: 0}, 1) 3
{0: 0, 1: 3, 2: 3, 3: 3, 4: 0, 6: 0, 8: 3, 9: 1, 10: 0, 13: 2, 14: 1} {0: 0.8235294100518615, 1: 0.8235294094898987, 2: 0.8235294090993286, 3: 0.8235294089005962, 4: 0.8235294102241579, 5: 0, 6: 0.5294117635379374, 7: 0, 8: 0.8235294104939872, 9: 0.8235294108399842, 10: 0.7647058815425976, 11: 0, 12: 0, 13: 0.8823529405337791, 14: 0.9411764702612166, 15: 0} 591
Policy Grid:
← ↑ ↑ ↑
← ← ← ←
↑ ↓ ← ←
← → ↓ ←


You can also write a function `test_policy()` to test your policy after training to find the number of times you reached the goal state

In [58]:
def test_policy(pi, env, goalstate):
    # Complete this function to test the policy
    return
