In [1]:
import numpy as np
from collections import defaultdict
from blackjack import BlackjackEnv

In [2]:
env = BlackjackEnv()

In [3]:
def get_action_policy(state):
    
    score, dealer_score, usable_ace = state
    
    return 0 if score >= 20 else 1

In [4]:

def mc_prediction_blackjack(total_episodes):
   
    returns_sum = defaultdict(float)
    states_count = defaultdict(float)
    
    V = defaultdict(float)
    
    for k in range(1, total_episodes + 1):
        
        episode = generate_episode()

        states_in_episode = list(set([sar[0] for sar in episode])) # sar--> state,action,reward
        
        for i,state in enumerate(states_in_episode):
            
            G = sum([sar[2] for i,sar in enumerate(episode[i:])])
            
            # for stationary problems 
            returns_sum[state] += G
            states_count[state] += 1.0         
            V[state] = returns_sum[state] / states_count[state]
            # end updating V
            
            #                    OR
            # V[state] = V[state]+ 1/states_count[state]*(G-V[state])
            
            # for non stationary problems 
            #alpha=0.5
            #V[state] = V[state]+ alpha*(G-V[state])
            

    return V

In [5]:
def generate_episode():
    episode = []
    current_state = env.reset()
    
    while(True):
        action = get_action_policy(current_state) # 0 or 1
        next_state, reward, done, _ = env.step(action)
        episode.append((current_state, action, reward))
        if done:
            break
        current_state = next_state
        
    return episode


In [6]:
mc_prediction_blackjack(10000)

defaultdict(float,
            {(12, 1, False): -0.7244897959183674,
             (12, 1, True): -0.16666666666666666,
             (12, 2, False): -0.5,
             (12, 2, True): -0.14285714285714285,
             (12, 3, False): -0.6813186813186813,
             (12, 3, True): -1.0,
             (12, 4, False): -0.379746835443038,
             (12, 4, True): -0.625,
             (12, 5, False): -0.6436781609195402,
             (12, 5, True): -1.0,
             (12, 6, False): -0.6161616161616161,
             (12, 6, True): 0.16666666666666666,
             (12, 7, False): -0.5052631578947369,
             (12, 7, True): 1.0,
             (12, 8, False): -0.717391304347826,
             (12, 8, True): 0.0,
             (12, 9, False): -0.38636363636363635,
             (12, 9, True): 0.5,
             (12, 10, False): -0.5421052631578948,
             (12, 10, True): -0.10526315789473684,
             (13, 1, False): -0.7142857142857143,
             (13, 1, True): -0.454545454545

State value function tells us How good it it to be in that state