In [75]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from collections import defaultdict

import sys
if "../" not in sys.path:
  sys.path.append("../")

import gym
import plotting
from lib.envs.gridworld import GridworldEnv
import itertools
matplotlib.style.use('ggplot')

In [76]:
env = GridworldEnv()

In [51]:
def random_policy(state, nA):
    A = np.ones(nA, dtype=float) / nA
    return A

In [45]:
# env.observation_space
state_space_size = 70

## Online Tabular TD( λ )

In [59]:
def online_td_lambda(env, num_episodes, discount=1.0, epsilon=0.1, alpha=0.5, lbda=0.5, gamma=0.5, debug=False):
    
    V = defaultdict(float)
    
    for i_episode in range(1, num_episodes+1):
        
        if debug:
            if i_episode % 100 == 0:
                print("\rEpisode {}/{}.".format(i_episode, num_episodes))
                
        E = {key:0 for key in np.arange(state_space_size)}
        state = env.reset()
        for t in itertools.count():
            action_probs = random_policy(state, env.nA)
            action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
            next_state, reward, done, _ = env.step(action)
            
            delta = reward + gamma * V[next_state] - V[state]
            E[state] += 1
            
            for s in E.keys():
                V[s] += alpha*delta*E[s]
                E[s] = gamma*lbda*E[s]
                    
            if done:
                break
                
            state = next_state
    
    return V

In [62]:
V = online_td_lambda(env, num_episodes=200, debug=True)

Episode 100/200.
Episode 200/200.


In [65]:
V

defaultdict(float,
            {0: -1.9999999999999998,
             1: -1.9999999999999991,
             2: -1.9999999999999971,
             3: -1.999999999999994,
             4: -1.9999999999999827,
             5: -1.9999999999995695,
             6: -1.999999999970696,
             7: -1.9999999998057143,
             8: -1.9999999983563992,
             9: -1.9999999950599563,
             10: -1.9999999999999953,
             11: -1.9999999999999951,
             12: -1.9999999999999922,
             13: -1.9999999999999432,
             14: -1.9999999999997848,
             15: -1.9999999999701155,
             16: -1.999999999328964,
             17: -1.9999997964549732,
             18: -1.999999998239193,
             19: -1.9999673585262316,
             20: -1.9999999999999365,
             21: -1.9999999999999833,
             22: -1.9999999999999816,
             23: -1.999999999999962,
             24: -1.999999999997015,
             25: -1.9999999998958584,
         