In [1]:
import numpy as np
#from itertools import permutations, repeat, product
import itertools
import gym
from libr.envs.cliff_walking2 import CliffWalkingEnv

In [2]:
env = CliffWalkingEnv()

In [3]:
def TDlambdaJC(env, numEpisodes, discount, lam):
    '''This is the main function for computing TD(lambda) policy evaluation method. I have used a random policy 
    to be evaluated.
    Inputs - (i) env: OpenAI gym environment
            (ii) numEpisodes: number of episodes
            (iii) discount: the discount factor for infinite horizon discounted DP
            (iv) lam: a scalar between 0 and 1, used as a parameter for TD(lambda)
    Outputs - Value
    For example, see Figure 7.7 of Sutton and Barto for the pseudocode
    '''
    #initialize
    V = np.zeros(env.observation_space.n)
    e = np.zeros_like(V) # eligibility trace vector
    alpha = 0.1 #learning rate
    rList=[]
    
    def randomPolicy(state):
        a=env.action_space.sample()
        return a
    
    for ep in range(numEpisodes):
        currState = env.reset()
        rTotalFromEpisode = 0
        for _ in itertools.count():
            currAction = randomPolicy(currState)
            
            #generate a sample
            nextState, reward, done, _ = env.step(currAction)            
            rTotalFromEpisode += reward
            
            TDerror = reward + discount*V[nextState] - V[currState]
            e[currState] += 1
            
            for state in range(env.observation_space.n):
                V[state] += alpha*TDerror*e[state]
                e[state] *= discount*lam
            
            if done:
                break
            currState = nextState
            
        rList.append(rTotalFromEpisode)
    return V
            
    

In [4]:
TDlambdaJC(env,500,0.9,0.7)

array([-27.5159222 , -30.68113286, -28.57320245, -26.94329109,
       -24.4708182 , -23.40134546, -20.21310359, -15.80273216,
       -13.03014196,  -3.57068323,   0.        ,   0.        ,
       -43.65890909, -35.64556172, -33.98397795, -35.29458492,
       -34.42900525, -41.36900042, -37.77329352, -32.83665137,
       -19.63939271,  -6.62936479,   0.        ,   0.        ,
       -58.42530351, -68.86284483, -63.84273792, -66.46671154,
       -67.00492011, -52.40826697, -51.67939691, -44.45007282,
       -16.81733887, -11.30936941,   0.        ,   0.        ,
       -73.94460568,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ,
         0.        ,   0.        ,   0.        ,   0.        ])