In [1]:
import gym
from matplotlib import pyplot
import matplotlib.pyplot as plt

In [2]:
from collections import defaultdict
from functools import partial

In [3]:
%matplotlib inline
plt.style.use('ggplot')

In [4]:
env = gym.make('Blackjack-v0')

[2018-10-19 20:53:31,027] Making new env: Blackjack-v0


In [5]:
# blackjack 
# rewards
# +1 if player wins game
# -1 if player loses game
# 0 if game is a draw

In [6]:
def sample_policy(observation):
    '''
    This function is an input to and used in generate_episode()
    '''
    score, dealer_score, usable_ace = observation
    if score >= 20:
        return(0)
    else:
        return(1)

In [10]:
my_policy = sample_policy

In [8]:
def generate_episode(policy, env): 
    '''
    This function is used in the first_visit_mc_prediction()
    This function is called again and again for each episode generated.
    '''
    
    states, actions, rewards = [], [], []
    observation = env.reset()
    while True:
        states.append(observation)  # append observations to states list
        action = policy(observation) 
        actions.append(action)
        observation, reward, done, info = env.step(action)
        rewards.append(reward)
        if done: # break during terminal state
            break
    return(states, actions, rewards)
        
        

In [16]:
generate_episode(policy=my_policy, env=env)

([(8, 10, False), (19, 10, True), (17, 10, False)], [1, 1, 1], [0, 0, -1])

In [17]:
def first_visit_mc_prediction(policy, env, n_episodes):
    value_table = defaultdict(float)
    N = defaultdict(int)
    
    for _ in range(n_episodes):
        states, _, rewards = generate_episode(policy, env)
        returns = 0
        for t in range(len(states) -1, -1, -1): # for all states - get rewards and states + and return Rewards
            R = rewards[t]
            S = states[t]
            returns += R
            
            if S not in states[:t]:
                N[S] += 1
                value_table[S] += (returns - value_table[S]) / N[S]
    
    return(value_table)

In [19]:
first_visit_mc_prediction(policy=my_policy, env=env, n_episodes=10)

defaultdict(float,
            {(21, 2, False): 1.0,
             (17, 2, False): 1.0,
             (19, 8, False): -1.0,
             (18, 8, False): -1.0,
             (16, 8, False): -1.0,
             (16, 10, False): -1.0,
             (20, 5, False): 0.0,
             (21, 10, True): 1.0,
             (20, 10, False): 1.0,
             (16, 4, False): -1.0,
             (6, 4, False): -1.0,
             (21, 2, True): 1.0,
             (17, 2, True): 1.0,
             (21, 9, False): 1.0,
             (11, 9, False): 1.0,
             (15, 9, False): -1.0,
             (5, 9, False): -1.0})