In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from collections import defaultdict

import sys
if "../" not in sys.path:
  sys.path.append("../")

import gym
import plotting
from lib.envs.gridworld import GridworldEnv
import itertools
matplotlib.style.use('ggplot')

In [2]:
env = GridworldEnv()

In [10]:
env.render()

T  o  o  o
o  o  o  o
o  o  x  o
o  o  o  T
[4, 4]


In [8]:
#Create an initial epsilon soft policy
def epsilon_greedy_policy(Q, epsilon, state, nA):
    A = np.ones(nA, dtype=float) * epsilon / nA
    best_action = np.argmax(Q[state])
    A[best_action] += (1.0 - epsilon)
    return A

In [45]:
def online_sarsa_lambda(env, num_episodes, discount=1.0, epsilon=0.1, alpha=0.5, lbda=0.9, debug=False):
    
    Q = defaultdict(lambda: np.zeros(env.action_space.n, dtype=float))
    
    for i_episode in range(1, num_episodes+1):
        
        if debug:
            if i_episode % 1000 == 0:
                print("\rEpisode {}/{}.".format(i_episode, num_episodes))
                
        E = {key:np.zeros(4, dtype=int) for key in np.arange(16)}
        state = env.reset()
        action_probs = epsilon_greedy_policy(Q, epsilon, state, env.nA)
        action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
        for t in itertools.count():
            next_state, reward, end, _ = env.step(action)
            
            next_action_probs = epsilon_greedy_policy(Q, epsilon, next_state, env.nA)
            next_action = np.random.choice(np.arange(len(next_action_probs)), p=next_action_probs)
            
            delta = reward + discount*Q[next_state][next_action] - Q[state][action]
            E[state][action] += 1
            
            for s in E.keys():
                for a in E[s]:
                    Q[s][a] += alpha*delta*E[s][a]
                    E[s][a] = discount*lbda*E[s][a]
                    
            if end:
                break
                
            state = next_state
            action = next_action
            
    return Q    

In [46]:
Q = online_sarsa_lambda(env, num_episodes=10000, debug=True)

Episode 1000/10000.
Episode 2000/10000.
Episode 3000/10000.
Episode 4000/10000.
Episode 5000/10000.
Episode 6000/10000.
Episode 7000/10000.
Episode 8000/10000.
Episode 9000/10000.
Episode 10000/10000.


In [47]:
Q

defaultdict(<function __main__.<lambda>>,
            {0: array([ 0.,  0.,  0.,  0.]),
             1: array([-1.00334249, -3.24963391, -3.42187521,  0.        ]),
             2: array([-3.24970931, -4.25767707, -2.2578125 , -2.24963379]),
             3: array([-4.26403069, -4.27333278, -4.0234375 , -3.2578125 ]),
             4: array([-1.        , -2.4777315 , -5.27484555, -1.5       ]),
             5: array([-1.00083547, -4.00062179, -3.8125    , -2.953125  ]),
             6: array([-3.24962602, -4.38325268, -3.90625   , -2.99707031]),
             7: array([-4.40579551, -4.10019576, -3.76989174, -3.0234375 ]),
             8: array([-2.00004227, -4.36602511, -4.6146492 , -6.06161064]),
             9: array([-3.96138778, -3.30563219, -3.8996314 , -5.78905885]),
             10: array([-3.93345338, -1.75675408, -4.1645967 , -8.73046327]),
             11: array([-4.069583  , -1.75012255, -0.75      , -3.00000286]),
             12: array([-3.06552368, -4.15733446, -5.31050544, -

## Run the Optimal Policy

In [56]:
state = env.reset()
print env.render()
print '#################################'
while(True):
    action = np.argmax(Q[state])
    next_state, reward, done, _ = env.step(action)
    state = next_state
    
    print env.render()
    print '#################################'
    
    if done:
        break
    

T  o  o  o
o  o  o  o
x  o  o  o
o  o  o  T
None
#################################
T  o  o  o
x  o  o  o
o  o  o  o
o  o  o  T
None
#################################
x  o  o  o
o  o  o  o
o  o  o  o
o  o  o  T
None
#################################
