In [None]:
import numpy as np 
from scipy.special import softmax 
import time 

from IPython.display import clear_output
%matplotlib inline
%config InlineBackend.figure_format='retina'

import matplotlib.pyplot as plt 
import seaborn as sns 

from utils.env import frozen_lake
from utils.viz import viz 
viz.get_style()

In [None]:
seed = 1234 
env = frozen_lake(seed=seed)
env.reset()
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
env.render(ax)

In [None]:
done = False
acts = [3, 3, 3, 1, 1, 1]
for a in acts:
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))
    clear_output(True)
    env.render(ax)
    plt.show()
    if done: break
    _, _, done =env.step(a)
    time.sleep(.1)
    

## Have a look at the environment 

Actions: 

* 0: up
* 1: down
* 2: left
* 3: right

In [None]:
## check transition function
# check p_trans of a surface
env.p_s_next(s=1, a=2).round(2)

In [None]:
# check p trans of a hole
env.p_s_next(s=19, a=2).round(2)

In [None]:
# check reward function of a surface, hole, and goal 
env.r(2), env.r(19), env.r(63), 

## Policy evaluation 

In [None]:
# a random policy 
rng = np.random.RandomState(1234)
pi_rand = softmax(rng.rand(env.nS, env.nA)*5, axis=1)
pi_rand

In [None]:
def policy_eval(pi, V, env, theta=1e-4, gamma=.99):

    # loop until convergence
    while True: 
        delta = 0
        for s in env.S:
            if s not in env.s_termination:
                v_old = V[s].copy()
                v_new = 0
                for a in env.A:
                    p = env.p_s_next(s, a)
                    for s_next in env.S:
                        r, done = env.r(s_next)
                        v_new += pi[s, a]*p[s_next]*(r + (1-done)*gamma*V[s_next])
                V[s] = v_new 
                # check convergence
                delta = np.max([delta, np.abs(v_new - v_old)])
        
        if delta < theta:
            break 
    
    return V

In [None]:
# initialize V(s), arbitrarily except V(terminal)=0
V = rng.rand(env.nS) * 0.001
# except v(terminal) = 0
for s in env.s_termination:
    V[s] = 0
v1 = policy_eval(pi_rand, V, env)

## Policy iteration 

In [None]:
def policy_improve(pi, V, env, theta=1e-4, gamma=.99):
    pi_old = pi.copy()
    for s in env.S:
        q = np.zeros([env.nA])
        for a in env.A:
            p = env.p_s_next(s, a)
            for s_next in env.S:
                r, done = env.r(s_next)
                q[a] += p[s_next]*(r + (1-done)*gamma*V[s_next])
        pi[s] = np.eye(env.nA)[np.argmax(q)]
    
    # check stable
    stable = (np.abs(pi - pi_old) < theta).all()

    return pi, stable  

In [None]:
def policy_iter(env, seed=1234):

    rng = np.random.RandomState(seed)

    # initialize V(s), arbitrarily except V(terminal)=0
    V = rng.rand(env.nS) * 0.001
    # except v(terminal) = 0
    for s in env.s_termination:
        V[s] = 0
    # initialize π(s), arbitrarily
    pi = softmax(rng.rand(env.nS, env.nA)*5, axis=1)

    while True: 

        V = policy_eval(pi, V, env)
        pi, stable = policy_improve(pi, V, env)
        if stable: break 

    return V, pi 

In [None]:
V1, pi1 = policy_iter(env)
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
ax = axs[0]
env.show_v(ax, V1)
ax = axs[1]
env.show_pi(ax, pi1)

## Value iteration

In [None]:
def value_iter(env, seed=1234, theta=1e-4, gamma=.99):
    
    rng = np.random.RandomState(seed)
    # initialize V(s), arbitrarily except V(terminal)=0
    V = rng.rand(env.nS) * 0.001
    # except v(terminal) = 0
    for s in env.s_termination:
        V[s] = 0
    # init policy 
    pi = np.zeros([env.nS, env.nA])
    # loop until converge
    while True:
        delta = 0
        for s in env.S:
            v_old = V[s].copy()
            q = np.zeros([env.nA])
            for a in env.A:
                p = env.p_s_next(s, a)
                for s_next in env.S:
                    r, done = env.r(s_next)
                    q[a] += p[s_next]*(r + (1-done)*gamma*V[s_next])
            V[s] = np.max(q)
            pi[s] = np.eye(env.nA)[np.argmax(q)]
            delta = np.max([delta, np.abs(V[s] - v_old)])

        if delta < theta:
            break 
    for s in env.s_termination:
        V[s] = 0
    return V, pi 
        

In [None]:
V2, pi2 = value_iter(env)
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
ax = axs[0]
env.show_v(ax, V2)
ax = axs[1]
env.show_pi(ax, pi2)

## TD learning, Q learning 

In [None]:
def e_greedy(q, rng, env, eps):
    a_max = np.argwhere(q==np.max(q)).flatten()
    policy = np.sum([np.eye(env.nA)[i] for i in a_max], axis=0) / len(a_max)
    if rng.rand() < 1-eps:
        a = rng.choice(env.nA, p=policy)
    else:
        a = rng.choice(env.nA)
    return a 

In [None]:
def Q_learning(env, alpha=.1, eps=.1, gamma=.99, max_epi=2000, seed=1234, theta=1e-4):
    # rng
    rng = np.random.RandomState(seed)
    # initialize Q
    Q = np.zeros([env.nS, env.nA])
    for epi in range(max_epi):
        s, r, done = env.reset()
        t = 0 
        q_old = Q.copy()
        G = 0
        while True:
            # sample At, observe Rt, St+1
            a = e_greedy(Q[s, :], rng, env, eps)
            # a = rng.choice(env.A, p=pi)
            s_next, r, done = env.step(a)
            Q_tar = r + gamma*(1-done)*(Q[s_next, :]).max()
            Q[s, a] += alpha*(Q_tar - Q[s, a])
            s = s_next 
            t += 1
            G += r
            
            if done:
                break 

            # if epi > 1400:
            #     Pi = np.eye(env.nA)[np.argmax(Q, axis=1)]
            #     V = Q.max(1)
            #     fig, axs = plt.subplots(1, 3, figsize=(11, 4))
            #     clear_output(True)
            #     ax = axs[0]
            #     env.render(ax)
            #     ax = axs[1]
            #     env.show_v(ax, V)
            #     ax = axs[2]
            #     env.show_pi(ax, Pi)
            #     time.sleep(.05)
            #     plt.show()
            
        if (np.abs(q_old - Q)<theta).all():
            break
    pi = np.eye(env.nA)[np.argmax(Q, axis=1)]
    return Q, pi

In [None]:
Q, pi3 = Q_learning(env)
V3 = Q.max(1)
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
ax = axs[0]
env.show_v(ax, V3)
ax = axs[1]
env.show_pi(ax, pi3)