In [1]:
import numpy as np
import itertools
import random
import gym
from libr.envs.cliff_walking2 import CliffWalkingEnv

In [2]:
env = CliffWalkingEnv()

In [3]:
def SARSAJC(env, numEpisodes, discount, epsilon):
    '''This is the main function for computing on-policy (epsilon-soft policy) TD0 method.
    Inputs - (i) env: OpenAI gym environment
            (ii) numEpisodes: number of episodes
            (iii) discount: the discount factor for infinite horizon discounted DP
            (iv) epsilon: small non-negative scalar for epsilon-greedy policy
    Outputs - Q-Value, greedy policy
    For example, see Figure 6.9 of Sutton and Barto for the pseudocode
    '''
    
    #initialize
    Q = np.zeros([env.observation_space.n,env.action_space.n])
    alpha = 0.1 #learning rate
    rList=[]
    greedyPolicy=np.ones([env.observation_space.n,1]) #Greedy policy for all states
    
    def epsilonGreedyPolicy(state,Q,epsilon):
        if np.random.rand(1) > epsilon:
            a = np.argmax(Q[state])
        else:
            a = env.action_space.sample()
        return a
    
    for ep in range(numEpisodes):
        currState = env.reset()
        rTotalFromEpisode = 0
        for _ in itertools.count():
            currAction = epsilonGreedyPolicy(currState,Q,epsilon)
            
            #generate a sample
            nextState, reward, done, _ = env.step(currAction)
            rTotalFromEpisode += reward
            
            nextAction = epsilonGreedyPolicy(nextState,Q,epsilon)
            
            TDerror = reward + discount*Q[nextState,nextAction] - Q[currState,currAction]
            
            Q[currState,currAction] += alpha*TDerror
            
            currState = nextState
            currAction = nextAction
            
            if done:
                break
        rList.append(rTotalFromEpisode)
    greedyPolicy=np.argmax(Q,1)
    return Q, greedyPolicy
    

In [4]:
SARSAJC(env,500,0.9,0.19)

(array([[ -7.98570681,  -7.95481193,  -7.96394024,  -7.96682797],
        [ -7.79127369,  -7.75645655,  -7.84970941,  -7.77303129],
        [ -7.52693416,  -7.51619767,  -7.53929678,  -7.66302969],
        [ -7.36990402,  -7.22983043,  -7.34271855,  -7.32953773],
        [ -6.99499234,  -6.9085327 ,  -7.12908227,  -7.04285069],
        [ -6.68187263,  -6.49547232,  -6.72007108,  -6.8006355 ],
        [ -6.37351918,  -6.10960526,  -6.41727626,  -6.37785343],
        [ -5.9057624 ,  -5.74829066,  -6.07608476,  -6.22246027],
        [ -5.4001074 ,  -5.09313305,  -5.65025183,  -5.70836598],
        [ -4.72383366,  -4.54010338,  -4.45029982,  -5.33776962],
        [ -4.26938418,  -3.90482763,  -3.92045034,  -4.58563933],
        [ -3.67842924,  -3.57495916,  -3.18799883,  -4.08372023],
        [ -8.1188387 ,  -8.25061689,  -9.42801943,  -8.12960335],
        [ -7.92856753,  -7.92632477,  -9.66384328,  -7.94341048],
        [ -7.71184263,  -7.71670511,  -8.09616318,  -7.71660227],
        [ 