In [20]:
# ======================================================
# PERSISTENT SAVE SETUP (GOOGLE DRIVE)
# ======================================================

import os
from google.colab import drive

# Mount Google Drive
drive.mount("/content/drive")

# Base path inside Google Drive
BASE_DIR = "/content/drive/MyDrive/Colab Notebooks"
PROJECT_DIR = os.path.join(BASE_DIR, "connect-4")

# Create project directory if it doesn't exist
os.makedirs(PROJECT_DIR, exist_ok=True)

print("Persistent save directory:")
print(f"  {PROJECT_DIR}")

# Optional: subfolders for organization
RESULTS_DIR = os.path.join(PROJECT_DIR, "results")
MODELS_DIR  = os.path.join(PROJECT_DIR, "models")
PLOTS_DIR   = os.path.join(PROJECT_DIR, "plots")

for d in [RESULTS_DIR, MODELS_DIR, PLOTS_DIR]:
    os.makedirs(d, exist_ok=True)

print("\nSubdirectories:")
print(f"  results → {RESULTS_DIR}")
print(f"  models  → {MODELS_DIR}")
print(f"  plots   → {PLOTS_DIR}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Persistent save directory:
  /content/drive/MyDrive/Colab Notebooks/connect-4

Subdirectories:
  results → /content/drive/MyDrive/Colab Notebooks/connect-4/results
  models  → /content/drive/MyDrive/Colab Notebooks/connect-4/models
  plots   → /content/drive/MyDrive/Colab Notebooks/connect-4/plots


# Connect-4 Policy/Value Network Training (MCTS-2000)

This notebook is designed to run on **Google Colab**.

Setup steps:
1. Clone the `connect-4` GitHub repository into the Colab runtime
2. Import shared utilities:
   - `mirroring/mirror.py` for dataset symmetry handling
   - `data_balance/balance.py` for move-depth weighting
3. Download the official MCTS-2000 dataset from GitHub Releases

No Google Drive mounting is required unless you want to persist trained models.

In [21]:
# ======================================================
# COLAB SETUP — CLONE REPO & IMPORT UTILITIES
# ======================================================

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from pathlib import Path
import urllib.request
import tempfile
import sys
import os

# ------------------------------
# 1. Clone the repo (once)
# ------------------------------

REPO_URL = "https://github.com/AHMerrill/connect-4.git"
REPO_DIR = "/content/connect-4"

if not os.path.exists(REPO_DIR):
    !git clone {REPO_URL} {REPO_DIR}

# Make repo importable
if REPO_DIR not in sys.path:
    sys.path.append(REPO_DIR)

# ------------------------------
# 2. Import shared utilities
# ------------------------------
# mirror.py  -> connect-4/mirroring/mirror.py
# balance.py -> connect-4/data_balance/balance.py

from mirroring.mirror import mirror_dataset
from data_balance.balance import compute_move_balance_weights

# ------------------------------
# 3. Reproducibility
# ------------------------------

SEED = 7
np.random.seed(SEED)
tf.random.set_seed(SEED)

print("TensorFlow version:", tf.__version__)
print("GPUs available:", tf.config.list_physical_devices("GPU"))

TensorFlow version: 2.19.0
GPUs available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [22]:
from tensorflow.keras import mixed_precision

mixed_precision.set_global_policy("mixed_float16")

In [23]:
try:
    tf.config.optimizer.set_jit(True)
    print("XLA JIT enabled")
except Exception as e:
    print("XLA JIT not available:", e)

XLA JIT enabled


## Step 2 — Load and Inspect the Raw Dataset (Dynamic)

In this step we:

- Download the Connect-4 MCTS dataset from its source (URL or local path)
- Load the `.npz` archive **without assumptions** about its internal structure
- Inspect available keys, shapes, and dtypes
- Decide dynamically how to construct the dataset dictionary used downstream

This cell does **no transformations** — it is purely for inspection and validation.

In [24]:
# ======================================================
# LOAD + INSPECT NPZ DATASET (NO ASSUMPTIONS)
# ======================================================

import numpy as np
from pathlib import Path
import urllib.request
import tempfile

DATASET_URL = (
    "https://github.com/AHMerrill/connect-4/"
    "releases/download/v0.1-data/mcts_not_mirrored_2000.npz"
)

def load_npz(source):
    if str(source).startswith("http"):
        tmp_dir = Path(tempfile.gettempdir())
        tmp_path = tmp_dir / Path(source).name
        if not tmp_path.exists():
            print(f"Downloading dataset to {tmp_path} ...")
            urllib.request.urlretrieve(source, tmp_path)
        return np.load(tmp_path, allow_pickle=True)
    return np.load(Path(source), allow_pickle=True)

# Load archive
npz = load_npz(DATASET_URL)

print("\nNPZ file loaded.")
print("Available keys:")
for k in npz.files:
    arr = npz[k]
    if isinstance(arr, np.ndarray):
        print(f"  {k:10s} | shape={arr.shape} dtype={arr.dtype}")
    else:
        print(f"  {k:10s} | type={type(arr)}")

# Sanity checks
required_keys = {"X", "policy", "value"}
missing = required_keys - set(npz.files)
if missing:
    raise KeyError(f"Missing required keys in NPZ: {missing}")

print("\nBasic sanity checks:")
print("X sample shape:", npz["X"][0].shape)
print("Policy sample:", npz["policy"][0])
print("Value sample:", npz["value"][0])


NPZ file loaded.
Available keys:
  boards     | shape=(703111, 6, 7) dtype=int8
  visits     | shape=(703111, 7) dtype=float32
  scores     | shape=(703111, 7) dtype=float32
  policy     | shape=(703111, 7) dtype=float32
  q          | shape=(703111, 7) dtype=float32
  value      | shape=(703111, 1) dtype=float32
  X          | shape=(703111, 6, 7, 2) dtype=float32

Basic sanity checks:
X sample shape: (6, 7, 2)
Policy sample: [0.1183     0.18783334 0.2223     0.06976666 0.2251     0.09663333
 0.08006667]
Value sample: [0.12936667]


## Step 3 — Mirror Boards and Apply Move-Depth Balancing (Dynamic)

In this step we:

- Construct the dataset dictionary **based on what actually exists** in the NPZ
- Apply **board mirroring** using the shared utility
- Apply **dynamic move-depth loss weighting**
  - No samples dropped
  - No resampling
  - Weighting derived from the dataset’s true move distribution

Outputs from this step are:
- `mirrored_data` — ready for training
- `sample_weights` — passed directly to `model.fit`

In [25]:
# ======================================================
# MIRROR DATASET + COMPUTE MOVE-DEPTH WEIGHTS (FULLY DYNAMIC)
# ======================================================

from mirroring.mirror import mirror_dataset
from data_balance.balance import compute_move_balance_weights
import numpy as np

# ------------------------------
# 1. Build dataset dict dynamically
# ------------------------------

data = {
    "X": npz["X"],
    "policy": npz["policy"],
    "value": npz["value"],
}

# Optional fields (only include if present)
for optional_key in ["boards", "visits", "scores", "q"]:
    if optional_key in npz.files:
        data[optional_key] = npz[optional_key]

print("\nDataset dictionary constructed with keys:")
for k, v in data.items():
    print(f"  {k:10s} | shape={v.shape}")

# ------------------------------
# 2. Mirror dataset
# ------------------------------

mirrored_data = mirror_dataset(data)

print("\nAfter mirroring:")
for k, v in mirrored_data.items():
    print(f"  {k:10s} | shape={v.shape}")

# ------------------------------
# 3. Compute move-depth balance weights
# ------------------------------

mirrored_data, sample_weights = compute_move_balance_weights(
    mirrored_data,
    num_bins=10,
)

print("\nSample weights summary (to be used during training):")
print(f"  shape : {sample_weights.shape}")
print(f"  mean  : {sample_weights.mean():.4f}")
print(f"  min   : {sample_weights.min():.4f}")
print(f"  max   : {sample_weights.max():.4f}")

# NOTE:
# `sample_weights` is a per-position loss multiplier.
# It is NOT applied yet.
# It will be passed into model.fit(...) later via `sample_weight=...`.

# ------------------------------
# 4. Fully data-driven dataset synopsis
# ------------------------------

X = mirrored_data["X"]

# Move count = number of stones on the board
move_count = np.count_nonzero(
    X[..., 0] + X[..., 1],
    axis=(1, 2),
)

N = len(move_count)
min_moves, max_moves = move_count.min(), move_count.max()

# Quantile-based depth regions (no hard-coded assumptions)
q_early, q_mid, q_late = np.quantile(move_count, [0.25, 0.50, 0.75])

early_mask = move_count <= q_early
mid_mask   = (move_count > q_early) & (move_count <= q_late)
late_mask  = move_count > q_late

early_frac = early_mask.mean()
mid_frac   = mid_mask.mean()
late_frac  = late_mask.mean()

# Bin-level imbalance (same bins used for weighting)
NUM_BINS = 10
bins = np.linspace(min_moves, max_moves + 1, NUM_BINS + 1)
bin_ids = np.digitize(move_count, bins) - 1
bin_ids = np.clip(bin_ids, 0, NUM_BINS - 1)

bin_counts = np.bincount(bin_ids, minlength=NUM_BINS)

most_common = bin_counts.max()
least_common = np.min(bin_counts[bin_counts > 0])
imbalance_ratio = most_common / least_common

dominant_bin = bin_counts.argmax()
dominant_share = bin_counts[dominant_bin] / N

print("\nData-driven interpretation:")
print(
    f"  Positions span move counts from {min_moves} to {max_moves}.\n\n"
    f"  Using empirical depth quantiles:\n"
    f"    • Early-game positions (≤ {int(q_early)} moves): {early_frac:.1%}\n"
    f"    • Mid-game positions ({int(q_early)+1}–{int(q_late)} moves): {mid_frac:.1%}\n"
    f"    • Late-game positions (> {int(q_late)} moves): {late_frac:.1%}\n\n"
    f"  When grouped into {NUM_BINS} move-depth bins, the most populated bin contains "
    f"{most_common:,} samples ({dominant_share:.1%} of the dataset), while the least "
    f"populated non-empty bin contains {least_common:,} samples "
    f"(≈ {imbalance_ratio:.1f}× fewer).\n\n"
    f"  This means raw training would overweight positions from high-frequency depth "
    f"regions and underweight rarer depths.\n\n"
    f"  To correct for this, inverse-frequency loss weights are computed per sample "
    f"based on its move-depth bin. These weights are stored in `sample_weights` and "
    f"will be applied during training so that each depth region contributes more "
    f"evenly in expectation.\n\n"
    f"  No samples are dropped or duplicated. All statistics and weights are computed "
    f"directly from the loaded dataset and will automatically adapt if a different "
    f"dataset is used."
)


Dataset dictionary constructed with keys:
  X          | shape=(703111, 6, 7, 2)
  policy     | shape=(703111, 7)
  value      | shape=(703111, 1)
  boards     | shape=(703111, 6, 7)
  visits     | shape=(703111, 7)
  scores     | shape=(703111, 7)
  q          | shape=(703111, 7)
Mirroring complete: 703111 → 1406222 samples

After mirroring:
  X          | shape=(1406222, 6, 7, 2)
  policy     | shape=(1406222, 7)
  value      | shape=(1406222, 1)
  boards     | shape=(1406222, 6, 7)
  visits     | shape=(1406222, 7)
  scores     | shape=(1406222, 7)
  q          | shape=(1406222, 7)
Move-count balancing:
  samples           : 1406222
  bins              : 10
  move range        : [0, 41]
  bin counts        : [2460, 145278, 373760, 329900, 231714, 171814, 78988, 45276, 20070, 6962]
  bin weight range  : [1.000, 151.935]
  sample weight μ   : 1.000

Sample weights summary (to be used during training):
  shape : (1406222,)
  mean  : 1.0000
  min   : 0.3762
  max   : 57.1635

Data-driv

## Step 4 — Stratified Train/Test Split by Move Depth (No Cross-Validation Yet)

We will now create a **train/test split** that is **stratified by move depth** (number of stones on the board).

Why:
- The dataset is not uniformly distributed across move depths.
- A purely random split can slightly shift the depth mix between train and test.
- Stratification ensures **train and test have the same move-depth distribution**, making model comparisons fair.

What this does **not** do:
- No samples are dropped.
- No resampling is performed.
- No cross-validation is performed yet (we’ll consider it after initial results).

We will still use the previously computed `sample_weights` during training to ensure
**balanced gradient contribution** across move-depth bins.

In [26]:
# ======================================================
# STEP 4 — STRATIFIED TRAIN / TEST SPLIT BY MOVE DEPTH
# (SPLIT *ALL* DATA FIELDS, FULLY ALIGNED)
# ======================================================

import numpy as np
from sklearn.model_selection import train_test_split

# ------------------------------------------------------
# Inputs (from previous steps):
#   mirrored_data  : dict with X, policy, value, and optional fields (q, etc.)
#   sample_weights : (N,) array from compute_move_balance_weights(...)
# ------------------------------------------------------

# ------------------------------
# 0) Extract core arrays
# ------------------------------

X      = mirrored_data["X"].astype(np.float32)
policy = mirrored_data["policy"].astype(np.float32)
value  = mirrored_data["value"].astype(np.float32)
weights = sample_weights.astype(np.float32)

# Optional fields (present in your dataset)
Q = mirrored_data.get("q", None)

N = X.shape[0]
print(f"Total samples: {N:,}")

# ------------------------------
# 1) Compute move count (stones on board)
# ------------------------------

move_count = np.count_nonzero(
    X[..., 0] + X[..., 1],
    axis=(1, 2),
).astype(int)

# ------------------------------
# 2) Bin move counts for stratification
# ------------------------------

NUM_BINS = 10
bins = np.linspace(move_count.min(), move_count.max() + 1, NUM_BINS + 1)

bin_ids = np.digitize(move_count, bins) - 1
bin_ids = np.clip(bin_ids, 0, NUM_BINS - 1)

# ------------------------------
# 3) Build split payload (THIS IS THE FIX)
# ------------------------------

split_payload = [
    X,
    policy,
    value,
    weights,
    bin_ids,
]

payload_names = [
    "X",
    "policy",
    "value",
    "weights",
    "bin_ids",
]

if Q is not None:
    split_payload.append(Q.astype(np.float32))
    payload_names.append("Q")

# ------------------------------
# 4) Stratified train/test split (ALL FIELDS)
# ------------------------------

split = train_test_split(
    *split_payload,
    test_size=0.20,
    random_state=SEED,
    stratify=bin_ids,
)

# ------------------------------
# 5) Unpack results cleanly
# ------------------------------

idx = 0
X_train, X_test = split[idx], split[idx+1]; idx += 2
policy_train, policy_test = split[idx], split[idx+1]; idx += 2
value_train, value_test = split[idx], split[idx+1]; idx += 2
w_train, w_test = split[idx], split[idx+1]; idx += 2
bin_train, bin_test = split[idx], split[idx+1]; idx += 2

if Q is not None:
    Q_train, Q_test = split[idx], split[idx+1]

# ------------------------------
# 6) Diagnostics — verify stratification
# ------------------------------

def summarize_bins(name, b):
    counts = np.bincount(b, minlength=NUM_BINS)
    frac = counts / counts.sum()
    print(f"\n{name} move-depth bin distribution:")
    for i in range(NUM_BINS):
        print(f"  bin {i:2d}: {counts[i]:8d} ({frac[i]:6.2%})")

print("\nStratified split complete.")
summarize_bins("TRAIN", bin_train)
summarize_bins("TEST ", bin_test)

print("\nWeight sanity check (means should be ~1.0):")
print(f"  mean(w_train) = {w_train.mean():.4f}")
print(f"  mean(w_test)  = {w_test.mean():.4f}")

print("\nShapes:")
print(f"  X_train      : {X_train.shape}")
print(f"  policy_train : {policy_train.shape}")
print(f"  value_train  : {value_train.shape}")
print(f"  X_test       : {X_test.shape}")
print(f"  policy_test  : {policy_test.shape}")
print(f"  value_test   : {value_test.shape}")

if Q is not None:
    print(f"  Q_train      : {Q_train.shape}")
    print(f"  Q_test       : {Q_test.shape}")

Total samples: 1,406,222

Stratified split complete.

TRAIN move-depth bin distribution:
  bin  0:     1968 ( 0.17%)
  bin  1:   116222 (10.33%)
  bin  2:   299008 (26.58%)
  bin  3:   263920 (23.46%)
  bin  4:   185371 (16.48%)
  bin  5:   137451 (12.22%)
  bin  6:    63190 ( 5.62%)
  bin  7:    36221 ( 3.22%)
  bin  8:    16056 ( 1.43%)
  bin  9:     5570 ( 0.50%)

TEST  move-depth bin distribution:
  bin  0:      492 ( 0.17%)
  bin  1:    29056 (10.33%)
  bin  2:    74752 (26.58%)
  bin  3:    65980 (23.46%)
  bin  4:    46343 (16.48%)
  bin  5:    34363 (12.22%)
  bin  6:    15798 ( 5.62%)
  bin  7:     9055 ( 3.22%)
  bin  8:     4014 ( 1.43%)
  bin  9:     1392 ( 0.49%)

Weight sanity check (means should be ~1.0):
  mean(w_train) = 1.0000
  mean(w_test)  = 1.0000

Shapes:
  X_train      : (1124977, 6, 7, 2)
  policy_train : (1124977, 7)
  value_train  : (1124977, 1)
  X_test       : (281245, 6, 7, 2)
  policy_test  : (281245, 7)
  value_test   : (281245, 1)
  Q_train      : (1124

## Step 5 — Define the CNN Model Grid (Residual Policy/Value Architectures)

In this step we define a **small, deliberately constrained set of CNN architectures** to evaluate.

The goal is to compare models in a way that is:
- **Statistically clean**
- **Interpretable**
- **Focused on capacity where it actually matters** for Connect-4

---

### Design intent

All models share the same **overall architecture and training setup**, differing only in a few carefully chosen capacity parameters.

**Shared across all models**
- AlphaZero-style **residual CNN** with separate policy and value heads
- Input shape: `(6, 7, 2)` (board planes)
- Optimizer: **AdamW**
- Losses:
  - Policy: categorical cross-entropy
  - Value: mean squared error
- Activations: ReLU
- Move-depth **loss weighting** applied during training
- Identical TRAIN / TEST split

**Varied deliberately**
- **Number of residual blocks** (network depth)
- **Learning rate** (optimization dynamics)

This keeps the comparison tight and ensures performance differences are driven by **model capacity and optimization behavior**, not confounded hyperparameter noise.

---

### What a residual block does

Each residual block:
- Applies two 3×3 convolutions with batch normalization
- Adds the block’s input back to its output (skip connection)
- Enables deeper networks without vanishing gradients

Increasing the number of blocks increases **effective depth and representational power** while preserving stable optimization.

---

### What `filters` means

The `filters` parameter controls the **channel width** of the network:
- Each convolution produces `filters` feature maps
- Higher values increase representational capacity at the cost of compute

In this grid, all models use the **same number of filters**, so depth and learning rate are isolated as the primary variables.

---

### Policy and value head structure

After the shared residual tower, the network splits into two heads:

**Policy head**
- 1×1 convolution → batch norm → ReLU
- Flatten → Dense(7) with softmax
- Outputs a probability distribution over the 7 columns

**Value head**
- 1×1 convolution → batch norm → ReLU
- Flatten → Dense(64) → Dense(1) with tanh
- Outputs a scalar value estimate in `[-1, 1]`

Both heads are trained jointly using their respective losses.

---

### What this step does

- Defines a reusable residual block
- Defines a `build_residual_cnn()` factory function
- Declares a compact model grid to be trained and evaluated in the next step

In [27]:
# ======================================================
# STEP 5 — DEFINE RESIDUAL CNN MODEL GRID
# (MIXED PRECISION SAFE — FINAL)
# ======================================================

import tensorflow as tf

# ------------------------------------------------------
# Residual block
# ------------------------------------------------------

def residual_block(x, filters):
    """
    Standard AlphaZero-style residual block:
      Conv → BN → ReLU → Conv → BN → Skip → ReLU
    """
    skip = x

    x = tf.keras.layers.Conv2D(
        filters,
        kernel_size=3,
        padding="same",
        use_bias=False,
    )(x)
    x = tf.keras.layers.BatchNormalization(dtype="float32")(x)
    x = tf.keras.layers.Activation("relu")(x)

    x = tf.keras.layers.Conv2D(
        filters,
        kernel_size=3,
        padding="same",
        use_bias=False,
    )(x)
    x = tf.keras.layers.BatchNormalization(dtype="float32")(x)

    x = tf.keras.layers.Add()([x, skip])
    x = tf.keras.layers.Activation("relu")(x)

    return x


# ------------------------------------------------------
# Model builder
# ------------------------------------------------------

def build_residual_cnn(
    input_shape=(6, 7, 2),
    num_blocks=8,
    filters=128,
    learning_rate=1e-3,
    weight_decay=1e-4,
):
    """
    AlphaZero-style policy/value CNN.
    Mixed precision safe:
      - Internal layers run in float16
      - BatchNorm + output heads run in float32
    """

    inputs = tf.keras.Input(shape=input_shape)

    # --------------------------------------------------
    # Initial convolution
    # --------------------------------------------------

    x = tf.keras.layers.Conv2D(
        filters,
        kernel_size=3,
        padding="same",
        use_bias=False,
    )(inputs)
    x = tf.keras.layers.BatchNormalization(dtype="float32")(x)
    x = tf.keras.layers.Activation("relu")(x)

    # --------------------------------------------------
    # Residual tower
    # --------------------------------------------------

    for _ in range(num_blocks):
        x = residual_block(x, filters)

    # --------------------------------------------------
    # Policy head (FLOAT32 OUTPUT)
    # --------------------------------------------------

    p = tf.keras.layers.Conv2D(
        2,
        kernel_size=1,
        use_bias=False,
    )(x)
    p = tf.keras.layers.BatchNormalization(dtype="float32")(p)
    p = tf.keras.layers.Activation("relu")(p)
    p = tf.keras.layers.Flatten()(p)

    policy_out = tf.keras.layers.Dense(
        7,
        activation="softmax",
        dtype="float32",   # CRITICAL FOR MIXED PRECISION
        name="policy",
    )(p)

    # --------------------------------------------------
    # Value head (FLOAT32 OUTPUT)
    # --------------------------------------------------

    v = tf.keras.layers.Conv2D(
        1,
        kernel_size=1,
        use_bias=False,
    )(x)
    v = tf.keras.layers.BatchNormalization(dtype="float32")(v)
    v = tf.keras.layers.Activation("relu")(v)
    v = tf.keras.layers.Flatten()(v)
    v = tf.keras.layers.Dense(64, activation="relu")(v)

    value_out = tf.keras.layers.Dense(
        1,
        activation="tanh",
        dtype="float32",   # CRITICAL FOR MIXED PRECISION
        name="value",
    )(v)

    # --------------------------------------------------
    # Compile model
    # --------------------------------------------------

    model = tf.keras.Model(
        inputs=inputs,
        outputs=[policy_out, value_out],
    )

    model.compile(
        optimizer=tf.keras.optimizers.AdamW(
            learning_rate=learning_rate,
            weight_decay=weight_decay,
        ),
        loss={
            "policy": "categorical_crossentropy",
            "value": "mse",
        },
        metrics={
            "policy": ["accuracy"],
            "value": ["mse"],
        },
    )

    return model


# ------------------------------------------------------
# Model grid (FINAL — 4 MODELS)
# ------------------------------------------------------

MODEL_GRID = [
    {"num_blocks": 8,  "learning_rate": 1e-3},
    {"num_blocks": 10, "learning_rate": 1e-3},
    {"num_blocks": 8,  "learning_rate": 3e-4},
    {"num_blocks": 10, "learning_rate": 3e-4},
]

print("Model training plan:")
print(f"  Total models to train: {len(MODEL_GRID)}\n")

print("Fixed hyperparameters (shared):")
print("  • input_shape   : (6, 7, 2)")
print("  • filters       : 128")
print("  • weight_decay  : 1e-4")
print("  • optimizer     : AdamW")
print("  • loss_policy   : categorical_crossentropy")
print("  • loss_value    : mse\n")

print("Model-specific configurations:")
for i, cfg in enumerate(MODEL_GRID, 1):
    print(f"\n  Model {i}:")
    for k, v in cfg.items():
        print(f"    - {k:13s}: {v}")

Model training plan:
  Total models to train: 4

Fixed hyperparameters (shared):
  • input_shape   : (6, 7, 2)
  • filters       : 128
  • weight_decay  : 1e-4
  • optimizer     : AdamW
  • loss_policy   : categorical_crossentropy
  • loss_value    : mse

Model-specific configurations:

  Model 1:
    - num_blocks   : 8
    - learning_rate: 0.001

  Model 2:
    - num_blocks   : 10
    - learning_rate: 0.001

  Model 3:
    - num_blocks   : 8
    - learning_rate: 0.0003

  Model 4:
    - num_blocks   : 10
    - learning_rate: 0.0003


## Step 6 — Train Model Grid with Early Stopping (TEST as Validation)

In this step we train each candidate CNN architecture in the model grid and
select the best-performing model **using the TEST split as a validation set**.

This workflow is intentional and mirrors common **AlphaZero-style pipelines**
where:
- Offline data is used to compare architectures
- Final evaluation happens through **actual gameplay**, not a static test set

---

### Key decisions (explicit and deliberate)

- **TRAIN is used for gradient updates**
- **TEST is used as validation**
  - Monitors `val_loss` during training
  - Drives early stopping
  - Used to compare architectures and learning rates
- **TEST is not considered a final benchmark**
  - It is allowed to influence model selection
  - It will be reincorporated later for final training
- **True evaluation will be gameplay performance**
  - Elo, win rate, or head-to-head matches vs baselines

---

### Early stopping behavior

- Early stopping monitors **`val_loss` on TEST**
- Training stops when TEST loss no longer improves
- Best weights (lowest TEST loss) are restored automatically
- This reveals:
  - Effective capacity of each architecture
  - Which models generalize better beyond TRAIN

---

### Loss weighting

- Move-depth loss weights (`sample_weights`) are applied during training
- This ensures:
  - Balanced learning across early, mid, and late game positions
  - No resampling or data duplication
- The same weighting scheme is used consistently across all models

---

### What this step produces

For each model in the grid, we record:
- Training history (loss & accuracy curves)
- Best validation (TEST) loss
- Epoch where early stopping occurred
- Training time

These outputs are used to:
- Select the **best architecture and learning rate**
- Decide a **reasonable epoch count** for final training
- Determine whether further tuning or cross-validation is warranted

---

### What this step does *not* do

- Does **not** provide a final unbiased performance estimate
- Does **not** evaluate gameplay strength
- Does **not** claim generalization beyond this dataset

All final claims about strength will come from **self-play or match results**,
not this TEST split.

In [None]:
# ======================================================
# STEP 6 — TRAIN MODEL GRID WITH EARLY STOPPING
# (TEST USED AS VALIDATION + PERSISTENT SAVES)
# ======================================================

from tensorflow.keras.callbacks import EarlyStopping
import time
import numpy as np
import pandas as pd
import pickle
import os

# ------------------------------------------------------
# Training configuration
# ------------------------------------------------------

EPOCHS_MAX = 100
BATCH_SIZE = 512
PATIENCE = 10
FILTERS = 128   # fixed across all models

results = []            # summary per model
histories = {}          # full training curves
trained_models = {}     # trained model objects

print("\nSaving outputs to:")
print(f"  RESULTS_DIR = {RESULTS_DIR}")
print(f"  MODELS_DIR  = {MODELS_DIR}")

# ------------------------------------------------------
# Train each model
# ------------------------------------------------------

for i, cfg in enumerate(MODEL_GRID, start=1):
    print("\n" + "=" * 70)
    print(f"Training model {i}/{len(MODEL_GRID)} | config = {cfg}")
    print("=" * 70)

    model = build_residual_cnn(
        input_shape=(6, 7, 2),
        num_blocks=cfg["num_blocks"],
        filters=FILTERS,
        learning_rate=cfg["learning_rate"],
    )

    # Early stopping based on VALIDATION (TEST) loss
    early_stop = EarlyStopping(
        monitor="val_loss",
        patience=PATIENCE,
        restore_best_weights=True,
        verbose=1,
    )

    start_time = time.time()

    history = model.fit(
        X_train,
        [policy_train, value_train],
        sample_weight=[w_train, w_train],
        validation_data=(
            X_test,
            [policy_test, value_test],
            [w_test, w_test],
        ),
        epochs=EPOCHS_MAX,
        batch_size=BATCH_SIZE,
        callbacks=[early_stop],
        verbose=2,
        shuffle=True,
    )

    elapsed = time.time() - start_time

    # --------------------------------------------------
    # Record validation-driven results
    # --------------------------------------------------

    best_epoch = int(np.argmin(history.history["val_loss"]) + 1)
    best_val_loss = float(np.min(history.history["val_loss"]))

    results.append({
        "model_id": i,
        "num_blocks": cfg["num_blocks"],
        "filters": FILTERS,
        "learning_rate": cfg["learning_rate"],
        "best_epoch_val": best_epoch,
        "best_val_loss": best_val_loss,
        "train_time_sec": elapsed,
    })

    histories[i] = history
    trained_models[i] = model

    print(
        f"\nModel {i} summary:"
        f"\n  Residual blocks  : {cfg['num_blocks']}"
        f"\n  Filters          : {FILTERS}"
        f"\n  Learning rate    : {cfg['learning_rate']}"
        f"\n  Best epoch (val) : {best_epoch}"
        f"\n  Best val loss    : {best_val_loss:.4f}"
        f"\n  Train time (s)   : {elapsed:.1f}"
    )

# ------------------------------------------------------
# Results table (sorted by validation loss)
# ------------------------------------------------------

results_df = (
    pd.DataFrame(results)
    .sort_values("best_val_loss")
    .reset_index(drop=True)
)

print("\n=== MODEL COMPARISON (VALIDATION / TEST) ===")
display(results_df)

# ------------------------------------------------------
# PERSIST RESULTS TO GOOGLE DRIVE
# ------------------------------------------------------

# 1) Save results table
results_csv_path = os.path.join(RESULTS_DIR, "model_grid_results.csv")
results_df.to_csv(results_csv_path, index=False)

# 2) Save full training histories
histories_path = os.path.join(RESULTS_DIR, "training_histories.pkl")
with open(histories_path, "wb") as f:
    pickle.dump(histories, f)

# 3) Save best model (lowest validation loss)
best_model_id = int(results_df.iloc[0]["model_id"])
best_model_path = os.path.join(MODELS_DIR, "best_model.keras")
trained_models[best_model_id].save(best_model_path)

print("\nSaved artifacts:")
print(f"  Results table     → {results_csv_path}")
print(f"  Training histories→ {histories_path}")
print(f"  Best model        → {best_model_path}")


Saving outputs to:
  RESULTS_DIR = /content/drive/MyDrive/Colab Notebooks/connect-4/results
  MODELS_DIR  = /content/drive/MyDrive/Colab Notebooks/connect-4/models

Training model 1/4 | config = {'num_blocks': 8, 'learning_rate': 0.001}
Epoch 1/100
2198/2198 - 85s - 39ms/step - loss: 1.5890 - policy_accuracy: 0.5251 - policy_loss: 1.4008 - val_loss: 1.3949 - val_policy_accuracy: 0.6444 - val_policy_loss: 1.2677 - val_value_loss: 0.1270 - val_value_mse: 0.0658 - value_loss: 0.1884 - value_mse: 0.1029
Epoch 2/100
2198/2198 - 18s - 8ms/step - loss: 1.3372 - policy_accuracy: 0.6673 - policy_loss: 1.2377 - val_loss: 1.3464 - val_policy_accuracy: 0.6690 - val_policy_loss: 1.2368 - val_value_loss: 0.1094 - val_value_mse: 0.0696 - value_loss: 0.0998 - value_mse: 0.0478
Epoch 3/100
2198/2198 - 17s - 8ms/step - loss: 1.2655 - policy_accuracy: 0.7010 - policy_loss: 1.2055 - val_loss: 1.2681 - val_policy_accuracy: 0.7113 - val_policy_loss: 1.2026 - val_value_loss: 0.0652 - val_value_mse: 0.0434 -

Unnamed: 0,model_id,num_blocks,filters,learning_rate,best_epoch_val,best_val_loss,train_time_sec
0,2,10,128,0.001,12,1.181843,493.142571
1,1,8,128,0.001,32,1.183266,804.454851
2,4,10,128,0.0003,31,1.187543,914.900691
3,3,8,128,0.0003,17,1.198297,524.967894



Saved artifacts:
  Results table     → /content/drive/MyDrive/Colab Notebooks/connect-4/results/model_grid_results.csv
  Training histories→ /content/drive/MyDrive/Colab Notebooks/connect-4/results/training_histories.pkl
  Best model        → /content/drive/MyDrive/Colab Notebooks/connect-4/models/best_model.keras


## Step 7 — Targeted Capacity Probe (Model 2 + Wider Network)

In this step we run a **single, focused experiment** to test whether additional
representational capacity improves performance.

### Motivation

From Step 6, **Model 2** emerged as the best-performing architecture:
- 10 residual blocks
- 128 filters
- learning rate = 1e-3
- fast convergence and lowest validation loss

This suggests the model may be **capacity-limited**, not optimization-limited.

### What we change (and only this)

- Increase network width:
  - **filters: 128 → 256**
- Reduce batch size to maintain stability and avoid OOM:
  - **batch size: 512 → 256**

### What we keep fixed

- Architecture (residual structure, heads)
- Number of residual blocks (10)
- Learning rate (1e-3)
- Loss functions and move-depth weighting
- TRAIN / TEST split
- Early stopping configuration

### Evaluation

- Train the wider model using TEST as validation (same as Step 6)
- Load the previously saved `best_model.keras`
- Compare:
  - Best validation loss (primary)
  - Convergence behavior (epoch count)

This isolates the effect of **width alone** and tells us whether a wider network
is worth adopting before moving to self-play or further architectural changes.

In [None]:
# ======================================================
# STEP 7 — TRAIN WIDER MODEL (MODEL 2 + 256 FILTERS)
# + COMPARE AGAINST SAVED BEST MODEL
# ======================================================

from tensorflow.keras.callbacks import EarlyStopping
import numpy as np
import time
import os
import tensorflow as tf

# ------------------------------------------------------
# Configuration (ONLY width & batch size change)
# ------------------------------------------------------

NUM_BLOCKS = 10           # same as Model 2
LEARNING_RATE = 1e-3      # same as Model 2
FILTERS_WIDE = 256        # wider network
BATCH_SIZE_WIDE = 256     # reduced batch size
EPOCHS_MAX = 100
PATIENCE = 10

print("\nRunning targeted capacity probe:")
print(f"  Residual blocks : {NUM_BLOCKS}")
print(f"  Filters         : {FILTERS_WIDE}")
print(f"  Learning rate   : {LEARNING_RATE}")
print(f"  Batch size      : {BATCH_SIZE_WIDE}")

# ------------------------------------------------------
# Build wide model
# ------------------------------------------------------

wide_model = build_residual_cnn(
    input_shape=(6, 7, 2),
    num_blocks=NUM_BLOCKS,
    filters=FILTERS_WIDE,
    learning_rate=LEARNING_RATE,
)

early_stop = EarlyStopping(
    monitor="val_loss",
    patience=PATIENCE,
    restore_best_weights=True,
    verbose=1,
)

# ------------------------------------------------------
# Train
# ------------------------------------------------------

start_time = time.time()

wide_history = wide_model.fit(
    X_train,
    [policy_train, value_train],
    sample_weight=[w_train, w_train],
    validation_data=(
        X_test,
        [policy_test, value_test],
        [w_test, w_test],
    ),
    epochs=EPOCHS_MAX,
    batch_size=BATCH_SIZE_WIDE,
    callbacks=[early_stop],
    verbose=2,
    shuffle=True,
)

elapsed = time.time() - start_time

best_epoch_wide = int(np.argmin(wide_history.history["val_loss"]) + 1)
best_val_loss_wide = float(np.min(wide_history.history["val_loss"]))

print("\nWide model summary:")
print(f"  Best epoch (val) : {best_epoch_wide}")
print(f"  Best val loss    : {best_val_loss_wide:.6f}")
print(f"  Train time (s)   : {elapsed:.1f}")

# ------------------------------------------------------
# Load previously saved best model (from Step 6)
# ------------------------------------------------------

best_model_path = os.path.join(MODELS_DIR, "best_model.keras")
best_model = tf.keras.models.load_model(best_model_path)

best_model_val = best_model.evaluate(
    X_test,
    [policy_test, value_test],
    sample_weight=[w_test, w_test],
    verbose=0,
)

best_model_val_loss = float(best_model_val[0])

# ------------------------------------------------------
# Comparison
# ------------------------------------------------------

print("\n=== VALIDATION LOSS COMPARISON ===")
print(f"  Previous best model (128 filters): {best_model_val_loss:.6f}")
print(f"  Wide model (256 filters)          : {best_val_loss_wide:.6f}")

delta = best_val_loss_wide - best_model_val_loss
print(f"\n  Δ val_loss (wide - previous) = {delta:+.6f}")

if delta < 0:
    print("Wider model improves validation loss.")
else:
    print("Wider model does NOT improve validation loss.")

# ------------------------------------------------------
# Optional: save wide model if it wins
# ------------------------------------------------------

if delta < 0:
    wide_model_path = os.path.join(MODELS_DIR, "best_model_256f.keras")
    wide_model.save(wide_model_path)
    print(f"\nSaved improved wide model → {wide_model_path}")


Running targeted capacity probe:
  Residual blocks : 10
  Filters         : 256
  Learning rate   : 0.001
  Batch size      : 256
Epoch 1/100
4395/4395 - 124s - 28ms/step - loss: 1.5364 - policy_accuracy: 0.5590 - policy_loss: 1.3662 - val_loss: 1.4283 - val_policy_accuracy: 0.6682 - val_policy_loss: 1.2422 - val_value_loss: 0.1860 - val_value_mse: 0.1265 - value_loss: 0.1703 - value_mse: 0.0898
Epoch 2/100
4395/4395 - 53s - 12ms/step - loss: 1.2897 - policy_accuracy: 0.6925 - policy_loss: 1.2172 - val_loss: 1.2481 - val_policy_accuracy: 0.7157 - val_policy_loss: 1.2000 - val_value_loss: 0.0480 - val_value_mse: 0.0251 - value_loss: 0.0726 - value_mse: 0.0380
Epoch 3/100
4395/4395 - 52s - 12ms/step - loss: 1.2275 - policy_accuracy: 0.7229 - policy_loss: 1.1898 - val_loss: 1.2293 - val_policy_accuracy: 0.7349 - val_policy_loss: 1.1878 - val_value_loss: 0.0414 - val_value_mse: 0.0272 - value_loss: 0.0378 - value_mse: 0.0231
Epoch 4/100
4395/4395 - 52s - 12ms/step - loss: 1.2026 - policy_

## Step 7 — Post-Training Diagnostics & Model Comparison

At this point, we have trained and saved multiple AlphaZero-style
policy/value networks using the same dataset, loss weighting, and
training protocol.

Two candidate models are now available:
- A baseline model (128 filters)
- A wider model (256 filters)

The validation loss difference between these models is real but small,
suggesting that **model capacity is no longer the dominant bottleneck**.

Before training larger or more complex architectures, we now shift from
*model search* to *model diagnosis*.

The purpose of this step is to understand **what the models have learned,
where they differ, and what signal they may be missing**.

---

### Why we pause architecture changes here

Adding more layers, filters, or dense heads only helps if:
- the model is under-parameterized **relative to the learning signal**
- or the model is failing to fit clear structure in the data

A small improvement from doubling filters suggests:
- the network can use extra capacity
- but the **supervision signal may be too coarse**, especially early in the game

This is a classic AlphaZero failure mode:  
the network learns late-game outcomes well, but struggles to extract
useful value signal from early or ambiguous positions.

---

### What is `Q` and why it matters

In this dataset, each position includes a `Q` value from MCTS.

Conceptually:
- **`value`** is typically a game outcome target  
  (e.g., win/loss or final result propagated back)
- **`Q`** is the **expected value estimate from search**, averaged over simulations

Key difference:
- `value` answers: *“Who eventually won?”*
- `Q` answers: *“How good is this position according to search?”*

In AlphaZero-style systems:
- `Q` is often a **richer, lower-variance signal**
- especially for early-game positions where outcomes are far away

Using `Q` (or a blend of `Q` and outcome value) as the value target can:
- improve value calibration
- stabilize early-game learning
- reduce noisy gradients from distant outcomes

However, this should **only be done after confirming** that:
- the current value head is miscalibrated
- or poorly correlated with search estimates

---

### What we will do next (no retraining yet)

We will **evaluate and compare the two trained models** on the same
validation set along three diagnostic axes:

1. **Value calibration**
   - Correlation between model value output and MCTS `Q`
   - Error by move depth (early / mid / late game)

2. **Policy sharpness**
   - Entropy of predicted policy distributions
   - KL divergence vs MCTS visit distribution (if available)

3. **Early-game confidence**
   - How quickly value and policy predictions stabilize as depth increases
   - Whether the wider model improves early-game signal or only late-game fit

These diagnostics tell us **why** the wider model helped and whether
improving the value target is the correct next step.

---

### Possible outcomes and decisions

- If both models show weak value–Q correlation:
  → Adjust the value target (use `Q` or a hybrid target)

- If the wider model improves early-game calibration:
  → Capacity matters, consider scaling carefully

- If both models behave similarly:
  → Architecture is not the bottleneck; move to better targets or gameplay

Only after this analysis do we consider:
- retraining with `Q` as a value target
- or proceeding directly to head-to-head gameplay evaluation

No new data is required for this step.

In [None]:
# ============================================================
# STEP 7 — LOAD MODELS + USE EXISTING TEST DATA
# ============================================================

import tensorflow as tf
import numpy as np
import os

# ------------------------------------------------------------
# Paths
# ------------------------------------------------------------

MODELS_DIR = "/content/drive/MyDrive/Colab Notebooks/connect-4/models"

MODEL_128_PATH = os.path.join(MODELS_DIR, "best_model.keras")
MODEL_256_PATH = os.path.join(MODELS_DIR, "best_model_256f.keras")

# ------------------------------------------------------------
# Load models
# ------------------------------------------------------------

model_128 = tf.keras.models.load_model(MODEL_128_PATH)
model_256 = tf.keras.models.load_model(MODEL_256_PATH)

print("Models loaded successfully.")

# ------------------------------------------------------------
# Use TEST split already created in Step 4
# ------------------------------------------------------------
# These MUST already exist in memory:
#   X_test
#   policy_test
#   value_test
#   w_test
#   bin_test  (optional)
#   mirrored_data["q"] (optional but important)

X_eval = X_test
policy_eval = policy_test
value_eval = value_test

print("\nEvaluation dataset:")
print("  X_eval shape     :", X_eval.shape)
print("  policy_eval shape:", policy_eval.shape)
print("  value_eval shape :", value_eval.shape)

# ------------------------------------------------------------
# Extract Q-values aligned with TEST split (if available)
# ------------------------------------------------------------

if "q" in mirrored_data:
    Q_all = mirrored_data["q"]
    Q_eval = Q_all[bin_test.index] if False else Q_all[:len(Q_all)]  # placeholder safety
    print("  Q available      :", True)
else:
    Q_eval = None
    print("  Q available      :", False)

# ------------------------------------------------------------
# Compute move depth (stones on board)
# ------------------------------------------------------------

move_depth_eval = np.count_nonzero(
    X_eval[..., 0] + X_eval[..., 1],
    axis=(1, 2),
)

print("  move_depth_eval  :", move_depth_eval.shape)

# ------------------------------------------------------------
# Forward pass sanity check
# ------------------------------------------------------------

policy_128, value_128 = model_128.predict(X_eval[:128], verbose=0)
policy_256, value_256 = model_256.predict(X_eval[:128], verbose=0)

print("\nSanity check:")
print("  policy output shape:", policy_128.shape)
print("  value output shape :", value_128.shape)

Models loaded successfully.

Evaluation dataset:
  X_eval shape     : (281245, 6, 7, 2)
  policy_eval shape: (281245, 7)
  value_eval shape : (281245, 1)
  Q available      : True
  move_depth_eval  : (281245,)

Sanity check:
  policy output shape: (128, 7)
  value output shape : (128, 1)


## Step 8 — Value Head Calibration vs Search Estimates (Diagnostic)

In this step we **evaluate how well the trained value head agrees with MCTS search**.
This is a **pure diagnostic step** — no retraining, no model selection, no architecture changes.

The goal is to determine whether the **value target itself is limiting performance**, or whether the models have already extracted essentially all useful signal from the data.

---

### What we are comparing

For each trained model (128 filters and 256 filters), we compare:

- **Model value output**  
  The scalar value predicted by the network for each position.

Against two search-derived references:

- **Search-expected Q**  
  The policy-weighted average of Q-values from MCTS.  
  This reflects how good the position is *according to search*, not just the final outcome.

- **Max Q (optimistic bound)**  
  The best Q-value available from search.  
  This is not a training target — it is used only to understand optimism vs conservatism.

---

### Why this matters

The value head is usually trained from game outcomes, which means:

- Early-game positions can have **high variance**
- Learning signal can be **slow to propagate backward**
- The gradient can be noisy even with lots of data

Meanwhile, MCTS Q-values represent **search-refined evaluations** that already incorporate lookahead.

If the value head is poorly aligned with Q:
- Using Q (or a blend of outcome value + Q) as the value target can help
- This is a common AlphaZero refinement step

If the value head already matches Q well:
- Changing the value target will not materially improve performance
- The bottleneck lies elsewhere (policy quality, search, or self-play)

---

### What we measure

On the held-out evaluation set we compute:

1. **Global calibration**
   - Correlation between predicted value and expected Q
   - Mean absolute error (MAE)

2. **Optimism check**
   - Correlation and MAE vs max-Q

3. **Depth-aware calibration**
   - Same metrics computed separately for early, mid, and late game positions
   - Reveals whether value quality degrades with move depth

---

### How to interpret the results

- **Very high correlation and low MAE across all depths**  
  → The value head is already learning the search signal well.  
  → Switching to a Q-based value target is unlikely to help.

- **Strong late-game calibration but weak early-game calibration**  
  → Value target may be too noisy early; Q-based targets could help.

- **Similar behavior between 128f and 256f models**  
  → Capacity is not the main bottleneck; supervision quality likely is.

If Step 8 looks strong (as it does here), the correct next steps are **policy diagnostics**, **head-to-head gameplay**, or **self-play iteration**, not further value-head tuning.

In [None]:
# ============================================================
# STEP 8 (FIXED & HARDENED) — VALUE CALIBRATION VS SEARCH Q
# ============================================================

import numpy as np
from scipy.stats import pearsonr

# ------------------------------------------------------------
# Use TEST split explicitly
# ------------------------------------------------------------

X_eval = X_test
policy_eval = policy_test      # MCTS policy targets (π_MCTS)
value_target_eval = value_test # outcome-based value targets
Q_eval = Q_test                # MCTS Q(s,a)
depth_eval = move_depth_eval

# ------------------------------------------------------------
# Sanity checks (DO NOT SKIP)
# ------------------------------------------------------------

N = X_eval.shape[0]

assert policy_eval.shape == (N, 7)
assert Q_eval.shape == (N, 7)
assert depth_eval.shape == (N,)

print("Sanity checks passed.")
print(f"N = {N:,}")

# ------------------------------------------------------------
# Predict model outputs
# ------------------------------------------------------------

policy_128, value_128 = model_128.predict(
    X_eval, batch_size=1024, verbose=0
)
policy_256, value_256 = model_256.predict(
    X_eval, batch_size=1024, verbose=0
)

value_128 = value_128.squeeze()
value_256 = value_256.squeeze()

# ------------------------------------------------------------
# Construct scalar V_Q(s) from MCTS
# ------------------------------------------------------------

# Expected value under search policy
VQ_expected = np.sum(policy_eval * Q_eval, axis=1)

# Optimistic bound (best action according to search)
VQ_max = np.max(Q_eval, axis=1)

# ------------------------------------------------------------
# Global calibration
# ------------------------------------------------------------

def summarize(name, pred, target):
    corr = pearsonr(pred, target)[0]
    mae = np.mean(np.abs(pred - target))
    print(f"{name:10s} | corr: {corr:.4f} | MAE: {mae:.4f}")

print("\n=== VALUE vs SEARCH-EXPECTED Q ===")
summarize("128f", value_128, VQ_expected)
summarize("256f", value_256, VQ_expected)

print("\n=== VALUE vs MAX Q (optimistic upper bound) ===")
summarize("128f", value_128, VQ_max)
summarize("256f", value_256, VQ_max)

# ------------------------------------------------------------
# Calibration by move depth
# ------------------------------------------------------------

def summarize_by_depth(pred, target, depth, label):
    bins = np.percentile(depth, [0, 25, 50, 75, 100]).astype(int)

    print(f"\n--- {label} ---")
    for i in range(4):
        lo, hi = bins[i], bins[i+1]
        mask = (depth >= lo) & (depth < hi)
        if mask.sum() == 0:
            continue

        corr = pearsonr(pred[mask], target[mask])[0]
        mae = np.mean(np.abs(pred[mask] - target[mask]))

        print(
            f"Moves {lo:2d}–{hi:2d} | "
            f"samples: {mask.sum():6d} | "
            f"corr: {corr:.3f} | MAE: {mae:.3f}"
        )

summarize_by_depth(value_128, VQ_expected, depth_eval, "128f vs VQ")
summarize_by_depth(value_256, VQ_expected, depth_eval, "256f vs VQ")

Sanity checks passed.
N = 281,245

=== VALUE vs SEARCH-EXPECTED Q ===
128f       | corr: 0.9816 | MAE: 0.0540
256f       | corr: 0.9834 | MAE: 0.0584

=== VALUE vs MAX Q (optimistic upper bound) ===
128f       | corr: 0.9282 | MAE: 0.1332
256f       | corr: 0.9279 | MAE: 0.1104

--- 128f vs VQ ---
Moves  0–11 | samples:  65470 | corr: 0.992 | MAE: 0.035
Moves 11–15 | samples:  74025 | corr: 0.986 | MAE: 0.048
Moves 15–20 | samples:  67201 | corr: 0.984 | MAE: 0.057
Moves 20–41 | samples:  74292 | corr: 0.975 | MAE: 0.075

--- 256f vs VQ ---
Moves  0–11 | samples:  65470 | corr: 0.993 | MAE: 0.041
Moves 11–15 | samples:  74025 | corr: 0.987 | MAE: 0.054
Moves 15–20 | samples:  67201 | corr: 0.985 | MAE: 0.062
Moves 20–41 | samples:  74292 | corr: 0.978 | MAE: 0.076


## Phase 2 — Final Supervised Model + Gameplay Evaluation

Now that we finished hyperparameter/model-size search using an 80/20 split,
we will:

1) Train a final supervised model on 100% of the mirrored dataset
   (TRAIN + TEST combined), using the best architecture we found.

2) Evaluate playing strength by running head-to-head matches against:
   - Random policy
   - Pure MCTS with random leaf evaluation
   - NN-guided MCTS (policy priors + value head)

This gives a gameplay-based check of whether improvements in validation loss
actually translate into stronger Connect-4 play.

In [None]:
# ======================================================
# FINAL SUPERVISED TRAINING ON 100% MIRRORED DATA
# (TRAIN + TEST COMBINED) — FIXED EPOCH SCHEDULE (NO EARLY STOP)
# ======================================================

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ReduceLROnPlateau
import time

# ------------------------------
# Final architecture choice
# You said: Model 2 + 256 filters
# ------------------------------
FINAL_NUM_BLOCKS = 10
FINAL_FILTERS = 256
FINAL_LR = 1e-3
FINAL_WEIGHT_DECAY = 1e-4

# ------------------------------
# Training knobs (NO VALIDATION => NO EARLY STOPPING)
# Pick a fixed epoch count based on earlier "best epoch" signals.
# Recommended starting points:
#   - 25 epochs (conservative)
#   - 35 epochs (more aggressive)
# ------------------------------
FINAL_BATCH_SIZE = 256
FINAL_EPOCHS = 25

# Output paths
FINAL_MODEL_DIR = MODELS_DIR
FINAL_MODEL_KERAS = os.path.join(FINAL_MODEL_DIR, "final_supervised_256f.keras")
FINAL_MODEL_H5    = os.path.join(FINAL_MODEL_DIR, "final_supervised_256f.h5")

print("Final model will be saved to:")
print(" ", FINAL_MODEL_KERAS)
print(" ", FINAL_MODEL_H5)

print("\nFinal training config:")
print(f"  blocks     : {FINAL_NUM_BLOCKS}")
print(f"  filters    : {FINAL_FILTERS}")
print(f"  lr         : {FINAL_LR}")
print(f"  weight_decay: {FINAL_WEIGHT_DECAY}")
print(f"  batch_size : {FINAL_BATCH_SIZE}")
print(f"  epochs     : {FINAL_EPOCHS}")

# ------------------------------
# 1) Build full arrays from mirrored_data (100% of mirrored set)
# ------------------------------
X_full      = mirrored_data["X"].astype(np.float32)
policy_full = mirrored_data["policy"].astype(np.float32)
value_full  = mirrored_data["value"].astype(np.float32)

# sample_weights returned by compute_move_balance_weights aligns with mirrored_data
w_full = sample_weights.astype(np.float32)

print("\nFull mirrored dataset:")
print("  X      :", X_full.shape)
print("  policy :", policy_full.shape)
print("  value  :", value_full.shape)
print("  weights:", w_full.shape)

# ------------------------------
# 2) Build and train final model (fixed schedule)
# ------------------------------
final_model = build_residual_cnn(
    input_shape=(6, 7, 2),
    num_blocks=FINAL_NUM_BLOCKS,
    filters=FINAL_FILTERS,
    learning_rate=FINAL_LR,
    weight_decay=FINAL_WEIGHT_DECAY,
)

# Keep LR reduction (fine without validation; it watches training loss)
callbacks = [
    ReduceLROnPlateau(
        monitor="loss",
        factor=0.5,
        patience=2,
        min_lr=1e-5,
        verbose=1,
    )
]

start = time.time()

history = final_model.fit(
    X_full,
    [policy_full, value_full],
    sample_weight=[w_full, w_full],   # weight BOTH policy and value (your choice)
    epochs=FINAL_EPOCHS,
    batch_size=FINAL_BATCH_SIZE,
    callbacks=callbacks,
    shuffle=True,
    verbose=2,
)

elapsed = time.time() - start
final_loss = float(history.history["loss"][-1])
best_loss = float(np.min(history.history["loss"]))
best_epoch = int(np.argmin(history.history["loss"]) + 1)

print("\nFinal supervised training complete.")
print(f"  final_loss       = {final_loss:.6f}")
print(f"  best_epoch(loss) = {best_epoch}")
print(f"  best_loss        = {best_loss:.6f}")
print(f"  train_time_sec   = {elapsed:.1f}")

# ------------------------------
# 3) Save both formats for teammate / portability
# ------------------------------
final_model.save(FINAL_MODEL_KERAS)
print("\nSaved:", FINAL_MODEL_KERAS)

final_model.save(FINAL_MODEL_H5)
print("Saved:", FINAL_MODEL_H5)

Final model will be saved to:
  /content/drive/MyDrive/Colab Notebooks/connect-4/models/final_supervised_256f.keras
  /content/drive/MyDrive/Colab Notebooks/connect-4/models/final_supervised_256f.h5

Final training config:
  blocks     : 10
  filters    : 256
  lr         : 0.001
  weight_decay: 0.0001
  batch_size : 256
  epochs     : 25

Full mirrored dataset:
  X      : (1406222, 6, 7, 2)
  policy : (1406222, 7)
  value  : (1406222, 1)
  weights: (1406222,)
Epoch 1/25
5494/5494 - 121s - 22ms/step - loss: 1.5033 - policy_accuracy: 0.5772 - policy_loss: 1.3449 - value_loss: 0.1584 - value_mse: 0.0829 - learning_rate: 1.0000e-03
Epoch 2/25
5494/5494 - 61s - 11ms/step - loss: 1.2765 - policy_accuracy: 0.7008 - policy_loss: 1.2116 - value_loss: 0.0649 - value_mse: 0.0336 - learning_rate: 1.0000e-03
Epoch 3/25
5494/5494 - 60s - 11ms/step - loss: 1.2174 - policy_accuracy: 0.7302 - policy_loss: 1.1848 - value_loss: 0.0326 - value_mse: 0.0199 - learning_rate: 1.0000e-03
Epoch 4/25
5494/5494 




Saved: /content/drive/MyDrive/Colab Notebooks/connect-4/models/final_supervised_256f.keras
Saved: /content/drive/MyDrive/Colab Notebooks/connect-4/models/final_supervised_256f.h5


## Gameplay Evaluation — Using the Final Supervised Policy Network

At this stage, we have produced a **final supervised policy/value network**
trained on **100% of the mirrored MCTS dataset** and exported as a `.h5` file
for portability.

The goal now is **not training**, but **evaluation**.

Specifically, we want to answer a simple question:

> *Is this network actually good at playing Connect 4?*

To do that, we will:
- Load the final `.h5` model
- Use **policy inference only**
- Select moves via **argmax over legal actions**
- Play full games against a variety of opponents
- Measure outcomes (win / loss / draw)

Important constraints:
- **No temperature**
- **No heuristics** (no hard-coded win/block logic)
- **No retraining**
- **No self-play yet**

This ensures we are evaluating the network **as-is**, exactly as it would be
used by another teammate loading the `.h5`.

### Evaluation opponents (initial set)

We will test against:
1. **Random player**
2. **Random + legality-aware player** (uniform over legal moves)
3. **MCTS opponent** (configurable rollout count, e.g. 800)

The MCTS rollout count will be a **single tunable parameter** at the top
of the gameplay cell so results are easy to interpret and reproduce.

The network’s value head will be ignored for now.
Only the policy head will be used to choose actions.

In [28]:
# ======================================================
# LOAD FINAL MODEL + POLICY INFERENCE (INFERENCE-ONLY)
# ======================================================

import numpy as np
import tensorflow as tf

FINAL_MODEL_H5 = "/content/drive/MyDrive/Colab Notebooks/connect-4/models/final_supervised_256f.h5"

# IMPORTANT:
# compile=False avoids Keras trying to reload training losses/metrics (e.g. "mse")
# This is the correct way to load a model for inference only.
model = tf.keras.models.load_model(
    FINAL_MODEL_H5,
    compile=False
)
print("Final model loaded (inference-only).")

NUM_ROWS = 6
NUM_COLS = 7

def policy_move(model, board):
    """
    Select a move using pure argmax(policy).

    Parameters
    ----------
    model : tf.keras.Model
        Loaded policy/value network
    board : np.ndarray
        Shape (6, 7, 2), float32
        board[..., 0] = current player stones
        board[..., 1] = opponent stones

    Returns
    -------
    int
        Column index [0–6]
    """
    # Add batch dimension
    x = board[None, ...]  # (1, 6, 7, 2)

    # Forward pass
    policy, _ = model.predict(x, verbose=0)
    policy = policy[0]  # (7,)

    # Mask illegal columns (full columns)
    col_heights = board[:, :, 0].sum(axis=0) + board[:, :, 1].sum(axis=0)
    illegal = col_heights >= NUM_ROWS

    policy = policy.copy()
    policy[illegal] = -np.inf

    # Pure argmax (no temperature, no sampling)
    move = int(np.argmax(policy))

    return move

Final model loaded (inference-only).


## Evaluation Gameplay: Argmax Policy vs (Win/Block + MCTS)

**Goal:** measure how strong your *policy network alone* is as a move selector.

**Your agent (you / model):**
- Plays **pure argmax(policy)** every turn
- No heuristics, no search

**Opponent:**
1. If it has an immediate winning move → take it  
2. Else if you have an immediate winning move next → block it  
3. Else run **MCTS** with **tunable rollouts** (e.g., 800)  
4. Optional: for the first `RANDOM_OPENING_PLIES`, opponent plays random moves (also tunable)

**Protocol:**
- Play `GAMES_PER_MATCHUP` games
- Alternate who goes first (50/50)
- Report W/L/D from the model’s perspective

In [40]:
# ======================================================
# CONNECT-4 GAMEPLAY EVAL (TUNABLE MODEL + OPPONENT MCTS)
# ======================================================

import numpy as np
import math
import time
from dataclasses import dataclass
from typing import List, Optional, Dict

# ======================================================
# TUNABLES
# ======================================================

GAMES_PER_MATCHUP      = 20

# ---- YOUR MODEL ----
MODEL_MCTS_ROLLOUTS    = 100       # 0 = pure argmax, >0 = model + MCTS
MODEL_MCTS_C_PUCT     = 1.4

# ---- OPPONENT ----
OPPONENT_TYPE          = "mcts"  # "random" | "mcts"
OPPONENT_MCTS_ROLLOUTS = 800
OPPONENT_RANDOM_PLIES  = 5

# ---- GLOBAL ----
ROLLOUT_MAX_PLIES      = 42
SEED_EVAL              = 7

rng = np.random.default_rng(SEED_EVAL)

# ======================================================
# BITBOARD HELPERS
# ======================================================

ROWS, COLS = 6, 7

def bit_index(r, c): return r * COLS + c

def check_win_bb(bb):
    for s in (1, 7, 6, 8):
        m = bb & (bb >> s)
        if m & (m >> (2 * s)):
            return True
    return False

def legal_moves(h): return [c for c in range(COLS) if h[c] < ROWS]

def apply_move(p1, p2, h, turn, col):
    b = 1 << bit_index(h[col], col)
    if turn == 1: p1 |= b
    else: p2 |= b
    h[col] += 1
    return p1, p2, -turn

def is_draw(h): return all(x == ROWS for x in h)

def to_model_planes(p1, p2, turn):
    cur = p1 if turn == 1 else p2
    opp = p2 if turn == 1 else p1
    board = np.zeros((6, 7, 2), dtype=np.float32)
    for i in range(42):
        b = 1 << i
        r, c = divmod(i, 7)
        if cur & b: board[r, c, 0] = 1
        elif opp & b: board[r, c, 1] = 1
    return board

# ======================================================
# MODEL POLICY
# ======================================================

def policy_move(model, board):
    p, _ = model.predict(board[None, ...], verbose=0)
    p = p[0]
    illegal = board[:, :, 0].sum(0) + board[:, :, 1].sum(0) == 6
    p[illegal] = -np.inf
    return int(np.argmax(p))

# ======================================================
# MCTS CORE
# ======================================================

@dataclass
class MCTSNode:
    visits: int = 0
    value_sum: float = 0.0
    children: Dict[int, "MCTSNode"] = None
    def __post_init__(self): self.children = self.children or {}
    @property
    def value(self): return 0 if self.visits == 0 else self.value_sum / self.visits

def uct(parent_n, child, c):
    if child.visits == 0: return float("inf")
    return child.value + c * math.sqrt(math.log(parent_n + 1) / child.visits)

def model_leaf_value(model, p1, p2, turn):
    _, v = model.predict(to_model_planes(p1, p2, turn)[None, ...], verbose=0)
    return float(v[0, 0])

def mcts(model, p1, p2, h, turn, rollouts, c, use_model_value):
    root = MCTSNode()
    root_player = turn

    for m in legal_moves(h):
        root.children[m] = MCTSNode()

    for _ in range(rollouts):
        P1, P2, H, T = p1, p2, h.copy(), turn
        node, path = root, []

        while True:
            if check_win_bb(P1) or check_win_bb(P2) or is_draw(H): break
            moves = legal_moves(H)
            for m in moves:
                node.children.setdefault(m, MCTSNode())
            parent_n = max(node.visits, 1)
            m = max(moves, key=lambda x: uct(parent_n, node.children[x], c))
            path.append((node, m))
            node = node.children[m]
            P1, P2, T = apply_move(P1, P2, H, T, m)

        if check_win_bb(P1): outcome = 1 if root_player == 1 else -1
        elif check_win_bb(P2): outcome = 1 if root_player == -1 else -1
        elif is_draw(H): outcome = 0
        else:
            outcome = model_leaf_value(model, P1, P2, T) if use_model_value else 0

        root.visits += 1
        root.value_sum += outcome
        for n, m in path:
            n.visits += 1
            n.value_sum += outcome
            n.children[m].visits += 1
            n.children[m].value_sum += outcome

    return max(root.children, key=lambda m: root.children[m].visits)

# ======================================================
# MOVE SELECTION
# ======================================================

def agent_move(model, p1, p2, h, turn):
    if MODEL_MCTS_ROLLOUTS == 0:
        return policy_move(model, to_model_planes(p1, p2, turn))
    return mcts(model, p1, p2, h, turn, MODEL_MCTS_ROLLOUTS, MODEL_MCTS_C_PUCT, True)

def opponent_move(model, p1, p2, h, turn, ply):
    if OPPONENT_TYPE == "random":
        return rng.choice(legal_moves(h))
    if ply < OPPONENT_RANDOM_PLIES:
        return rng.choice(legal_moves(h))
    return mcts(model, p1, p2, h, turn, OPPONENT_MCTS_ROLLOUTS, 1.4, False)

# ======================================================
# GAME LOOP
# ======================================================

def play_game(model, model_starts):
    p1 = p2 = 0
    h = [0]*7
    turn = 1
    model_player = 1 if model_starts else -1
    ply = 0

    while True:
        if check_win_bb(p1): return 1 if model_player == 1 else -1
        if check_win_bb(p2): return 1 if model_player == -1 else -1
        if is_draw(h): return 0

        col = agent_move(model, p1, p2, h, turn) if turn == model_player \
              else opponent_move(model, p1, p2, h, turn, ply)

        if h[col] >= ROWS:
            col = rng.choice(legal_moves(h))

        p1, p2, turn = apply_move(p1, p2, h, turn, col)
        ply += 1

def run_match(model):
    wins = losses = draws = 0
    t0 = time.time()
    for g in range(GAMES_PER_MATCHUP):
        res = play_game(model, g % 2 == 0)
        wins += res == 1
        losses += res == -1
        draws += res == 0
    dt = time.time() - t0
    return {
        "wins": wins, "losses": losses, "draws": draws,
        "win_rate": wins / GAMES_PER_MATCHUP,
        "sec_per_game": dt / GAMES_PER_MATCHUP,
    }

# ======================================================
# RUN
# ======================================================

print("Running evaluation...")
print(f"Model MCTS rollouts    : {MODEL_MCTS_ROLLOUTS}")
print(f"Opponent type          : {OPPONENT_TYPE}")
print(f"Opponent MCTS rollouts : {OPPONENT_MCTS_ROLLOUTS}")
print(f"Random opening plies   : {OPPONENT_RANDOM_PLIES}")

results = run_match(model)
print(results)

Running evaluation...
Model MCTS rollouts    : 100
Opponent type          : mcts
Opponent MCTS rollouts : 800
Random opening plies   : 5
{'wins': 13, 'losses': 7, 'draws': 0, 'win_rate': 0.65, 'sec_per_game': 0.7441078662872315}
