# Connect4 MCTS Distribution Generator

- This notebook first builds a single rolling `mcts_distribution.npz` that aggregates MCTS root distributions per board.  
    - The MCTS generated approx 500,000 unique board positions at rollout nsteps 2000.
- Each data entry consists of the following:
    - 6x7 Connect 4 board grid with +1 (player to move), -1 (opponent), 0 (empty).
    - 6x7x2 Encoding (channel 0 = to-move; channel 1 = opponent).
    - Length-7 raw visit count from the MCTS process (one for each possible connect 4 column).
    - Length-7 net win scores from MCTS.
    - Length-7 "policy", which is normalized visits and represents move probabilities.
    - Length-7 "Q", which is the per move expected outcome.
    - Length-1 "value", or single expected outcome for the board.
- The last part of the code are cells to view random samples of boards from various stages of the game. 

**Other Key ideas:**
- The number of random opening moves is sampled from a skewed distribution (4-14, mode 6) and not recorded. This preserves opening diversity.
- After that, each board is labeled with a distribution over 7 moves for the player to move.
- The code was run a few times with different numbers of RANDOM_OPENING_MOVES to try to increase strength in early, mid, and late stage boards
- Repeated boards are accumulated (no deduping) so the distribution becomes an average over many MCTS runs.
- Connect 4 has symmetry, so boards and their mirrors are equivalent. 


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [20]:
from pathlib import Path
import json
import random
import shutil
import time
import numpy as np

# Make arrays print more readably in notebook output.
np.set_printoptions(suppress=True, linewidth=120)

# Output location for the rolling dataset.
# Always resolve relative to the project root to avoid duplicate paths.
PROJECT_ROOT = Path.cwd()
# If running from guided_mcts/, use the parent as project root.
if PROJECT_ROOT.name in {"dist_workflow", "guided_mcts"}:
    PROJECT_ROOT = PROJECT_ROOT.parent
DATA_DIR = PROJECT_ROOT / "dist_workflow" / "mcts_distribution"
DATA_DIR.mkdir(parents=True, exist_ok=True)
NPZ_PATH = DATA_DIR / "mcts_distribution.npz"

# MCTS and dataset settings.
DEFAULT_MCTS_STEPS = 2000  # fast coverage; refine later with higher nsteps

# This section sets up the random opening moves into a distribution
# The distribution is centered at 6, but heavy right-tailed to ensure mid and late game coverage
OPENING_MIN = 4
OPENING_MODE = 6
OPENING_MAX = 14
OPENING_LEFT_DECAY = 1.8  # faster drop on the left tail
OPENING_RIGHT_DECAY = 6.5  # slower drop on the right tail
OPENING_MOVES = np.arange(OPENING_MIN, OPENING_MAX + 1)
OPENING_WEIGHTS = np.where(
    OPENING_MOVES <= OPENING_MODE,
    np.exp(-(OPENING_MODE - OPENING_MOVES) / OPENING_LEFT_DECAY),
    np.exp(-(OPENING_MOVES - OPENING_MODE) / OPENING_RIGHT_DECAY),
)
OPENING_WEIGHTS = OPENING_WEIGHTS / OPENING_WEIGHTS.sum()

def sample_opening_moves():
    return int(np.random.choice(OPENING_MOVES, p=OPENING_WEIGHTS))
TARGET_UNIQUE = 1_000_000
CHECKPOINT_EVERY = 100  # save after every N new unique boards
REPORT_EVERY_GAMES = 25
REPORT_EVERY_SECONDS = 60
BACKUP_EVERY = 10_000
BACKUP_DIR = PROJECT_ROOT / "dist_workflow" / "recovery_copies"



In [21]:
# Quick checker: print the opening-move distribution
def show_opening_distribution():
    for moves, weight in zip(OPENING_MOVES, OPENING_WEIGHTS):
        print(f"{moves}: {weight:.3f}")

show_opening_distribution()


4: 0.053
5: 0.093
6: 0.162
7: 0.139
8: 0.119
9: 0.102
10: 0.088
11: 0.075
12: 0.064
13: 0.055
14: 0.047


## Core game logic
These helpers implement Connect4 mechanics used by MCTS and self-play.


In [22]:
# Drop a checker into the specified column for the given color and return the new board.
# color: "plus" or "minus". Board is 6x7 with values +1, -1, or 0.
def update_board(board_temp, color, column):
    # Work on a copy so callers do not mutate the original board.
    board = board_temp.copy()
    nrow = board.shape[0]
    # Count how many cells are occupied in this column to find the landing row.
    colsum = np.abs(board[:, column]).sum()
    row = int(nrow - 1 - colsum)
    if row > -0.5:
        board[row, column] = 1 if color == "plus" else -1
    return board

# Fast win check that only inspects lines through the last-played column.
def check_for_win(board, col):
    nrow, ncol = board.shape
    colsum = np.abs(board[:, col]).sum()
    row = int(nrow - colsum)
    # vertical
    if row + 3 < nrow:
        vert = board[row, col] + board[row+1, col] + board[row+2, col] + board[row+3, col]
        if vert == 4:
            return "v-plus"
        elif vert == -4:
            return "v-minus"
    # horizontal (four spans through this column)
    if col + 3 < ncol:
        hor = board[row, col] + board[row, col+1] + board[row, col+2] + board[row, col+3]
        if hor == 4:
            return "h-plus"
        elif hor == -4:
            return "h-minus"
    if col - 1 >= 0 and col + 2 < ncol:
        hor = board[row, col-1] + board[row, col] + board[row, col+1] + board[row, col+2]
        if hor == 4:
            return "h-plus"
        elif hor == -4:
            return "h-minus"
    if col - 2 >= 0 and col + 1 < ncol:
        hor = board[row, col-2] + board[row, col-1] + board[row, col] + board[row, col+1]
        if hor == 4:
            return "h-plus"
        elif hor == -4:
            return "h-minus"
    if col - 3 >= 0:
        hor = board[row, col-3] + board[row, col-2] + board[row, col-1] + board[row, col]
        if hor == 4:
            return "h-plus"
        elif hor == -4:
            return "h-minus"
    # diag down-right
    if row < 3 and col < 4:
        dr = board[row, col] + board[row+1, col+1] + board[row+2, col+2] + board[row+3, col+3]
        if dr == 4:
            return "d-plus"
        elif dr == -4:
            return "d-minus"
    if row - 1 >= 0 and col - 1 >= 0 and row + 2 < 6 and col + 2 < 7:
        dr = board[row-1, col-1] + board[row, col] + board[row+1, col+1] + board[row+2, col+2]
        if dr == 4:
            return "d-plus"
        elif dr == -4:
            return "d-minus"
    if row - 2 >= 0 and col - 2 >= 0 and row + 1 < 6 and col + 1 < 7:
        dr = board[row-2, col-2] + board[row-1, col-1] + board[row, col] + board[row+1, col+1]
        if dr == 4:
            return "d-plus"
        elif dr == -4:
            return "d-minus"
    if row - 3 >= 0 and col - 3 >= 0:
        dr = board[row-3, col-3] + board[row-2, col-2] + board[row-1, col-1] + board[row, col]
        if dr == 4:
            return "d-plus"
        elif dr == -4:
            return "d-minus"
    # diag down-left
    if row + 3 < 6 and col - 3 >= 0:
        dl = board[row, col] + board[row+1, col-1] + board[row+2, col-2] + board[row+3, col-3]
        if dl == 4:
            return "d-plus"
        elif dl == -4:
            return "d-minus"
    if row - 1 >= 0 and col + 1 < 7 and row + 2 < 6 and col - 2 >= 0:
        dl = board[row-1, col+1] + board[row, col] + board[row+1, col-1] + board[row+2, col-2]
        if dl == 4:
            return "d-plus"
        elif dl == -4:
            return "d-minus"
    if row - 2 >= 0 and col + 2 < 7 and row + 1 < 6 and col - 1 >= 0:
        dl = board[row-2, col+2] + board[row-1, col+1] + board[row, col] + board[row+1, col-1]
        if dl == 4:
            return "d-plus"
        elif dl == -4:
            return "d-minus"
    if row - 3 >= 0 and col + 3 < 7:
        dl = board[row-3, col+3] + board[row-2, col+2] + board[row-1, col+1] + board[row, col]
        if dl == 4:
            return "d-plus"
        elif dl == -4:
            return "d-minus"
    return "nobody"

# Legal moves are columns whose top cell is empty.
def find_legal(board):
    return [i for i in range(7) if abs(board[0, i]) < 0.1]

# Check if the current player can win immediately by placing in some column.
def look_for_win(board_, color):
    board_copy = board_.copy()
    legal = find_legal(board_copy)
    for m in legal:
        bt = update_board(board_copy, color, m)
        wi = check_for_win(bt, m)
        if wi[2:] == color:
            return m
    return -1

# Avoid moves that let the opponent win immediately on their next turn.
def find_all_nonlosers(board, color):
    opp = "minus" if color == "plus" else "plus"
    legal = find_legal(board)
    poss_boards = [update_board(board, color, l) for l in legal]
    poss_legal = [find_legal(b) for b in poss_boards]
    allowed = []
    for i in range(len(legal)):
        # If opponent has a winning response, exclude this move.
        wins = [j for j in poss_legal[i] if check_for_win(update_board(poss_boards[i], opp, j), j) != "nobody"]
        if len(wins) == 0:
            allowed.append(legal[i])
    return allowed

# Backpropagate a rollout result through the visited path in the MCTS tree.
# Each state stores [visit_count, score], where score is net wins for root player.
def back_prop(winner, path, color0, md):
    for i in range(len(path)):
        board_temp = path[i]
        md[board_temp][0] += 1
        if winner[2] == color0[0]:
            # Root player win: alternating signs by ply.
            if i % 2 == 1:
                md[board_temp][1] += 1
            else:
                md[board_temp][1] -= 1
        elif winner[2] == "e":
            pass
        else:
            # Root player loss.
            if i % 2 == 1:
                md[board_temp][1] -= 1
            else:
                md[board_temp][1] += 1

# Random playout from a given board until someone wins or ties.
def rollout(board, next_player):
    winner = "nobody"
    player = next_player
    while winner == "nobody":
        legal = find_legal(board)
        if len(legal) == 0:
            return "tie"
        move = random.choice(legal)
        board = update_board(board, player, move)
        winner = check_for_win(board, move)
        player = "minus" if player == "plus" else "plus"
    return winner


## MCTS policy + value

Each MCTS run produces **raw stats** for the 7 root moves:
- `visits`: how many times MCTS selected that move at the root.
- `scores`: net wins from rollouts for that move (wins - losses).

From those raw stats we derive training targets:
- **policy** (length 7) = `visits / sum(visits)`
- **q** (length 7) = `scores / visits` (expected outcome per move, can be negative)
- **value** (length 1) = `sum(scores) / sum(visits)` (expected outcome for the board)

Important intuition: **policy is not the win rate**. It is the search preference.
UCB selects moves using a mix of exploration and win rate, so higher win-rate moves
tend to get more visits, but the policy is still an indirect signal. The direct win-rate
signal is in **q/value**.


In [23]:
def mcts_policy_value(board_temp, color0, nsteps):
    """Run MCTS and return per-move visit counts and net scores."""
    board = board_temp.copy()
    legal = find_legal(board)

    # Forced win: return a one-hot policy and +1 score for that move.
    win_column = look_for_win(board, color0)
    if win_column > -0.5:
        visits = np.zeros(7, dtype=np.float32)
        scores = np.zeros(7, dtype=np.float32)
        visits[win_column] = 1.0
        scores[win_column] = 1.0
        return visits, scores

    # Standard MCTS loop with UCB1 selection and random rollouts.
    mcts_dict = {tuple(board.ravel()): [0, 0]}
    for _ in range(nsteps):
        color = color0
        winner = "nobody"
        board_mcts = board.copy()
        path = [tuple(board_mcts.ravel())]
        while winner == "nobody":
            legal_loop = find_legal(board_mcts)
            if len(legal_loop) == 0:
                winner = "tie"
                back_prop(winner, path, color0, mcts_dict)
                break
            board_list = [tuple(update_board(board_mcts, color, col).ravel()) for col in legal_loop]
            for bl in board_list:
                if bl not in mcts_dict:
                    mcts_dict[bl] = [0, 0]
            # UCB1 balances exploration (unseen states) vs exploitation.
            ucb1 = np.zeros(len(legal_loop))
            for i in range(len(legal_loop)):
                num_denom = mcts_dict[board_list[i]]
                if num_denom[0] == 0:
                    ucb1[i] = 10 * nsteps  # force exploration of unseen states
                else:
                    ucb1[i] = num_denom[1] / num_denom[0] + 2 * np.sqrt(np.log(mcts_dict[path[-1]][0]) / mcts_dict[board_list[i]][0])
            chosen = int(np.argmax(ucb1))
            board_mcts = update_board(board_mcts, color, legal_loop[chosen])
            path.append(tuple(board_mcts.ravel()))
            winner = check_for_win(board_mcts, legal_loop[chosen])
            if winner[2] == color[0]:
                back_prop(winner, path, color0, mcts_dict)
                break
            color = "minus" if color == "plus" else "plus"
            if mcts_dict[tuple(board_mcts.ravel())][0] == 0:
                winner = rollout(board_mcts, color)
                back_prop(winner, path, color0, mcts_dict)
                break

    # Convert the root child stats into per-move visits and scores.
    visits = np.zeros(7, dtype=np.float32)
    scores = np.zeros(7, dtype=np.float32)
    for col in legal:
        child = tuple(update_board(board, color0, col).ravel())
        v, s = mcts_dict.get(child, [0, 0])
        visits[col] = float(v)
        scores[col] = float(s)
    return visits, scores


## Encoding helpers
We store boards from the perspective of the side to move (always "+1" in channel 0),
so every distribution is for the next player to move.


In [24]:
# Convert a 6x7 board of +1/-1/0 into two channels (plus, minus).
def encode_two_channel(board_6x7):
    plus = (board_6x7 == 1).astype(np.float32)
    minus = (board_6x7 == -1).astype(np.float32)
    return np.stack([plus, minus], axis=-1)

# Make a board viewpoint where the side-to-move is always encoded as plus.
# This keeps labels consistent for the next player to move.
def to_plus_perspective(board_6x7, player):
    if player == "plus":
        return board_6x7
    return -board_6x7


## Dataset management

Each row in the saved `.npz` includes:
- `boards[i]`: 6x7 board with +1 (to-move), -1 (opponent), 0 (empty).
- `X[i]`: 6x7x2 two-channel encoding (channel 0 = to-move, channel 1 = opponent).
- `policy[i]`: length-7 move probabilities from visits.
- `q[i]`: length-7 expected outcomes per move (score / visits).
- `value[i]`: single expected outcome for the board.
- `visits[i]`: length-7 raw visit counts.
- `scores[i]`: length-7 raw net scores.


## How these targets are used later

**Policy-only network**:
- Train the network to predict `policy` from `X`.
- At play time, choose the move by argmax (or sampling) from the predicted policy.

**Policy + value network (AlphaZero-style)**:
- Train a policy head on `policy` and a value head on `value`.
- During MCTS, use the policy head as priors and the value head to evaluate leaves.
- The final move is chosen by MCTS visit counts (not necessarily the raw policy).

**Per-move Q (optional)**:
- `q` can be used for diagnostics or as an auxiliary target, but is not required for AlphaZero.


In [25]:
def load_distribution(npz_path=NPZ_PATH):
    # Load existing rolling dataset if it exists.
    if not npz_path.exists():
        return {}
    npz = np.load(npz_path, allow_pickle=False)
    boards = npz["boards"]
    if "visits" in npz and "scores" in npz:
        visits = npz["visits"]
        scores = npz["scores"]
    elif "counts" in npz:
        # Backward compatibility: old datasets used counts without scores.
        visits = npz["counts"]
        scores = np.zeros_like(visits)
    else:
        raise ValueError("Dataset missing visits/scores")
    board_map = {}
    for i in range(boards.shape[0]):
        board_map[tuple(boards[i].ravel())] = (
            visits[i].astype(np.float32),
            scores[i].astype(np.float32),
        )
    return board_map

def save_distribution(board_map, npz_path=NPZ_PATH):
    # Save boards plus MCTS stats and derived targets to a compressed npz.
    keys = list(board_map.keys())
    if len(keys) == 0:
        boards = np.zeros((0, 6, 7), dtype=np.int8)
        visits = np.zeros((0, 7), dtype=np.float32)
        scores = np.zeros((0, 7), dtype=np.float32)
    else:
        boards = np.array(keys, dtype=np.int8).reshape(-1, 6, 7)
        visits = np.stack([board_map[k][0] for k in keys]).astype(np.float32)
        scores = np.stack([board_map[k][1] for k in keys]).astype(np.float32)
    # Policy target is just normalized visit counts.
    policy = visits.copy()
    row_sums = policy.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1.0
    policy = policy / row_sums
    # Q per move is score / visits (can be negative).
    q = np.zeros_like(scores)
    with np.errstate(divide="ignore", invalid="ignore"):
        mask = visits > 0
        q[mask] = scores[mask] / visits[mask]
    # Value per board is total score / total visits.
    total_visits = visits.sum(axis=1, keepdims=True)
    total_scores = scores.sum(axis=1, keepdims=True)
    value = np.zeros((visits.shape[0], 1), dtype=np.float32)
    mask_v = total_visits[:, 0] > 0
    value[mask_v, 0] = (total_scores[mask_v, 0] / total_visits[mask_v, 0]).astype(np.float32)
    # Two-channel encoding for model input.
    X = np.stack([encode_two_channel(b) for b in boards], axis=0).astype(np.float32)
    np.savez_compressed(
        npz_path,
        boards=boards,
        visits=visits,
        scores=scores,
        policy=policy,
        q=q,
        value=value,
        X=X,
    )

def add_position(board_map, board_plus, visits, scores):
    # Add visit and score stats to the running totals for this board.
    key = tuple(board_plus.ravel())
    if key not in board_map:
        board_map[key] = (visits.astype(np.float32), scores.astype(np.float32))
        return True
    v, s = board_map[key]
    board_map[key] = (v + visits.astype(np.float32), s + scores.astype(np.float32))
    return False



## Self-play generator

- A sampled number of opening moves (4-14, mode 6) are random and **not recorded** to preserve opening diversity.
- We do **not** filter out moves that allow an immediate opponent win during self-play. This keeps the policy targets honest and improves robustness, even if some moves are blunders.
- After the random opening phase, we record the board and add the MCTS root distribution.
- Each recorded board is stored in plus-perspective so labels always target the player to move.
- For fast coverage, we **sample** moves from the distribution instead of always taking argmax.


In [26]:
def generate_distribution(
    target_unique=TARGET_UNIQUE,
    nsteps=DEFAULT_MCTS_STEPS,
    checkpoint_every=CHECKPOINT_EVERY,
    backup_every=BACKUP_EVERY,
    seed=0,
    report_every_games=REPORT_EVERY_GAMES,
    report_every_seconds=REPORT_EVERY_SECONDS,
    play_mode="sample",  # "sample" for coverage, "argmax" for strength
):
    # Seed randomness for reproducibility.
    random.seed(seed)
    np.random.seed(seed)

    # Load existing dataset if present.
    board_map = load_distribution()
    BACKUP_DIR.mkdir(parents=True, exist_ok=True)
    unique_start = len(board_map)
    new_unique = 0
    game_idx = 0
    start_time = time.time()
    last_report_time = start_time
    last_report_games = 0
    last_report_unique = unique_start

    while len(board_map) < target_unique:
        game_idx += 1
        board = np.zeros((6, 7), dtype=np.int8)
        player = "plus"

        # Random opening moves, unrecorded.
        # This forces early-game diversity without polluting labels.
        for _ in range(sample_opening_moves()):
            legal = find_legal(board)
            if len(legal) == 0:
                break
            move = random.choice(legal)
            board = update_board(board, player, move)
            winner = check_for_win(board, move)
            if winner != "nobody":
                break
            player = "minus" if player == "plus" else "plus"

        # Continue game after the opening phase, recording positions.
        # Every recorded board adds visits/scores to its running totals.
        winner = "nobody"
        while winner == "nobody":
            legal = find_legal(board)
            if len(legal) == 0:
                break
            # Encode board from side-to-move perspective.
            board_plus = to_plus_perspective(board, player)
            # MCTS always assumes the root player is "plus" in this view.
            visits, scores = mcts_policy_value(board_plus, "plus", nsteps)
            is_new = add_position(board_map, board_plus, visits, scores)
            if is_new:
                new_unique += 1
                if new_unique % checkpoint_every == 0:
                    save_distribution(board_map)
                    if new_unique % backup_every == 0:
                        backup_path = BACKUP_DIR / f"mcts_distribution_{len(board_map):06d}.npz"
                        shutil.copy2(NPZ_PATH, backup_path)
                        print(f"Backup saved: {backup_path.name}")
                    elapsed = time.time() - start_time
                    boards_per_min = len(board_map) / max(elapsed, 1.0) * 60.0
                    print(
                        f"Checkpoint: total_unique={len(board_map)} (new {new_unique}) | "
                        f"games={game_idx} | elapsed={elapsed/60:.1f}m | boards/min={boards_per_min:.2f}"
                    )

            # Choose a move for self-play.
            # Sampling increases coverage; argmax increases strength.
            policy = visits.copy()
            if policy.sum() > 0:
                policy = policy / policy.sum()
            else:
                policy[legal] = 1.0 / len(legal)
            if play_mode == "sample":
                move = int(np.random.choice(7, p=policy))
            else:
                move = int(np.argmax(policy))
            if move not in legal:
                move = random.choice(legal)
            board = update_board(board, player, move)
            winner = check_for_win(board, move)
            player = "minus" if player == "plus" else "plus"

        now = time.time()
        if (game_idx % report_every_games == 0) or (now - last_report_time >= report_every_seconds):
            elapsed = now - start_time
            interval = max(now - last_report_time, 1e-6)
            games_delta = game_idx - last_report_games
            unique_delta = len(board_map) - last_report_unique
            games_per_min = games_delta / interval * 60.0
            boards_per_min = unique_delta / interval * 60.0
            overall_boards_per_min = len(board_map) / max(elapsed, 1.0) * 60.0
            remaining = max(target_unique - len(board_map), 0)
            eta_min = remaining / max(overall_boards_per_min, 1e-6)
            print(
                f"Progress: total_unique={len(board_map)} | games={game_idx} | "
                f"elapsed={elapsed/60:.1f}m | interval boards/min={boards_per_min:.2f} | "
                f"overall boards/min={overall_boards_per_min:.2f} | ETA={eta_min:.1f}m"
            )
            last_report_time = now
            last_report_games = game_idx
            last_report_unique = len(board_map)

    save_distribution(board_map)
    total_elapsed = time.time() - start_time
    print(
        f"Done. Unique boards: {len(board_map)} (start {unique_start}) | "
        f"games={game_idx} | elapsed={total_elapsed/60:.1f}m"
    )


## Run
Uncomment to start generation.


In [27]:
generate_distribution(
    target_unique=1_000_000,
    nsteps=2000,
    checkpoint_every=100,
    backup_every=10_000,
    seed=54,
    # seeds used so far: [85, 23, 3, 84, 86, 54]
    report_every_games=25,
    report_every_seconds=60,
    play_mode="sample",
)


Checkpoint: total_unique=687611 (new 100) | games=11 | elapsed=1.1m | boards/min=636836.51
Progress: total_unique=687628 | games=11 | elapsed=1.2m | interval boards/min=97.06 | overall boards/min=570463.27 | ETA=0.5m
Checkpoint: total_unique=687711 (new 200) | games=17 | elapsed=2.1m | boards/min=332336.19
Progress: total_unique=687745 | games=18 | elapsed=2.3m | interval boards/min=103.55 | overall boards/min=294496.87 | ETA=1.1m
Checkpoint: total_unique=687811 (new 300) | games=23 | elapsed=3.1m | boards/min=224706.43
Progress: total_unique=687836 | games=25 | elapsed=3.3m | interval boards/min=97.20 | overall boards/min=210250.56 | ETA=1.5m
Checkpoint: total_unique=687911 (new 400) | games=35 | elapsed=4.1m | boards/min=166331.46
Progress: total_unique=687927 | games=36 | elapsed=4.3m | interval boards/min=90.16 | overall boards/min=160700.95 | ETA=1.9m
Checkpoint: total_unique=688011 (new 500) | games=44 | elapsed=5.2m | boards/min=132602.16
Progress: total_unique=688027 | games=44

Traceback (most recent call last):
  File "/opt/anaconda3/envs/workbench/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3701, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/dh/4l0nymtn4znckrgkqw4lfytm0000gn/T/ipykernel_60421/2909685135.py", line 1, in <module>
    generate_distribution(
  File "/var/folders/dh/4l0nymtn4znckrgkqw4lfytm0000gn/T/ipykernel_60421/2958018659.py", line 54, in generate_distribution
    visits, scores = mcts_policy_value(board_plus, "plus", nsteps)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/dh/4l0nymtn4znckrgkqw4lfytm0000gn/T/ipykernel_60421/552671075.py", line None, in mcts_policy_value
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/anaconda3/envs/workbench/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 2201, in showtraceback
    stb = self.Inter

## Inspect most-sampled boards
Use this to see samples of boards from various stages of the game.


In [37]:
# Inspect random samples from different game stages (by move count).
import random

def _load_npz(npz_path_override=None):
    npz_path = Path(npz_path_override) if npz_path_override else NPZ_PATH
    return np.load(npz_path)

def _board_move_count(board):
    # number of non-empty cells
    return int((board != 0).sum())

def _print_board_entry(npz, idx, label=None):
    if label:
        print(label)
    print(f"row={idx}")
    total_visits = float(npz["visits"][idx].sum())
    print(f"total_visits={total_visits:.1f}")
    print("boards:\n", npz["boards"][idx])
    print("X shape:", npz["X"][idx].shape)
    print("X channel0 sum:", float(npz["X"][idx][:,:,0].sum()), "channel1 sum:", float(npz["X"][idx][:,:,1].sum()))
    print("visits:", npz["visits"][idx].tolist())
    print("scores:", npz["scores"][idx].tolist())
    print("policy:", npz["policy"][idx].tolist())
    print("q:", npz["q"][idx].tolist())
    print("value:", float(npz["value"][idx][0]))
    print()

def sample_boards_by_move_count(n=3, min_moves=0, max_moves=42, npz_path_override=None, seed=0):
    """Print n random boards whose move-count is within [min_moves, max_moves]."""
    rng = random.Random(seed)
    npz = _load_npz(npz_path_override=npz_path_override)
    candidates = []
    for i, board in enumerate(npz["boards"]):
        m = _board_move_count(board)
        if min_moves <= m <= max_moves:
            candidates.append((i, m))
    if not candidates:
        print(f"No boards found in moves {min_moves}-{max_moves}.")
        return
    picks = rng.sample(candidates, k=min(n, len(candidates)))
    for j, (idx, m) in enumerate(picks, start=1):
        _print_board_entry(npz, idx, label=f"sample {j} | moves={m}")


In [44]:
# Early stage (0-6 moves)
sample_boards_by_move_count(n=3, min_moves=0, max_moves=4, seed=1)


sample 1 | moves=3
row=184561
total_visits=36000.0
boards:
 [[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [-1  0 -1  1  0  0  0]]
X shape: (6, 7, 2)
X channel0 sum: 1.0 channel1 sum: 2.0
visits: [3695.0, 2489.0, 6505.0, 10367.0, 3772.0, 4457.0, 4715.0]
scores: [357.0, 7.0, 1178.0, 2552.0, 360.0, 580.0, 629.0]
policy: [0.10263888537883759, 0.06913889199495316, 0.18069444596767426, 0.28797221183776855, 0.10477777570486069, 0.12380555272102356, 0.130972221493721]
q: [0.09661705046892166, 0.002812374383211136, 0.18109147250652313, 0.24616572260856628, 0.0954400822520256, 0.13013237714767456, 0.13340403139591217]
value: 0.1573055535554886

sample 2 | moves=4
row=225771
total_visits=4000.0
boards:
 [[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [-1  0  0  0  0  0  0]
 [-1  1  0  1  0  0  0]]
X shape: (6, 7, 2)
X channel0 sum: 2.0 channel1 sum: 2.0
visits: [741.0, 371.0

In [46]:
# Early-mid stage (7-14 moves)
sample_boards_by_move_count(n=3, min_moves=6, max_moves=10, seed=2)


sample 1 | moves=7
row=44389
total_visits=8000.0
boards:
 [[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 1  0  0 -1  0  0  0]
 [-1 -1  0  1  1 -1  0]]
X shape: (6, 7, 2)
X channel0 sum: 3.0 channel1 sum: 4.0
visits: [1525.0, 1051.0, 549.0, 1323.0, 1613.0, 1360.0, 579.0]
scores: [-239.0, -224.0, -190.0, -229.0, -220.0, -199.0, -191.0]
policy: [0.19062499701976776, 0.1313749998807907, 0.06862500309944153, 0.1653749942779541, 0.2016250044107437, 0.17000000178813934, 0.07237499952316284]
q: [-0.1567213088274002, -0.21313035488128662, -0.34608379006385803, -0.17309145629405975, -0.13639181852340698, -0.1463235318660736, -0.32987910509109497]
value: -0.18649999797344208

sample 2 | moves=6
row=73847
total_visits=2000.0
boards:
 [[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  1  0  0  0]
 [ 0  1  0 -1  0  0  0]
 [-1 -1  0  1  0  0  0]]
X shape: (6, 7, 2)
X channel0 sum: 3.0 channel1 sum: 3.0
visits: [154.0

In [47]:
# Mid-late stage (15-24 moves)
sample_boards_by_move_count(n=3, min_moves=11, max_moves=14, seed=3)


sample 1 | moves=11
row=258552
total_visits=2000.0
boards:
 [[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  1  0  0  0  0  0]
 [ 0 -1  0  0  0  0  0]
 [ 0 -1 -1  0  0  0  1]
 [ 1 -1  1  0 -1 -1  1]]
X shape: (6, 7, 2)
X channel0 sum: 5.0 channel1 sum: 6.0
visits: [208.0, 154.0, 231.0, 181.0, 578.0, 324.0, 324.0]
scores: [-51.0, -47.0, -52.0, -49.0, -49.0, -55.0, -55.0]
policy: [0.10400000214576721, 0.07699999958276749, 0.11550000309944153, 0.09049999713897705, 0.289000004529953, 0.16200000047683716, 0.16200000047683716]
q: [-0.24519230425357819, -0.30519479513168335, -0.22510822117328644, -0.2707182466983795, -0.08477509021759033, -0.1697530895471573, -0.1697530895471573]
value: -0.17900000512599945

sample 2 | moves=13
row=594548
total_visits=2000.0
boards:
 [[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  1  0  0  0]
 [ 0  0  0 -1  0  1  0]
 [-1  0  0 -1  1 -1  0]
 [-1  0  1  1 -1 -1  1]]
X shape: (6, 7, 2)
X channel0 sum: 6.0 channel1 sum: 7.0
visits: [235.0, 65.0, 

In [48]:
# Late stage (25-42 moves)
sample_boards_by_move_count(n=3, min_moves=14, max_moves=42, seed=4)


sample 1 | moves=17
row=238833
total_visits=2000.0
boards:
 [[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 1  0 -1  1  0  0  0]
 [ 1  0  1 -1 -1  0  0]
 [-1 -1  1  1 -1  0  0]
 [-1 -1  1 -1  1  0  0]]
X shape: (6, 7, 2)
X channel0 sum: 8.0 channel1 sum: 9.0
visits: [134.0, 125.0, 86.0, 130.0, 125.0, 1263.0, 137.0]
scores: [-80.0, -77.0, -62.0, -79.0, -77.0, -348.0, -81.0]
policy: [0.06700000166893005, 0.0625, 0.0430000014603138, 0.06499999761581421, 0.0625, 0.6315000057220459, 0.06849999725818634]
q: [-0.5970149040222168, -0.6159999966621399, -0.7209302186965942, -0.607692301273346, -0.6159999966621399, -0.275534451007843, -0.5912408828735352]
value: -0.4020000100135803

sample 2 | moves=15
row=300715
total_visits=2000.0
boards:
 [[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  1  0  0  1]
 [ 0  0  0 -1  0  0 -1]
 [ 1  0  1 -1  0 -1 -1]
 [-1  1 -1  1  0  1 -1]]
X shape: (6, 7, 2)
X channel0 sum: 7.0 channel1 sum: 8.0
visits: [112.0, 207.0, 239.0, 155.0, 49.0, 1098.0, 140