### Assignment : Week 2
## Finding best policies in simple MDPs

Great work making the MDPs in Week 1!

In this assignment, we'll use the simplest RL techniques - Policy and Value iteration to find the best policies (which maximize the discounted total reward) in our MDPs from last week.

Feel free to use your own MDPs, or import them from the OpenAI Gym library.

You can start this assignment during/after reading Grokking Ch-3.

For this you have to install gymnasium, which is an API standard for reinforcement learning with a diverse collection of reference environments. This can be easily done by running:

    pip install gymnasium

## Frozen Lake

Let's now try to solve the Frozen Lake environment for some cases

In [473]:
# Step 0 is to import stuff

import gymnasium as gym
import numpy as np
from gymnasium.envs.toy_text.frozen_lake import generate_random_map
import pprint

In [474]:
# Step 1 is to get the MDP
env = gym.make('FrozenLake-v1', desc=None, map_name="4x4", is_slippery=True)
env = env.unwrapped
mdp_transitions = env.P
init_state = env.reset()
goal_state = 15

In [475]:
# Step 2 is to write the policy

# This is according to the convention of gymnasium
LEFT, DOWN, RIGHT, UP = range(4)

pi = {
    0:RIGHT, 1:RIGHT, 2:DOWN, 3:LEFT,
    4:DOWN, 5:LEFT, 6:DOWN, 7:LEFT,
    8:RIGHT, 9:RIGHT, 10:DOWN, 11:LEFT,
    12:LEFT, 13:RIGHT, 14:RIGHT, 15:LEFT
}

# Or you can do it randomly
# pi = dict()
# for state in mdp:
#     pi[state] = np.random.choice(mdp[state].keys())

In [476]:
# Step 3 is computing the value function for this envi and policy

# Let us start with a random value function

val = dict()
for state in mdp_transitions:
    val[state] = 0

Holes = [5,7,11,12,15]
# Since 5, 7, 11, 12 and 15 are terminal states, we know their values are 0

#Or you could do it randomly, remember to set the terminal states to 0. You can also implement this while evaluating the value function using 
# val = dict()
# for state in mdp:
#     val[state] = np.random.random()
#     if mdp[state][0][0][0] == 0: # if the first action in the first outcome of the first state is 0, then it is a terminal state
#         val[state] = 0

#instead of doing this you can simply intialize the value function to 0 for all states 
# for state in swf_mdp:
#   val[state] = 0

In [477]:
def get_new_value_fn(val, mdp, pi, gamma = 1.0):
    new_val = dict()
    for state in range(16):
        if state in Holes:
            new_val[state]=0
        else:
            state_new_val = 0
            action = pi[state]
            for my_tupple in mdp[state][action]:
                prob,next_state,new_reward,done = my_tupple
                state_new_val += prob*(new_reward + gamma * val[next_state])
            new_val[state] = state_new_val          
    return new_val

{0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0, 5: 0, 6: 0.0, 7: 0, 8: 0.0, 9: 0.0, 10: 0.0, 11: 0, 12: 0, 13: 0.0, 14: 0.3333333333333333, 15: 0}


In [478]:
def policy_evaluation(val, mdp, pi, epsilon=1e-10, gamma=1.0):
    count = 0
    start_v = val
    max_difference = 1
    while max_difference>epsilon:
        prev_v = start_v
        next_v = get_new_value_fn(prev_v,mdp,pi,gamma)
        keys = sorted(prev_v.keys())
        prev_values = [prev_v[key] for key in keys]
        next_values = [next_v[key] for key in keys]
        differences = [abs(next_val - prev_val) for next_val, prev_val in zip(next_values, prev_values)]
        max_difference = max(differences)
        count +=1
        start_v = next_v 
    return next_v , count

print(policy_evaluation(val,mdp_transitions,pi,1e-10,1))

({0: 0.03749999984671091, 1: 0.024999999910466038, 2: 0.04999999989517735, 3: 0.024999999910466038, 4: 0.049999999910112294, 5: 0, 6: 0.09999999995153354, 7: 0, 8: 0.09999999989482361, 9: 0.24999999986164584, 10: 0.2999999998948236, 11: 0, 12: 0, 13: 0.4499999998049359, 14: 0.6499999999101123, 15: 0}, 68)


In [479]:
# Perform policy improvement using the policy and the value function and return a new policy, the action value function should be a nested dictionary
def policy_improvement(val, mdp, pi, gamma=1.0):
    new_pi = dict()
    q = dict()
    # Complete this function to get the new policy given the value function and the mdp
    return new_pi, q


In [480]:
# Use the above functions to get the optimal policy and optimal value function and return the total number of iterations it took to converge
# Create a random policy and value function to start with or use the ones defined above
def policy_iteration(mdp, epsilon=1e-10, gamma=1.0):
    pi = dict()
    val = dict()
    count = 0
    # Complete this function to get the optimal policy and value function and return the total number of iterations it took to converge
    return pi, val, count

In [481]:
#Now perform value iteration, note that the value function is a dictionary and not a list, also return the number of iterations it took to converge
def value_iteration(mdp, gamma=1.0, epsilon=1e-10):
    val = {s: 0 for s in mdp}
    count = 0
    q = dict()
    # Complete this function to get the optimal policy, optimal value function and return the total number of iterations it took to converge
    pi = {s: max(q[s], key=q[s].get) for s in mdp}
    return pi, val, count
    

In [482]:
#Function to print the policy you got after running the policy iteration or value iteration on the 4x4 FrozenLake environment
def print_policy(policy, env):
    """
    Prints the policy for the 4x4 FrozenLake environment in a grid layout.
    """
    action_symbols = {0: '←', 1: '↓', 2: '→', 3: '↑'}  #action symbols
    grid_size = env.desc.shape  #get the grid dimensions (e.g., 4x4)
    
    policy_symbols = np.array([action_symbols[action] for cell,action in policy.items()])
    policy_grid = policy_symbols.reshape(grid_size)  #reshape into a grid

    print("Policy Grid:")
    for row in policy_grid:
        print(" ".join(row))


In [483]:
pi1, val1, count1 = policy_iteration(mdp)
pi2, val2, count2 = value_iteration(mdp)


NameError: name 'mdp' is not defined

You can also write a function `test_policy()` to test your policy after training to find the number of times you reached the goal state

In [None]:
def test_policy(pi, env, goalstate):
    # Complete this function to test the policy
    return
