# Connect4 MCTS Distribution Generator

- This notebook first builds a single rolling `mcts_distribution.npz` that aggregates MCTS root distributions per board.  
    - The first MCTS generated approx 610,000 unique board positions at rollout nsteps 200.
- Each data entry consists of the following:
    - 6x7 Connect 4 board grid with +1 (player to move), -1 (opponent), 0 (empty).
    - 6x7x2 Encoding (channel 0 = to-move; channel 1 = opponent).
    - Length-7 raw visit count from the MCTS process (one for each possible connect 4 column).
    - Length-7 net win scores from MCTS.
    - Length-7 "policy", which is normalized visits and represents move probabilities.
    - Length-7 "Q", which is the per move expected outcome.
    - Length-1 "value", or single expected outcome for the board.
- The top 61,000 boards were then refined at rollout nsteps 3000.
    - This was done to combine breadth (many unique boards) with depth (label quality) of the initial data set.
- We can then mirror this dataset. 

**Other Key ideas:**
- The first two moves of each self-play game are random and not recorded. This preserves opening randomness.
- After that, each board is labeled with a distribution over 7 moves for the player to move.
- Repeated boards are accumulated (no deduping) so the distribution becomes an average over many MCTS runs.
- An Opening-book is injected into the data as one-hot distributions.  This opening book uses https://connect4.gamesolver.org/, which is a mathematically formulated perfect connect 4 player by Pascal Pons.  The opening book is exactly 8 boards: The empty board (used if we play first), and 7 additional boards that represent responses to first moves (if we play second).
- Connect 4 has symmetry, so boards and their mirrors are equivalent. 


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [12]:
from pathlib import Path
import json
import random
import shutil
import time
import numpy as np

# Make arrays print more readably in notebook output.
np.set_printoptions(suppress=True, linewidth=120)

# Output location for the rolling dataset.
# Always resolve relative to the project root to avoid duplicate paths.
PROJECT_ROOT = Path.cwd()
# If running from guided_mcts/, use the parent as project root.
if PROJECT_ROOT.name in {"dist_workflow", "guided_mcts"}:
    PROJECT_ROOT = PROJECT_ROOT.parent
DATA_DIR = PROJECT_ROOT / "dist_workflow" / "mcts_distribution"
DATA_DIR.mkdir(parents=True, exist_ok=True)
NPZ_PATH = DATA_DIR / "mcts_distribution.npz"

# MCTS and dataset settings.
DEFAULT_MCTS_STEPS = 200  # fast coverage; refine later with higher nsteps
TARGET_UNIQUE = 1_000_000
CHECKPOINT_EVERY = 100  # save after every N new unique boards
REPORT_EVERY_GAMES = 25
REPORT_EVERY_SECONDS = 60
BACKUP_EVERY = 10000
BACKUP_DIR = PROJECT_ROOT / "dist_workflow" / "recovery_copies"

# Fixed opening-book labels for training data (single pick per opening).
# Source: Pascal Pons, https://connect4.gamesolver.org/
# Keys use a single opponent piece in the given column; -1 represents the empty board.
OPENING_BOOK_LABELS = {
    -1: 3,
    0: 3,
    1: 2,
    2: 3,
    3: 3,
    4: 3,
    5: 4,
    6: 3,
}


## Core game logic
These helpers implement Connect4 mechanics used by MCTS and self-play.


In [13]:
# Drop a checker into the specified column for the given color and return the new board.
# color: "plus" or "minus". Board is 6x7 with values +1, -1, or 0.
def update_board(board_temp, color, column):
    # Work on a copy so callers do not mutate the original board.
    board = board_temp.copy()
    nrow = board.shape[0]
    # Count how many cells are occupied in this column to find the landing row.
    colsum = np.abs(board[:, column]).sum()
    row = int(nrow - 1 - colsum)
    if row > -0.5:
        board[row, column] = 1 if color == "plus" else -1
    return board

# Fast win check that only inspects lines through the last-played column.
def check_for_win(board, col):
    nrow, ncol = board.shape
    colsum = np.abs(board[:, col]).sum()
    row = int(nrow - colsum)
    # vertical
    if row + 3 < nrow:
        vert = board[row, col] + board[row+1, col] + board[row+2, col] + board[row+3, col]
        if vert == 4:
            return "v-plus"
        elif vert == -4:
            return "v-minus"
    # horizontal (four spans through this column)
    if col + 3 < ncol:
        hor = board[row, col] + board[row, col+1] + board[row, col+2] + board[row, col+3]
        if hor == 4:
            return "h-plus"
        elif hor == -4:
            return "h-minus"
    if col - 1 >= 0 and col + 2 < ncol:
        hor = board[row, col-1] + board[row, col] + board[row, col+1] + board[row, col+2]
        if hor == 4:
            return "h-plus"
        elif hor == -4:
            return "h-minus"
    if col - 2 >= 0 and col + 1 < ncol:
        hor = board[row, col-2] + board[row, col-1] + board[row, col] + board[row, col+1]
        if hor == 4:
            return "h-plus"
        elif hor == -4:
            return "h-minus"
    if col - 3 >= 0:
        hor = board[row, col-3] + board[row, col-2] + board[row, col-1] + board[row, col]
        if hor == 4:
            return "h-plus"
        elif hor == -4:
            return "h-minus"
    # diag down-right
    if row < 3 and col < 4:
        dr = board[row, col] + board[row+1, col+1] + board[row+2, col+2] + board[row+3, col+3]
        if dr == 4:
            return "d-plus"
        elif dr == -4:
            return "d-minus"
    if row - 1 >= 0 and col - 1 >= 0 and row + 2 < 6 and col + 2 < 7:
        dr = board[row-1, col-1] + board[row, col] + board[row+1, col+1] + board[row+2, col+2]
        if dr == 4:
            return "d-plus"
        elif dr == -4:
            return "d-minus"
    if row - 2 >= 0 and col - 2 >= 0 and row + 1 < 6 and col + 1 < 7:
        dr = board[row-2, col-2] + board[row-1, col-1] + board[row, col] + board[row+1, col+1]
        if dr == 4:
            return "d-plus"
        elif dr == -4:
            return "d-minus"
    if row - 3 >= 0 and col - 3 >= 0:
        dr = board[row-3, col-3] + board[row-2, col-2] + board[row-1, col-1] + board[row, col]
        if dr == 4:
            return "d-plus"
        elif dr == -4:
            return "d-minus"
    # diag down-left
    if row + 3 < 6 and col - 3 >= 0:
        dl = board[row, col] + board[row+1, col-1] + board[row+2, col-2] + board[row+3, col-3]
        if dl == 4:
            return "d-plus"
        elif dl == -4:
            return "d-minus"
    if row - 1 >= 0 and col + 1 < 7 and row + 2 < 6 and col - 2 >= 0:
        dl = board[row-1, col+1] + board[row, col] + board[row+1, col-1] + board[row+2, col-2]
        if dl == 4:
            return "d-plus"
        elif dl == -4:
            return "d-minus"
    if row - 2 >= 0 and col + 2 < 7 and row + 1 < 6 and col - 1 >= 0:
        dl = board[row-2, col+2] + board[row-1, col+1] + board[row, col] + board[row+1, col-1]
        if dl == 4:
            return "d-plus"
        elif dl == -4:
            return "d-minus"
    if row - 3 >= 0 and col + 3 < 7:
        dl = board[row-3, col+3] + board[row-2, col+2] + board[row-1, col+1] + board[row, col]
        if dl == 4:
            return "d-plus"
        elif dl == -4:
            return "d-minus"
    return "nobody"

# Legal moves are columns whose top cell is empty.
def find_legal(board):
    return [i for i in range(7) if abs(board[0, i]) < 0.1]

# Check if the current player can win immediately by placing in some column.
def look_for_win(board_, color):
    board_copy = board_.copy()
    legal = find_legal(board_copy)
    for m in legal:
        bt = update_board(board_copy, color, m)
        wi = check_for_win(bt, m)
        if wi[2:] == color:
            return m
    return -1

# Avoid moves that let the opponent win immediately on their next turn.
def find_all_nonlosers(board, color):
    opp = "minus" if color == "plus" else "plus"
    legal = find_legal(board)
    poss_boards = [update_board(board, color, l) for l in legal]
    poss_legal = [find_legal(b) for b in poss_boards]
    allowed = []
    for i in range(len(legal)):
        # If opponent has a winning response, exclude this move.
        wins = [j for j in poss_legal[i] if check_for_win(update_board(poss_boards[i], opp, j), j) != "nobody"]
        if len(wins) == 0:
            allowed.append(legal[i])
    return allowed

# Backpropagate a rollout result through the visited path in the MCTS tree.
# Each state stores [visit_count, score], where score is net wins for root player.
def back_prop(winner, path, color0, md):
    for i in range(len(path)):
        board_temp = path[i]
        md[board_temp][0] += 1
        if winner[2] == color0[0]:
            # Root player win: alternating signs by ply.
            if i % 2 == 1:
                md[board_temp][1] += 1
            else:
                md[board_temp][1] -= 1
        elif winner[2] == "e":
            pass
        else:
            # Root player loss.
            if i % 2 == 1:
                md[board_temp][1] -= 1
            else:
                md[board_temp][1] += 1

# Random playout from a given board until someone wins or ties.
def rollout(board, next_player):
    winner = "nobody"
    player = next_player
    while winner == "nobody":
        legal = find_legal(board)
        if len(legal) == 0:
            return "tie"
        move = random.choice(legal)
        board = update_board(board, player, move)
        winner = check_for_win(board, move)
        player = "minus" if player == "plus" else "plus"
    return winner


## MCTS policy + value

Each MCTS run produces **raw stats** for the 7 root moves:
- `visits`: how many times MCTS selected that move at the root.
- `scores`: net wins from rollouts for that move (wins - losses).

From those raw stats we derive training targets:
- **policy** (length 7) = `visits / sum(visits)`
- **q** (length 7) = `scores / visits` (expected outcome per move, can be negative)
- **value** (length 1) = `sum(scores) / sum(visits)` (expected outcome for the board)

Important intuition: **policy is not the win rate**. It is the search preference.
UCB selects moves using a mix of exploration and win rate, so higher win-rate moves
tend to get more visits, but the policy is still an indirect signal. The direct win-rate
signal is in **q/value**.


In [14]:
def mcts_policy_value(board_temp, color0, nsteps):
    """Run MCTS and return per-move visit counts and net scores."""
    board = board_temp.copy()
    legal = find_legal(board)

    # Forced win: return a one-hot policy and +1 score for that move.
    win_column = look_for_win(board, color0)
    if win_column > -0.5:
        visits = np.zeros(7, dtype=np.float32)
        scores = np.zeros(7, dtype=np.float32)
        visits[win_column] = 1.0
        scores[win_column] = 1.0
        return visits, scores

    # Standard MCTS loop with UCB1 selection and random rollouts.
    mcts_dict = {tuple(board.ravel()): [0, 0]}
    for _ in range(nsteps):
        color = color0
        winner = "nobody"
        board_mcts = board.copy()
        path = [tuple(board_mcts.ravel())]
        while winner == "nobody":
            legal_loop = find_legal(board_mcts)
            if len(legal_loop) == 0:
                winner = "tie"
                back_prop(winner, path, color0, mcts_dict)
                break
            board_list = [tuple(update_board(board_mcts, color, col).ravel()) for col in legal_loop]
            for bl in board_list:
                if bl not in mcts_dict:
                    mcts_dict[bl] = [0, 0]
            # UCB1 balances exploration (unseen states) vs exploitation.
            ucb1 = np.zeros(len(legal_loop))
            for i in range(len(legal_loop)):
                num_denom = mcts_dict[board_list[i]]
                if num_denom[0] == 0:
                    ucb1[i] = 10 * nsteps  # force exploration of unseen states
                else:
                    ucb1[i] = num_denom[1] / num_denom[0] + 2 * np.sqrt(np.log(mcts_dict[path[-1]][0]) / mcts_dict[board_list[i]][0])
            chosen = int(np.argmax(ucb1))
            board_mcts = update_board(board_mcts, color, legal_loop[chosen])
            path.append(tuple(board_mcts.ravel()))
            winner = check_for_win(board_mcts, legal_loop[chosen])
            if winner[2] == color[0]:
                back_prop(winner, path, color0, mcts_dict)
                break
            color = "minus" if color == "plus" else "plus"
            if mcts_dict[tuple(board_mcts.ravel())][0] == 0:
                winner = rollout(board_mcts, color)
                back_prop(winner, path, color0, mcts_dict)
                break

    # Convert the root child stats into per-move visits and scores.
    visits = np.zeros(7, dtype=np.float32)
    scores = np.zeros(7, dtype=np.float32)
    for col in legal:
        child = tuple(update_board(board, color0, col).ravel())
        v, s = mcts_dict.get(child, [0, 0])
        visits[col] = float(v)
        scores[col] = float(s)
    return visits, scores


## Encoding helpers
We store boards from the perspective of the side to move (always "+1" in channel 0),
so every distribution is for the next player to move.


In [15]:
# Convert a 6x7 board of +1/-1/0 into two channels (plus, minus).
def encode_two_channel(board_6x7):
    plus = (board_6x7 == 1).astype(np.float32)
    minus = (board_6x7 == -1).astype(np.float32)
    return np.stack([plus, minus], axis=-1)

# Make a board viewpoint where the side-to-move is always encoded as plus.
# This keeps labels consistent for the next player to move.
def to_plus_perspective(board_6x7, player):
    if player == "plus":
        return board_6x7
    return -board_6x7


## Dataset management

Each row in the saved `.npz` includes:
- `boards[i]`: 6x7 board with +1 (to-move), -1 (opponent), 0 (empty).
- `X[i]`: 6x7x2 two-channel encoding (channel 0 = to-move, channel 1 = opponent).
- `policy[i]`: length-7 move probabilities from visits.
- `q[i]`: length-7 expected outcomes per move (score / visits).
- `value[i]`: single expected outcome for the board.
- `visits[i]`: length-7 raw visit counts.
- `scores[i]`: length-7 raw net scores.


## How these targets are used later

**Policy-only network**:
- Train the network to predict `policy` from `X`.
- At play time, choose the move by argmax (or sampling) from the predicted policy.

**Policy + value network (AlphaZero-style)**:
- Train a policy head on `policy` and a value head on `value`.
- During MCTS, use the policy head as priors and the value head to evaluate leaves.
- The final move is chosen by MCTS visit counts (not necessarily the raw policy).

**Per-move Q (optional)**:
- `q` can be used for diagnostics or as an auxiliary target, but is not required for AlphaZero.


In [16]:
def one_hot_visits(move):
    # One-hot visit vector for opening-book entries.
    visits = np.zeros(7, dtype=np.float32)
    visits[move] = 1.0
    return visits

def load_distribution(npz_path=NPZ_PATH):
    # Load existing rolling dataset if it exists.
    if not npz_path.exists():
        return {}
    npz = np.load(npz_path, allow_pickle=False)
    boards = npz["boards"]
    if "visits" in npz and "scores" in npz:
        visits = npz["visits"]
        scores = npz["scores"]
    elif "counts" in npz:
        # Backward compatibility: old datasets used counts without scores.
        visits = npz["counts"]
        scores = np.zeros_like(visits)
    else:
        raise ValueError("Dataset missing visits/scores")
    board_map = {}
    for i in range(boards.shape[0]):
        board_map[tuple(boards[i].ravel())] = (
            visits[i].astype(np.float32),
            scores[i].astype(np.float32),
        )
    return board_map

def save_distribution(board_map, npz_path=NPZ_PATH):
    # Save boards plus MCTS stats and derived targets to a compressed npz.
    keys = list(board_map.keys())
    if len(keys) == 0:
        boards = np.zeros((0, 6, 7), dtype=np.int8)
        visits = np.zeros((0, 7), dtype=np.float32)
        scores = np.zeros((0, 7), dtype=np.float32)
    else:
        boards = np.array(keys, dtype=np.int8).reshape(-1, 6, 7)
        visits = np.stack([board_map[k][0] for k in keys]).astype(np.float32)
        scores = np.stack([board_map[k][1] for k in keys]).astype(np.float32)
    # Policy target is just normalized visit counts.
    policy = visits.copy()
    row_sums = policy.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1.0
    policy = policy / row_sums
    # Q per move is score / visits (can be negative).
    q = np.zeros_like(scores)
    with np.errstate(divide="ignore", invalid="ignore"):
        mask = visits > 0
        q[mask] = scores[mask] / visits[mask]
    # Value per board is total score / total visits.
    total_visits = visits.sum(axis=1, keepdims=True)
    total_scores = scores.sum(axis=1, keepdims=True)
    value = np.zeros((visits.shape[0], 1), dtype=np.float32)
    mask_v = total_visits[:, 0] > 0
    value[mask_v, 0] = (total_scores[mask_v, 0] / total_visits[mask_v, 0]).astype(np.float32)
    # Two-channel encoding for model input.
    X = np.stack([encode_two_channel(b) for b in boards], axis=0).astype(np.float32)
    np.savez_compressed(
        npz_path,
        boards=boards,
        visits=visits,
        scores=scores,
        policy=policy,
        q=q,
        value=value,
        X=X,
    )

def add_position(board_map, board_plus, visits, scores):
    # Add visit and score stats to the running totals for this board.
    key = tuple(board_plus.ravel())
    if key not in board_map:
        board_map[key] = (visits.astype(np.float32), scores.astype(np.float32))
        return True
    v, s = board_map[key]
    board_map[key] = (v + visits.astype(np.float32), s + scores.astype(np.float32))
    return False

def add_opening_book(board_map):
    # Empty board (no pieces).
    empty = np.zeros((6, 7), dtype=np.int8)
    visits = one_hot_visits(OPENING_BOOK_LABELS[-1])
    scores = np.zeros(7, dtype=np.float32)
    add_position(board_map, empty, visits, scores)
    # Single opponent move in column col (opponent is minus, side to move is plus).
    for col, move in OPENING_BOOK_LABELS.items():
        if col == -1:
            continue
        board = np.zeros((6, 7), dtype=np.int8)
        board = update_board(board, "minus", col)
        visits = one_hot_visits(move)
        scores = np.zeros(7, dtype=np.float32)
        add_position(board_map, board, visits, scores)


## Self-play generator

- First two moves are random and **not recorded** to preserve opening randomness.
- From move 3 onward, we record the board and add the MCTS root distribution.
- Each recorded board is stored in plus-perspective so labels always target the player to move.
- For fast coverage, we **sample** moves from the distribution instead of always taking argmax.


In [17]:
def generate_distribution(
    target_unique=TARGET_UNIQUE,
    nsteps=DEFAULT_MCTS_STEPS,
    checkpoint_every=CHECKPOINT_EVERY,
    backup_every=BACKUP_EVERY,
    seed=0,
    report_every_games=REPORT_EVERY_GAMES,
    report_every_seconds=REPORT_EVERY_SECONDS,
    play_mode="sample",  # "sample" for coverage, "argmax" for strength
):
    # Seed randomness for reproducibility.
    random.seed(seed)
    np.random.seed(seed)

    # Load existing dataset if present.
    board_map = load_distribution()
    BACKUP_DIR.mkdir(parents=True, exist_ok=True)
    # Only inject the opening book on a fresh dataset.
    if not NPZ_PATH.exists():
        add_opening_book(board_map)

    unique_start = len(board_map)
    new_unique = 0
    game_idx = 0
    start_time = time.time()
    last_report_time = start_time
    last_report_games = 0
    last_report_unique = unique_start

    while len(board_map) < target_unique:
        game_idx += 1
        board = np.zeros((6, 7), dtype=np.int8)
        player = "plus"

        # Two random opening moves, unrecorded.
        # This forces early-game diversity without polluting labels.
        for _ in range(2):
            legal = find_legal(board)
            if len(legal) == 0:
                break
            move = random.choice(legal)
            board = update_board(board, player, move)
            winner = check_for_win(board, move)
            if winner != "nobody":
                break
            player = "minus" if player == "plus" else "plus"

        # Continue game from move 3 onward, recording positions.
        # Every recorded board adds visits/scores to its running totals.
        winner = "nobody"
        while winner == "nobody":
            legal = find_legal(board)
            if len(legal) == 0:
                break
            # Encode board from side-to-move perspective.
            board_plus = to_plus_perspective(board, player)
            # MCTS always assumes the root player is "plus" in this view.
            visits, scores = mcts_policy_value(board_plus, "plus", nsteps)
            is_new = add_position(board_map, board_plus, visits, scores)
            if is_new:
                new_unique += 1
                if new_unique % checkpoint_every == 0:
                    save_distribution(board_map)
                    if new_unique % backup_every == 0:
                        backup_path = BACKUP_DIR / f"mcts_distribution_{len(board_map):06d}.npz"
                        shutil.copy2(NPZ_PATH, backup_path)
                        print(f"Backup saved: {backup_path.name}")
                    elapsed = time.time() - start_time
                    boards_per_min = len(board_map) / max(elapsed, 1.0) * 60.0
                    print(
                        f"Checkpoint: total_unique={len(board_map)} (new {new_unique}) | "
                        f"games={game_idx} | elapsed={elapsed/60:.1f}m | boards/min={boards_per_min:.2f}"
                    )

            # Choose a move for self-play.
            # Sampling increases coverage; argmax increases strength.
            policy = visits.copy()
            if policy.sum() > 0:
                policy = policy / policy.sum()
            else:
                policy[legal] = 1.0 / len(legal)
            if play_mode == "sample":
                move = int(np.random.choice(7, p=policy))
            else:
                move = int(np.argmax(policy))
            if move not in legal:
                move = random.choice(legal)
            board = update_board(board, player, move)
            winner = check_for_win(board, move)
            player = "minus" if player == "plus" else "plus"

        now = time.time()
        if (game_idx % report_every_games == 0) or (now - last_report_time >= report_every_seconds):
            elapsed = now - start_time
            interval = max(now - last_report_time, 1e-6)
            games_delta = game_idx - last_report_games
            unique_delta = len(board_map) - last_report_unique
            games_per_min = games_delta / interval * 60.0
            boards_per_min = unique_delta / interval * 60.0
            overall_boards_per_min = len(board_map) / max(elapsed, 1.0) * 60.0
            remaining = max(target_unique - len(board_map), 0)
            eta_min = remaining / max(overall_boards_per_min, 1e-6)
            print(
                f"Progress: total_unique={len(board_map)} | games={game_idx} | "
                f"elapsed={elapsed/60:.1f}m | interval boards/min={boards_per_min:.2f} | "
                f"overall boards/min={overall_boards_per_min:.2f} | ETA={eta_min:.1f}m"
            )
            last_report_time = now
            last_report_games = game_idx
            last_report_unique = len(board_map)

    save_distribution(board_map)
    total_elapsed = time.time() - start_time
    print(
        f"Done. Unique boards: {len(board_map)} (start {unique_start}) | "
        f"games={game_idx} | elapsed={total_elapsed/60:.1f}m"
    )


## Run
Uncomment to start generation.


In [None]:
# generate_distribution(
#     target_unique=1_200_000,
#     nsteps=200,
#     checkpoint_every=100,
#     backup_every=10000,
#     seed=85,
#     # seeds used so far: [0, 85]
#     report_every_games=25,
#     report_every_seconds=60,
#     play_mode="sample",
# )


### Refinement pass
Use rescoring to improve the **most common** boards after a large fast-coverage run.
This keeps compute focused where the network will see the most examples.


## Inspect most-sampled boards
Use this to confirm that some boards are visited more often than others.


In [18]:
def top_k_by_samples(k=10, include_opening=False):
    """Print the top-K boards by total visits, with full entry details."""
    board_map = load_distribution()
    npz_path = Path(NPZ_PATH_OVERRIDE) if NPZ_PATH_OVERRIDE else NPZ_PATH
    npz = np.load(npz_path)

    # Build lookup from board -> row index in npz.
    board_index = {tuple(b.ravel()): i for i, b in enumerate(npz["boards"])}

    items = []
    for key, stats in board_map.items():
        visits, scores = stats
        board = np.array(key, dtype=np.int8).reshape(6, 7)
        if not include_opening and is_opening_board(board):
            continue
        items.append((visits.sum(), key))
    items.sort(key=lambda x: x[0], reverse=True)

    for i, (total, key) in enumerate(items[:k]):
        idx = board_index[key]
        print(f"rank={i+1} total_visits={total:.1f} row={idx}")
        print("boards:\n", npz["boards"][idx])
        print("X shape:", npz["X"][idx].shape)
        print("X channel0 sum:", float(npz["X"][idx][:,:,0].sum()), "channel1 sum:", float(npz["X"][idx][:,:,1].sum()))
        print("visits:", npz["visits"][idx].tolist())
        print("scores:", npz["scores"][idx].tolist())
        print("policy:", npz["policy"][idx].tolist())
        print("q:", npz["q"][idx].tolist())
        print("value:", float(npz["value"][idx][0]))
        print()

# Example:
# top_k_by_samples(k=10, include_opening=False)


## Re-score existing boards

Use this to re-evaluate only the existing boards with a new `nsteps` (e.g., 1500 or 3000).
Opening-book boards are skipped. This updates the distributions without adding new boards.


In [26]:
def is_opening_board(board):
    # Empty board.
    if np.all(board == 0):
        return True
    # Single opponent piece in bottom row.
    if np.count_nonzero(board) != 1:
        return False
    row, col = np.argwhere(board != 0)[0]
    return row == 5 and board[row, col] == -1

def rescore_existing_boards(
    nsteps,
    seed=0,
    passes=1,
    top_k=None,  # if set, rescore only the most-sampled K boards
    checkpoint_every=1000,
    report_every_seconds=60,
    npz_path=NPZ_PATH,
):
    # Re-run MCTS on the current dataset to refine policy + value targets.
    random.seed(seed)
    np.random.seed(seed)

    board_map = load_distribution(npz_path=npz_path)
    keys = [k for k in board_map.keys() if not is_opening_board(np.array(k, dtype=np.int8).reshape(6, 7))]
    # Rank by how often each board has been sampled (sum of visits).
    if top_k is not None and top_k < len(keys):
        keys = sorted(keys, key=lambda k: board_map[k][0].sum(), reverse=True)[:top_k]
    if len(keys) == 0:
        print("No non-opening boards to rescore.")
        return

    start_time = time.time()
    last_report = start_time
    total = len(keys) * passes
    processed = 0

    for p in range(passes):
        np.random.shuffle(keys)
        for key in keys:
            board_plus = np.array(key, dtype=np.int8).reshape(6, 7)
            visits, scores = mcts_policy_value(board_plus, "plus", nsteps)
            v, s = board_map[key]
            board_map[key] = (v + visits.astype(np.float32), s + scores.astype(np.float32))
            processed += 1

            if processed % checkpoint_every == 0:
                save_distribution(board_map, npz_path=npz_path)
                elapsed = time.time() - start_time
                per_min = processed / max(elapsed, 1.0) * 60.0
                print(f"Checkpoint: rescored={processed}/{total} | elapsed={elapsed/60:.1f}m | boards/min={per_min:.2f}")

            now = time.time()
            if now - last_report >= report_every_seconds:
                elapsed = now - start_time
                per_min = processed / max(elapsed, 1.0) * 60.0
                remaining = total - processed
                eta_min = remaining / max(per_min, 1e-6)
                print(f"Progress: rescored={processed}/{total} | elapsed={elapsed/60:.1f}m | boards/min={per_min:.2f} | ETA={eta_min:.1f}m")
                last_report = now

    save_distribution(board_map, npz_path=npz_path)
    elapsed = time.time() - start_time
    print(f"Done. Rescored {processed} boards in {elapsed/60:.1f}m")


### Run rescore (optional)
Uncomment to re-evaluate existing boards with higher nsteps.


In [None]:
# rescore_existing_boards(
#     nsteps=3000,
#     seed=23,
#     passes=1,
#     top_k=61000,
#     checkpoint_every=1000,
#     report_every_seconds=60,
# )
