## BlackJack: Off-policy n-step Sarsa algorithm for state-action value update
---
* dearler policy: HIT 17

In [23]:
import numpy as np
import pickle

In [24]:
class BlackJack_off_sarsa(object):
    
    def __init__(self, lr=0.1, exp_rate=0.3):
        self.player_Q_Values = {}  # key: [(player_value, show_card, usable_ace)][action] -> [state][action]
        # initialise Q values | (12-21) x (1-10) x (True, False) x (1, 0) 400 in total
        # when player's sum is in range (2,11) there is no risk to bust, automatically pick
        for i in range(12, 22):
            for j in range(1, 12):
                for k in [True, False]:
                    self.player_Q_Values[(i, j, k)] = {}
                    for a in [1, 0]:    # two actions: draw card or don't draw
                        self.player_Q_Values[(i, j, k)][a] = 0
        
        self.player_state_action = []
        self.rewards = []
        self.state = (0, 0, False)  # initial state
        self.actions = [1, 0]  # 1: HIT -> draw  0: STAND -> don't draw
        self.end = False
        self.lr = lr
        self.exp_rate = exp_rate
    
    # give card
    @staticmethod
    def giveCard():
        # 1 stands for ace
        c_list = list(range(1, 11)) + [10, 10, 10]
        return np.random.choice(c_list)
    
    
    def dealerPolicy(self, current_value, usable_ace, is_end):
        """_summary_

        Args:
            current_value (_type_): 
            usable_ace (_type_):
            is_end (bool): 

        Returns:
            state: the next state of dealer
        """
        if current_value > 21:
            if usable_ace:
                current_value -= 10
                usable_ace = False
            else:
                return current_value, usable_ace, True  # having ace and bust -> game ends
        # HIT17, curr_val >= 17 then we stick
        if current_value >= 17:
            return current_value, usable_ace, True
        else:
            card = self.giveCard()
            if card == 1:
                if current_value <= 10:
                    return current_value+11, True, False
                return current_value+1, usable_ace, False
            else:
                return current_value+card, usable_ace, False
            
    def chooseAction(self):
        """_summary_

        Returns:
            0 or 1: the action to perform at this state
        """
        # if current value <= 11, always hit
        current_value = self.state[0]
        if current_value <= 11:
            return 1
        
        # epsilon greedy
        if np.random.uniform(0, 1) <= self.exp_rate:
            action = np.random.choice(self.actions)
        else:
            # greedy action
            v = -999
            action = 0
            for a in self.player_Q_Values[self.state]:
                if self.player_Q_Values[self.state][a] > v:
                    action = a
                    v = self.player_Q_Values[self.state][a]
        return action
            
    # one can only has 1 usable ace 
    # return next state
    def playerNxtState(self, action):
        """_summary_

        Args:
            action (0 or 1): the action of the current state

        get the next state of player
        """
        current_value = self.state[0]
        show_card = self.state[1]
        usable_ace = self.state[2]
        
        if current_value > 21:
            if usable_ace:
                current_value -= 10
                usable_ace = False
            else:
                # should not reach here
                self.end = True
                self.state = (current_value, show_card, usable_ace)
                return
        if action:
            card = self.giveCard()
            if card == 1:
                if current_value <= 10:
                    current_value += 11
                    usable_ace = True
                else:
                    current_value += 1
            else:
                current_value += card
        else:
            # action stand
            self.end = True
        
        if current_value > 21:
            self.end = True
        self.state = (current_value, show_card, usable_ace)
    
    def target_policy(self, state, action):
        """ Returns 1 if the action is the greedy action, otherwise returns 0. """
        best_action = max(self.player_Q_Values[state], key=self.player_Q_Values[state].get)
        return 1 if action == best_action else 0

    def behavior_policy(self, state, action):
        """ Returns the probability of taking an action under the behavior policy, typically ε-greedy. """
        best_action = max(self.player_Q_Values[state], key=self.player_Q_Values[state].get)
        num_actions = len(self.player_Q_Values[state])
        epsilon = 0.1
        return 1 - epsilon + epsilon / num_actions if action == best_action else epsilon / num_actions

            
    def _giveCredit(self,player_value, dealer_value, is_end = True, n=3, gamma=0.9):
        # Initial reward calculation based on game outcome
        reward = 0
        if is_end:
            if player_value > 21:
                if dealer_value > 21:
                    reward = 0  # draw
                else:
                    reward = -1
            else:
                if dealer_value > 21:
                    reward = 1
                elif player_value < dealer_value:
                    reward = -1
                elif player_value > dealer_value:
                    reward = 1
                else:
                    reward = 0  # draw
        self.rewards.append(reward)  # Append the final reward to the rewards list
        T = len(self.player_state_action) - 1  # Time of the last state-action pair
        G = 0  # Initialize G for the n-step accumulated return
        W = 1  # Importance sampling ratio
        
        for timestep in range(T+1):
            tau = timestep - n + 1
            if tau >= 0:
                G = 0
                for i in range(tau, min(tau + n, T + 1)):
                    G += gamma ** (i - tau) * self.rewards[i]
                if tau + n <= T:
                    s = self.player_state_action[tau + n][0]
                    a = self.player_state_action[tau + n][1]
                    G += gamma ** n * self.player_Q_Values[s][a]
                    
                    
                state, action = self.player_state_action[tau][:2]  # State and action at time tau
                
                # Update the importance sampling ratio W
                pi = self.target_policy(state, action)  # Probability under target policy
                b = self.behavior_policy(state, action)  # Probability under behavior policy
                if b == 0:
                    break  # If b is 0, break to avoid division by zero
                W *= pi / b  # Update W with the ratio of the probabilities
                
                # update Q value
                self.player_Q_Values[state][action] += self.lr * W * (G - self.player_Q_Values[state][action])
            
            
    def reset(self):
        self.player_state_action = []
        self.state = (0, 0, False)  # initial state
        self.end = False
    
    def play(self, rounds=1000):
        wins = np.zeros(2)      # the first entry is player win#, the second is dealer win#
        for i in range(rounds):
            if i % 1000 == 0:
                print("round", i)
            # hit 2 cards each
            dealer_value, player_value = 0, 0
            d_usable_ace, p_usable_ace = False, False
            show_card = 0

            # give dealer 2 cards and show 1
            # card 1
            card = self.giveCard()
            if card == 1:
                card = 11
                d_usable_ace = True
            
            show_card = card
            dealer_value += card
            self.state = (player_value, show_card, p_usable_ace)
            # card 2
            card = self.giveCard()
            if card == 1:
                if dealer_value == 11:
                    dealer_value += 1
                else:
                    dealer_value += 11
                    d_usable_ace = True
            else:
                dealer_value += card
            
            # player's turn
            # player gets 2 cards
            card = self.giveCard()
            if card == 1:
                player_value += 11
                p_usable_ace = True
            else:
                player_value += card
            self.state = (player_value, show_card, p_usable_ace)
            
            card = self.giveCard()
            if card == 1:
                if player_value == 11:
                    player_value += 1
                else:
                    player_value += 11
                    p_usable_ace = True
            else:
                player_value += card
            self.state = (player_value, show_card, p_usable_ace)
            
            # judge winner after 2 cards
            if player_value == 21 or dealer_value == 21:
                if player_value == 21:
                    wins[0] += 1
                if dealer_value == 21:
                    wins[1] += 1
                # game end
                print("reach 21 in 2 cards: player value {} | dealer value {}".format(player_value, dealer_value))
            else:   
                while True:
                    action = self.chooseAction()
                    # print("current value {}, action {}".format(self.state[0], action))
                    if self.state[0] >= 12:
                        self.player_state_action.append([self.state, action])
                    # update next state
                    self.playerNxtState(action)
                    if self.end:
                        break    

                # dealer's turn
                is_end = False
                while not is_end:
                    dealer_value, d_usable_ace, is_end = self.dealerPolicy(dealer_value, d_usable_ace, is_end)

                # judge winner
                # give reward and update Q value
                player_value = self.state[0]
                print("player value {} | dealer value {}".format(player_value, dealer_value))
                self._giveCredit(player_value, dealer_value)
            self.reset()
        print("player wins: ",wins[0], ", dealer wins: ", wins[1])
            
    def savePolicy(self, file="policy_sarsa"):
        fw = open(file, 'wb')
        pickle.dump(self.player_Q_Values, fw)
        fw.close()

    def loadPolicy(self, file="policy_sarsa"):
        fr = open(file,'rb')
        self.player_Q_Values = pickle.load(fr)
        fr.close()
        
    def playWithDealer(self, rounds=1000):
        """_summary_

        Args:
            rounds (int, optional): Defaults to 1000.

        Returns:
            results: the number of (wins, draws, losses)
        """
        self.reset()
        self.loadPolicy()
        self.exp_rate = 0   # set exploration to 0, meaning we only exploit the learned policy
        
        result = np.zeros(3)  # player [win, draw, lose]
        for _ in range(rounds):
            # hit 2 cards each
            dealer_value, player_value = 0, 0
            d_usable_ace, p_usable_ace = False, False
            show_card = 0

            # give dealer 2 cards and show 1
            # card 1
            card = self.giveCard()
            if card == 1:
                card = 11
                d_usable_ace = True
            
            show_card = card
            dealer_value += card
            self.state = (player_value, show_card, p_usable_ace)
            # card 2
            card = self.giveCard()
            if card == 1:
                if dealer_value == 11:
                    dealer_value += 1
                else:
                    dealer_value += 11
                    d_usable_ace = True
            else:
                dealer_value += card
            
            # player's turn
            # player gets 2 cards
            card = self.giveCard()
            if card == 1:
                player_value += 11
                p_usable_ace = True
            else:
                player_value += card
            self.state = (player_value, show_card, p_usable_ace)
            
            card = self.giveCard()
            if card == 1:
                if player_value == 11:
                    player_value += 1
                else:
                    player_value += 11
                    p_usable_ace = True
            else:
                player_value += card
            self.state = (player_value, show_card, p_usable_ace)
            
            # judge winner after 2 cards
            if player_value == 21 or dealer_value == 21:
                if player_value == dealer_value:
                    result[1] += 1
                elif player_value > dealer_value:
                    result[0] += 1
                else:
                    result[2] += 1
            else:
                # player's turn
                while True:
                    action = self.chooseAction()
                    # update next state
                    self.playerNxtState(action)
                    if self.end:
                        break    

                # dealer's turn
                is_end = False
                while not is_end:
                    dealer_value, d_usable_ace, is_end = self.dealerPolicy(dealer_value, d_usable_ace, is_end)

                # judge
                player_value = self.state[0]
                # print("player value {} | dealer value {}".format(player_value, dealer_value))
                if player_value > 21:
                    if dealer_value > 21:
                        # draw
                        result[1] += 1
                    else:
                        result[2] += 1
                else:
                    if dealer_value > 21:
                        result[0] += 1
                    else:
                        if player_value < dealer_value:
                            result[2] += 1
                        elif player_value > dealer_value:
                            result[0] += 1
                        else:
                            # draw
                            result[1] += 1
            self.reset()
        return result

In [25]:
b = BlackJack_off_sarsa()
b.play(1000)

round 0
player value 23 | dealer value 20
player value 30 | dealer value 18
player value 22 | dealer value 18
player value 24 | dealer value 25
player value 27 | dealer value 19
player value 25 | dealer value 20
player value 19 | dealer value 20
player value 28 | dealer value 23
player value 24 | dealer value 20
player value 29 | dealer value 18
player value 26 | dealer value 25
player value 29 | dealer value 24
player value 21 | dealer value 18
reach 21 in 2 cards: player value 18 | dealer value 21
player value 19 | dealer value 18
player value 25 | dealer value 20
player value 27 | dealer value 22
player value 20 | dealer value 20
player value 22 | dealer value 18
player value 23 | dealer value 22
player value 27 | dealer value 20
player value 29 | dealer value 17
player value 15 | dealer value 17
player value 12 | dealer value 20
player value 30 | dealer value 21
reach 21 in 2 cards: player value 20 | dealer value 21
player value 13 | dealer value 23
player value 25 | dealer value 2

In [26]:
b.savePolicy()

In [27]:
a = b.playWithDealer(rounds=1000)
print("player win:{}, draw:{}, loose:{}".format(a[0],a[1],a[2]))

player win:125.0, draw:179.0, loose:696.0
