In [1]:
import numpy as np
from collections import defaultdict
from cliff_walking import CliffWalkingEnv

In [2]:
env = CliffWalkingEnv()
nA = env.action_space.n
epsilon = 0.1
gamma = 1.0
alpha=0.1

In [3]:
def get_epision_greedy_action_policy(Q,observation):
    
    A = np.ones(nA, dtype=float) * epsilon / nA
    best_action = np.argmax(Q[observation])
    A[best_action] += (1.0 - epsilon)
    
    return A

In [4]:
def qlearning(total_episodes):
    
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    
    for k in range(total_episodes):
        
        current_state = env.reset()
        
        while True:
            
            prob_scores = get_epision_greedy_action_policy(Q,current_state)
            current_action = np.random.choice(np.arange(nA), p=prob_scores)
            
            next_state, reward, done, _ = env.step(current_action)
            
            best_next_action = np.argmax(Q[next_state])
            
            td_target = reward + gamma * Q[next_state][best_next_action]
            td_error = td_target - Q[current_state][current_action]
            
            Q[current_state][current_action] = Q[current_state][current_action] + alpha * td_error
    
            if done:
                break
                
            current_state = next_state 
            
    return Q

In [5]:
Q = qlearning(100)

In [6]:
Q

defaultdict(<function __main__.qlearning.<locals>.<lambda>>,
            {0: array([-5.74115812, -5.76274022, -5.81423673, -5.78061167]),
             1: array([-5.56057005, -5.54730542, -5.5527676 , -5.62917289]),
             2: array([-5.27868193, -5.2359664 , -5.29809639, -5.34491866]),
             3: array([-4.87912312, -4.90269976, -4.92935982, -4.95863317]),
             4: array([-4.58543447, -4.60895259, -4.62320193, -4.55497123]),
             5: array([-4.37347797, -4.26681055, -4.31319584, -4.25999066]),
             6: array([-3.96786757, -3.92798133, -3.93544779, -3.87951782]),
             7: array([-3.58468709, -3.59280577, -3.60657531, -3.54166157]),
             8: array([-3.3       , -3.23168297, -3.29963524, -3.29871575]),
             9: array([-2.88119342, -2.8324609 , -2.85587904, -2.96546845]),
             10: array([-2.48706185, -2.48456809, -2.49887629, -2.55260755]),
             11: array([-2.18553672, -2.2       , -2.12554231, -2.3343677 ]),
             