In [1]:
#%pip install gymnasium
#%pip install torch

import gymnasium as gym
import numpy as np
import random

import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import torch

print("CUDA available:", torch.cuda.is_available())

CUDA available: False


In [2]:
env = gym.make("Blackjack-v1", natural=True, sab=False)
num_episodes = 200_000
gamma = 0.99 #współczynnik znaczenia przyszłej nagrody
actor_lr = 1e-4 #tempo uczenia aktorzyny
critic_lr = 5e-4 #tempo uczenia niekrytego krytyka
entropy_beta = 0.01
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def state_to_tensor(state):

    player_sum, dealer_card, usable_ace = state
    x = np.array([
        player_sum / 32.0,
        dealer_card / 10.0,
        float(usable_ace)
    ], dtype=np.float32)

    return torch.tensor(x, dtype=torch.float32, device=device)


In [4]:
class Actor(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=128, output_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        return self.net(x)


class Critic(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        return self.net(x).view(-1)



In [5]:
actor = Actor().to(device)
critic = Critic().to(device)

actor_optimizer = optim.Adam(actor.parameters(), lr=actor_lr)
critic_optimizer = optim.Adam(critic.parameters(), lr=critic_lr)

In [6]:
def a2c_select_action(state):
    state_tensor = state_to_tensor(state)
    probs = actor(state_tensor)
    dist = Categorical(probs=probs)
    action = dist.sample()

    value = critic(state_tensor)

    return action.item(), dist.log_prob(action), value

In [7]:
def a2c_greedy_action(state):
    with torch.no_grad():
        state_tensor = state_to_tensor(state)
        probs = actor(state_tensor)
        action = torch.argmax(probs).item()
    return action

In [8]:
def basic_strategy(state):
    player_sum, dealer_card, usable_ace = state
    if player_sum >= 17:
        return 0 
    else:
        return 1

In [9]:
def evaluate_policy(policy_fn, n_games=100_000):
    wins = 0
    losses = 0
    draws = 0

    for _ in range(n_games):
        state, _ = env.reset()
        done = False

        while not done:
            action = policy_fn(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            state = next_state

        if reward > 0:
            wins += 1
        elif reward < 0:
            losses += 1
        else:
            draws += 1

    return wins, losses, draws

In [10]:
episode_rewards_history = []

for episode in range(1, num_episodes + 1):
    state, _ = env.reset()
    done = False
    ep_reward = 0.0

    while not done:
        s_tensor = state_to_tensor(state)

        probs = actor(s_tensor)             
        dist = Categorical(probs)
        action = dist.sample()               
        log_prob = dist.log_prob(action)     
        value = critic(s_tensor)             

        next_state, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        ep_reward += reward

        if done:
            next_value = torch.tensor(0.0, device=device)
        else:
            ns_tensor = state_to_tensor(next_state)
            with torch.no_grad():
                next_value = critic(ns_tensor)

        td_target = reward + gamma * next_value
        advantage = td_target - value

        actor_loss = -(log_prob * advantage.detach())   
        critic_loss = advantage.pow(2)                  

        loss = actor_loss + critic_loss

        actor_optimizer.zero_grad()
        critic_optimizer.zero_grad()
        loss.backward()
        actor_optimizer.step()
        critic_optimizer.step()

        state = next_state

    episode_rewards_history.append(ep_reward)

    if episode % 10_000 == 0:
        avg_reward = np.mean(episode_rewards_history[-10_000:])
        print(f"Episode {episode}, średnia nagroda z ostatnich 10k epizodów: {avg_reward:.3f}")


Episode 10000, średnia nagroda z ostatnich 10k epizodów: -0.177
Episode 20000, średnia nagroda z ostatnich 10k epizodów: -0.155
Episode 30000, średnia nagroda z ostatnich 10k epizodów: -0.154
Episode 40000, średnia nagroda z ostatnich 10k epizodów: -0.167
Episode 50000, średnia nagroda z ostatnich 10k epizodów: -0.168
Episode 60000, średnia nagroda z ostatnich 10k epizodów: -0.096
Episode 70000, średnia nagroda z ostatnich 10k epizodów: -0.059
Episode 80000, średnia nagroda z ostatnich 10k epizodów: -0.051
Episode 90000, średnia nagroda z ostatnich 10k epizodów: -0.074
Episode 100000, średnia nagroda z ostatnich 10k epizodów: -0.037
Episode 110000, średnia nagroda z ostatnich 10k epizodów: -0.031
Episode 120000, średnia nagroda z ostatnich 10k epizodów: -0.024
Episode 130000, średnia nagroda z ostatnich 10k epizodów: -0.054
Episode 140000, średnia nagroda z ostatnich 10k epizodów: -0.042
Episode 150000, średnia nagroda z ostatnich 10k epizodów: -0.045
Episode 160000, średnia nagroda z 

In [11]:
def a2c_policy(state):
    return a2c_greedy_action(state)

wins_a2c, losses_a2c, draws_a2c = evaluate_policy(a2c_policy)
wins_bs, losses_bs, draws_bs = evaluate_policy(basic_strategy)

print("A2C:          Wins:", wins_a2c, "Losses:", losses_a2c, "Draws:", draws_a2c)
print("BasicStrategy: Wins:", wins_bs,  "Losses:", losses_bs,  "Draws:", draws_bs)


A2C:          Wins: 42337 Losses: 48141 Draws: 9522
BasicStrategy: Wins: 40701 Losses: 48733 Draws: 10566
