# EX: BLACKJACK

In [1]:
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

In [33]:
class game():
    def __init__(self,HIT=0,STICK=1):
        self.hit=HIT
        self.stick=STICK
        self.actions = [self.hit,self.stick]
        self.policy_player = np.zeros(22, dtype=np.int)
        for i in range(12, 20):
            self.policy_player[i] = self.hit
        self.policy_player[20] = self.stick
        self.policy_player[21] = self.stick
        self.policy_dealer = np.zeros(22)
        for i in range(12, 17):
            self.policy_dealer[i] = self.hit
        for i in range(17, 22):
            self.policy_dealer[i] = self.stick
        
    
    #target policy
    def target_policy(self, ace_player, player_sum, dealer_card):
        return self.policy_player[player_sum]
    
    #behavior policy
    def behavior_policy(self, ace_player, player_sum, dealer_card):
        if np.random.binomial(1, 0.5) == 1:
            return self.stick
        return self.hit       
            
    # get a new card
    def get_card(self):
        card = np.random.randint(1, 14)
        card = min(card, 10)
        return card
    
    # get the value of a card (11 for ace).
    def card_value(self,card_id):
        if card_id == 1:
            return 11
        return card_id
    
    #initialize dealer
    def initialize_dealer(self, dealer_card1, dealer_card2):
        dealer_sum = self.card_value(dealer_card1) + self.card_value(dealer_card2)
        ace_dealer = 1 in (dealer_card1, dealer_card2)
        if dealer_sum > 21:
            assert dealer_sum == 22
            dealer_sum = dealer_sum - 10
        assert dealer_sum <= 21
        assert player_sum <= 21
    
    #start playing the game
    def play(self, policy, initial_state=None, initial_action=None):
        # player stats
        player_sum = 0
        player_moves = []
        ace_player = False

        # dealer stats
        dealer_card1 = 0
        dealer_card2 = 0
        ace_dealer = False
        
        if initial_state is None:
            # generate a random initial state for player
            while player_sum < 12: 
                card = self.get_card()
                player_sum = player_sum + self.card_value(card)
                if player_sum > 21:
                    assert player_sum == 22
                    player_sum = player_sum - 10
                else:
                    ace_player |= (1 == card)
            # initialize cards for dealer
            dealer_card1 = self.get_card()
            dealer_card2 = self.get_card()
        else:
            # use specified initial state for player and dealer
            ace_player, player_sum, dealer_card1 = initial_state
            dealer_card2 = self.get_card()
        state = [ace_player, player_sum, dealer_card1]
        dealer_sum = self.card_value(dealer_card1) + self.card_value(dealer_card2)
        ace_dealer = 1 in (dealer_card1, dealer_card2)
        if dealer_sum > 21:
            assert dealer_sum == 22
            dealer_sum = dealer_sum - 10
        assert dealer_sum <= 21
        assert player_sum <= 21
        # player's turn
        while True:
            if initial_action is not None:
                action = initial_action
                initial_action = None
            else:
                action = self.policy(ace_player, player_sum, dealer_card1)
            # track player's moves
            player_moves.append([(ace_player, player_sum, dealer_card1), action])
            if action == self.stick:
                break
            card = self.get_card()
            ace_count = int(ace_player)
            if card == 1:
                ace_count = ace_count + 1
            player_sum = player_sum + self.card_value(card)
            # to avoid bursts
            while player_sum > 21 and ace_count:
                player_sum = player_sum - 10
                ace_count = ace_count - 1
            # player busts : looses
            if player_sum > 21:
                return state, -1, player_moves
            assert player_sum <= 21
            ace_player = (ace_count == 1)
        # dealer's turn
        while True:
            action = self.policy_dealer[dealer_sum]
            if action == self.stick:
                break
            new_card = self.get_card()
            ace_count = int(ace_dealer)
            if new_card == 1:
                ace_count = ace_count + 1
            dealer_sum = dealer_sum + self.card_value(new_card)
            # avoid bursting
            while dealer_sum > 21 and ace_count:
                dealer_sum = dealer_sum - 10
                ace_count = ace_count - 1
            # dealer busts : looses
            if dealer_sum > 21:
                return state, 1, player_moves
            ace_dealer = (ace_count == 1)
            
        # choses the winner
        assert player_sum <= 21 and dealer_sum <= 21
        if player_sum > dealer_sum:
            return state, 1, player_moves
        elif player_sum == dealer_sum:
            return state, 0, player_moves
        else:
            return state, -1, player_moves


In [31]:
class policies():
    def __init__(self):
        self.g = game()
    
    #Greedy Policy
    def behavior_policy(self, ace, player_sum, dealer_card):
        ace = int(ace)
        player_sum =player_sum - 12
        dealer_card = dealer_card - 1
        # get argmax of the average returns(s, a)
        values_ = state_action_values[player_sum, dealer_card, ace, :] / \
                  state_action_pair_count[player_sum, dealer_card, ace, :]
        return np.random.choice([action_ for action_, value_ in enumerate(values_) if value_ == np.max(values_)])
    
    #Policy 1: on policy
    def mc_on_policy(self,episodes):
        states_ace = np.zeros((10, 10))
        states_ace_count = np.ones((10, 10))
        states_no_ace = np.zeros((10, 10))
        states_no_ace_count = np.ones((10, 10))
        for i in tqdm(range(0, episodes)):
            _, reward, player_moves = self.g.play(self.g.target_policy)
            for (ace, player_sum, dealer_card), _ in player_moves:
                player_sum = player_sum - 12
                dealer_card = dealer_card - 1
                if ace:
                    states_ace_count[player_sum, dealer_card] += 1
                    states_ace[player_sum, dealer_card] += reward
                else:
                    states_no_ace_count[player_sum, dealer_card] += 1
                    states_no_ace[player_sum, dealer_card] += reward
        return states_ace / states_ace_count, states_no_ace / states_no_ace_count
    
    # Policy 2: Exploring Starts
    def mc_es(self,episodes):
        state_action_values = np.zeros((10, 10, 2, 2))
        state_action_pair_count = np.ones((10, 10, 2, 2))
        for episode in tqdm(range(episodes)):
            initial_state = [bool(np.random.choice([0, 1])), np.random.choice(range(12, 22)), np.random.choice(range(1, 11))]
            initial_action = np.random.choice(self.g.actions)
            if episode:
                current_policy = behavior_policy 
            else:
                current_policy = target_policy
            _, reward, moves = self.g.play(current_policy, initial_state, initial_action)
            for (ace, player_sum, dealer_card), action in moves:
                ace = int(ace)
                player_sum = player_sum - 12
                dealer_card = dealer_card - 1
                # update values of state-action pairs
                state_action_values[player_sum, dealer_card, ace, action] = state_action_values[player_sum, dealer_card, ace, action] + reward
                state_action_pair_count[player_sum, dealer_card, ace, action] = state_action_pair_count[player_sum, dealer_card, ace, action] + 1
        return state_action_values / state_action_pair_count
    
    # Policy 3: Off-Policy
    def mc_off_policy(self,episodes):
        initial_state = [True, 13, 2]
        rhos = []     #importance sampling ratio
        returns = []
        for i in range(0, episodes):
            _, reward, player_moves = self.g.play(behavior_policy_player, initial_state=initial_state)
            # get the importance ratio
            numerator = 1.0
            denominator = 1.0
            for (ace, player_sum, dealer_card), action in player_moves:
                if action == target_policy_player(ace, player_sum, dealer_card):
                    denominator = denominator * 0.5
                else:
                    numerator = 0.0
                    break
            rho = numerator / denominator
            rhos.append(rho)
            returns.append(reward)
        rhos = np.asarray(rhos)
        returns = np.asarray(returns)
        weighted_returns = rhos * returns
        weighted_returns = np.add.accumulate(weighted_returns)
        rhos = np.add.accumulate(rhos)
        ordinary_sampling = weighted_returns / np.arange(1, episodes + 1)
        with np.errstate(divide='ignore',invalid='ignore'):
            weighted_sampling = np.where(rhos != 0, weighted_returns / rhos, 0)
        return ordinary_sampling, weighted_sampling

In [27]:
class figure_plotting():
    def __init__(self):
        self.p = policies()
        
    def plot(self,state,title,axis,name):
        for state, title, axis in zip(states, titles, axes):
            fig = sns.heatmap(np.flipud(state), cmap="YlGnBu", ax=axis, xticklabels=range(1, 11),yticklabels=list(reversed(range(12, 22))))
            fig.set_ylabel('player sum', fontsize=30)
            fig.set_xlabel('dealer showing', fontsize=30)
            fig.set_title(title, fontsize=30)
        plt.savefig(name)
        plt.close()

    def figure_5_1(self):
        states_ace_1, states_no_ace_1 = self.p.mc_on_policy(10000)
        states_ace_2, states_no_ace_2 = self.p.mc_on_policy(500000)
        states = [states_ace_1, states_ace_2, states_no_ace_1, states_no_ace_2]
        titles = ['Usable Ace, 10000 Episodes', 'Usable Ace, 500000 Episodes', 'No Usable Ace, 10000 Episodes', 'No Usable Ace, 500000 Episodes']
        _, axes = plt.subplots(2, 2, figsize=(40, 30))
        plt.subplots_adjust(wspace=0.1, hspace=0.2)
        axes = axes.flatten()
        self.plot(state,title,axis,'figure_5_1.png')
        
    def figure_5_2(self):
        state_action_values = self.p.mc_es(500000)
        state_value_no_ace = np.max(state_action_values[:, :, 0, :], axis=-1)
        state_value_ace = np.max(state_action_values[:, :, 1, :], axis=-1)
        # get the optimal policy
        action_no_ace = np.argmax(state_action_values[:, :, 0, :], axis=-1)
        action_ace = np.argmax(state_action_values[:, :, 1, :], axis=-1)
        images = [action_ace, state_value_ace, action_no_ace, state_value_no_ace]
        titles = ['Optimal policy with usable Ace', 'Optimal value with usable Ace', 'Optimal policy without usable Ace', 'Optimal value without usable Ace']
        _, axes = plt.subplots(2, 2, figsize=(40, 30))
        plt.subplots_adjust(wspace=0.1, hspace=0.2)
        axes = axes.flatten()
        self.plot(state,title,axis,'figure_5_2.png')

    def figure_5_3(self):
        true_value = -0.27726
        episodes = 10000
        runs = 100
        error_ordinary = np.zeros(episodes)
        error_weighted = np.zeros(episodes)
        for i in tqdm(range(0, runs)):
            ordinary_sampling_, weighted_sampling_ = self.p.mc_off_policy(episodes)
            # get the squared error
            error_ordinary += np.power(ordinary_sampling_ - true_value, 2)
            error_weighted += np.power(weighted_sampling_ - true_value, 2)
        error_ordinary /= runs
        error_weighted /= runs
        plt.plot(error_ordinary, label='Ordinary Importance Sampling')
        plt.plot(error_weighted, label='Weighted Importance Sampling')
        plt.xlabel('Episodes (log scale)')
        plt.ylabel('Mean square error')
        plt.xscale('log')
        plt.legend()
        plt.savefig('../images/figure_5_3.png')
        plt.close()

In [39]:
if __name__ == '__main__':
    f = figures()
    f.figure_5_1()
    f.figure_5_2()
    f.figure_5_3()

100%|██████████| 10000/10000 [00:00<00:00, 73529.33it/s]
100%|██████████| 500000/500000 [00:07<00:00, 65172.15it/s]
100%|██████████| 500000/500000 [00:29<00:00, 16857.75it/s]
100%|██████████| 100/100 [00:19<00:00,  5.14it/s]
