In [1]:
import numpy as np
import itertools
import random
import gym
from libr.envs.cliff_walking2 import CliffWalkingEnv

In [2]:
env = CliffWalkingEnv()

In [3]:
def qLearningJC(env, numEpisodes, discount, epsilon):
    '''This is the main function for computing off-policy TD0 method.
    Inputs - (i) env: OpenAI gym environment
            (ii) numEpisodes: number of episodes
            (iii) discount: the discount factor for infinite horizon discounted DP
            (iv) epsilon: small non-negative scalar for epsilon-greedy policy
    Outputs - Q-Value, greedy policy
    For example, see Figure 6.12 of Sutton and Barto for the pseudocode
    '''
    
    #initialize
    Q = np.zeros([env.observation_space.n,env.action_space.n])
    alpha = 0.1 #learning rate
    rList=[]
    greedyPolicy=np.ones([env.observation_space.n,1]) #Greedy policy for all states
    
    def epsilonGreedyPolicy(state,Q,epsilon):
        if np.random.rand(1) > epsilon:
            a = np.argmax(Q[state])
        else:
            a = env.action_space.sample()
        return a
    
    for ep in range(numEpisodes):
        currState = env.reset()
        rTotalFromEpisode = 0
        for _ in itertools.count():
            currAction = epsilonGreedyPolicy(currState,Q,epsilon)
            
            #generate a sample
            nextState, reward, done, _ = env.step(currAction)
            rTotalFromEpisode += reward
            
            TDerror = reward + discount*np.max(Q[nextState]) - Q[currState,currAction]
            Q[currState,currAction] += alpha*TDerror
            
            if done:
                break
                
            currState = nextState
        rList.append(rTotalFromEpisode)
    greedyPolicy = np.max(Q,1)
    
    return Q, greedyPolicy
            

In [4]:
qLearningJC(env,500,0.9,0.19)

(array([[ -6.78084729,  -6.79271205,  -6.847377  ,  -6.80140967],
        [ -6.64112771,  -6.63280667,  -6.65629485,  -6.62592065],
        [ -6.41692215,  -6.4062033 ,  -6.43778504,  -6.3949237 ],
        [ -6.16774649,  -6.14710661,  -6.14973546,  -6.14389942],
        [ -5.86744489,  -5.83826708,  -5.85272877,  -5.86460827],
        [ -5.50353125,  -5.49876119,  -5.51801434,  -5.53084717],
        [ -5.17732488,  -5.11387806,  -5.12162024,  -5.17894744],
        [ -4.68367885,  -4.67948671,  -4.71229246,  -4.75522376],
        [ -4.20667678,  -4.2125883 ,  -4.24236576,  -4.34850579],
        [ -3.74350846,  -3.71169308,  -3.72735318,  -3.7731774 ],
        [ -3.16707181,  -3.20551517,  -3.18686672,  -3.23550135],
        [ -2.65607002,  -2.66406101,  -2.65343296,  -2.7077994 ],
        [ -6.9532841 ,  -6.95203168,  -6.96013575,  -6.95440869],
        [ -6.7342367 ,  -6.73267406,  -6.7699865 ,  -6.76244921],
        [ -6.48172116,  -6.47986877,  -6.49111461,  -6.49121863],
        [ 