# blackjack_vi_pi

In [1]:


import time
import random
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gymnasium as gym

# ---------------------------------------------------------
# Global config
# ---------------------------------------------------------


seed_list = list(range(1,11))
print("len(seed_list): ", len(seed_list))
print("seed_list: ", seed_list)

GAMMA = 0.98
DELTA = 5e-6
PI_GAMMA = 0.95

TERMINAL = ("terminal", "terminal", False)

len(seed_list):  10
seed_list:  [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]


In [2]:


# ---------------------------------------------------------
# Utils
# ---------------------------------------------------------

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)


def estimate_transitions(env, n_samples: int, seed: int = 0):
    """
    Empirical estimation of P(s'|s,a) and R(s,a) using random policy.
    """
    set_seed(seed)
    counts = {}
    rewards = {}

    actions = [0, 1]

    for _ in range(n_samples):
        state, _ = env.reset()
        done = False

        while not done:
            a = np.random.choice(actions)
            next_state, r, terminated, truncated, info = env.step(a)
            done = terminated or truncated

            counts.setdefault((state, a, next_state), 0)
            counts[(state, a, next_state)] += 1

            rewards.setdefault((state, a), [])
            rewards[(state, a)].append(r)

            state = next_state

    P = {}
    R = {}

    # Build transition probabilities
    for (s, a, s2), c in counts.items():
        P.setdefault((s, a), {})
        P[(s, a)][s2] = P[(s, a)].get(s2, 0) + c

    for key in P:
        total = sum(P[key].values())
        for s2 in P[key]:
            P[key][s2] /= total

    # Build mean rewards
    for (s, a), rew_list in rewards.items():
        R[(s, a)] = float(np.mean(rew_list))

    return P, R


# ---------------------------------------------------------
# Value Iteration
# ---------------------------------------------------------

def value_iteration(
    states,
    actions,
    P: Dict,
    R: Dict,
    gamma: float = GAMMA,
    delta: float = DELTA,
    patience: int = 3,
    init_random: bool = True,
):
    """
    Standard Value Iteration with per-iteration ΔV logging.
    If init_random=True, V is randomly initialized (seed-dependent).
    """
    if init_random:
        V = {s: float(np.random.randn() * 1e-3) for s in states}
    else:
        V = {s: 0.0 for s in states}

    V[TERMINAL] = 0.0

    best_policy = {s: 0 for s in states}
    deltas: List[float] = []
    no_improve = 0

    while True:
        max_diff = 0.0
        V_new = {}

        for s in states:
            q_values = []
            for a in actions:
                q = R.get((s, a), 0.0)
                if (s, a) in P:
                    for s2, p in P[(s, a)].items():
                        if s2 not in V:
                            s2 = TERMINAL
                        q += gamma * p * V[s2]
                q_values.append(q)

            best_q = max(q_values)
            best_policy[s] = int(np.argmax(q_values))
            V_new[s] = best_q

            max_diff = max(max_diff, abs(V_new[s] - V[s]))

        V = V_new
        deltas.append(max_diff)

        if max_diff < delta:
            no_improve += 1
        else:
            no_improve = 0

        if no_improve >= patience:
            break

    return V, best_policy, deltas


def run_vi_seed(seed, states, actions, P, R, gamma=GAMMA):
    set_seed(seed)
    start_time = time.time()
    V, pi, deltas = value_iteration(states, actions, P, R, gamma, init_random=True)
    wall_clock = time.time() - start_time

    vals = np.array([v for s, v in V.items() if s != TERMINAL])
    avg_return = float(vals.mean())

    return {
        "seed": seed,
        "V": V,
        "pi": pi,
        "deltas": deltas,
        "avg_return": avg_return,
        "wall_clock": wall_clock,
    }


# ---------------------------------------------------------
# Policy Iteration (with stats)
# ---------------------------------------------------------

def policy_evaluation(policy, states, actions, P, R, gamma=PI_GAMMA, delta=DELTA):
    V = {s: 0.0 for s in states}
    V[TERMINAL] = 0.0

    while True:
        max_diff = 0.0
        V_new = {}

        for s in states:
            a = policy[s]
            q = R.get((s, a), 0.0)
            if (s, a) in P:
                for s2, p in P[(s, a)].items():
                    if s2 not in V:
                        s2 = TERMINAL
                    q += gamma * p * V[s2]
            V_new[s] = q
            max_diff = max(max_diff, abs(V_new[s] - V[s]))

        V = V_new
        if max_diff < delta:
            break

    return V


def policy_iteration_with_stats(
    states,
    actions,
    P,
    R,
    gamma=PI_GAMMA,
):
    """
    Policy Iteration with:
      - ΔV_k = max_s |V^{π_k}(s) - V^{π_{k-1}}(s)| per outer iteration
      - policy_change_frac_k = fraction of states whose action changes at iteration k
    """
    # random initial policy
    policy = {s: np.random.choice(actions) for s in states}

    iteration = 0
    stable = False
    deltaV_list: List[float] = []
    policy_change_list: List[float] = []

    V_prev = {s: 0.0 for s in states}
    V_prev[TERMINAL] = 0.0

    non_terminal_states = [s for s in states if s != TERMINAL]
    n_non_terminal = len(non_terminal_states)

    while not stable:
        iteration += 1

        # Policy evaluation
        V = policy_evaluation(policy, states, actions, P, R, gamma)

        # ΔV between this and previous iteration
        diffs = [abs(V[s] - V_prev[s]) for s in non_terminal_states]
        deltaV = max(diffs) if diffs else 0.0
        deltaV_list.append(deltaV)
        V_prev = V

        # Policy improvement
        stable = True
        changed_states = 0
        for s in non_terminal_states:
            old_a = policy[s]
            q_vals = []
            for a in actions:
                q = R.get((s, a), 0.0)
                if (s, a) in P:
                    for s2, p in P[(s, a)].items():
                        if s2 not in V:
                            s2 = TERMINAL
                        q += gamma * p * V[s2]
                q_vals.append(q)

            best_a = int(np.argmax(q_vals))
            policy[s] = best_a

            if best_a != old_a:
                changed_states += 1
                stable = False

        policy_change_frac = changed_states / n_non_terminal if n_non_terminal > 0 else 0.0
        policy_change_list.append(policy_change_frac)

    return V, policy, deltaV_list, policy_change_list


def run_pi_seed(seed, states, actions, P, R, gamma=PI_GAMMA):
    set_seed(seed)
    start_time = time.time()
    V, pi, deltaVs, policy_changes = policy_iteration_with_stats(
        states, actions, P, R, gamma
    )
    wall_clock = time.time() - start_time

    vals = np.array([v for s, v in V.items() if s != TERMINAL])
    avg_return = float(vals.mean())

    return {
        "seed": seed,
        "V": V,
        "pi": pi,
        "deltaVs": deltaVs,
        "policy_changes": policy_changes,
        "avg_return": avg_return,
        "wall_clock": wall_clock,
    }


# ---------------------------------------------------------
# Aggregation helpers
# ---------------------------------------------------------

def pad_and_stack(sequences: List[List[float]]) -> np.ndarray:
    """
    Take a list of 1D sequences of possibly different lengths and
    return a 2D array (n_sequences, max_len) padded with NaNs.
    """
    max_len = max(len(seq) for seq in sequences)
    arr = np.full((len(sequences), max_len), np.nan, dtype=float)
    for i, seq in enumerate(sequences):
        arr[i, : len(seq)] = np.array(seq, dtype=float)
    return arr


def plot_mean_iqr(
    values_2d: np.ndarray,
    title: str,
    ylabel: str,
    filename: str,
    xlabel: str = "Iteration",
):
    """
    values_2d: shape (n_seeds, n_iters)
    """
    iters = np.arange(values_2d.shape[1])
    mean = np.nanmean(values_2d, axis=0)
    q25, q75 = np.nanpercentile(values_2d, [25, 75], axis=0)

    plt.figure(figsize=(6, 4))
    plt.plot(iters, mean, label="Mean")
    plt.fill_between(iters, q25, q75, alpha=0.3, label="IQR")
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(filename, bbox_inches="tight")
    plt.close()
    print(f"Saved: {filename}")


# ---------------------------------------------------------
# Heatmaps
# ---------------------------------------------------------

def plot_value_heatmap(V: Dict, filename: str, title: str):
    """
    V: dict state -> value
    Aggregates over usable ace dimension and plots player sum vs dealer card.
    """
    values = np.zeros((22, 11), dtype=float)  # indices: [player, dealer]
    counts = np.zeros((22, 11), dtype=float)

    for s, v in V.items():
        if s == TERMINAL:
            continue
        player, dealer, ace = s
        values[player, dealer] += v
        counts[player, dealer] += 1.0

    avg_values = np.divide(values, counts, out=np.zeros_like(values), where=counts != 0)

    plt.figure(figsize=(6, 5))
    plt.imshow(avg_values[4:22, 1:11], origin="lower", aspect="auto")
    plt.colorbar(label="Value")
    plt.xlabel("Dealer Card (1–10)")
    plt.ylabel("Player Sum (4–21)")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(filename, bbox_inches="tight")
    plt.close()
    print(f"Saved: {filename}")


def plot_policy_heatmap(pi: Dict, filename: str, title: str):
    """
    pi: dict state -> action (0=stick,1=hit)
    Aggregates over usable ace dimension and plots average action.
    """
    policy_map = np.zeros((22, 11), dtype=float)
    counts = np.zeros((22, 11), dtype=float)

    for s, a in pi.items():
        if s == TERMINAL:
            continue
        player, dealer, ace = s
        policy_map[player, dealer] += a
        counts[player, dealer] += 1.0

    avg_policy = np.divide(policy_map, counts, out=np.zeros_like(policy_map), where=counts != 0)

    plt.figure(figsize=(6, 5))
    plt.imshow(avg_policy[4:22, 1:11], origin="lower", aspect="auto", vmin=0, vmax=1)
    plt.colorbar(label="Action (0=stick, 1=hit, averaged)")
    plt.xlabel("Dealer Card (1–10)")
    plt.ylabel("Player Sum (4–21)")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(filename, bbox_inches="tight")
    plt.close()
    print(f"Saved: {filename}")





In [3]:
# ---------------------------------------------------------
# Main experiment
# ---------------------------------------------------------

def main():
    # ---- Build Blackjack MDP model ----
    env = gym.make("Blackjack-v1", sab=True)

    # Enumerate states
    states = []
    for player in range(4, 22):
        for dealer in range(1, 11):
            for ace in [False, True]:
                states.append((player, dealer, ace))
    states.append(TERMINAL)

    actions = [0, 1]  # 0=stick, 1=hit

    print(f"Number of states (including terminal): {len(states)}")
    print(f"Actions: {actions}")

    # Estimate transitions once (fixed model)
    n_samples = 400_000
    print(f"Estimating transitions with {n_samples} samples...")
    P, R = estimate_transitions(env, n_samples=n_samples, seed=0)
    print(f"Estimated |P|={len(P)}, |R|={len(R)}")

    # ---- Run VI over seeds ----
    print("\n=== Running Value Iteration over seeds ===")
    vi_results = []
    for seed in seed_list:
        res = run_vi_seed(seed, states, actions, P, R, gamma=GAMMA)
        vi_results.append(res)

    vi_returns = np.array([r["avg_return"] for r in vi_results])
    vi_wall_clocks = np.array([r["wall_clock"] for r in vi_results])

    print(f"VI: mean return = {vi_returns.mean():.4f} ± {vi_returns.std():.4f}")
    print(
        f"VI: wall-clock mean = {vi_wall_clocks.mean():.4f}s, "
        f"total = {vi_wall_clocks.sum():.4f}s"
    )

    # ΔV vs iterations (VI) with mean + IQR
    vi_deltas_mat = pad_and_stack([r["deltas"] for r in vi_results])
    plot_mean_iqr(
        vi_deltas_mat,
        title="Blackjack VI: ΔV vs Iterations",
        ylabel="Max ΔV",
        filename="blackjack_vi_deltaV.pdf",
    )

    # ---- Run PI over seeds ----
    print("\n=== Running Policy Iteration over seeds ===")
    pi_results = []
    for seed in seed_list:
        res = run_pi_seed(seed, states, actions, P, R, gamma=PI_GAMMA)
        pi_results.append(res)

    pi_returns = np.array([r["avg_return"] for r in pi_results])
    pi_wall_clocks = np.array([r["wall_clock"] for r in pi_results])

    print(f"PI: mean return = {pi_returns.mean():.4f} ± {pi_returns.std():.4f}")
    print(
        f"PI: wall-clock mean = {pi_wall_clocks.mean():.4f}s, "
        f"total = {pi_wall_clocks.sum():.4f}s"
    )

    # ΔV vs iterations and policy stability (PI)
    pi_deltaV_mat = pad_and_stack([r["deltaVs"] for r in pi_results])
    pi_polchg_mat = pad_and_stack([r["policy_changes"] for r in pi_results])

    # One figure, two subplots (ΔV and policy change fraction)
    iters = np.arange(pi_deltaV_mat.shape[1])

    delta_mean = np.nanmean(pi_deltaV_mat, axis=0)
    delta_q25, delta_q75 = np.nanpercentile(pi_deltaV_mat, [25, 75], axis=0)

    pol_mean = np.nanmean(pi_polchg_mat, axis=0)
    pol_q25, pol_q75 = np.nanpercentile(pi_polchg_mat, [25, 75], axis=0)

    fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(6, 6))

    # ΔV subplot
    ax1.plot(iters, delta_mean, label="Mean ΔV")
    ax1.fill_between(iters, delta_q25, delta_q75, alpha=0.3, label="IQR")
    ax1.set_ylabel("Max ΔV")
    ax1.set_title("Blackjack PI: ΔV and Policy Stability")
    ax1.grid(True)
    ax1.legend()

    # Policy stability subplot
    ax2.plot(iters, pol_mean, label="Mean policy change fraction")
    ax2.fill_between(iters, pol_q25, pol_q75, alpha=0.3, label="IQR")
    ax2.set_xlabel("Policy Iteration")
    ax2.set_ylabel("Fraction of states changed")
    ax2.grid(True)
    ax2.legend()

    plt.tight_layout()
    plt.savefig("blackjack_pi_deltaV_policy_stability.pdf", bbox_inches="tight")
    plt.close()
    print("Saved: blackjack_pi_deltaV_policy_stability.pdf")

    # ---- Final heatmaps (using first seed's solution) ----
    vi_V_example = vi_results[0]["V"]
    vi_pi_example = vi_results[0]["pi"]

    pi_V_example = pi_results[0]["V"]
    pi_pi_example = pi_results[0]["pi"]

    # VI heatmaps
    plot_value_heatmap(
        vi_V_example,
        filename="blackjack_vi_value_heatmap.pdf",
        title="Blackjack VI: Final Value Function",
    )
    plot_policy_heatmap(
        vi_pi_example,
        filename="blackjack_vi_policy_heatmap.pdf",
        title="Blackjack VI: Final Greedy Policy",
    )

    # PI heatmaps
    plot_value_heatmap(
        pi_V_example,
        filename="blackjack_pi_value_heatmap.pdf",
        title="Blackjack PI: Final Value Function",
    )
    plot_policy_heatmap(
        pi_pi_example,
        filename="blackjack_pi_policy_heatmap.pdf",
        title="Blackjack PI: Final Policy",
    )

    for res in vi_results:
        print(f"Seed {res['seed']} – VI iterations: {len(res['deltas'])}")

    for res in pi_results:
        print(f"Seed {res['seed']} – PI iterations: {len(res['deltaVs'])}")

    # ---------------------------------------------------------
    # Wall-Clock Time per Seed Plots (VI and PI)
    # ---------------------------------------------------------

    # VI wall-clock vs seed
    plt.figure(figsize=(6,4))
    plt.plot(seed_list, vi_wall_clocks, marker='o')
    plt.xlabel("Seed")
    plt.ylabel("Wall-Clock Time (s)")
    plt.title("Blackjack VI: Wall-Clock Time per Seed")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("blackjack_vi_wallclock_per_seed.pdf", bbox_inches="tight")
    plt.close()
    print("Saved: blackjack_vi_wallclock_per_seed.pdf")

    # PI wall-clock vs seed
    plt.figure(figsize=(6,4))
    plt.plot(seed_list, pi_wall_clocks, marker='o')
    plt.xlabel("Seed")
    plt.ylabel("Wall-Clock Time (s)")
    plt.title("Blackjack PI: Wall-Clock Time per Seed")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("blackjack_pi_wallclock_per_seed.pdf", bbox_inches="tight")
    plt.close()
    print("Saved: blackjack_pi_wallclock_per_seed.pdf")


    # ---- Final report of seeds and wall-clock ----
    print("\n=== Summary ===")
    print(f"Seeds used: {seed_list}")
    print(f"Total VI wall-clock over all seeds: {vi_wall_clocks.sum():.4f}s")
    print(f"Total PI wall-clock over all seeds: {pi_wall_clocks.sum():.4f}s")
    print(
        f"Total wall-clock (VI + PI): {(vi_wall_clocks.sum() + pi_wall_clocks.sum()):.4f}s"
    )

In [4]:
main()

Number of states (including terminal): 361
Actions: [0, 1]
Estimating transitions with 400000 samples...
Estimated |P|=560, |R|=560

=== Running Value Iteration over seeds ===
VI: mean return = 13.2174 ± 0.0000
VI: wall-clock mean = 1.2718s, total = 12.7183s
Saved: blackjack_vi_deltaV.pdf

=== Running Policy Iteration over seeds ===
PI: mean return = 4.9400 ± 0.0000
PI: wall-clock mean = 0.3490s, total = 3.4901s
Saved: blackjack_pi_deltaV_policy_stability.pdf
Saved: blackjack_vi_value_heatmap.pdf
Saved: blackjack_vi_policy_heatmap.pdf
Saved: blackjack_pi_value_heatmap.pdf
Saved: blackjack_pi_policy_heatmap.pdf
Seed 1 – VI iterations: 607
Seed 2 – VI iterations: 607
Seed 3 – VI iterations: 607
Seed 4 – VI iterations: 607
Seed 5 – VI iterations: 607
Seed 6 – VI iterations: 607
Seed 7 – VI iterations: 607
Seed 8 – VI iterations: 607
Seed 9 – VI iterations: 607
Seed 10 – VI iterations: 607
Seed 1 – PI iterations: 3
Seed 2 – PI iterations: 3
Seed 3 – PI iterations: 3
Seed 4 – PI iterations: