### Assignment : Week 2
## Finding best policies in simple MDPs

Great work making the MDPs in Week 1!

In this assignment, we'll use the simplest RL techniques - Policy and Value iteration to find the best policies (which maximize the discounted total reward) in our MDPs from last week.

Feel free to use your own MDPs, or import them from the OpenAI Gym library.

You can start this assignment during/after reading Grokking Ch-3.

Let us recall the equation to find the value function of agent's states under a policy $\pi$ -
$$v_{\pi}(s) = \sum _{a} \pi(a|s) ~ \left( ~ \sum _{s', r} ~ p(s', r | s, a) ~ \left[r + \gamma v_{\pi}(s') \right] ~ \right)$$

We can observe that the value function $v_{\pi}$ has a lot of circular dependencies on different states. 

To solve such equations, one of the ways is to iteratively calculate the RHS and replace the LHS by it until the $v_{\pi}(s)$ values start to converge. 

The point of convergence makes all the equations simultaneously true and hence is the required solution.

Let us calculate the value functions for some policies in the MDPs we created last week.

## Environment 0 - Bandit Walk

Again, we consider the BW environment on Page 39.

Let's consider what seems to be the most natural policy - always go Right.

This environment is so simple, that we can simply calculate the value functions by hand.

Note that by convention for the terminal states, 
$$v_{\pi}(0) = v_{\pi}(2) = 0$$

Now, 
$$v_{\pi}(1) = 1 + \gamma \cdot v_{\pi}(2) = 1$$

Note both the summations just have one term due to the deterministic nature of the environment and the policy (check which summation was corresponding to which stochastic variable)

## Environment 1 - Slippery Walk

Let's now try to solve the SWF environment from Page 67 for the naturally adversarial policy - always go Left.

Since we have 5 coupled equations for states 1-5 with 5 unknown variables, we'll use Python to bruteforce the solution.

To align with Grokking, let us consider an unusual $\gamma = 1$.

In [11]:
# Step 0 is to import stuff
import gym, gym_walk
import numpy as np
from gym.envs.toy_text.frozen_lake import generate_random_map

In [12]:
# Step 1 is to get the MDP

env = gym.make('SlipperyWalkFive-v0')
swf_mdp = env.P
# swf_mdp

# Note that in Gym, action "Left" is "0" and "Right" is "1"

In [69]:
swf_mdp

{0: {0: [(0.5000000000000001, 0, 0.0, True),
   (0.3333333333333333, 0, 0.0, True),
   (0.16666666666666666, 0, 0.0, True)],
  1: [(0.5000000000000001, 0, 0.0, True),
   (0.3333333333333333, 0, 0.0, True),
   (0.16666666666666666, 0, 0.0, True)]},
 1: {0: [(0.5000000000000001, 0, 0.0, True),
   (0.3333333333333333, 1, 0.0, False),
   (0.16666666666666666, 2, 0.0, False)],
  1: [(0.5000000000000001, 2, 0.0, False),
   (0.3333333333333333, 1, 0.0, False),
   (0.16666666666666666, 0, 0.0, True)]},
 2: {0: [(0.5000000000000001, 1, 0.0, False),
   (0.3333333333333333, 2, 0.0, False),
   (0.16666666666666666, 3, 0.0, False)],
  1: [(0.5000000000000001, 3, 0.0, False),
   (0.3333333333333333, 2, 0.0, False),
   (0.16666666666666666, 1, 0.0, False)]},
 3: {0: [(0.5000000000000001, 2, 0.0, False),
   (0.3333333333333333, 3, 0.0, False),
   (0.16666666666666666, 4, 0.0, False)],
  1: [(0.5000000000000001, 4, 0.0, False),
   (0.3333333333333333, 3, 0.0, False),
   (0.16666666666666666, 2, 0.0, Fa

In [59]:
# Step 2 is to write the policy

pi = {
    0 : 0,
    1 : 0,
    2 : 0,
    3 : 0,
    4 : 0,
    5 : 0,
    6 : 0
}

# Or you can do it randomly
# pi = dict()
# for state in mdp:
#     pi[state] = np.random.choice(mdp[state].keys())

In [60]:
# Step 3 is computing the value function for this envi and policy

# Let us start with a random value function

val = dict()
for state in swf_mdp:
    val[state] = 0

# Since 0 and 6 are terminal states, we know their values are 0

val[0] = 0
val[6] = 0

#Or you could do it randomly, remember to set the terminal states to 0. You can also implement this while evaluating the value function using 
# val = dict()
# for state in mdp:
#     val[state] = np.random.random()
#     if mdp[state][0][0][0] == 0: # if the first action in the first outcome of the first state is 0, then it is a terminal state
#         val[state] = 0

#instead of doing thsi you can simply intialize the value function to 0 for all states 
# for state in swf_mdp:
#   val[state] = 0

In [142]:
def create_epsilon_greedy_policy(mdp, epsilon=0.1):
    policy = {
        state: [1 if len(actions) == 1 else 1 - epsilon if action == list(actions.keys())[0] else epsilon/(len(actions)-1) for action in actions.keys()]
        for state, actions in mdp.items()
    }
    return policy

In [131]:
def get_new_value_fn(val, mdp, pi, gamma = 0.7):
    new_val = {
        state: sum([tup[0]*(tup[2] + gamma*val[tup[1]]) for tup in mdp[state][np.random.choice(range(len(mdp[state])), p = pi[state])]])
        for state in mdp
    }
    print(new_val[11])
    return new_val

In [122]:
#Use to above function to get the new value function, also print how many iterations it took to converge
def policy_evaluation(val, mdp, pi, epsilon=1e-10, gamma=0.7):
    count = 0
    diff = 10
    while (diff > epsilon):
        count += 1
        diff = 0
        new_val = get_new_value_fn(val, mdp, pi, gamma)
        for state in val.keys():
            diff = max(diff, abs(val[state] - new_val[state]))
        val = {state: value for state, value in new_val.items()}
    # Complete this function to iteratively caluculate the value function until the difference between the new and old value function is less than epsilon
    # Also return the number of iterations it took to converge
    # print(count)
    return val, count 

In [150]:
# Perform policy improvement using the polivy and the value function and return a new policy, the action value function should be a nested dictionary
def policy_improvement(val, mdp, gamma=0.7, epsilon=0.1):
    q = {
        state: {
            action: sum([tup[0]*val[tup[1]] for tup in result])
            for action, result in mdp[state].items()
        }
        for state in mdp
    }
    new_pi = dict()
    for state in q:
        a = np.argmax(list(q[state].values()))
        if len(q[state]) == 1:
            new_pi[state] = [1]
        else:
            new_pi[state] = [1 - epsilon if action == a else epsilon/(len(q[state]) - 1) for action in q[state]]
    new_pi = {
        state: [1 if len(q[state]) == 1 else 1 - epsilon if action == list(actions.keys())[0] else epsilon/(len(actions)-1) for action in actions.keys()]
        for state, actions in mdp.items()
    }
    # new_pi = {
    #     state: np.argmax(list(q[state].values()))
    #     for state in mdp
    # }
    # print(new_pi)
    # Complete this function to get the new policy given the value function and the mdp
    return new_pi, q


In [140]:
# Use the above functions to get the optimal policy and optimal value function and return the total number of iterations it took to converge
# Create a random policy and value function to start with or use the ones defined above
def policy_iteration(mdp, epsilon=1e-10, gamma=1.0):
    global pi, val
    count = 0
    # Complete this function to get the optimal policy and value function and return the total number of iterations it took to converge
    return pi, val, count

In [154]:
#Now perform value iteration, note that the value function is a dictionary and not a list, also return the number of iterations it took to converge
def value_iteration(mdp, pi = {}, val={}, gamma=0.7, epsilon=1e-10):
    if not pi:
        pi = create_epsilon_greedy_policy(mdp)
    if not val:
        val = {s: 0 for s in mdp}
    count = 0
    q = {s: {act: 0 for act in action} for s, action in mdp.items()}
    diff = 10
    new_pi = {}
    while any([pi.get(state) != new_pi.get(state) for state in mdp]):
        if new_pi:
            pi = {state: [val for val in action] for state, action in new_pi.items()}
        count += 1
        diff = 0
        val, iteration_count = policy_evaluation(val, mdp, pi, epsilon, gamma)
        new_pi, q = policy_improvement(val, mdp, gamma)
        print(new_pi)
    pi = {state: np.argmax(action) for state, action in new_pi.items()}
    # Complete this function to get the optimal policy, optimal value function and return the total number of iterations it took to converge
    return pi, val, count    

In [100]:
pi1, val1, count1 = value_iteration(swf_mdp, pi1, val1)

{0: 0.0, 1: 0.015304941824200887, 2: 0.03352511060277252, 3: 0.06833430921122735, 4: 0.13850964038422622, 5: 0.2806239662843198, 6: 0.0}
{0: {0: 0.0, 1: 0.0}, 1: {0: 0.010689165708529048, 1: 0.021864202576119894}, 2: {0: 0.030216559314895843, 1: 0.047893015110571334}, 3: {0: 0.06262559843583308, 1: 0.09762044169631766}, 4: {0: 0.12710769578107572, 1: 0.1978709148054399}, 5: {0: 0.16279614228688638, 1: 0.11662626215881096}, 6: {0: 0.0, 1: 0.0}}


In [101]:
val1

{0: 0.0,
 1: 0.015304941824200887,
 2: 0.03352511060277252,
 3: 0.06833430921122735,
 4: 0.13850964038422622,
 5: 0.2806239662843198,
 6: 0.0}

In [102]:
pi1

{0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 0, 6: 0}

## Enviroment 2 - Frozen Lake

Repeat the above steps for the frozen lake environment. Don't create new functions , use the old functions.

You can also write a function `test_policy()` to test your policy after training to find the number of times you reached the goal state

In [78]:
env2 = gym.make('FrozenLake-v1',desc=generate_random_map(size=4))
mdp2 = env2.P

In [126]:
terminal_states = [10, 12, 15]
winning_states = [15]
dirs = {0: -1, 1: 4, 2:1, 3: -4}
mdp2 = {
    state: {
        action: [(1, state+dirs[action]*(0<=state + dirs[action]<=15), 1*(state + dirs[action] in winning_states), state+dirs[action]*(0<=state + dirs[action]<=15) in terminal_states)]
        if dirs[action == 0]
        else
        []
        if dirs[action] == 1
        for action in range(4)
    }
    if state not in terminal_states
    else
    {
        action: [(1, state, 0, state in terminal_states)]
        for action in range(4)
    }
    for state in mdp2
}

In [127]:
mdp2

{0: {0: [(1, 0, 0, False)],
  1: [(1, 4, 0, False)],
  2: [(1, 1, 0, False)],
  3: [(1, 0, 0, False)]},
 1: {0: [(1, 0, 0, False)],
  1: [(1, 5, 0, False)],
  2: [(1, 2, 0, False)],
  3: [(1, 1, 0, False)]},
 2: {0: [(1, 1, 0, False)],
  1: [(1, 6, 0, False)],
  2: [(1, 3, 0, False)],
  3: [(1, 2, 0, False)]},
 3: {0: [(1, 2, 0, False)],
  1: [(1, 7, 0, False)],
  2: [(1, 4, 0, False)],
  3: [(1, 3, 0, False)]},
 4: {0: [(1, 3, 0, False)],
  1: [(1, 8, 0, False)],
  2: [(1, 5, 0, False)],
  3: [(1, 0, 0, False)]},
 5: {0: [(1, 4, 0, False)],
  1: [(1, 9, 0, False)],
  2: [(1, 6, 0, False)],
  3: [(1, 1, 0, False)]},
 6: {0: [(1, 5, 0, False)],
  1: [(1, 10, 0, True)],
  2: [(1, 7, 0, False)],
  3: [(1, 2, 0, False)]},
 7: {0: [(1, 6, 0, False)],
  1: [(1, 11, 0, False)],
  2: [(1, 8, 0, False)],
  3: [(1, 3, 0, False)]},
 8: {0: [(1, 7, 0, False)],
  1: [(1, 12, 0, True)],
  2: [(1, 9, 0, False)],
  3: [(1, 4, 0, False)]},
 9: {0: [(1, 8, 0, False)],
  1: [(1, 13, 0, False)],
  2: [(1,

In [None]:
2: GO RIGHT
1: GO DOWN
0: GO LEFT
3: GO UP

In [168]:
# pi1, val1, count1 = policy_iteration(mdp2)
pi2, val2, count2 = value_iteration(mdp2)

0.0
{0: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 1: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 2: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 3: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 4: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 5: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 6: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 7: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 8: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 9: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 10: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 11: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 12: [0.9, 0.03333333333333333, 0.03333333333333333, 0.03333333333333333], 13: [0.9, 0.03333333333333333, 0.0333333333

In [169]:
pi2

{0: 0,
 1: 0,
 2: 0,
 3: 0,
 4: 0,
 5: 0,
 6: 0,
 7: 0,
 8: 0,
 9: 0,
 10: 0,
 11: 0,
 12: 0,
 13: 0,
 14: 0,
 15: 0}

In [153]:
val2

{0: 0.0,
 1: 0.0,
 2: 0.0,
 3: 0.0,
 4: 0.0,
 5: 0.0,
 6: 0.0,
 7: 0.0,
 8: 0.0,
 9: 0.0,
 10: 0.0,
 11: 0.0,
 12: 0.0,
 13: 0.0,
 14: 0.0,
 15: 0.0}

In [None]:
def test_policy(pi, env, goalstate):
    # Complete this function to test the policy
    return