### Assignment : Week 2
## Finding best policies in simple MDPs

Great work making the MDPs in Week 1!

In this assignment, we'll use the simplest RL techniques - Policy and Value iteration to find the best policies (which maximize the discounted total reward) in our MDPs from last week.

Feel free to use your own MDPs, or import them from the OpenAI Gym library.

You can start this assignment during/after reading Grokking Ch-3.

Let us recall the equation to find the value function of agent's states under a policy $\pi$ -
$$v_{\pi}(s) = \sum _{a} \pi(a|s) ~ \left( ~ \sum _{s', r} ~ p(s', r | s, a) ~ \left[r + \gamma v_{\pi}(s') \right] ~ \right)$$

We can observe that the value function $v_{\pi}$ has a lot of circular dependencies on different states. 

To solve such equations, one of the ways is to iteratively calculate the RHS and replace the LHS by it until the $v_{\pi}(s)$ values start to converge. 

The point of convergence makes all the equations simultaneously true and hence is the required solution.

Let us calculate the value functions for some policies in the MDPs we created last week.

## Environment 0 - Bandit Walk

Again, we consider the BW environment on Page 39.

Let's consider what seems to be the most natural policy - always go Right.

This environment is so simple, that we can simply calculate the value functions by hand.

Note that by convention for the terminal states, 
$$v_{\pi}(0) = v_{\pi}(2) = 0$$

Now, 
$$v_{\pi}(1) = 1 + \gamma \cdot v_{\pi}(2) = 1$$

Note both the summations just have one term due to the deterministic nature of the environment and the policy (check which summation was corresponding to which stochastic variable)

## Environment 1 - Slippery Walk

Let's now try to solve the SWF environment from Page 67 for the naturally adversarial policy - always go Left.

Since we have 5 coupled equations for states 1-5 with 5 unknown variables, we'll use Python to bruteforce the solution.

To align with Grokking, let us consider an unusual $\gamma = 1$.

## Imports

In [1]:
# Step 0 is to import stuff

import gym, gym_walk
import numpy as np
from gym.envs.toy_text.frozen_lake import generate_random_map
import pprint
import random

## Load MDP from gym

In [2]:
# Step 1 is to get the MDP

env = gym.make('SlipperyWalkFive-v0')
swf_mdp = env.P
pprint.pprint(swf_mdp)

# Note that in this mdp in gym, action "Left" is "0" and "Right" is "1"

{0: {0: [(0.5000000000000001, 0, 0.0, True),
         (0.3333333333333333, 0, 0.0, True),
         (0.16666666666666666, 0, 0.0, True)],
     1: [(0.5000000000000001, 0, 0.0, True),
         (0.3333333333333333, 0, 0.0, True),
         (0.16666666666666666, 0, 0.0, True)]},
 1: {0: [(0.5000000000000001, 0, 0.0, True),
         (0.3333333333333333, 1, 0.0, False),
         (0.16666666666666666, 2, 0.0, False)],
     1: [(0.5000000000000001, 2, 0.0, False),
         (0.3333333333333333, 1, 0.0, False),
         (0.16666666666666666, 0, 0.0, True)]},
 2: {0: [(0.5000000000000001, 1, 0.0, False),
         (0.3333333333333333, 2, 0.0, False),
         (0.16666666666666666, 3, 0.0, False)],
     1: [(0.5000000000000001, 3, 0.0, False),
         (0.3333333333333333, 2, 0.0, False),
         (0.16666666666666666, 1, 0.0, False)]},
 3: {0: [(0.5000000000000001, 2, 0.0, False),
         (0.3333333333333333, 3, 0.0, False),
         (0.16666666666666666, 4, 0.0, False)],
     1: [(0.5000000000000

## Initiate Policy

In [3]:
# Step 2 is to write the policy
# policy recommends action for each state (either deterministically or stochastically)

pi = {
    0 : 0,
    1 : 0,
    2 : 0,
    3 : 0,
    4 : 0,
    5 : 0,
    6 : 0
}

# Or you can initialize action for each state randomly

pprint.pprint(pi)

{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}


## Initiate Value function

In [4]:
# Step 3 is computing the value function for this environment and policy

val = dict()
# We could start with a random value function (after ensuring that values for all terminal states are 0)
# Or, instead of doing this you can simply intialize the value function to 0 for all states 
for state in swf_mdp:
    val[state] = 0

pprint.pprint(val)

{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}


## Get new value function

In [5]:
def get_new_value_fn(val, mdp, pi, gamma = 1.0):
    # val is old value function
    # mdp is mdp
    # pi contains the action for each state
    # gamma is discount
    # new_val is new value function after 1 iteration over old value function
    new_val = dict()
    
    for state in mdp:
        action = pi[state] # action to be taken now given by deterministic policy
        probabilities = mdp[state][action] # list of tuples
        # for 5 and action 0 this is 
        """
        [(0.5000000000000001, 4, 0.0, False),
         (0.3333333333333333, 5, 0.0, False),
         (0.16666666666666666, 6, 1.0, True)]
        """
        new_value_for_state = 0
        
        for (prob, next_state, reward, isTerminal) in probabilities:
            new_value_for_state += prob * (reward + gamma * val[next_state] * (not isTerminal)) # in case value function
            # is not set to 0 in all terminal states

        new_val[state] = new_value_for_state
    return new_val

In [6]:
# some trials to ensure function is correct (matches book)
new_value_fn1 = get_new_value_fn(val, swf_mdp, pi, gamma = 1.0) 
pprint.pprint(new_value_fn1)
new_value_fn2 = get_new_value_fn(new_value_fn1, swf_mdp, pi, gamma = 1.0)
pprint.pprint(new_value_fn2)
new_value_fn3 = get_new_value_fn(new_value_fn2, swf_mdp, pi, gamma = 1.0)
pprint.pprint(new_value_fn3)

{0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0, 5: 0.16666666666666666, 6: 0.0}
{0: 0.0,
 1: 0.0,
 2: 0.0,
 3: 0.0,
 4: 0.027777777777777776,
 5: 0.2222222222222222,
 6: 0.0}
{0: 0.0,
 1: 0.0,
 2: 0.0,
 3: 0.004629629629629629,
 4: 0.046296296296296294,
 5: 0.25462962962962965,
 6: 0.0}


### Determine when value functions begin to converge

In [7]:
# this helper function has been written for policy evaluation to determine when 
# value functions have begun to converge

def low_difference(value_fn1, value_fn2, epsilon = 1e-10):
    # both must have same keys
    similar = True
    for i in value_fn1:
        if abs(value_fn1[i] - value_fn2[i]) > epsilon:
            similar = False
    return similar # every element must be within epsilon of element of other value fn

## Policy Evaluation

In [8]:
# Use to above functions to get the final value function

def policy_evaluation(val, mdp, pi, epsilon=1e-10, gamma=1.0):
    count = 0
    while True:
        val_new = get_new_value_fn(val, mdp, pi, gamma)
        if low_difference(val_new, val, epsilon):
            return val, count + 1
        else:
            count += 1
            val = val_new
            
# Iteratively caluculate the value function until the difference between the new and old value function 
# is less than epsilon; also return the number of iterations it took to converge

In [9]:
pprint.pprint(policy_evaluation(val, swf_mdp, pi, gamma = 1)) # for swf, take gamma = 1 (as in book)

({0: 0.0,
  1: 0.0027472526825522074,
  2: 0.010989010794909371,
  3: 0.03571428532608248,
  4: 0.10989010930780505,
  5: 0.3324175818352776,
  6: 0.0},
 104)


## Determine action with max expected return for state in q function

In [10]:
# for policy improvement, we need action in q function for which
# expected return after following that action  and thereafter
# following policy pi is maximum

# this helper determines key with max value in dict (in q[state])

def argmax(dict1):
    items = list(dict1.items())
    max_key = items[0][0]
    max_val = items[0][1]
    for (key, val) in items:
        if val > max_val:
            max_key = key
            max_val = val

    return max_key

# print(argmax(policy_evaluation(val, swf_mdp, pi)[0])) # trial

## Policy Improvement

In [11]:
# Perform policy improvement using the policy and the value function and return a new policy, 
# the action value function (q function) should be a nested dictionary

def policy_improvement(val, mdp, gamma=1.0):
    new_pi = dict()
    q = dict()
    # q must be something like
    # {0: {0: val, 1: val}, 1: {0: val, 1: val},...}

    for state in mdp:
        q[state] = dict() # initialization, each value will be a dictionary
        for action in mdp[state]:
            q[state][action] = 0
            # prob_tuples = mdp[state][action]
            for (prob, next_state, reward, isTerminal) in mdp[state][action]:
                q[state][action] += prob * (reward + gamma * val[next_state] * (not isTerminal))
            
    # after q has been made, 
    for state in q:
        new_pi[state] = argmax(q[state]) # q[state] is a dict
        
    return new_pi, q

In [12]:
pprint.pprint(policy_improvement(val, swf_mdp, gamma = 1.0)[0])

{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 1, 6: 0}


## Policy Iteration

In [13]:
# Use the above functions to get the optimal policy and optimal value function 
# and return the total number of iterations it took to converge.

def policy_iteration(mdp, epsilon=1e-10, gamma=1.0):
    
    optimal_pi = {s: 0 for s in mdp}
    optimal_val = {s: 0 for s in mdp} # both pi and val are all zeroes initially
    # as always, it is possible to create a random policy and value function to start with
    count = 0
    
    while True:
        new_pi = policy_improvement(optimal_val, mdp, gamma)[0]
        if new_pi == optimal_pi: # convergence when policy cannot be optimized further
            return optimal_pi, optimal_val, count + 1
        else:
            count += 1
            optimal_pi = new_pi
            optimal_val = policy_evaluation(optimal_val, mdp, optimal_pi, gamma = gamma)[0] # epsilon could be required here

In [14]:
pprint.pprint(policy_iteration(swf_mdp, gamma = 1.0))

({0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 0},
 {0: 0.0,
  1: 0.6675824170597608,
  2: 0.8901098895872336,
  3: 0.9642857139372767,
  4: 0.9890109888367704,
  5: 0.9972527471946744,
  6: 0.0},
 3)


## Value Iteration

In [22]:
# Now perform value iteration, note that the value function is a dictionary and not a list,
# also return the number of iterations it took to converge
def value_iteration(mdp, gamma=1.0, epsilon=1e-10):
    val = {s: 0 for s in mdp}
    count = 0
    
    while True:

        q = dict()

        for state in mdp:
            q[state] = dict()
            for action in mdp[state]:
                q[state][action] = 0
                for (prob, next_state, reward, isTerminal) in mdp[state][action]:
                    q[state][action] += prob * (reward + gamma * val[next_state] * (not isTerminal))

        new_val = {state: max(q[state].values()) for state in mdp}
                    
        if low_difference(val, new_val, epsilon):
            break

        val = new_val.copy()
        count += 1

    pi = {s: 0 for s in mdp}
    for state in mdp:
        pi[state] = argmax(q[state])
        
    return val, pi, count + 1

In [23]:
pprint.pprint(value_iteration(swf_mdp, gamma = 1.0))

({0: 0.0,
  1: 0.6675824169918751,
  2: 0.8901098895193477,
  3: 0.9642857138920196,
  4: 0.9890109888141417,
  5: 0.9972527471871315,
  6: 0.0},
 {0: 0, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 0},
 122)


# Enviroment 2 - Frozen Lake

Repeat the above steps for the frozen lake environment. Don't create new functions , use the old functions.

You can also write a function `test_policy()` to test your policy after training to find the number of times you reached the goal state

In [17]:
env2 = gym.make('FrozenLake-v1')
mdp2 = env2.P
pprint.pprint(mdp2)

# ACTIONS ARE ASSIGNED NUMBERS AS FOLLOWS: (to match fl env in open ai gym)
# "Left" -> 0
# "Down" -> 1
# "Right" -> 2
# "Up" -> 3

{0: {0: [(0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 4, 0.0, False)],
     1: [(0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 4, 0.0, False),
         (0.3333333333333333, 1, 0.0, False)],
     2: [(0.3333333333333333, 4, 0.0, False),
         (0.3333333333333333, 1, 0.0, False),
         (0.3333333333333333, 0, 0.0, False)],
     3: [(0.3333333333333333, 1, 0.0, False),
         (0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 0, 0.0, False)]},
 1: {0: [(0.3333333333333333, 1, 0.0, False),
         (0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 5, 0.0, True)],
     1: [(0.3333333333333333, 0, 0.0, False),
         (0.3333333333333333, 5, 0.0, True),
         (0.3333333333333333, 2, 0.0, False)],
     2: [(0.3333333333333333, 5, 0.0, True),
         (0.3333333333333333, 2, 0.0, False),
         (0.3333333333333333, 1, 0.0, False)],
     3: [(0.3333333333333333,

In [18]:
pi1, val1, count1 = policy_iteration(mdp2, gamma = 0.99)
pprint.pprint(pi1)
pprint.pprint(val1)
pprint.pprint(count1)

{0: 0,
 1: 3,
 2: 3,
 3: 3,
 4: 0,
 5: 0,
 6: 0,
 7: 0,
 8: 3,
 9: 1,
 10: 0,
 11: 0,
 12: 0,
 13: 2,
 14: 1,
 15: 0}
{0: 0.542025930296267,
 1: 0.4988031849539974,
 2: 0.4706956878752193,
 3: 0.45685169676602944,
 4: 0.5584509586627534,
 5: 0.0,
 6: 0.3583480707644934,
 7: 0.0,
 8: 0.5917987435152554,
 9: 0.6430798237640515,
 10: 0.6152075569912918,
 11: 0.0,
 12: 0.0,
 13: 0.7417204382759435,
 14: 0.8628374297788133,
 15: 0.0}
7


In [25]:
pi2, val2, count2 = value_iteration(mdp2, gamma = 0.99)
pprint.pprint(pi2)
pprint.pprint(val2)
pprint.pprint(count2)

{0: 0.5420259303047927,
 1: 0.49880318496538073,
 2: 0.4706956878886318,
 3: 0.45685169678049486,
 4: 0.5584509586706584,
 5: 0.0,
 6: 0.35834807077058933,
 7: 0.0,
 8: 0.5917987435219645,
 9: 0.6430798237690762,
 10: 0.6152075569957234,
 11: 0.0,
 12: 0.0,
 13: 0.7417204382795115,
 14: 0.8628374297806647,
 15: 0.0}
{0: 0,
 1: 3,
 2: 3,
 3: 3,
 4: 0,
 5: 0,
 6: 0,
 7: 0,
 8: 3,
 9: 1,
 10: 0,
 11: 0,
 12: 0,
 13: 2,
 14: 1,
 15: 0}
571


## Testing an optimal policy for frozen lake

In [27]:
# ACTIONS ARE ASSIGNED NUMBERS AS FOLLOWS: (to match fl env in open ai gym)
# "Left" -> 0
# "Down" -> 1
# "Right" -> 2
# "Up" -> 3

terminal_states = [5, 7, 11, 12, 15]

additions_dict = { # dictionary of actions with their corresponding additions to the current state
    0: -1,
    1: 4,
    2: 1,
    3: -4
}
actions_orthogonal_dict = { # for each key, value is set of actions orthogonal to it
    0: (1, 3),
    1: (0, 2),
    2: (1, 3),
    3: (0, 2),
}

def is_terminal(n):
    return n in terminal_states

def reward(n):
    return float(n == 15) # (to match fl env in open ai gym)

def add_change_to_state(cur_state, addition):
    if abs(addition) == 4: # up or down, check only if new state remains inside lake
        if cur_state + addition in range(16):
            return cur_state + addition
        else:
            return cur_state
            
    elif abs(addition) == 1: # right or left, check only if new state is in same row as original state
        if (cur_state // 4) == ((cur_state + addition) // 4):
            return cur_state + addition
        else:
            return cur_state

In [42]:
def test_policy(pi, env, goalstate = 15):
    # Complete this function to test the policy
    successes = 0
    failures = 0
    
    for i in range(10000):
        current_state = 0
        while True:
            chance = random.random()
            move_to_make = pi[current_state]
            if chance < 1/3:
                current_state = add_change_to_state(current_state, additions_dict[move_to_make])
            elif chance < 2/3:
                current_state = add_change_to_state(current_state, additions_dict[actions_orthogonal_dict[move_to_make][0]])
            else:
                current_state = add_change_to_state(current_state, additions_dict[actions_orthogonal_dict[move_to_make][1]])

            if current_state == goalstate:
                successes += 1
                break
                
            elif current_state in terminal_states:
                failures += 1
                break                   
                
    return f"{successes/100} % times it reached end state"

In [43]:
for i in range(10):
    print(test_policy(pi1, mdp2))

82.84 % times it reached end state
83.24 % times it reached end state
82.02 % times it reached end state
81.67 % times it reached end state
82.54 % times it reached end state
81.53 % times it reached end state
82.4 % times it reached end state
82.41 % times it reached end state
81.97 % times it reached end state
81.9 % times it reached end state
