In [4]:
import matplotlib.pyplot as plt
import numpy as np

from custom_classes import CustomBlackjackEnv

env = CustomBlackjackEnv()
starting_state = (13, 2, True)

env.reset(observation=starting_state)

env.render()

Player's hand: ['2', 'Ace'] with sum: 13
Dealer's showing card: 2


In [8]:
target_probability = {state : [1.0, 0.0] if int(state[0] >= 20) else [0.0, 1.0] for state in env.state_space}
behavior_policy = {state : [0.5, 0.5] for state in env.state_space}

returns = []
for episode in range(1_000_000):
    episode_return = 0
    state = env.reset(observation=starting_state)
    done = False
    while not done:
        action = np.random.choice([0, 1], p=behavior_policy[state])
        next_state, reward, done = env.step(action)
        state = next_state
        episode_return += reward
    returns.append(episode_return)

returns = np.array(returns)
print('Average random policy return:', returns.mean())

Average return: -0.17714315018083826


In [24]:
returns = []
for episode in range(500_000):
    episode_return = 0
    state = env.reset(observation=starting_state)
    done = False
    while not done:
        action = np.random.choice([0, 1], p=target_probability[state])
        next_state, reward, done = env.step(action)
        state = next_state
        episode_return += reward
    returns.append(episode_return)

returns = np.array(returns)
print('Average target policy returns:', returns.mean())

Average target policy returns: -0.282934


In [23]:
# Ordinary Importance Sampling
returns = []
for _ in range(500_000):
    episode_return = 0
    state = env.reset(observation=starting_state)
    done = False
    importance_weight = 1
    while not done:
        action = np.random.choice([0, 1], p=behavior_policy[state])
        next_state, reward, done = env.step(action)
        importance_weight *= target_probability[state][action] / behavior_policy[state][action]
        state = next_state
        episode_return += importance_weight * reward
    returns.append(episode_return)

returns = np.array(returns)
print('Average ordinary importance sampling return:', returns.mean())

Average ordinary importance sampling return: -0.276528
