In [60]:
import gym
import numpy as np
from collections import defaultdict

In [61]:
def best_policy(Q):
    
    return dict((state,np.argmax(best_action)) for state,best_action in Q.items())

In [62]:
def best_value(Q):
    
    return dict((state,np.max(action)) for state,action in Q.items())

In [63]:
env = gym.make('NChain-v0')

In [64]:
def get_action_prob(Q_s,state,nA,eps):
    
    action_prob = np.ones(nA)*(eps/nA)
    best_action = np.argmax(Q_s)
    action_prob[best_action] = (1-eps) + (eps/nA)
    return action_prob

In [65]:
def get_trajectory(Q,eps,nA):
    
    state = env.reset()
    trajectory = []
    
    while True:
        
        action_prob = get_action_prob(Q[state],state,nA,eps) 
        action = np.random.choice(np.arange(nA),p=action_prob) if state in Q else env.action_space.sample()
        
        next_state,reward,done,_ = env.step(action)
        trajectory.append((state,action,reward))
        state = next_state
        if done:
            break
        
    return trajectory

In [66]:
def update_Q(tra,Q,alpha):
    
    states,actions,rewards = zip(*tra)
    
    for i,state in enumerate(states):
        
        old_Q = Q[state][actions[i]]
        Q[state][actions[i]] = old_Q + alpha*(sum(rewards[i:]) - old_Q)
        
    return Q

In [67]:
def MC_Control(alpha,eps_start,eps_min,eps_decay,Q,nA,iters):
    
    eps = eps_start
    for _ in range(iters):
        
        eps = max(eps_min,eps*eps_decay)
        tra = get_trajectory(Q,eps,nA)
        Q = update_Q(tra,Q,alpha)
        policy = best_policy(Q)
    return Q,policy

In [68]:
nA = env.action_space.n
Q = defaultdict(lambda: np.zeros(nA))
eps_start = 1
eps_min = 0.05
eps_decay = 0.995
iters = 10
alpha = 0.02

Q,policy = MC_Control(alpha,eps_start,eps_min,eps_decay,Q,nA,iters)
V = best_value(Q)

In [77]:
print(Q)
print(V)
print(policy)

defaultdict(<function <lambda> at 0x7ff90b65ba60>, {0: array([255.00305524, 260.56271811]), 1: array([448.31021838, 387.06572951]), 2: array([524.46928803, 532.24413373]), 3: array([585.47985549, 559.34196663]), 4: array([557.21829031, 548.04194945])})
{0: 260.56271811467764, 1: 448.31021837503255, 2: 532.2441337266217, 3: 585.4798554926188, 4: 557.2182903123618}
{0: 1, 1: 0, 2: 1, 3: 0, 4: 0}
