# Connect4 High-Depth Dataset (2 Hour Cap)

Generate a smaller, higher-quality dataset with deeper MCTS within a strict time budget.

In [None]:
# Mount Google Drive
from google.colab import drive
import os

drive.mount('/content/drive')

PROJECT_DIR = '/content/drive/MyDrive/Connect4_Project'
DATASET_DIR = f'{PROJECT_DIR}/datasets'
CHECKPOINT_DIR = f'{PROJECT_DIR}/checkpoints_high_depth'
LOG_DIR = f'{PROJECT_DIR}/logs'

os.makedirs(DATASET_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

print('Drive mounted')
print('Dataset dir:', DATASET_DIR)
print('Checkpoint dir:', CHECKPOINT_DIR)

In [None]:
import numpy as np
import random
import pickle
import time
from datetime import datetime
from tqdm.notebook import tqdm

In [None]:
# Config
TOTAL_GAMES_TARGET = 1500
CHECKPOINT_INTERVAL_MINUTES = 30
MAX_RUNTIME_MINUTES = 120

# Depth bands for high quality
DEPTH_BANDS = [
    (0.70, 2000, 4000),
    (0.20, 4000, 6000),
    (0.10, 6000, 8000),
]

RANDOM_PROB = 0.10
RANDOM_DEPTH = 10
STRONG_EPSILON = 0.02

print('Config:')
print('  target games:', TOTAL_GAMES_TARGET)
print('  time cap (min):', MAX_RUNTIME_MINUTES)
print('  depth bands:', DEPTH_BANDS)

In [None]:
# Connect4 engine
class Connect4:
    def __init__(self):
        self.board = np.zeros((6, 7), dtype=np.int8)
        self.heights = np.zeros(7, dtype=np.int8)
        self.current_player = 1
        self.winner = None
        self.move_count = 0

    def copy(self):
        g = Connect4()
        g.board = self.board.copy()
        g.heights = self.heights.copy()
        g.current_player = self.current_player
        g.winner = self.winner
        g.move_count = self.move_count
        return g

    def legal_moves(self):
        return [c for c in range(7) if self.heights[c] < 6]

    def make_move(self, col):
        if self.heights[col] >= 6:
            return False
        row = self.heights[col]
        self.board[row, col] = self.current_player
        self.heights[col] += 1
        self.move_count += 1
        if self._check_win(row, col):
            self.winner = self.current_player
        elif self.move_count >= 42:
            self.winner = 0
        self.current_player *= -1
        return True

    def _check_win(self, row, col):
        player = self.board[row, col]
        if row <= 2 and np.sum(self.board[row:row+4, col]) == 4 * player:
            return True
        for c in range(max(0, col-3), min(4, col+1)):
            if np.sum(self.board[row, c:c+4]) == 4 * player:
                return True
        for dr, dc in [(1, 1), (1, -1)]:
            count = 1
            for sign in [1, -1]:
                r, c = row + sign*dr, col + sign*dc
                while 0 <= r < 6 and 0 <= c < 7 and self.board[r, c] == player:
                    count += 1
                    r += sign*dr
                    c += sign*dc
            if count >= 4:
                return True
        return False

    def is_terminal(self):
        return self.winner is not None

    def get_result(self, player):
        if self.winner is None:
            return 0.0
        return 1.0 if self.winner == player else (-1.0 if self.winner != 0 else 0.0)

    def encode(self):
        enc = np.zeros((6, 7, 2), dtype=np.float32)
        enc[:, :, 0] = (self.board == 1)
        enc[:, :, 1] = (self.board == -1)
        return enc

In [None]:
# MCTS
class MCTS:
    def __init__(self, sims=2000):
        self.sims = sims

    def get_move(self, game, sims=None):
        sims = sims or self.sims
        win = self._find_winning_move(game, game.current_player)
        if win is not None:
            return win
        block = self._find_winning_move(game, -game.current_player)
        if block is not None:
            return block
        return self._mcts(game, sims)

    def _find_winning_move(self, game, player):
        for col in game.legal_moves():
            test = game.copy()
            old_p = test.current_player
            test.current_player = player
            test.make_move(col)
            if test.winner == player:
                return col
            test.current_player = old_p
        return None

    def _mcts(self, game, sims):
        root_player = game.current_player
        stats = {}
        for _ in range(sims):
            node = game.copy()
            path = []
            while not node.is_terminal():
                state = hash(node.board.tobytes())
                path.append(state)
                if state not in stats:
                    stats[state] = [0, 0.0]
                moves = node.legal_moves()
                if not moves:
                    break
                best_move = None
                best_ucb = -1e9
                parent_visits = stats[state][0]
                for col in moves:
                    test = node.copy()
                    test.make_move(col)
                    child = hash(test.board.tobytes())
                    if child not in stats:
                        stats[child] = [0, 0.0]
                    visits, value = stats[child]
                    if visits == 0:
                        ucb = 1e9
                    else:
                        exploit = value / visits
                        explore = 1.4 * np.sqrt(np.log(parent_visits + 1) / visits)
                        ucb = exploit + explore
                    if ucb > best_ucb:
                        best_ucb = ucb
                        best_move = col
                node.make_move(best_move)
                if stats[hash(node.board.tobytes())][0] == 0:
                    path.append(hash(node.board.tobytes()))
                    break
            depth = 0
            while not node.is_terminal() and depth < 20:
                moves = node.legal_moves()
                if not moves:
                    break
                node.make_move(random.choice(moves))
                depth += 1
            result = 1.0 if node.winner == root_player else (-1.0 if node.winner == -root_player else 0.0)
            for st in path:
                if st in stats:
                    stats[st][0] += 1
                    stats[st][1] += result
        best_move, best_val = None, -1e9
        for col in game.legal_moves():
            test = game.copy()
            test.make_move(col)
            st = hash(test.board.tobytes())
            if st in stats and stats[st][0] > 0:
                val = stats[st][1] / stats[st][0]
            else:
                val = -1e9
            if val > best_val:
                best_val, best_move = val, col
        return best_move if best_move is not None else random.choice(game.legal_moves())

In [None]:
# Depth sampler

def sample_depth():
    r = random.random()
    acc = 0.0
    for pct, lo, hi in DEPTH_BANDS:
        acc += pct
        if r <= acc:
            return random.randint(lo, hi)
    return random.randint(DEPTH_BANDS[-1][1], DEPTH_BANDS[-1][2])

In [None]:
# Generate dataset with checkpoints
player = MCTS(sims=2000)

all_boards = []
all_moves = []

games_done = 0
start_time = time.time()
last_ckpt = start_time
ckpt_num = 0

while games_done < TOTAL_GAMES_TARGET:
    if (time.time() - start_time) / 60.0 >= MAX_RUNTIME_MINUTES:
        print('Time cap reached, stopping')
        break

    game = Connect4()
    game_boards = []
    game_moves = []
    move_count = 0

    while not game.is_terminal():
        current_player = game.current_player
        if move_count < RANDOM_DEPTH and random.random() < RANDOM_PROB:
            col = random.choice(game.legal_moves())
        else:
            depth = sample_depth()
            if random.random() < STRONG_EPSILON:
                col = random.choice(game.legal_moves())
            else:
                col = player.get_move(game, sims=depth)

            if current_player == 1:
                board = game.encode()
            else:
                flipped = game.copy()
                flipped.board *= -1
                board = flipped.encode()

            game_boards.append(board)
            game_moves.append(col)

        game.make_move(col)
        move_count += 1
        if move_count > 42:
            break

    all_boards.extend(game_boards)
    all_moves.extend(game_moves)
    games_done += 1

    # checkpoint
    if (time.time() - last_ckpt) / 60.0 >= CHECKPOINT_INTERVAL_MINUTES:
        ckpt_num += 1
        ckpt_path = f"{CHECKPOINT_DIR}/high_depth_ckpt_{ckpt_num:03d}.pkl"
        with open(ckpt_path, 'wb') as f:
            pickle.dump({'boards': all_boards, 'moves': all_moves, 'games': games_done}, f)
        print(f"Checkpoint {ckpt_num} saved: games={games_done}, examples={len(all_boards)}")
        last_ckpt = time.time()

# Save final dataset
npz_path = f"{DATASET_DIR}/connect4_high_depth.npz"
np.savez_compressed(npz_path, X_train=np.array(all_boards), y_train=np.array(all_moves))
print('Saved:', npz_path)
print('Games:', games_done, 'Examples:', len(all_boards))