In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gym
from tqdm import tqdm

np.random.seed(42)

# Blackjack environment

https://github.com/openai/gym/blob/master/gym/envs/toy_text/blackjack.py

<img src="https://www.blackjack.org/wp-content/uploads/2018/12/Blackjack-values.png">

In [None]:
_ = gym.make('Blackjack-v0')  # makes sure BlackJackEnv is imported


class BlackJack(gym.envs.toy_text.BlackjackEnv):
    """
    The observation space of the blackjack environment is a 3-tuple containing 
    the following information: (sum of players hand, dealer his showing card, usable ace)
    action space contains just actions: (0: stick, 1: hit)
    
    The player can request additional cards (hit) until they decide to stop
    (stick) or exceed 21 (bust).
    
    The dealer draws cards until their sum is 17 or greater.

    If neither player nor dealer busts, the outcome (win, lose, draw) is
    decided by whose sum is closer to 21. The reward for winning is +1,
    drawing is 0, and losing is -1.
    """
    
    def _get_obs(self):
        sum_hand, dealer_card, usable_ace = super()._get_obs()
        return (sum_hand, dealer_card, int(usable_ace))

# Agent class

In [None]:
class Agent:
    
    def __init__(self, observation_space, action_space, discount=1., alpha=0.01):
        self.q_values = np.zeros([space.n for space in observation_space] + [action_space.n])
        self.policy = self.initialize_policy(observation_space)
        self.discount = discount
        self.alpha = alpha
        
    def initialize_policy(self, observation_space):
        """Initial policy is to hit whenever the sum in the hand is 19 or less,
        regardless of what the dealer showing card is.
        """
        policy = np.zeros([space.n for space in observation_space], dtype='int16')
        policy[:20,:,:] = 1
        return policy
    
    def get_action(self, sum_hand, dealer_card, usable_ace):
        return self.policy[sum_hand, dealer_card, usable_ace]
    
    def estimate_q_values_and_update_policy(self, episode_history):
        # calculate the returns for the visited states
        G = 0
        returns = dict()
        # replay the episode backwards
        for obs, a, r in episode_history[::-1]:
            # obs is tuple with: ('sum of players hand', 'dealer his showing card', 'usable ace')
            G = self.discount * G + r
            returns[obs+(a,)] = G
        # update q value estimates and update the policy
        for k,G in returns.items():
            # k is tuple with: ('sum of players hand', 'dealer his showing card', 'usable ace', 'action taken')
            self.q_values[k] = self.q_values[k] + self.alpha * (G - self.q_values[k])
            self.policy[k[:3]] = np.argmax(self.q_values[k[:3]])

    def render_policy(self):
        fig = plt.figure(figsize=(10, 8), constrained_layout=True)
        spec = fig.add_gridspec(ncols=2, nrows=2, width_ratios=[1, 1], height_ratios=[5, 3])
        ax1 = fig.add_subplot(spec[:, :-1])
        ax2 = fig.add_subplot(spec[:-1, 1:])

        sns.heatmap(self.policy[4:22,1:11,0], cbar=False, linewidths=1, linecolor='grey', ax=ax1)
        ax1.invert_yaxis()
        ax1.set_title("No Usable Ace", fontsize=18)
        ax1.set_xlabel("Dealer Showing", fontsize=16)
        ax1.set_xticklabels(['A'] + list(range(2, 11)), fontsize=14)
        ax1.set_ylabel("Sum Hand", fontsize=16)
        ax1.set_yticklabels(range(4, 22), fontsize=14)
        ax1.text(5, 3, "HIT", ha='center', va='center', fontsize=56)
        ax1.text(5, 16.5, "STICK", ha='center', va='center', fontsize=56, color='white')

        sns.heatmap(self.policy[12:22,1:11,1], cbar=False, linewidths=1, linecolor='grey', ax=ax2)
        ax2.invert_yaxis()
        ax2.set_title("Usable Ace", fontsize=18)
        ax2.set_xlabel("Dealer Showing", fontsize=16)
        ax2.set_xticklabels(['A'] + list(range(2, 11)), fontsize=14)
        ax2.set_ylabel("Sum Hand", fontsize=16)
        ax2.set_yticklabels(range(12, 22), fontsize=14)
        ax2.text(5, 2.5, "HIT", ha='center', va='center', fontsize=56)
        ax2.text(5, 8.5, "STICK", ha='center', va='center', fontsize=56, color='white');

    def render_q_values(self, hit_or_stick):
        """
        Args:
            hit_or_stick (int): 0=stick, 1=hit
        """
        action = 'HIT' if hit_or_stick else 'STICK'
        fig = plt.figure(figsize=(10, 8), constrained_layout=True)
        spec = fig.add_gridspec(ncols=2, nrows=2, width_ratios=[1, 1], height_ratios=[5, 3])
        ax1 = fig.add_subplot(spec[:, :-1])
        ax2 = fig.add_subplot(spec[:-1, 1:])
        fig.text(0.5, 1.05, "Action values for {}".format(action), fontsize=20, ha='center', va='center')

        sns.heatmap(self.q_values[4:22-hit_or_stick,1:11,0,hit_or_stick], vmin=-1, vmax=1, cmap='coolwarm', cbar=True, linewidths=1, linecolor='grey', ax=ax1)
        ax1.invert_yaxis()
        ax1.set_title("No Usable Ace", fontsize=18)
        ax1.set_xlabel("Dealer Showing", fontsize=16)
        ax1.set_xticklabels(['A'] + list(range(2, 11)), fontsize=14)
        ax1.set_ylabel("Sum Hand", fontsize=16)
        ax1.set_yticklabels(range(4, 22-hit_or_stick), fontsize=14)

        sns.heatmap(self.q_values[12:22,1:11,1,hit_or_stick], vmin=-1, vmax=1, cmap='coolwarm', cbar=False, linewidths=1, linecolor='grey', ax=ax2)
        ax2.invert_yaxis()
        ax2.set_title("Usable Ace", fontsize=18)
        ax2.set_xlabel("Dealer Showing", fontsize=16)
        ax2.set_xticklabels(['A'] + list(range(2, 11)), fontsize=14)
        ax2.set_ylabel("Sum Hand", fontsize=16)
        ax2.set_yticklabels(range(12, 22), fontsize=14);

# Function to run one episode

In [None]:
def episode():
    last_observation = env.reset()
    sum_hand, dealer_card, usable_ace = last_observation
    last_action = env.action_space.sample()  # initial action is random (exploring starts)
    terminal = False
    history = []
    while not terminal:
        observation, reward, terminal, info = env.step(last_action)
        sum_hand, dealer_card, usable_ace = observation
        history += [(last_observation, last_action, reward)]
        last_observation = observation
        last_action = agent.get_action(sum_hand, dealer_card, usable_ace)  # next action based on policy
    return history

# Setup environment and agent and run

In [None]:
env = BlackJack()
agent = Agent(env.observation_space, env.action_space, alpha=0.001)

In [None]:
for i in tqdm(range(2500000)):
    history = episode()
    agent.estimate_q_values_and_update_policy(history)

# Visualize policy

In [None]:
agent.render_policy()

# Visualize action values (q-values)

In [None]:
agent.render_q_values(0)

In [None]:
agent.render_q_values(1)