# Connect 4 Dataset Generation
Optimized for Colab Pro

## Setup
1. Upload this notebook to Google Colab
2. Run cells in order
3. Files save to Google Drive with checkpoints

## Output
- Generated games (time-boxed)
- 6x7x2 encoding
- Checkpoints and logs

In [None]:
# Connect 4 Dataset Generation
# Optimized for Colab Pro

## CELL 1: Mount Google Drive

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

PROJECT_DIR = '/content/drive/MyDrive/Connect4_Project'
CHECKPOINT_DIR = f'{PROJECT_DIR}/checkpoints'
DATASET_DIR = f'{PROJECT_DIR}/datasets'
LOG_DIR = f'{PROJECT_DIR}/logs'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(DATASET_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

if not PROJECT_DIR.startswith('/content/drive'):
    raise ValueError('PROJECT_DIR must be on Google Drive for checkpoint persistence')

print("Google Drive mounted")
print(f"Project directory: {PROJECT_DIR}")
print(f"Checkpoint directory: {CHECKPOINT_DIR}")

## CELL 2: Check Resources

In [None]:
# Check what resources Colab Pro gave you
!echo "CPU Info:"
!lscpu | grep -E "Model name|CPU\(s\):"
!echo "\nMemory Info:"
!free -h
print("\n Resource check complete")

## CELL 3: Import Dependencies

In [None]:
import numpy as np
import random
import pickle
import json
import time
from datetime import datetime
from collections import defaultdict
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

print("Imports loaded")

## CELL 4: Optimized Connect 4 Engine

In [None]:
# This is the optimized version of the provided Connect4.ipynb
# 3-5x faster with cleaner API

class Connect4Game:
    """Optimized Connect 4 game - 6x7 board with +1/-1/0"""

    def __init__(self):
        self.board = np.zeros((6, 7), dtype=np.int8)
        self.current_player = 1
        self.last_move = None
        self.winner = None

    def copy(self):
        new_game = Connect4Game()
        new_game.board = self.board.copy()
        new_game.current_player = self.current_player
        new_game.last_move = self.last_move
        new_game.winner = self.winner
        return new_game

    def get_legal_moves(self):
        return [col for col in range(7) if self.board[0, col] == 0]

    def make_move(self, col):
        if col < 0 or col > 6 or self.board[0, col] != 0:
            return False

        # Find lowest empty row (optimized from provided code)
        colsum = np.sum(np.abs(self.board[:, col]))
        row = 5 - int(colsum)

        if row < 0:
            return False

        self.board[row, col] = self.current_player
        self.last_move = col

        # Check for win (optimized from provided check_for_win)
        if self._check_win_fast(row, col):
            self.winner = self.current_player
        elif len(self.get_legal_moves()) == 0:
            self.winner = 0

        self.current_player = -self.current_player
        return True

    def _check_win_fast(self, row, col):
        """Fast win check - based on provided code but optimized"""
        player = self.board[row, col]

        # Vertical
        if row <= 2:
            if (self.board[row:row+4, col].sum()) == 4 * player:
                return True

        # Horizontal
        for start_col in range(max(0, col-3), min(4, col+1)):
            if (self.board[row, start_col:start_col+4].sum()) == 4 * player:
                return True

        # Diagonals
        for offset in range(-3, 1):
            r, c = row + offset, col + offset
            if 0 <= r <= 2 and 0 <= c <= 3:
                if sum(self.board[r+i, c+i] for i in range(4)) == 4 * player:
                    return True

            r, c = row - offset, col + offset
            if 0 <= r <= 2 and 0 <= c <= 3:
                if sum(self.board[r+i, c-i] for i in range(4)) == 4 * player:
                    return True

        return False

    def is_terminal(self):
        return self.winner is not None

    def get_result(self, player):
        if self.winner is None:
            return 0.0
        if self.winner == 0:
            return 0.0
        return 1.0 if self.winner == player else -1.0

    def get_board_encoding_b(self):
        """Get 6x7x2 encoding (better for neural networks)"""
        encoded = np.zeros((6, 7, 2), dtype=np.float32)
        encoded[:, :, 0] = (self.board == 1).astype(np.float32)
        encoded[:, :, 1] = (self.board == -1).astype(np.float32)
        return encoded

print(" Connect4Game class loaded")

## CELL 5: Optimized MCTS

In [None]:
# Based on provided mcts() function but optimized for speed and clarity

class MCTSPlayer:
    """Optimized MCTS with heuristic rollouts, threat priors, and LRU cache"""

    def __init__(self, num_simulations=1000, cache_size=200000):
        self.num_simulations = num_simulations
        self.cache_size = cache_size
        from collections import OrderedDict
        self.transpo = OrderedDict()  # (board_key, player) -> [visits, value]

    def _key(self, game):
        return (tuple(game.board.ravel()), game.current_player)

    def _cache_get(self, key):
        if key in self.transpo:
            self.transpo.move_to_end(key)
            return self.transpo[key]
        return None

    def _cache_set(self, key, value):
        if key in self.transpo:
            self.transpo.move_to_end(key)
            self.transpo[key][0] += value[0]
            self.transpo[key][1] += value[1]
        else:
            self.transpo[key] = value
        if len(self.transpo) > self.cache_size:
            self.transpo.popitem(last=False)

    def get_move(self, game, num_simulations=None):
        player = game.current_player

        win_col = self._look_for_win(game, player)
        if win_col is not None:
            return win_col

        block_col = self._look_for_win(game, -player)
        if block_col is not None:
            return block_col

        root_key = self._key(game)
        mcts_dict = {root_key: [0, 0]}

        for _ in range(num_simulations if num_simulations is not None else self.num_simulations):
            self._simulate(game.copy(), mcts_dict, player)

        # Merge into cache
        for k, v in mcts_dict.items():
            self._cache_set(k, v)

        legal_moves = game.get_legal_moves()
        best_move = None
        best_value = -1e9

        for col in legal_moves:
            test_game = game.copy()
            test_game.make_move(col)
            board_key = self._key(test_game)

            value = -1e9
            if board_key in mcts_dict and mcts_dict[board_key][0] > 0:
                value = mcts_dict[board_key][1] / mcts_dict[board_key][0]
            else:
                cached = self._cache_get(board_key)
                if cached and cached[0] > 0:
                    value = cached[1] / cached[0]

            if value > best_value:
                best_value = value
                best_move = col

        if best_move is None:
            best_move = random.choice(legal_moves)

        return best_move

    def _simulate(self, game, mcts_dict, root_player):
        path = []

        while not game.is_terminal():
            board_key = self._key(game)
            path.append(board_key)

            legal_moves = game.get_legal_moves()
            if not legal_moves:
                break

            next_states = []
            for col in legal_moves:
                test_game = game.copy()
                test_game.make_move(col)
                next_key = self._key(test_game)
                if next_key not in mcts_dict:
                    mcts_dict[next_key] = [0, 0]
                next_states.append((col, next_key, test_game))

            best_col = None
            best_ucb = -1e9

            for col, next_key, test_game in next_states:
                visits, value_sum = mcts_dict[next_key]
                parent_visits = mcts_dict[board_key][0]

                if visits == 0:
                    ucb = 1e9
                else:
                    exploitation = value_sum / visits
                    exploration = 2 * np.sqrt(np.log(parent_visits + 1) / visits)
                    prior = self._threat_prior(test_game, root_player)
                    ucb = exploitation + exploration + prior

                if ucb > best_ucb:
                    best_ucb = ucb
                    best_col = col

            game.make_move(best_col)

            board_key = self._key(game)
            if board_key not in mcts_dict or mcts_dict[board_key][0] == 0:
                path.append(board_key)
                result = self._rollout(game.copy(), root_player)
                break
        else:
            result = game.get_result(root_player)

        for board_key in path:
            if board_key in mcts_dict:
                mcts_dict[board_key][0] += 1
                mcts_dict[board_key][1] += result

    def _rollout(self, game, root_player):
        while not game.is_terminal():
            legal_moves = game.get_legal_moves()
            if not legal_moves:
                break

            player = game.current_player
            win_col = self._look_for_win(game, player)
            if win_col is not None:
                game.make_move(win_col)
                continue

            block_col = self._look_for_win(game, -player)
            if block_col is not None:
                game.make_move(block_col)
                continue

            center_order = [3, 2, 4, 1, 5, 0, 6]
            for col in center_order:
                if col in legal_moves:
                    game.make_move(col)
                    break
        return game.get_result(root_player)

    def _look_for_win(self, game, player):
        for col in game.get_legal_moves():
            test_game = game.copy()
            old_player = test_game.current_player
            test_game.current_player = player
            test_game.make_move(col)
            if test_game.winner == player:
                return col
            test_game.current_player = old_player
        return None

    def _threat_prior(self, game, root_player):
        # Small prior for creating 3-in-a-row threats for root player
        board = game.board
        player = root_player
        score = 0

        def count_threes(line):
            s = 0
            for i in range(len(line) - 3):
                window = line[i:i+4]
                if np.sum(window == player) == 3 and np.sum(window == 0) == 1:
                    s += 1
            return s

        # Horizontal
        for r in range(6):
            score += count_threes(board[r, :])
        # Vertical
        for c in range(7):
            score += count_threes(board[:, c])
        # Diagonals
        for r in range(3):
            for c in range(4):
                window = [board[r+i, c+i] for i in range(4)]
                if sum(1 for x in window if x == player) == 3 and sum(1 for x in window if x == 0) == 1:
                    score += 1
        for r in range(3):
            for c in range(3, 7):
                window = [board[r+i, c-i] for i in range(4)]
                if sum(1 for x in window if x == player) == 3 and sum(1 for x in window if x == 0) == 1:
                    score += 1

        return 0.05 * score

print("MCTSPlayer class loaded")

## CELL 6: Dataset Generator

In [None]:
class DatasetGenerator:
    def __init__(self, mcts_simulations, random_prob, random_depth, checkpoint_dir, log_file, cache_size):
        self.mcts_player = MCTSPlayer(mcts_simulations, cache_size=cache_size)
        self.random_prob = random_prob
        self.random_depth = random_depth
        self.checkpoint_dir = checkpoint_dir
        self.log_file = log_file

        self.boards = []
        self.moves = []
        self.stats = {'start_time': None, 'games': 0, 'examples': 0}

    def log(self, msg):
        timestamp = datetime.now().strftime('%H:%M:%S')
        log_msg = f"[{timestamp}] {msg}"
        print(log_msg)
        if self.log_file:
            with open(self.log_file, 'a') as f:
                f.write(log_msg + '\n')

    def _sample_depth(self):
        r = random.random()
        acc = 0.0
        for pct, lo, hi in DEPTH_BANDS:
            acc += pct
            if r <= acc:
                return random.randint(lo, hi)
        return random.randint(DEPTH_BANDS[-1][1], DEPTH_BANDS[-1][2])

    def generate_one_game(self):
        game = Connect4Game()
        game_boards = []
        game_moves = []
        depth_used = []
        move_count = 0

        while not game.is_terminal():
            current_player = game.current_player

            use_random = (move_count < self.random_depth and
                         random.random() < self.random_prob)

            if use_random:
                col = random.choice(game.get_legal_moves())
            else:
                depth = self._sample_depth()
                depth_used.append(depth)
                if random.random() < STRONG_EPSILON:
                    col = random.choice(game.get_legal_moves())
                else:
                    col = self.mcts_player.get_move(game, num_simulations=depth)

                if current_player == 1:
                    board = game.get_board_encoding_b()
                else:
                    flipped_game = game.copy()
                    flipped_game.board = -flipped_game.board
                    board = flipped_game.get_board_encoding_b()

                game_boards.append(board)
                game_moves.append(col)

            game.make_move(col)
            move_count += 1

            if move_count > 42:
                break

        avg_depth = sum(depth_used) / len(depth_used) if depth_used else 0
        return game_boards, game_moves, avg_depth


    def _baseline_eval(self, num_games=20):
        wins = 0
        losses = 0
        ties = 0
        for _ in range(num_games):
            game = Connect4Game()
            while not game.is_terminal():
                if game.current_player == 1:
                    col = self.mcts_player.get_move(game)
                else:
                    col = random.choice(game.get_legal_moves())
                game.make_move(col)

            if game.winner == 1:
                wins += 1
            elif game.winner == -1:
                losses += 1
            else:
                ties += 1
        return wins, losses, ties

    def _save_checkpoint(self, batch_boards, batch_moves, checkpoint_num, start_time, games_done, avg_depth):
        checkpoint_data = {
            'boards': batch_boards,
            'moves': batch_moves,
            'checkpoint_num': checkpoint_num,
            'timestamp': datetime.now().isoformat()
        }
        checkpoint_path = f'{self.checkpoint_dir}/checkpoint_{checkpoint_num:04d}.pkl'
        with open(checkpoint_path, 'wb') as f:
            pickle.dump(checkpoint_data, f)

        elapsed = time.time() - start_time
        rate = games_done / max(1e-9, elapsed / 60.0)
        self.log(
            f"Checkpoint {checkpoint_num} saved. Examples: {len(batch_boards):,}. "
            f"Games: {games_done}. Games/min: {rate:.2f}. Avg depth: {avg_depth:.0f}. "
            f"Elapsed: {elapsed/60:.1f} min"
        )
        w, l, t = self._baseline_eval(BASELINE_EVAL_GAMES)
        self.log(f"Baseline vs random (as +1) over {BASELINE_EVAL_GAMES} games: W {w}, L {l}, T {t}")
        return checkpoint_path

    def generate_full_dataset(self, total_games, checkpoint_interval_minutes=30, max_runtime_minutes=180, keep_in_memory=True):
        self.stats['start_time'] = datetime.now()
        self.log("Dataset generation start")
        self.log(f"Total games (target): {total_games:,}")
        self.log(f"Depth bands: {DEPTH_BANDS}")
        self.log(f"Random prob: {self.random_prob}, random depth: {self.random_depth}, strong epsilon: {STRONG_EPSILON}")
        self.log(f"Checkpoint interval (min): {checkpoint_interval_minutes}")
        self.log(f"Max runtime (min): {max_runtime_minutes}")

        start_time = time.time()
        last_checkpoint_time = start_time
        checkpoint_num = 0

        batch_boards = []
        batch_moves = []
        depth_samples = []
        games_done = 0

        while games_done < total_games:
            elapsed_minutes = (time.time() - start_time) / 60.0
            if elapsed_minutes >= max_runtime_minutes:
                self.log("Max runtime reached. Stopping generation.")
                break

            boards, moves, avg_depth_game = self.generate_one_game()
            batch_boards.extend(boards)
            batch_moves.extend(moves)
            depth_samples.append(avg_depth_game)
            games_done += 1

            if keep_in_memory:
                self.boards.extend(boards)
                self.moves.extend(moves)

            if (time.time() - last_checkpoint_time) / 60.0 >= checkpoint_interval_minutes:
                checkpoint_num += 1
                avg_depth = sum(depth_samples) / len(depth_samples) if depth_samples else 0
                self._save_checkpoint(batch_boards, batch_moves, checkpoint_num, start_time, games_done, avg_depth)
                batch_boards = []
                batch_moves = []
                depth_samples = []
                last_checkpoint_time = time.time()

        if batch_boards or batch_moves:
            checkpoint_num += 1
            avg_depth = sum(depth_samples) / len(depth_samples) if depth_samples else 0
            self._save_checkpoint(batch_boards, batch_moves, checkpoint_num, start_time, games_done, avg_depth)

        self.stats['games'] = games_done
        self.stats['examples'] = len(self.boards) if keep_in_memory else 0
        self.stats['end_time'] = datetime.now()

        total_time = (self.stats['end_time'] - self.stats['start_time']).total_seconds()
        self.log("Dataset generation complete")
        self.log(f"Games: {games_done}")
        self.log(f"Total time: {total_time/60:.1f} min")

        if keep_in_memory:
            return np.array(self.boards), np.array(self.moves)
        return np.array([]), np.array([])

print("DatasetGenerator class loaded")

## CELL 7: Configuration

In [None]:
# Configuration
TOTAL_GAMES = 15000
MCTS_SIMS = 800  # fallback if depth sampling is disabled

# Randomness
RANDOM_PROB = 0.25
RANDOM_DEPTH = 14
STRONG_EPSILON = 0.03
CACHE_SIZE = 200000
BASELINE_EVAL_GAMES = 20

# MCTS depth distribution (percent, min, max)
DEPTH_BANDS = [
    (0.85, 700, 1500),
    (0.10, 1500, 3000),
    (0.05, 3000, 8000),
]

# Time control
CHECKPOINT_INTERVAL_MINUTES = 30
MAX_RUNTIME_MINUTES = 180
KEEP_IN_MEMORY = True

print("Configuration:")
print(f"  Total games (target): {TOTAL_GAMES:,}")
print(f"  Random probability: {RANDOM_PROB}")
print(f"  Random depth: {RANDOM_DEPTH}")
print(f"  Strong epsilon: {STRONG_EPSILON}")
print(f"  Cache size: {CACHE_SIZE}")
print(f"  Baseline eval games: {BASELINE_EVAL_GAMES}")
print(f"  Depth bands: {DEPTH_BANDS}")
print(f"  Checkpoint interval (min): {CHECKPOINT_INTERVAL_MINUTES}")
print(f"  Max runtime (min): {MAX_RUNTIME_MINUTES}")
RUN_FAST_TEST = False
FAST_TEST_GAMES = 50
FAST_TEST_MAX_MINUTES = 3


## CELL 8: Initialize Generator

In [None]:
generator = DatasetGenerator(
    mcts_simulations=MCTS_SIMS,
    cache_size=CACHE_SIZE,
    random_prob=RANDOM_PROB,
    random_depth=RANDOM_DEPTH,
    checkpoint_dir=CHECKPOINT_DIR,
    log_file=f'{LOG_DIR}/generation.log'
)

print("Generator initialized")
print("Ready to generate dataset")

In [None]:
# Fast test run (optional)
if RUN_FAST_TEST:
    print("Running fast test")
    X_test, y_test = generator.generate_full_dataset(
        total_games=FAST_TEST_GAMES,
        checkpoint_interval_minutes=1,
        max_runtime_minutes=FAST_TEST_MAX_MINUTES,
        keep_in_memory=True
    )
    print(f"Fast test complete. Examples: {len(X_test):,}")

## CELL 9: GENERATE DATASET (MAIN CELL)

In [None]:
print("Starting dataset generation")

X, y = generator.generate_full_dataset(
    total_games=TOTAL_GAMES,
    checkpoint_interval_minutes=CHECKPOINT_INTERVAL_MINUTES,
    max_runtime_minutes=MAX_RUNTIME_MINUTES,
    keep_in_memory=KEEP_IN_MEMORY
)

print("Generation complete")
print(f"X shape: {X.shape}")
print(f"y shape: {y.shape}")

## CELL 10: Data Augmentation

In [None]:
def augment_dataset(boards, moves):
    """Horizontal flip augmentation"""
    aug_boards = []
    aug_moves = []

    print("Augmenting with horizontal flips...")
    for board, move in tqdm(zip(boards, moves), total=len(boards)):
        # Original
        aug_boards.append(board)
        aug_moves.append(move)

        # Flipped
        flipped = np.flip(board, axis=1)
        flipped_move = 6 - move
        aug_boards.append(flipped)
        aug_moves.append(flipped_move)

    return np.array(aug_boards), np.array(aug_moves)

X_aug, y_aug = augment_dataset(X, y)

print(f"\n Augmentation complete!")
print(f"Original: {len(X):,} examples")
print(f"Augmented: {len(X_aug):,} examples")

## CELL 11: Quality Analysis

In [None]:
# Statistics
print("Dataset Statistics:")
print("="*50)
print(f"Total examples: {len(X_aug):,}")
print(f"X shape: {X_aug.shape}")
print(f"y shape: {y_aug.shape}")
print(f"\nMove distribution:")
for col in range(7):
    count = np.sum(y_aug == col)
    pct = (count / len(y_aug)) * 100
    print(f"  Column {col}: {count:,} ({pct:.1f}%)")

# Visualizations
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Move distribution
axes[0, 0].bar(range(7), np.bincount(y_aug), edgecolor='black')
axes[0, 0].set_title('Move Distribution', fontweight='bold')
axes[0, 0].set_xlabel('Column')
axes[0, 0].set_ylabel('Frequency')

# Game stage
stages = [np.sum(board) for board in X_aug]
axes[0, 1].hist(stages, bins=30, edgecolor='black')
axes[0, 1].set_title('Game Stage Coverage', fontweight='bold')
axes[0, 1].set_xlabel('Total Pieces')

# Sample board
idx = random.randint(0, len(X_aug)-1)
sample = X_aug[idx]
board_2d = sample[:, :, 0] - sample[:, :, 1]
axes[1, 0].imshow(board_2d, cmap='RdBu', vmin=-1, vmax=1)
axes[1, 0].set_title(f'Sample (Move: Col {y_aug[idx]})', fontweight='bold')

# Move by stage
early = [y_aug[i] for i in range(len(y_aug)) if stages[i] < 10]
late = [y_aug[i] for i in range(len(y_aug)) if stages[i] >= 20]
axes[1, 1].bar(np.arange(7)-0.2, np.bincount(early, minlength=7), width=0.4, label='Early', alpha=0.7)
axes[1, 1].bar(np.arange(7)+0.2, np.bincount(late, minlength=7), width=0.4, label='Late', alpha=0.7)
axes[1, 1].set_title('Moves by Game Stage', fontweight='bold')
axes[1, 1].legend()

plt.tight_layout()
plt.savefig(f'{PROJECT_DIR}/dataset_analysis.png', dpi=150)
plt.show()

print(f"\n Analysis saved to {PROJECT_DIR}/dataset_analysis.png")

## CELL 12: Save Final Dataset

In [None]:
final_dataset = {
    'X_train': X_aug,
    'y_train': y_aug,
    'X_original': X,
    'y_original': y,
    'encoding': 'b',
    'metadata': {
        'num_games': TOTAL_GAMES,
        'mcts_simulations': MCTS_SIMS,
        'random_prob': RANDOM_PROB,
        'random_depth': RANDOM_DEPTH,
        'augmented': True,
        'total_examples': len(X_aug),
        'generation_date': datetime.now().isoformat()
    }
}

# Save pickle
final_path = f'{DATASET_DIR}/connect4_final.pkl'
with open(final_path, 'wb') as f:
    pickle.dump(final_dataset, f)

# Save numpy (compressed)
np.savez_compressed(
    f'{DATASET_DIR}/connect4_final.npz',
    X_train=X_aug,
    y_train=y_aug
)

# Save metadata
with open(f'{DATASET_DIR}/metadata.json', 'w') as f:
    json.dump(final_dataset['metadata'], f, indent=2)

print("Dataset saved")
print(f"Pickle: {final_path}")
print(f"NumPy: {PROJECT_DIR}/datasets/connect4_final.npz")
print(f"\nFile sizes:")
print(f"  Pickle: {os.path.getsize(final_path)/1024**2:.1f} MB")
print(f"  NumPy: {os.path.getsize(f'{DATASET_DIR}/connect4_final.npz')/1024**2:.1f} MB")

## CELL 13: Download (Optional)

In [None]:
from google.colab import files

print("Downloading files...")
files.download(f'{DATASET_DIR}/connect4_final.npz')
files.download(f'{DATASET_DIR}/metadata.json')
print(" Downloads complete!")

## CELL 14: How to Load in Future

In [None]:
# Example: Load dataset in training notebook
data = np.load(f'{DATASET_DIR}/connect4_final.npz')
X_train = data['X_train']
y_train = data['y_train']

print(f"Loaded dataset:")
print(f"  X shape: {X_train.shape}")  # Should be (N, 6, 7, 2)
print(f"  y shape: {y_train.shape}")  # Should be (N,)
print(f"\n Ready for training!")