# Q-Learning blackjack

In [8]:

import time
import random
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gymnasium as gym

# ---------------------------------------------------------
# Global config
# ---------------------------------------------------------

seed_list = list(range(1, 3))   # use 1..10 for final experiments if you want
print("len(seed_list): ", len(seed_list))
print("seed_list: ", seed_list)

GAMMA = 0.98

N_EPISODES = 30_000
ALPHA = 0.05
EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY = 5_000

TERMINAL = ("terminal", "terminal", False)

len(seed_list):  2
seed_list:  [1, 2]


In [9]:


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)

def epsilon_greedy(Q: Dict[Tuple, float], state, actions, epsilon: float) -> int:
    """
    ε-greedy policy over Q for a single state.
    Q is a dict keyed by (state, action).
    """
    if np.random.rand() < epsilon:
        return np.random.choice(actions)
    q_vals = [Q.get((state, a), 0.0) for a in actions]
    return int(np.argmax(q_vals))

def pad_and_stack(sequences: List[List[float]]) -> np.ndarray:
    """
    Take a list of 1D sequences of possibly different lengths and
    return a 2D array (n_sequences, max_len) padded with NaNs.
    """
    max_len = max(len(seq) for seq in sequences)
    arr = np.full((len(sequences), max_len), np.nan, dtype=float)
    for i, seq in enumerate(sequences):
        arr[i, : len(seq)] = np.array(seq, dtype=float)
    return arr

def plot_mean_iqr(
    values_2d: np.ndarray,
    title: str,
    ylabel: str,
    filename: str,
    xlabel: str = "Episode",
):
    """
    values_2d: shape (n_seeds, n_episodes)
    """
    episodes = np.arange(values_2d.shape[1])
    mean = np.nanmean(values_2d, axis=0)
    q25, q75 = np.nanpercentile(values_2d, [25, 75], axis=0)

    plt.figure(figsize=(6, 4))
    plt.plot(episodes, mean, label="Mean")
    plt.fill_between(episodes, q25, q75, alpha=0.3, label="IQR")
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(filename, bbox_inches="tight")
    plt.close()
    print(f"Saved: {filename}")

def plot_comparison_mean(
    sarsa_mat: np.ndarray,
    ql_mat: np.ndarray,
    title: str,
    ylabel: str,
    filename: str,
    xlabel: str = "Episode",
):
    """
    Plot mean return vs episodes for SARSA and Q-Learning on one figure.
    """
    n_episodes = min(sarsa_mat.shape[1], ql_mat.shape[1])
    episodes = np.arange(n_episodes)

    sarsa_mean = np.nanmean(sarsa_mat[:, :n_episodes], axis=0)
    ql_mean = np.nanmean(ql_mat[:, :n_episodes], axis=0)

    plt.figure(figsize=(6, 4))
    plt.plot(episodes, sarsa_mean, label="SARSA")
    plt.plot(episodes, ql_mean, label="Q-Learning")
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(filename, bbox_inches="tight")
    plt.close()
    print(f"Saved: {filename}")

# ---------------------------------------------------------
# Heatmaps
# ---------------------------------------------------------

def plot_value_heatmap(V: Dict, filename: str, title: str):
    """
    V: dict state -> value
    Aggregates over usable ace dimension and plots player sum vs dealer card.
    """
    values = np.zeros((22, 11), dtype=float)  # indices: [player, dealer]
    counts = np.zeros((22, 11), dtype=float)

    for s, v in V.items():
        if s == TERMINAL:
            continue
        player, dealer, ace = s
        values[player, dealer] += v
        counts[player, dealer] += 1.0

    avg_values = np.divide(values, counts, out=np.zeros_like(values), where=counts != 0)

    plt.figure(figsize=(6, 5))
    plt.imshow(avg_values[4:22, 1:11], origin="lower", aspect="auto")
    plt.colorbar(label="Value")
    plt.xlabel("Dealer Card (1–10)")
    plt.ylabel("Player Sum (4–21)")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(filename, bbox_inches="tight")
    plt.close()
    print(f"Saved: {filename}")

def plot_policy_heatmap(pi: Dict, filename: str, title: str):
    """
    pi: dict state -> action (0=stick,1=hit)
    Aggregates over usable ace dimension and plots average action.
    """
    policy_map = np.zeros((22, 11), dtype=float)
    counts = np.zeros((22, 11), dtype=float)

    for s, a in pi.items():
        if s == TERMINAL:
            continue
        player, dealer, ace = s
        policy_map[player, dealer] += a
        counts[player, dealer] += 1.0

    avg_policy = np.divide(policy_map, counts, out=np.zeros_like(policy_map), where=counts != 0)

    plt.figure(figsize=(6, 5))
    plt.imshow(avg_policy[4:22, 1:11], origin="lower", aspect="auto", vmin=0, vmax=1)
    plt.colorbar(label="Action (0=stick, 1=hit, averaged)")
    plt.xlabel("Dealer Card (1–10)")
    plt.ylabel("Player Sum (4–21)")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(filename, bbox_inches="tight")
    plt.close()
    print(f"Saved: {filename}")

# ---------------------------------------------------------
# Q -> policy/value helpers
# ---------------------------------------------------------

def q_to_v_pi_from_dict(
    Q: Dict[Tuple, float],
    states: List[Tuple],
    actions: List[int],
):
    """
    From Q(s,a) dict to:
      - V(s) = max_a Q(s,a)
      - π(s) = argmax_a Q(s,a)
    """
    V = {}
    pi = {}
    for s in states:
        if s == TERMINAL:
            continue
        q_vals = [Q.get((s, a), 0.0) for a in actions]
        best_a = int(np.argmax(q_vals))
        V[s] = float(np.max(q_vals))
        pi[s] = best_a
    V[TERMINAL] = 0.0
    pi[TERMINAL] = 0
    return V, pi

# ---------------------------------------------------------
# SARSA
# ---------------------------------------------------------

def run_sarsa_seed(
    seed: int,
    n_episodes: int,
    gamma: float,
    alpha: float,
    epsilon_start: float,
    epsilon_end: float,
    epsilon_decay: int,
):
    """
    Tabular SARSA on Blackjack-v1 (sab=True).

    Returns:
      - per-episode returns
      - per-episode ΔQ (max |ΔQ| in that episode)
      - final Q dict
      - wall-clock time
    """
    set_seed(seed)
    env = gym.make("Blackjack-v1", sab=True)

    actions = [0, 1]  # 0=stick, 1=hit
    Q: Dict[Tuple, float] = {}

    returns_per_episode: List[float] = []
    deltaQ_per_episode: List[float] = []

    start_time = time.time()
    epsilon = epsilon_start

    for episode in range(n_episodes):
        # Decay epsilon over time
        frac = min(1.0, episode / max(1, epsilon_decay))
        epsilon = epsilon_start + frac * (epsilon_end - epsilon_start)

        state, _ = env.reset(seed=seed + episode)
        done = False
        a = epsilon_greedy(Q, state, actions, epsilon)

        G = 0.0
        max_delta_this_ep = 0.0

        while not done:
            next_state, r, terminated, truncated, info = env.step(a)
            done = terminated or truncated

            G += r

            # On-policy target (SARSA)
            if not done:
                next_a = epsilon_greedy(Q, next_state, actions, epsilon)
                target = r + gamma * Q.get((next_state, next_a), 0.0)
            else:
                next_a = None
                target = r

            old_q = Q.get((state, a), 0.0)
            new_q = old_q + alpha * (target - old_q)
            Q[(state, a)] = new_q

            delta = abs(new_q - old_q)
            if delta > max_delta_this_ep:
                max_delta_this_ep = delta

            state = next_state
            a = next_a if next_a is not None else 0

        returns_per_episode.append(G)
        deltaQ_per_episode.append(max_delta_this_ep)

    wall_clock = time.time() - start_time
    env.close()

    return {
        "seed": seed,
        "returns": returns_per_episode,
        "deltaQs": deltaQ_per_episode,
        "Q": Q,
        "wall_clock": wall_clock,
    }

# ---------------------------------------------------------
# Q-Learning (off-policy, no Double Q)
# ---------------------------------------------------------

def run_q_learning_seed(
    seed: int,
    n_episodes: int,
    gamma: float,
    alpha: float,
    epsilon_start: float,
    epsilon_end: float,
    epsilon_decay: int,
):
    """
    Tabular Q-Learning on Blackjack-v1 (sab=True).

    Returns:
      - per-episode returns
      - per-episode ΔQ (max |ΔQ| in that episode)
      - final Q dict
      - wall-clock time
    """
    set_seed(seed)
    env = gym.make("Blackjack-v1", sab=True)

    actions = [0, 1]  # 0=stick, 1=hit
    Q: Dict[Tuple, float] = {}

    returns_per_episode: List[float] = []
    deltaQ_per_episode: List[float] = []

    start_time = time.time()
    epsilon = epsilon_start

    for episode in range(n_episodes):
        # Decay epsilon over time
        frac = min(1.0, episode / max(1, epsilon_decay))
        epsilon = epsilon_start + frac * (epsilon_end - epsilon_start)

        state, _ = env.reset(seed=seed + 10_000 + episode)  # shift seeds to differ from SARSA
        done = False
        G = 0.0
        max_delta_this_ep = 0.0

        while not done:
            # Behaviour: ε-greedy on current Q
            a = epsilon_greedy(Q, state, actions, epsilon)
            next_state, r, terminated, truncated, info = env.step(a)
            done = terminated or truncated

            G += r

            # Off-policy target: max_a' Q(next_state, a')
            if not done:
                q_next_vals = [Q.get((next_state, a2), 0.0) for a2 in actions]
                best_next = np.max(q_next_vals)
                target = r + gamma * best_next
            else:
                target = r

            old_q = Q.get((state, a), 0.0)
            new_q = old_q + alpha * (target - old_q)
            Q[(state, a)] = new_q

            delta = abs(new_q - old_q)
            if delta > max_delta_this_ep:
                max_delta_this_ep = delta

            state = next_state

        returns_per_episode.append(G)
        deltaQ_per_episode.append(max_delta_this_ep)

    wall_clock = time.time() - start_time
    env.close()

    return {
        "seed": seed,
        "returns": returns_per_episode,
        "deltaQs": deltaQ_per_episode,
        "Q": Q,
        "wall_clock": wall_clock,
    }

# %%
# ---- Enumerate all blackjack states (same as VI/PI code) ----



In [10]:
states = []
for player in range(4, 22):
    for dealer in range(1, 11):
        for ace in [False, True]:
            states.append((player, dealer, ace))
states.append(TERMINAL)

actions = [0, 1]  # 0=stick, 1=hit

print(f"Number of states (including terminal): {len(states)}")
print(f"Actions: {actions}")
print(f"Episodes: {N_EPISODES}")
print(f"alpha={ALPHA}, gamma={GAMMA}, eps_start={EPSILON_START}, eps_end={EPSILON_END}")

# ---------------------------------------------------------
# Run SARSA for each seed
# ---------------------------------------------------------
sarsa_results = []
for seed in seed_list:
    print(f"Running SARSA for seed {seed}...")
    res = run_sarsa_seed(
        seed=seed,
        n_episodes=N_EPISODES,
        gamma=GAMMA,
        alpha=ALPHA,
        epsilon_start=EPSILON_START,
        epsilon_end=EPSILON_END,
        epsilon_decay=EPSILON_DECAY,
    )
    sarsa_results.append(res)

# ---------------------------------------------------------
# Run Q-Learning for each seed
# ---------------------------------------------------------
qlearning_results = []
for seed in seed_list:
    print(f"Running Q-Learning for seed {seed}...")
    res = run_q_learning_seed(
        seed=seed,
        n_episodes=N_EPISODES,
        gamma=GAMMA,
        alpha=ALPHA,
        epsilon_start=EPSILON_START,
        epsilon_end=EPSILON_END,
        epsilon_decay=EPSILON_DECAY,
    )
    qlearning_results.append(res)

# ---------------------------------------------------------
# Aggregate learning curves and ΔQ
# ---------------------------------------------------------

# SARSA
returns_mat_sarsa = pad_and_stack([r["returns"] for r in sarsa_results])
deltaQ_mat_sarsa = pad_and_stack([r["deltaQs"] for r in sarsa_results])
wall_clocks_sarsa = np.array([r["wall_clock"] for r in sarsa_results])

# Q-Learning
returns_mat_ql = pad_and_stack([r["returns"] for r in qlearning_results])
deltaQ_mat_ql = pad_and_stack([r["deltaQs"] for r in qlearning_results])
wall_clocks_ql = np.array([r["wall_clock"] for r in qlearning_results])

# ---------------------------------------------------------
# Learning curves: return vs episodes
# ---------------------------------------------------------
plot_mean_iqr(
    returns_mat_sarsa,
    title="Blackjack SARSA: Return vs Episodes",
    ylabel="Episode Return",
    filename="blackjack_sarsa_learning_curve.pdf",
    xlabel="Episode",
)

plot_mean_iqr(
    returns_mat_ql,
    title="Blackjack Q-Learning: Return vs Episodes",
    ylabel="Episode Return",
    filename="blackjack_qlearning_learning_curve.pdf",
    xlabel="Episode",
)

# ---------------------------------------------------------
# ΔQ vs episodes
# ---------------------------------------------------------
plot_mean_iqr(
    deltaQ_mat_sarsa,
    title="Blackjack SARSA: ΔQ vs Episodes",
    ylabel="Max |ΔQ| per Episode",
    filename="blackjack_sarsa_deltaQ.pdf",
    xlabel="Episode",
)

plot_mean_iqr(
    deltaQ_mat_ql,
    title="Blackjack Q-Learning: ΔQ vs Episodes",
    ylabel="Max |ΔQ| per Episode",
    filename="blackjack_qlearning_deltaQ.pdf",
    xlabel="Episode",
)

# ---------------------------------------------------------
# SARSA vs Q-Learning comparison curve (return vs episodes)
# ---------------------------------------------------------
plot_comparison_mean(
    returns_mat_sarsa,
    returns_mat_ql,
    title="Blackjack: SARSA vs Q-Learning (Return vs Episodes)",
    ylabel="Episode Return",
    filename="blackjack_sarsa_vs_qlearning_learning_curve.pdf",
    xlabel="Episode",
)

# ---------------------------------------------------------
# Final policy & value maps (from Q-Learning Q-table)
# ---------------------------------------------------------
# Use first seed's Q-table for Q-Learning
Q_ql_example = qlearning_results[0]["Q"]
V_ql_example, pi_ql_example = q_to_v_pi_from_dict(Q_ql_example, states, actions)

plot_value_heatmap(
    V_ql_example,
    filename="blackjack_qlearning_value_heatmap.pdf",
    title="Blackjack Q-Learning: Final Value Map (max_a Q)",
)

plot_policy_heatmap(
    pi_ql_example,
    filename="blackjack_qlearning_policy_heatmap.pdf",
    title="Blackjack Q-Learning: Final Policy Map (argmax_a Q)",
)

# (If you also want SARSA maps, you can still keep your previous SARSA plots)
Q_sarsa_example = sarsa_results[0]["Q"]
V_sarsa_example, pi_sarsa_example = q_to_v_pi_from_dict(Q_sarsa_example, states, actions)

plot_value_heatmap(
    V_sarsa_example,
    filename="blackjack_sarsa_value_heatmap.pdf",
    title="Blackjack SARSA: Final Value Map (max_a Q)",
)

plot_policy_heatmap(
    pi_sarsa_example,
    filename="blackjack_sarsa_policy_heatmap.pdf",
    title="Blackjack SARSA: Final Policy Map (argmax_a Q)",
)

# ---------------------------------------------------------
# Wall-Clock Time per Seed (SARSA and Q-Learning)
# ---------------------------------------------------------
plt.figure(figsize=(6,4))
plt.plot(seed_list, wall_clocks_sarsa, marker='o', label="SARSA")
plt.plot(seed_list, wall_clocks_ql, marker='s', label="Q-Learning")
plt.xlabel("Seed")
plt.ylabel("Wall-Clock Time (s)")
plt.title("Blackjack: Wall-Clock Time per Seed (SARSA vs Q-Learning)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig("blackjack_sarsa_vs_qlearning_wallclock_per_seed.pdf", bbox_inches="tight")
plt.close()
print("Saved: blackjack_sarsa_vs_qlearning_wallclock_per_seed.pdf")

# ---------------------------------------------------------
# Text summary
# ---------------------------------------------------------
mean_return_last_100_sarsa = np.nanmean(returns_mat_sarsa[:, -100:], axis=1)
mean_return_last_100_ql = np.nanmean(returns_mat_ql[:, -100:], axis=1)

print("\n=== Per-seed mean return over last 100 episodes (SARSA) ===")
for seed, mr in zip(seed_list, mean_return_last_100_sarsa):
    print(f"SARSA Seed {seed}: {mr:.4f}")

print("\n=== Per-seed mean return over last 100 episodes (Q-Learning) ===")
for seed, mr in zip(seed_list, mean_return_last_100_ql):
    print(f"Q-Learning Seed {seed}: {mr:.4f}")

print("\n=== Wall-clock summary (SARSA) ===")
print(f"Mean wall-clock per seed: {wall_clocks_sarsa.mean():.4f}s ± {wall_clocks_sarsa.std():.4f}s")
print(f"Total wall-clock over all seeds: {wall_clocks_sarsa.sum():.4f}s")

print("\n=== Wall-clock summary (Q-Learning) ===")
print(f"Mean wall-clock per seed: {wall_clocks_ql.mean():.4f}s ± {wall_clocks_ql.std():.4f}s")
print(f"Total wall-clock over all seeds: {wall_clocks_ql.sum():.4f}s")

Number of states (including terminal): 361
Actions: [0, 1]
Episodes: 30000
alpha=0.05, gamma=0.98, eps_start=1.0, eps_end=0.05
Running SARSA for seed 1...
Running SARSA for seed 2...
Running Q-Learning for seed 1...
Running Q-Learning for seed 2...
Saved: blackjack_sarsa_learning_curve.pdf
Saved: blackjack_qlearning_learning_curve.pdf
Saved: blackjack_sarsa_deltaQ.pdf
Saved: blackjack_qlearning_deltaQ.pdf
Saved: blackjack_sarsa_vs_qlearning_learning_curve.pdf
Saved: blackjack_qlearning_value_heatmap.pdf
Saved: blackjack_qlearning_policy_heatmap.pdf
Saved: blackjack_sarsa_value_heatmap.pdf
Saved: blackjack_sarsa_policy_heatmap.pdf
Saved: blackjack_sarsa_vs_qlearning_wallclock_per_seed.pdf

=== Per-seed mean return over last 100 episodes (SARSA) ===
SARSA Seed 1: 0.0000
SARSA Seed 2: -0.0700

=== Per-seed mean return over last 100 episodes (Q-Learning) ===
Q-Learning Seed 1: -0.0500
Q-Learning Seed 2: -0.0600

=== Wall-clock summary (SARSA) ===
Mean wall-clock per seed: 5.0005s ± 0.0338s