# Connect 4 Additional Data Generator (Notebook Format)

**Speed-optimized MCTS data generator** — same format as your existing notebooks.

## Features
- **Numba JIT** for fast rollouts (~10–20x speedup)
- **Zero-leakage deduplication** across 4 board transformations
- **6×7×2 encoding** — same as Connect4_DeepSearch and Connect4_Dataset_Generator
- **Output format** — `npz` with `X_train`, `y_train` (DeepSearch style) and optional `X`, `y_move`, `y_result`, `turns` (Dataset_Generator style)

## Output formats (choose one or both)
1. **DeepSearch style**: `X_train`, `y_train` in `.npz`
2. **Dataset_Generator style**: `X`, `y_move`, `y_result`, `turns` in `.npz` + train/val/test splits

## Usage
1. Run cells in order.
2. Or use the **Load existing .npy** cell if you already have `best_20k_X.npy` / `best_20k_Y.npy`.

## Cell 1: Environment setup & directories

In [None]:
import os
import sys

# ---------- Auto-detect Colab vs Local ----------
IS_COLAB = False
try:
    import google.colab
    IS_COLAB = True
except ImportError:
    pass

if IS_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    PROJECT_DIR = '/content/drive/MyDrive/Connect4_AdditionalData'
    print("Running on Google Colab")
else:
    # Local -- save in Additional_DataGenerator or current dir
    PROJECT_DIR = os.path.join(os.getcwd(), 'Additional_DataGenerator')
    if not os.path.exists(PROJECT_DIR):
        PROJECT_DIR = os.getcwd()
    print("Running locally")

DATASET_DIR = os.path.join(PROJECT_DIR, 'datasets')
CHECKPOINT_DIR = os.path.join(PROJECT_DIR, 'checkpoints')
os.makedirs(DATASET_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Project dir: {PROJECT_DIR}")
print(f"Dataset dir: {DATASET_DIR}")

## Cell 2: Configuration

In [None]:
# ==================== CONFIGURATION ====================

NUM_GAMES = 10000
MCTS_ROLLOUTS = 10000
RANDOM_OPENING_MOVES = 6
NUM_WORKERS = None  # None = auto (cpu_count)
SAVE_PREFIX = "fast_10k"
CHECKPOINT_EVERY = 500

# Output format: 'both' | 'deepsearch' | 'dataset_generator'
# - deepsearch: X_train, y_train only (npz)
# - dataset_generator: X, y_move, y_result, turns + train/val/test splits
# - both: save in both formats
OUTPUT_FORMAT = 'both'

SEED = 42
import random
import numpy as np
random.seed(SEED)
np.random.seed(SEED)

print(f"Games: {NUM_GAMES:,} | MCTS rollouts: {MCTS_ROLLOUTS:,} | Opening moves: 0-{RANDOM_OPENING_MOVES}")

## Cell 3: Imports

In [None]:
import numpy as np
import time
import random
import math
import hashlib
import pickle
from collections import Counter, defaultdict
from multiprocessing import Pool, cpu_count

try:
    from numba import njit
    HAS_NUMBA = True
except ImportError:
    HAS_NUMBA = False
    def njit(*args, **kwargs):
        if args and callable(args[0]):
            return args[0]
        return lambda f: f

print(f"Numba: {'YES' if HAS_NUMBA else 'NO (install numba for speed)'}")

## Cell 4: Numba JIT game engine

In [None]:
@njit(cache=True)
def find_row(board, col):
    for r in range(5, -1, -1):
        if board[r, col] == 0:
            return r
    return -1

@njit(cache=True)
def find_legal(board):
    moves = np.empty(7, dtype=np.int32)
    n = 0
    for c in range(7):
        if board[0, c] == 0:
            moves[n] = c
            n += 1
    return moves[:n]

@njit(cache=True)
def check_win(board, row, col):
    p = board[row, col]
    if p == 0:
        return False
    if row <= 2:
        if board[row+1,col]==p and board[row+2,col]==p and board[row+3,col]==p:
            return True
    count = 0
    for c in range(7):
        if board[row, c] == p:
            count += 1
            if count >= 4:
                return True
        else:
            count = 0
    count = 0
    sr = row - min(row, col)
    sc = col - min(row, col)
    while sr < 6 and sc < 7:
        if board[sr, sc] == p:
            count += 1
            if count >= 4:
                return True
        else:
            count = 0
        sr += 1
        sc += 1
    count = 0
    sr = row + min(5 - row, col)
    sc = col - min(5 - row, col)
    while sr >= 0 and sc < 7:
        if board[sr, sc] == p:
            count += 1
            if count >= 4:
                return True
        else:
            count = 0
        sr -= 1
        sc += 1
    return False

@njit(cache=True)
def try_win(board, col, player):
    row = find_row(board, col)
    if row < 0:
        return False
    board[row, col] = player
    won = check_win(board, row, col)
    board[row, col] = 0
    return won

@njit(cache=True)
def rollout_jit(board, next_player):
    b = board.copy()
    p = next_player
    for _ in range(42):
        legal = find_legal(b)
        if len(legal) == 0:
            return 0
        opp = -p
        chosen = np.int32(-1)
        for i in range(len(legal)):
            if try_win(b, legal[i], p):
                return p
        for i in range(len(legal)):
            if try_win(b, legal[i], opp):
                chosen = legal[i]
                break
        if chosen == -1:
            has_center = False
            for i in range(len(legal)):
                if legal[i] == 3:
                    has_center = True
                    break
            if has_center and np.random.random() < 0.2:
                chosen = np.int32(3)
            else:
                chosen = legal[np.random.randint(len(legal))]
        row = find_row(b, chosen)
        b[row, chosen] = p
        if check_win(b, row, chosen):
            return p
        p = opp
    return 0

print("Game engine loaded")

## Cell 5: MCTS & data generation

In [None]:
def mcts_fast(board, color_val, nsteps):
    legal0 = find_legal(board)
    for c in legal0:
        if try_win(board, int(c), color_val):
            return int(c)
    opp = -color_val
    allowed = []
    for c in legal0:
        col = int(c)
        row = int(find_row(board, col))
        board[row, col] = color_val
        opp_legal = find_legal(board)
        loses = False
        for oc in opp_legal:
            if try_win(board, int(oc), opp):
                loses = True
                break
        board[row, col] = 0
        if not loses:
            allowed.append(col)
    if not allowed:
        allowed = [int(c) for c in legal0]
    root_key = board.tobytes()
    md = {root_key: [0, 0]}
    _sqrt = math.sqrt
    _log = math.log
    for _ in range(nsteps):
        cv = color_val
        b = board.copy()
        path = [root_key]
        while True:
            legal = find_legal(b)
            n_legal = len(legal)
            if n_legal == 0:
                for key in path:
                    md[key][0] += 1
                break
            keys = [None] * n_legal
            for i in range(n_legal):
                col = int(legal[i])
                row = int(find_row(b, col))
                b[row, col] = cv
                k = b.tobytes()
                keys[i] = k
                if k not in md:
                    md[k] = [0, 0]
                b[row, col] = 0
            parent_n = md[path[-1]][0]
            best_ucb = -1e30
            best_i = 0
            if parent_n == 0:
                best_i = 0
            else:
                log_p = _log(parent_n)
                for i in range(n_legal):
                    nd = md[keys[i]]
                    if nd[0] == 0:
                        best_i = i
                        break
                    ucb = nd[1] / nd[0] + 2.0 * _sqrt(log_p / nd[0])
                    if ucb > best_ucb:
                        best_ucb = ucb
                        best_i = i
            chosen_col = int(legal[best_i])
            row = int(find_row(b, chosen_col))
            b[row, chosen_col] = cv
            path.append(keys[best_i])
            if check_win(b, row, chosen_col):
                winner = cv
                for j, key in enumerate(path):
                    md[key][0] += 1
                    if winner == color_val:
                        md[key][1] += 1 if j % 2 == 1 else -1
                    else:
                        md[key][1] += -1 if j % 2 == 1 else 1
                break
            cv = -cv
            if md[keys[best_i]][0] == 0:
                result = int(rollout_jit(b, cv))
                for j, key in enumerate(path):
                    md[key][0] += 1
                    if result == 0:
                        pass
                    elif result == color_val:
                        md[key][1] += 1 if j % 2 == 1 else -1
                    else:
                        md[key][1] += -1 if j % 2 == 1 else 1
                break
    best_val = -1e30
    best_col = allowed[0]
    for col in allowed:
        row = int(find_row(board, col))
        board[row, col] = color_val
        k = board.tobytes()
        board[row, col] = 0
        if k in md and md[k][0] > 0:
            val = md[k][1] / md[k][0]
            if val > best_val:
                best_val = val
                best_col = col
    return best_col


def board_to_canonical(board, current_player_val):
    canonical = np.zeros((6, 7, 2), dtype=np.float32)
    canonical[:, :, 0] = (board == current_player_val).astype(np.float32)
    canonical[:, :, 1] = (board == -current_player_val).astype(np.float32)
    return canonical


def play_single_game(args):
    game_id, mcts_rollouts, random_opening_moves = args
    board = np.zeros((6, 7), dtype=np.int8)
    cv = np.int8(1)
    game_data = []
    n_random = random.randint(0, random_opening_moves)
    for _ in range(n_random):
        legal = find_legal(board)
        if len(legal) == 0:
            break
        col = int(random.choice(legal))
        row = int(find_row(board, col))
        board[row, col] = cv
        if check_win(board, row, col):
            return []
        cv = np.int8(-cv)
    recorded = 0
    winner = 0
    while True:
        legal = find_legal(board)
        if len(legal) == 0:
            winner = 0
            break
        move = mcts_fast(board, int(cv), mcts_rollouts)
        canonical = board_to_canonical(board, int(cv))
        current_player = int(cv)
        game_data.append((canonical, move, {
            'game_id': game_id,
            'move_number': recorded,
            'current_player': current_player,
        }))
        recorded += 1
        row = int(find_row(board, move))
        board[row, move] = cv
        if check_win(board, row, move):
            winner = int(cv)
            break
        cv = np.int8(-cv)
    for i, (c, m, meta) in enumerate(game_data):
        meta['winner'] = winner
        meta['result'] = 1.0 if winner == meta['current_player'] else (0.0 if winner != 0 else 0.5)
    return game_data

print("MCTS & play_single_game loaded")

## Cell 6: Deduplication & verification

In [None]:
def hash_board(board):
    return hashlib.md5(board.tobytes()).hexdigest()

def get_all_equivalent_boards(board, move):
    equivalents = []
    equivalents.append((board.copy(), move, 'original'))
    mirror = board[:, ::-1, :].copy()
    equivalents.append((mirror, 6 - move, 'mirror'))
    flip = board[:, :, [1, 0]].copy()
    equivalents.append((flip, move, 'perspective_flip'))
    mirror_flip = mirror[:, :, [1, 0]].copy()
    equivalents.append((mirror_flip, 6 - move, 'mirror_and_flip'))
    return equivalents

def normalize_to_canonical_form(board, move):
    equivalents = get_all_equivalent_boards(board, move)
    min_hash = None
    canonical = None
    for equiv_board, equiv_move, transform in equivalents:
        h = hash_board(equiv_board)
        if min_hash is None or h < min_hash:
            min_hash = h
            canonical = (equiv_board, equiv_move, transform)
    return canonical

def deduplicate_all_transformations(raw_data):
    print("  Phase 1: Normalizing to canonical form...")
    canonical_to_entries = defaultdict(list)
    for idx, (board, label, metadata) in enumerate(raw_data):
        if idx % 50000 == 0 and idx > 0:
            print(f"    {idx:,} / {len(raw_data):,}...")
        canon_board, canon_move, transform = normalize_to_canonical_form(board, label)
        canon_hash = hash_board(canon_board)
        canonical_to_entries[canon_hash].append({
            'board': canon_board, 'label': canon_move,
            'metadata': metadata, 'transform': transform,
        })
    print(f"  Phase 1 done: {len(raw_data):,} -> {len(canonical_to_entries):,} unique")
    print("  Phase 2: Majority voting...")
    clean_data = []
    for canon_hash, entries in canonical_to_entries.items():
        if len(entries) == 1:
            e = entries[0]
            clean_data.append((e['board'], e['label'], e['metadata']))
        else:
            labels = [e['label'] for e in entries]
            unique_labels = set(labels)
            if len(unique_labels) <= 2:
                label_counts = Counter(labels)
                majority_label = label_counts.most_common(1)[0][0]
                for e in entries:
                    if e['label'] == majority_label:
                        clean_data.append((e['board'], e['label'], e['metadata']))
                        break
            # else: skip conflicts
    print(f"  Final: {len(clean_data):,} samples")
    return clean_data

def warmup_numba():
    if not HAS_NUMBA:
        print("[WARMUP] Numba not available")
        return
    print("[WARMUP] Compiling Numba...", end=" ", flush=True)
    t0 = time.time()
    b = np.zeros((6, 7), dtype=np.int8)
    _ = find_legal(b)
    _ = rollout_jit(b, np.int8(1))
    print(f"done in {time.time()-t0:.1f}s")

print("Deduplication & verification loaded")

## Cell 7: Run generation (or load existing .npy)

In [None]:
# ---------- Option A: Generate new data ----------
GENERATE_NEW = True  # Set False to only load existing .npy files

if GENERATE_NEW:
    warmup_numba()
    nw = NUM_WORKERS or max(1, cpu_count())
    args_list = [(i, MCTS_ROLLOUTS, RANDOM_OPENING_MOVES) for i in range(NUM_GAMES)]
    all_data = []
    start = time.time()
    with Pool(nw) as pool:
        for game_data in pool.imap_unordered(play_single_game, args_list, chunksize=1):
            all_data.extend(game_data)
            if len(all_data) % 50000 == 0 and len(all_data) > 0:
                print(f"  Samples: {len(all_data):,}")
    print(f"Raw: {len(all_data):,} samples in {time.time()-start:.1f}s")
    clean_data = deduplicate_all_transformations(all_data)
    X = np.array([b for b, _, _ in clean_data], dtype=np.float32)
    Y = np.array([m for _, m, _ in clean_data], dtype=np.int64)
    metadata = [m for _, _, m in clean_data]
else:
    # ---------- Option B: Load existing .npy (e.g. best_20k_X.npy, best_20k_Y.npy) ----------
    LOAD_PATH = os.path.join(PROJECT_DIR, 'best_20k')  # prefix without _X / _Y
    X = np.load(LOAD_PATH + '_X.npy')
    Y = np.load(LOAD_PATH + '_Y.npy')
    metadata = []
    if os.path.exists(LOAD_PATH + '_metadata.pkl'):
        with open(LOAD_PATH + '_metadata.pkl', 'rb') as f:
            metadata = pickle.load(f)
    else:
        metadata = [{'move_number': i % 21, 'result': 0.5} for i in range(len(X))]
    print(f"Loaded: X={X.shape}, Y={Y.shape}")

print(f"X shape: {X.shape} | Y shape: {Y.shape}")

## Cell 8: Build dataset in notebook format (X_train, y_train, etc.)

In [None]:
# Same format as Connect4_DeepSearch and Connect4_Dataset_Generator
n = len(X)
y_move = Y.astype(np.int8) if Y.dtype != np.int8 else Y
turns = np.array([m.get('move_number', 0) for m in metadata], dtype=np.int8) if metadata else np.zeros(n, dtype=np.int8)
y_result = np.array([m.get('result', 0.5) for m in metadata], dtype=np.float32) if metadata else np.full(n, 0.5, dtype=np.float32)

# Shuffle
idx = np.random.permutation(n)
X = X[idx]
y_move = y_move[idx]
turns = turns[idx]
y_result = y_result[idx]
if metadata:
    metadata = [metadata[i] for i in idx]

print(f"Prepared: X={X.shape}, y_move={y_move.shape}, y_result={y_result.shape}, turns={turns.shape}")

## Cell 9: Save in notebook format (npz + optional splits)

In [None]:
from sklearn.model_selection import train_test_split

# Save DeepSearch-style (X_train, y_train)
if OUTPUT_FORMAT in ('both', 'deepsearch'):
    npz_path = os.path.join(DATASET_DIR, 'connect4_additional_data.npz')
    np.savez_compressed(npz_path, X_train=X, y_train=y_move)
    print(f"Saved (DeepSearch style): {npz_path}")

# Save Dataset_Generator-style (X, y_move, y_result, turns) + train/val/test
if OUTPUT_FORMAT in ('both', 'dataset_generator'):
    full_path = os.path.join(DATASET_DIR, 'connect4_additional_full.npz')
    np.savez_compressed(full_path, X=X, y_move=y_move, y_result=y_result, turns=turns)
    print(f"Saved (Dataset_Generator style): {full_path}")
    X_tr, X_tmp, y_m_tr, y_m_tmp, y_r_tr, y_r_tmp, t_tr, t_tmp = train_test_split(
        X, y_move, y_result, turns, test_size=0.2, random_state=SEED)
    X_val, X_te, y_m_val, y_m_te, y_r_val, y_r_te, t_val, t_te = train_test_split(
        X_tmp, y_m_tmp, y_r_tmp, t_tmp, test_size=0.5, random_state=SEED)
    np.savez_compressed(os.path.join(DATASET_DIR, 'train.npz'), X=X_tr, y_move=y_m_tr, y_result=y_r_tr, turns=t_tr)
    np.savez_compressed(os.path.join(DATASET_DIR, 'val.npz'), X=X_val, y_move=y_m_val, y_result=y_r_val, turns=t_val)
    np.savez_compressed(os.path.join(DATASET_DIR, 'test.npz'), X=X_te, y_move=y_m_te, y_result=y_r_te, turns=t_te)
    print(f"Saved train/val/test splits to {DATASET_DIR}")

print("Done.")

## Cell 10: Optional merge with existing dataset

In [None]:
# Merge with existing dataset (e.g. from DeepSearch or Dataset_Generator)
EXISTING_DATASET_PATH = None  # e.g. 'Connect4_DeepSearch_Output/datasets/connect4_deep_search.npz'

if EXISTING_DATASET_PATH and os.path.exists(EXISTING_DATASET_PATH):
    existing = np.load(EXISTING_DATASET_PATH)
    X_ex = existing['X_train'] if 'X_train' in existing else existing['X']
    y_ex = existing['y_train'] if 'y_train' in existing else existing['y_move']
    X_merged = np.concatenate([X_ex, X], axis=0)
    y_merged = np.concatenate([y_ex, y_move], axis=0)
    idx = np.random.permutation(len(X_merged))
    X_merged = X_merged[idx]
    y_merged = y_merged[idx]
    merged_path = os.path.join(DATASET_DIR, 'connect4_merged.npz')
    np.savez_compressed(merged_path, X_train=X_merged, y_train=y_merged)
    print(f"Merged: {merged_path} | X={X_merged.shape}")
else:
    print("No existing path set. Set EXISTING_DATASET_PATH to merge.")

## How to load in your training notebook

In [None]:
# Load DeepSearch-style
data = np.load(os.path.join(DATASET_DIR, 'connect4_additional_data.npz'))
X_train = data['X_train']  # (N, 6, 7, 2)
y_train = data['y_train']  # (N,) column 0-6
print(f"X: {X_train.shape} | y: {y_train.shape}")

# Load Dataset_Generator-style (if you saved it)
# data = np.load(os.path.join(DATASET_DIR, 'connect4_additional_full.npz'))
# X, y_move, y_result, turns = data['X'], data['y_move'], data['y_result'], data['turns']