In [1]:
import random
import torch
import torch.nn as nn
from copy import deepcopy
from collections import deque
from blackjack import Game
from utils import get_game_state, get_game_state_sparse

In [2]:
class Player(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Player, self).__init__()

        self.stack_1 = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU()
        )

        self.stack_2 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU()
        )

        self.stack_3 = nn.Linear(hidden_size // 2, 2)

    def forward(self, x):
        x = self.stack_1(x)
        x = self.stack_2(x)
        x = self.stack_3(x)
        return x

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate = 5e-4
model = Player(input_size=6, hidden_size=256).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [4]:
target_network = deepcopy(model)  # target network for stability during learning
memory = deque(maxlen=25_000)  # memory buffer for experience replay
batch_size = 128
gamma = 0.99  # bellman equation constant
epsilon = 0.8  # base exploration chance

In [5]:
def choose_action(e, state) -> int:
    """
    Choose an action according to the epsilon-greedy policy. 0 for hit, 1 for stay
    ---------------------------
    Parameters:
        e (float): the exploration chance
        state (torch.tensor): the current state of the game
    """
    if random.random() < e:
        return random.choice([0, 1])

    state = state.to(device)
    output = model(state)
    return torch.argmax(output).item()


def train():
    if len(memory) < batch_size:
        return

    batch = random.sample(memory, batch_size)
    states, actions, rewards, dones, next_states = zip(*batch)

    states = torch.stack(states).to(device)
    actions = torch.tensor(actions, dtype=torch.int64, device=device).reshape(batch_size)
    rewards = torch.tensor(rewards, dtype=torch.float32).reshape(batch_size).to(device)  # Convert rewards to a tensor
    dones = torch.tensor(dones, dtype=torch.float32).reshape(batch_size).to(device)  # Convert dones to a tensor

    q_values = model(states)
    predicted_q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)  # get the q values for the corresponding actions taken

    next_q_values = torch.zeros(batch_size, device=device)

    non_terminal_mask = (dones == 0)  # Mask for non-terminal states
    non_terminal_next_states = [s for s in next_states if s is not None]

    if non_terminal_next_states:
        non_terminal_next_states = torch.stack(non_terminal_next_states).to(device)
        next_q_values[non_terminal_mask] = torch.max(target_network(non_terminal_next_states), dim=1).values

    target_q_values = rewards + gamma * next_q_values * (1 - dones)  # target q values according to bellman equation

    loss = criterion(target_q_values, predicted_q_values)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

win = 0
recent_wins = 0
for episode in range(1, 500_001):
    if episode % 100 == 0:
        epsilon = max(epsilon * 0.99, 0.05)  # decay the exploration chance

    game = Game(five_card_charlie=True)  # each episode starts a new game; the deck is NOT re-used.
    while not game.is_over:
        state = get_game_state(game)
        action = choose_action(epsilon, state)

        if action == 0:
            game.take_action("H")
        else:
            game.take_action("S")

        reward = game.score
        done = int(game.is_over)

        if done:
            next_state = None
        else:
            next_state = get_game_state(game)

        memory.append((state, action, reward, done, next_state))

    if game.score == 1:
        win += 1
        recent_wins += 1

    train()

    if episode % 1000 == 0:
        target_network.load_state_dict(model.state_dict())
        print(f"Episode: {episode}, Lifetime winrate: {win / episode:.4f}, Recent Winrate: {recent_wins / 1000:.4f}")
        recent_wins = 0

    if episode % 50_000 == 0:
        learning_rate /= 2
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Episode: 1000, Lifetime winrate: 0.3040, Recent Winrate: 0.3040
Episode: 2000, Lifetime winrate: 0.3130, Recent Winrate: 0.3220
Episode: 3000, Lifetime winrate: 0.3343, Recent Winrate: 0.3770
Episode: 4000, Lifetime winrate: 0.3350, Recent Winrate: 0.3370
Episode: 5000, Lifetime winrate: 0.3412, Recent Winrate: 0.3660
Episode: 6000, Lifetime winrate: 0.3387, Recent Winrate: 0.3260
Episode: 7000, Lifetime winrate: 0.3450, Recent Winrate: 0.3830
Episode: 8000, Lifetime winrate: 0.3459, Recent Winrate: 0.3520
Episode: 9000, Lifetime winrate: 0.3457, Recent Winrate: 0.3440
Episode: 10000, Lifetime winrate: 0.3467, Recent Winrate: 0.3560
Episode: 11000, Lifetime winrate: 0.3488, Recent Winrate: 0.3700
Episode: 12000, Lifetime winrate: 0.3494, Recent Winrate: 0.3560
Episode: 13000, Lifetime winrate: 0.3518, Recent Winrate: 0.3810
Episode: 14000, Lifetime winrate: 0.3510, Recent Winrate: 0.3400
Episode: 15000, Lifetime winrate: 0.3525, Recent Winrate: 0.3730
Episode: 16000, Lifetime winrate: 

In [6]:
torch.save(model.state_dict(), "bj_player_dense_big.pth")

In [4]:
model.load_state_dict(torch.load("bj_player_dense_big.pth", weights_only=True))

<All keys matched successfully>

In [7]:
model.eval()
wins = 0
num_games = 10_000
for _ in range(num_games):
    game = Game(five_card_charlie=True)
    # print("=============================")
    # print("Player hand:", *game.player_hand, "\nDealer hand:", *game.dealer_hand)
    while not game.is_over:
        state = get_game_state(game).to(device)
        output = model(state)
        # print(output)
        action = torch.argmax(output).item()
        if action == 0:
            # print("Bot hits")
            game.take_action("H")
        else:
            # print("Bot stays")
            game.take_action("S")
        # print("-----------------------------")
        # print("Player hand:", *game.player_hand, "\nDealer hand:", *game.dealer_hand)

    # print("Score:", game.score)
    wins += 1 if game.score == 1 else 0

print(f"Win rate: {100 * wins / num_games:.4f}%")

Win rate: 42.4600%
