In [1]:
import numpy as np
import sys
from collections import defaultdict
from blackjack import BlackjackEnv

In [2]:
env = BlackjackEnv()

In [3]:
epsilon=0.1
nA=2

In [4]:
def get_epision_greedy_action_policy(Q,observation):
    
    A = np.ones(nA, dtype=float) * epsilon / nA
    best_action = np.argmax(Q[observation])
    A[best_action] += (1.0 - epsilon)
    
    return A

In [5]:
def mc_control_epsilon_greedy(total_episodes):
  
    returns_sum = defaultdict(float)
    states_count = defaultdict(float)
    
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    
    for k in range(total_episodes):
        
        episode = generate_episode(Q)
        
        state_actions_in_episode = list(set([(sar[0], sar[1]) for sar in episode]))
        
        for i,sa_pair in enumerate(state_actions_in_episode):
            state, action = sa_pair
    
            G = sum([sar[2] for i,sar in enumerate(episode[i:])])
            
            returns_sum[sa_pair] += G
            states_count[sa_pair] += 1.0
            Q[state][action] = returns_sum[sa_pair] / states_count[sa_pair]
        
        
    return Q

In [6]:
def generate_episode(Q):
    episode = []
    current_state = env.reset()
    
    while(True):
        
        prob_scores = get_epision_greedy_action_policy(Q,current_state)
        action = np.random.choice(np.arange(len(prob_scores)), p=prob_scores) #0 or 1
        
        next_state, reward, done, _ = env.step(action)
        episode.append((current_state, action, reward))
        if done:
            break
        current_state = next_state    
        
    return episode

In [7]:
mc_control_epsilon_greedy(50000)

defaultdict(<function __main__.mc_control_epsilon_greedy.<locals>.<lambda>>,
            {(12, 1, False): array([-0.7826087, -0.5907781]),
             (12, 1, True): array([-0.84615385, -1.        ]),
             (12, 2, False): array([-0.65217391, -0.28678304]),
             (12, 2, True): array([-0.03448276, -0.33333333]),
             (12, 3, False): array([-0.26666667, -0.45      ]),
             (12, 3, True): array([-0.41176471,  0.2       ]),
             (12, 4, False): array([-0.37931034, -0.16795866]),
             (12, 4, True): array([0., 0.]),
             (12, 5, False): array([-0.09511568, -0.44444444]),
             (12, 5, True): array([-0.33333333,  0.31578947]),
             (12, 6, False): array([-0.09137056, -0.2       ]),
             (12, 6, True): array([-1. ,  0.4]),
             (12, 7, False): array([-0.54545455, -0.27777778]),
             (12, 7, True): array([-0.4, -1. ]),
             (12, 8, False): array([-0.65517241, -0.42093023]),
             (12, 

Action value function tells us how good is it to take that action