In [1]:
import numpy as np
import itertools
import random
import gym
from libr.envs.cliff_walking2 import CliffWalkingEnv

In [2]:
env = CliffWalkingEnv()

In [3]:
def SARSAJC(env, numEpisodes, discount, epsilon):
    '''This is the main function for computing on-policy (epsilon-soft policy) TD0 method.
    Inputs - (i) env: OpenAI gym environment
            (ii) numEpisodes: number of episodes
            (iii) discount: the discount factor for infinite horizon discounted DP
            (iv) epsilon: small non-negative scalar for epsilon-greedy policy
    Outputs - Q-Value, greedy policy
    For example, see Figure 6.2 of Sutton and Barto for the pseudocode
    '''
    
    #initialize
    Q = np.zeros([env.observation_space.n,env.action_space.n])
    alpha = 0.1 #learning rate
    rList=[]
    greedyPolicy=np.ones([env.observation_space.n,1]) #Greedy policy for all states
    
    def epsilonGreedyPolicy(state,Q,epsilon):
        if np.random.rand(1) > epsilon:
            a = np.argmax(Q[state])
        else:
            a = env.action_space.sample()
        return a
    
    for ep in range(numEpisodes):
        currState = env.reset()
        rTotalFromEpisode = 0
        for _ in itertools.count():
            currAction = epsilonGreedyPolicy(currState,Q,epsilon)
            
            #generate a sample
            nextState, reward, done, _ = env.step(currAction)
            rTotalFromEpisode += reward
            
            nextAction = epsilonGreedyPolicy(nextState,Q,epsilon)
            
            TDerror = reward + discount*Q[nextState,nextAction] - Q[currState,currAction]
            
            Q[currState,currAction] += alpha*TDerror
            
            currState = nextState
            currAction = nextAction
            
            if done:
                break
        rList.append(rTotalFromEpisode)
    greedyPolicy=np.argmax(Q,1)
    return Q, greedyPolicy
    

In [4]:
SARSAJC(env,500,0.9,0.19)

(array([[ -8.03836915,  -7.97451196,  -8.09624163,  -7.99308364],
        [ -7.8040971 ,  -7.7249689 ,  -7.75125782,  -7.90957757],
        [ -7.51329041,  -7.46128581,  -7.8405537 ,  -7.47658116],
        [ -7.34476925,  -7.17189708,  -7.35319043,  -7.39649645],
        [ -6.9795931 ,  -6.84916436,  -7.00558608,  -7.12301216],
        [ -6.7145006 ,  -6.51850674,  -6.5926254 ,  -6.58422688],
        [ -6.10739868,  -6.07670468,  -6.10532918,  -6.43824652],
        [ -5.84995133,  -5.61423528,  -5.73410054,  -5.98414357],
        [ -5.31944721,  -5.1205522 ,  -5.17726678,  -5.67979084],
        [ -4.67343603,  -4.59992402,  -4.58488182,  -5.16832218],
        [ -4.15228212,  -4.01290535,  -4.06922602,  -4.3294205 ],
        [ -3.5469889 ,  -3.4982771 ,  -3.31997252,  -3.96645757],
        [ -8.16402694,  -8.24350604,  -8.62181519,  -8.16324956],
        [ -7.87999185,  -7.8763601 , -10.33867033,  -7.90135403],
        [ -7.32854992,  -7.28456785, -10.56935775,  -7.37183941],
        [ 