#  Connect 4 Dataset Generator - Comprehensive Edition

**Designed for Google Colab with High RAM runtime**

## Game Types Included:
-  **Deep Strategic Games** - 25+ moves, advanced positional play, forks, double threats
- ⚔️ **Complex Tactical Games** - Traps, sacrifices, forced sequences
-  **Random/Beginner Games** - Suboptimal moves to learn from errors
-  **Opening Variety** - Diverse first 5-10 moves
-  **Endgame Patterns** - Common winning endgame techniques

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q tqdm joblib

In [None]:
import numpy as np
import random
import time
import pickle
import os
from tqdm.auto import tqdm
from joblib import Parallel, delayed
from collections import defaultdict
import gc
import psutil

def get_memory_info():
    mem = psutil.virtual_memory()
    print(f"RAM: {mem.available / (1024**3):.2f} GB available / {mem.total / (1024**3):.2f} GB total ({mem.percent}% used)")

get_memory_info()

## Configuration

In [None]:
# ================== CONFIGURATION ==================

# Game distribution (total games = sum of all types)
GAME_CONFIG = {
    'deep_strategic': 14000,    # 25+ moves, high MCTS depth
    'complex_tactical': 12000,  # Variable depth, trap detection
    'beginner_random': 12000,    # Low skill, mistakes
    'opening_variety': 12000,   # Focused on diverse openings
    'endgame_patterns': 5000,   # Start from late-game positions
}

TOTAL_GAMES = sum(GAME_CONFIG.values())

# MCTS settings per game type
MCTS_SETTINGS = {
    'deep_strategic': (2500, 5000),   # Very high depth
    'complex_tactical': (1500, 4000), # Medium-high depth
    'beginner_random': (50, 300),     # Very low depth (makes mistakes)
    'opening_variety': (1000, 2500),  # Medium depth
    'endgame_patterns': (2000, 4000), # High depth for endgames
}

# Batch and parallel settings
BATCH_SIZE = 500  # legacy; not used for checkpoints
N_JOBS = 4

# Output paths
OUTPUT_DIR = '/content/drive/MyDrive/Connect4_Dataset'
CHECKPOINT_DIR = f'{OUTPUT_DIR}/checkpoints'
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Total games: {TOTAL_GAMES:,}")
for gtype, count in GAME_CONFIG.items():
    print(f"  {gtype}: {count:,}")

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# Opening book (first 5-10 moves) for structured variety
# Each sequence is a list of columns, alternating players
OPENING_BOOK = [
    [3, 2, 3, 2, 3],
    [3, 4, 3, 4, 3],
    [3, 1, 3, 1, 3],
    [2, 3, 2, 3, 2],
    [4, 3, 4, 3, 4],
    [2, 2, 3, 3, 4],
    [4, 4, 3, 3, 2],
    [3, 3, 2, 4, 2],
    [3, 3, 4, 2, 4],
    [1, 2, 3, 2, 1],
    [5, 4, 3, 4, 5],
    [0, 3, 2, 3, 4],
    [6, 3, 4, 3, 2],
    [2, 3, 4, 3, 2],
    [4, 3, 2, 3, 4],
    [1, 3, 2, 4, 2],
    [5, 3, 4, 2, 4],
    [2, 4, 3, 2, 3],
    [4, 2, 3, 4, 3],
    [3, 2, 4, 2, 3],
    [3, 4, 2, 4, 3],
    [2, 3, 2, 4, 3],
    [4, 3, 4, 2, 3],
]

# Tactical filters
TACTICAL_MIN_MOVES = 4  # minimum tactical motifs in a tactical game

# Endgame filters
ENDGAME_MIN_PIECES = 32   # board pieces count to qualify as endgame focus
ENDGAME_MAX_EMPTY = 10    # or remaining empty cells

# Win-type balancing (vertical/horizontal/diagonal)
WIN_TYPES = ['v', 'h', 'd']
TARGET_WIN_TYPE_BALANCE = 0.33  # soft target; resample if too skewed

TACTICAL_REQUIRE_FORK = True

STRONG_EPSILON = 0.02  # small randomness for variety in strong games

# Checkpointing: save dataset in N equal chunks
CHECKPOINT_SPLITS = 8
CHECKPOINT_EVERY_GAMES = max(1, TOTAL_GAMES // CHECKPOINT_SPLITS)

RUN_SMOKE_TESTS = False

MAX_RETRIES = 50


## Core Game Functions

In [None]:
def update_board(board_temp, color, column):
    board = board_temp.copy()
    colsum = abs(board[0,column])+abs(board[1,column])+abs(board[2,column])+abs(board[3,column])+abs(board[4,column])+abs(board[5,column])
    row = int(5-colsum)
    if row > -0.5:
        board[row,column] = 1 if color == 'plus' else -1
    return board

def check_for_win(board, col):
    colsum = abs(board[0,col])+abs(board[1,col])+abs(board[2,col])+abs(board[3,col])+abs(board[4,col])+abs(board[5,col])
    row = int(6-colsum)
    
    # Vertical
    if row+3<6:
        vert = board[row,col] + board[row+1,col] + board[row+2,col] + board[row+3,col]
        if vert == 4: return 'v-plus'
        elif vert == -4: return 'v-minus'
    
    # Horizontal
    for sc in range(max(0, col-3), min(4, col+1)):
        if sc + 3 < 7:
            hor = board[row, sc] + board[row, sc+1] + board[row, sc+2] + board[row, sc+3]
            if hor == 4: return 'h-plus'
            elif hor == -4: return 'h-minus'
    
    # Diagonals
    for i in range(-3, 1):
        r, c = row + i, col + i
        if 0 <= r <= 2 and 0 <= c <= 3:
            diag = board[r,c] + board[r+1,c+1] + board[r+2,c+2] + board[r+3,c+3]
            if diag == 4: return 'd-plus'
            elif diag == -4: return 'd-minus'
    for i in range(-3, 1):
        r, c = row + i, col - i
        if 0 <= r <= 2 and 3 <= c <= 6:
            diag = board[r,c] + board[r+1,c-1] + board[r+2,c-2] + board[r+3,c-3]
            if diag == 4: return 'd-plus'
            elif diag == -4: return 'd-minus'
    return 'nobody'

def find_legal(board):
    return [i for i in range(7) if abs(board[0,i]) < 0.1]

def look_for_win(board_, color):
    for m in find_legal(board_):
        bt = update_board(board_.copy(), color, m)
        if check_for_win(bt, m)[2:] == color:
            return m
    return -1

def find_all_nonlosers(board, color):
    opp = 'minus' if color == 'plus' else 'plus'
    legal = find_legal(board)
    allowed = []
    for l in legal:
        pb = update_board(board, color, l)
        wins = [j for j in find_legal(pb) if check_for_win(update_board(pb, opp, j), j) != 'nobody']
        if len(wins) == 0:
            allowed.append(l)
    return allowed if allowed else legal

## MCTS Implementation

In [None]:
def back_prop(winner, path, color0, md):
    for i, board_temp in enumerate(path):
        md[board_temp][0] += 1
        if winner[2] == color0[0]:
            md[board_temp][1] += 1 if i % 2 == 1 else -1
        elif winner[2] != 'e':
            md[board_temp][1] += -1 if i % 2 == 1 else 1

def rollout(board, next_player):
    winner = 'nobody'
    player = next_player
    while winner == 'nobody':
        legal = find_legal(board)
        if len(legal) == 0: return 'tie'
        move = random.choice(legal)
        board = update_board(board, player, move)
        winner = check_for_win(board, move)
        player = 'minus' if player == 'plus' else 'plus'
    return winner

def mcts(board_temp, color0, nsteps):
    board = board_temp.copy()
    winColumn = look_for_win(board, color0)
    if winColumn > -0.5: return winColumn
    
    legal0 = find_all_nonlosers(board, color0)
    if len(legal0) == 0: return find_legal(board)[0] if find_legal(board) else -1
    
    mcts_dict = {tuple(board.ravel()): [0, 0]}
    
    for _ in range(nsteps):
        color = color0
        winner = 'nobody'
        board_mcts = board.copy()
        path = [tuple(board_mcts.ravel())]
        
        while winner == 'nobody':
            legal = find_legal(board_mcts)
            if len(legal) == 0:
                winner = 'tie'
                back_prop(winner, path, color0, mcts_dict)
                break
            
            board_list = [tuple(update_board(board_mcts, color, col).ravel()) for col in legal]
            for bl in board_list:
                if bl not in mcts_dict: mcts_dict[bl] = [0, 0]
            
            ucb1 = np.zeros(len(legal))
            for i in range(len(legal)):
                nd = mcts_dict[board_list[i]]
                if nd[0] == 0: ucb1[i] = 10 * nsteps
                else: ucb1[i] = nd[1]/nd[0] + 2*np.sqrt(np.log(mcts_dict[path[-1]][0])/nd[0])
            
            chosen = np.argmax(ucb1)
            board_mcts = update_board(board_mcts, color, legal[chosen])
            path.append(tuple(board_mcts.ravel()))
            winner = check_for_win(board_mcts, legal[chosen])
            
            if winner[2] == color[0]:
                back_prop(winner, path, color0, mcts_dict)
                break
            
            color = 'minus' if color == 'plus' else 'plus'
            if mcts_dict[tuple(board_mcts.ravel())][0] == 0:
                winner = rollout(board_mcts, color)
                back_prop(winner, path, color0, mcts_dict)
                break
    
    maxval = -np.inf
    best_col = legal0[0]
    for col in legal0:
        nd = mcts_dict.get(tuple(update_board(board, color0, col).ravel()), [0, 0])
        if nd[0] > 0 and nd[1]/nd[0] > maxval:
            maxval = nd[1]/nd[0]
            best_col = col
    return best_col

## Board Utilities

In [None]:
def board_to_tensor(board, current_player):
    tensor = np.zeros((6, 7, 2), dtype=np.float32)
    if current_player == 'plus':
        tensor[:, :, 0] = (board == 1).astype(np.float32)
        tensor[:, :, 1] = (board == -1).astype(np.float32)
    else:
        tensor[:, :, 0] = (board == -1).astype(np.float32)
        tensor[:, :, 1] = (board == 1).astype(np.float32)
    return tensor

def mirror_board(board): return np.fliplr(board)
def mirror_move(move): return 6 - move

def count_pieces(board):
    return int(np.sum(np.abs(board)))

## Pattern Detection for Tactical Games

In [None]:
def count_threats(board, color):
    """Count positions where player has 3-in-a-row with empty 4th spot."""
    val = 1 if color == 'plus' else -1
    threats = 0
    
    # Horizontal threats
    for r in range(6):
        for c in range(4):
            window = board[r, c:c+4]
            if np.sum(window == val) == 3 and np.sum(window == 0) == 1:
                threats += 1
    
    # Vertical threats
    for r in range(3):
        for c in range(7):
            window = board[r:r+4, c]
            if np.sum(window == val) == 3 and np.sum(window == 0) == 1:
                threats += 1
    
    # Diagonal threats
    for r in range(3):
        for c in range(4):
            window = [board[r+i, c+i] for i in range(4)]
            if sum(1 for x in window if x == val) == 3 and sum(1 for x in window if x == 0) == 1:
                threats += 1
    for r in range(3):
        for c in range(3, 7):
            window = [board[r+i, c-i] for i in range(4)]
            if sum(1 for x in window if x == val) == 3 and sum(1 for x in window if x == 0) == 1:
                threats += 1
    return threats

def has_fork(board, color):
    """Check if player has a fork (2+ simultaneous winning threats)."""
    threats = 0
    for col in find_legal(board):
        new_board = update_board(board.copy(), color, col)
        if look_for_win(new_board, color) > -0.5:
            threats += 1
    return threats >= 2

## Game Generators by Type

In [None]:
def collect_game_data(board_history, winner):
    game_data = []
    result_for_plus = 1.0 if 'plus' in winner else (0.0 if 'minus' in winner else 0.5)

    for pos in board_history:
        result = result_for_plus if pos['player'] == 'plus' else 1.0 - result_for_plus
        tensor = board_to_tensor(pos['board'], pos['player'])
        game_data.append({'board': tensor, 'move': pos['move'], 'result': result,
                          'turn': pos['turn'], 'game_type': pos.get('game_type', 'unknown')})

        mirrored_tensor = board_to_tensor(mirror_board(pos['board']), pos['player'])
        game_data.append({'board': mirrored_tensor, 'move': mirror_move(pos['move']),
                          'result': result, 'turn': pos['turn'], 'game_type': pos.get('game_type', 'unknown')})
    return game_data


def play_deep_strategic_game():
    min_steps, max_steps = MCTS_SETTINGS['deep_strategic']

    for _ in range(MAX_RETRIES):
        board = np.zeros((6, 7))
        winner = 'nobody'
        color = 'plus'
        history = []
        turn = 0

        while winner == 'nobody':
            legal = find_legal(board)
            if len(legal) == 0:
                winner = 'tie'
                break

            nsteps = random.randint(min_steps, max_steps)
            if random.random() < STRONG_EPSILON:
                move = random.choice(legal)
            else:
                move = mcts(board, color, nsteps)
            if move < 0:
                break

            history.append({'board': board.copy(), 'player': color, 'move': move,
                            'turn': turn, 'game_type': 'deep_strategic'})
            board = update_board(board, color, move)
            winner = check_for_win(board, move)
            color = 'minus' if color == 'plus' else 'plus'
            turn += 1

        if turn >= 25:
            return collect_game_data(history, winner), winner

    return collect_game_data(history, winner), winner


def play_tactical_game():
    min_steps, max_steps = MCTS_SETTINGS['complex_tactical']

    for _ in range(MAX_RETRIES):
        board = np.zeros((6, 7))
        winner = 'nobody'
        color = 'plus'
        history = []
        turn = 0
        tactical_moves = 0
        has_any_fork = False

        while winner == 'nobody':
            legal = find_legal(board)
            if len(legal) == 0:
                winner = 'tie'
                break

            threats_before = count_threats(board, color)
            nsteps = random.randint(min_steps, max_steps)
            if random.random() < STRONG_EPSILON:
                move = random.choice(legal)
            else:
                move = mcts(board, color, nsteps)
            if move < 0:
                break

            new_board = update_board(board.copy(), color, move)
            threats_after = count_threats(new_board, color)
            is_fork = has_fork(new_board, color)

            if threats_after > threats_before or is_fork:
                if is_fork:
                    has_any_fork = True
                tactical_moves += 1

            history.append({'board': board.copy(), 'player': color, 'move': move,
                            'turn': turn, 'game_type': 'complex_tactical'})
            board = new_board
            winner = check_for_win(board, move)
            color = 'minus' if color == 'plus' else 'plus'
            turn += 1

        if tactical_moves >= TACTICAL_MIN_MOVES and (not TACTICAL_REQUIRE_FORK or has_any_fork):
            return collect_game_data(history, winner), winner

    return collect_game_data(history, winner), winner


def play_beginner_game():
    min_steps, max_steps = MCTS_SETTINGS['beginner_random']
    board = np.zeros((6, 7))
    winner = 'nobody'
    color = 'plus'
    history = []
    turn = 0

    while winner == 'nobody':
        legal = find_legal(board)
        if len(legal) == 0:
            winner = 'tie'
            break

        if random.random() < 0.4:
            move = random.choice(legal)
        else:
            nsteps = random.randint(min_steps, max_steps)
            move = mcts(board, color, nsteps)
            if move < 0:
                move = random.choice(legal)

        history.append({'board': board.copy(), 'player': color, 'move': move,
                        'turn': turn, 'game_type': 'beginner_random'})
        board = update_board(board, color, move)
        winner = check_for_win(board, move)
        color = 'minus' if color == 'plus' else 'plus'
        turn += 1

    return collect_game_data(history, winner), winner


def play_opening_variety_game():
    min_steps, max_steps = MCTS_SETTINGS['opening_variety']
    board = np.zeros((6, 7))
    winner = 'nobody'
    color = 'plus'
    history = []
    turn = 0

    random_opening_moves = random.randint(5, 10)
    use_book = random.random() < 0.5
    opening_seq = random.choice(OPENING_BOOK) if use_book else None

    while winner == 'nobody':
        legal = find_legal(board)
        if len(legal) == 0:
            winner = 'tie'
            break

        if turn < random_opening_moves:
            if opening_seq and turn < len(opening_seq) and opening_seq[turn] in legal:
                move = opening_seq[turn]
            else:
                weights = [1, 2, 3, 4, 3, 2, 1]
                weights = [weights[i] if i in legal else 0 for i in range(7)]
                s = sum(weights)
                if s == 0:
                    move = random.choice(legal)
                else:
                    weights = [w/s for w in weights]
                    move = np.random.choice(7, p=weights)
        else:
            nsteps = random.randint(min_steps, max_steps)
            move = mcts(board, color, nsteps)
            if move < 0:
                move = random.choice(legal)

        history.append({'board': board.copy(), 'player': color, 'move': move,
                        'turn': turn, 'game_type': 'opening_variety'})
        board = update_board(board, color, move)
        winner = check_for_win(board, move)
        color = 'minus' if color == 'plus' else 'plus'
        turn += 1

    return collect_game_data(history, winner), winner


def play_endgame_pattern_game():
    min_steps, max_steps = MCTS_SETTINGS['endgame_patterns']

    for _ in range(MAX_RETRIES):
        board = np.zeros((6, 7))
        color = 'plus'

        setup_moves = random.randint(12, 20)
        for _ in range(setup_moves):
            legal = find_legal(board)
            if len(legal) == 0:
                break
            move = random.choice(legal)
            board = update_board(board, color, move)
            if check_for_win(board, move) != 'nobody':
                board = None
                break
            color = 'minus' if color == 'plus' else 'plus'
        if board is None:
            continue

        pieces = count_pieces(board)
        empty = 42 - pieces
        if pieces < ENDGAME_MIN_PIECES and empty > ENDGAME_MAX_EMPTY:
            continue

        winner = 'nobody'
        history = []
        turn = setup_moves

        while winner == 'nobody':
            legal = find_legal(board)
            if len(legal) == 0:
                winner = 'tie'
                break

            nsteps = random.randint(min_steps, max_steps)
            if random.random() < STRONG_EPSILON:
                move = random.choice(legal)
            else:
                move = mcts(board, color, nsteps)
            if move < 0:
                move = random.choice(legal)

            history.append({'board': board.copy(), 'player': color, 'move': move,
                            'turn': turn, 'game_type': 'endgame_patterns'})
            board = update_board(board, color, move)
            winner = check_for_win(board, move)
            color = 'minus' if color == 'plus' else 'plus'
            turn += 1

        return collect_game_data(history, winner), winner

    return collect_game_data([], 'tie'), 'tie'

In [None]:
def generate_single_game(game_type):
    """Generate a single game of the specified type."""
    if game_type == 'deep_strategic':
        return play_deep_strategic_game()
    elif game_type == 'complex_tactical':
        return play_tactical_game()
    elif game_type == 'beginner_random':
        return play_beginner_game()
    elif game_type == 'opening_variety':
        return play_opening_variety_game()
    elif game_type == 'endgame_patterns':
        return play_endgame_pattern_game()
    else:
        return play_tactical_game()

In [None]:
if RUN_SMOKE_TESTS:
    print("Running smoke tests")
    for gtype in GAME_CONFIG.keys():
        data, winner = generate_single_game(gtype)
        print(f"{gtype}: winner={winner}, positions={len(data)}")

## Dataset Generation

In [None]:
def save_checkpoint(data, batch_num, checkpoint_dir):
    filepath = f"{checkpoint_dir}/batch_{batch_num:05d}.pkl"
    with open(filepath, 'wb') as f:
        pickle.dump(data, f)
    return filepath

def generate_full_dataset():
    all_data = []
    stats = {gt: {'games': 0, 'plus': 0, 'minus': 0, 'tie': 0} for gt in GAME_CONFIG}
    win_type_counts = {wt: 0 for wt in WIN_TYPES}
    
    # Create task list with game types
    tasks = []
    for gtype, count in GAME_CONFIG.items():
        tasks.extend([gtype] * count)
    random.shuffle(tasks)  # Shuffle for balanced batches
    
    start_time = time.time()
    batch_num = 0
    checkpoint_num = 0
    batch_data = []
    
    for i, gtype in enumerate(tqdm(tasks, desc="Generating games")):
        try:
            data, winner = generate_single_game(gtype)
            wtype = win_type_label(winner)

            # Soft balance: if one win type dominates too much, resample
            if wtype in WIN_TYPES:
                total_w = sum(win_type_counts.values()) + 1e-9
                frac = win_type_counts[wtype] / total_w
                if frac > (TARGET_WIN_TYPE_BALANCE + 0.15):
                    data, winner = generate_single_game(gtype)
                    wtype = win_type_label(winner)
            batch_data.extend(data)
            
            # Update stats
            stats[gtype]['games'] += 1
            if wtype in WIN_TYPES: win_type_counts[wtype] += 1
            if 'plus' in winner: stats[gtype]['plus'] += 1
            elif 'minus' in winner: stats[gtype]['minus'] += 1
            else: stats[gtype]['tie'] += 1
            
            # Save checkpoint (1/8th of total games per batch)
            if (i + 1) % CHECKPOINT_EVERY_GAMES == 0:
                save_checkpoint(batch_data, batch_num, CHECKPOINT_DIR)
                all_data.extend(batch_data)
                batch_data = []
                batch_num += 1
                checkpoint_num += 1
                gc.collect()
                
                # Progress update
                elapsed = time.time() - start_time
                rate = (i + 1) / elapsed
                eta = (len(tasks) - i - 1) / rate / 60
                print(f"\nBatch {batch_num}: {len(all_data):,} positions | ETA: {eta:.1f} min")
                get_memory_info()
        except Exception as e:
            print(f"Error in game {i}: {e}")
            continue
    
    # Save remaining data
    if batch_data:
        save_checkpoint(batch_data, batch_num, CHECKPOINT_DIR)
        all_data.extend(batch_data)
    
    # Save final remainder checkpoint
    if batch_data:
        save_checkpoint(batch_data, batch_num, CHECKPOINT_DIR)
        all_data.extend(batch_data)
        batch_data = []
        print(f"Final checkpoint saved. Positions: {len(all_data):,}.")

    return all_data, stats, win_type_counts

In [None]:
# Helper to get win type (v/h/d/tie/nobody)
def win_type_label(winner):
    if winner == 'tie' or winner == 'nobody':
        return 'tie'
    return winner[0]

In [None]:
print("Starting dataset generation")

dataset, game_stats, win_type_counts = generate_full_dataset()

print("Generation complete")
print(f"Total positions: {len(dataset):,}")
print("Game statistics by type:")
for gtype, st in game_stats.items():
    print(f"  {gtype}: {st['games']} games (P1 wins: {st['plus']}, P2 wins: {st['minus']}, ties: {st['tie']})")
print("Win type distribution:")
for wt, cnt in win_type_counts.items():
    print(f"  {wt}: {cnt}")

## Prepare and Save Final Dataset

In [None]:
def prepare_final_dataset(data):
    n = len(data)
    X = np.zeros((n, 6, 7, 2), dtype=np.float32)
    y_move = np.zeros(n, dtype=np.int8)
    y_result = np.zeros(n, dtype=np.float32)
    turns = np.zeros(n, dtype=np.int8)
    game_types = []
    
    for i, item in enumerate(tqdm(data, desc="Preparing dataset")):
        X[i] = item['board']
        y_move[i] = item['move']
        y_result[i] = item['result']
        turns[i] = item['turn']
        game_types.append(item.get('game_type', 'unknown'))
    
    return X, y_move, y_result, turns, game_types

X, y_move, y_result, turns, game_types = prepare_final_dataset(dataset)

# Shuffle
indices = np.random.permutation(len(X))
X, y_move, y_result, turns = X[indices], y_move[indices], y_result[indices], turns[indices]
game_types = [game_types[i] for i in indices]

print(f"\nDataset shapes: X={X.shape}, y_move={y_move.shape}, y_result={y_result.shape}")

In [None]:
# Save dataset
np.savez_compressed(f"{OUTPUT_DIR}/connect4_dataset.npz", X=X, y_move=y_move, y_result=y_result, turns=turns)

# Save metadata
metadata = {
    'total_positions': len(X),
    'game_config': GAME_CONFIG,
    'mcts_settings': MCTS_SETTINGS,
    'game_stats': game_stats,
    'win_type_counts': win_type_counts,
    'seed': SEED,
    'game_types_distribution': {gt: game_types.count(gt) for gt in set(game_types)}
}
with open(f"{OUTPUT_DIR}/metadata.pkl", 'wb') as f:
    pickle.dump(metadata, f)

print(f"\n✅ Dataset saved to: {OUTPUT_DIR}")
for f in os.listdir(OUTPUT_DIR):
    if not os.path.isdir(f"{OUTPUT_DIR}/{f}"):
        size = os.path.getsize(f"{OUTPUT_DIR}/{f}") / (1024**2)
        print(f"  {f}: {size:.2f} MB")

## Dataset Analysis

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Game type distribution
gt_counts = {gt: game_types.count(gt) for gt in set(game_types)}
axes[0,0].bar(gt_counts.keys(), gt_counts.values(), color='steelblue', edgecolor='black')
axes[0,0].set_title('Positions by Game Type')
axes[0,0].tick_params(axis='x', rotation=45)

# Move distribution
move_counts = np.bincount(y_move, minlength=7)
axes[0,1].bar(range(7), move_counts, color='green', edgecolor='black')
axes[0,1].set_xlabel('Column')
axes[0,1].set_title('Move Distribution')

# Turn distribution
axes[1,0].hist(turns, bins=range(0, 43), color='purple', edgecolor='black', alpha=0.7)
axes[1,0].set_xlabel('Turn Number')
axes[1,0].set_title('Position Distribution by Turn')

# Result distribution
wins = np.sum(y_result == 1)
losses = np.sum(y_result == 0)
draws = np.sum(y_result == 0.5)
axes[1,1].bar(['Loss', 'Draw', 'Win'], [losses, draws, wins], color=['red', 'gray', 'green'])
axes[1,1].set_title('Outcome Distribution')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/dataset_analysis.png", dpi=150)
plt.show()

print(f"\n Dataset Summary:")
print(f"Total positions: {len(X):,}")
print(f"Average turn: {np.mean(turns):.1f}")
print(f"Unique game types: {len(set(game_types))}")

In [None]:
# Validation summary
print("\n✅ Validation Summary")
print(f"Total positions: {len(X):,}")
print(f"Average turn: {np.mean(turns):.1f}")
print("Game type counts:", {gt: game_types.count(gt) for gt in set(game_types)})
print("Win type counts:", win_type_counts)

# Opening diversity (first 10 moves)
first_moves = [item for item in dataset if item['turn'] < 10]
if first_moves:
    opening_move_counts = np.bincount([i['move'] for i in first_moves], minlength=7)
    print("Opening move distribution (turn < 10):", opening_move_counts.tolist())

In [None]:
# Save QA report
report_path = f"{OUTPUT_DIR}/dataset_qa_report.txt"
with open(report_path, 'w') as f:
    f.write("Connect4 Dataset QA Report\n")
    f.write("===========================\n")
    f.write(f"Total positions: {len(X):,}\n")
    f.write(f"Average turn: {np.mean(turns):.1f}\n")
    f.write(f"Game type counts: { {gt: game_types.count(gt) for gt in set(game_types)} }\n")
    f.write(f"Win type counts: {win_type_counts}\n")
    f.write(f"Opening move distribution (turn < 10): {opening_move_counts.tolist() if 'opening_move_counts' in globals() else 'n/a'}\n")
    f.write(f"Seed: {SEED}\n")
print(f"QA report saved to: {report_path}")

## Train/Val/Test Split

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_temp, y_m_train, y_m_temp, y_r_train, y_r_temp = train_test_split(
    X, y_move, y_result, test_size=0.2, random_state=42)
X_val, X_test, y_m_val, y_m_test, y_r_val, y_r_test = train_test_split(
    X_temp, y_m_temp, y_r_temp, test_size=0.5, random_state=42)

np.savez_compressed(f"{OUTPUT_DIR}/train.npz", X=X_train, y_move=y_m_train, y_result=y_r_train)
np.savez_compressed(f"{OUTPUT_DIR}/val.npz", X=X_val, y_move=y_m_val, y_result=y_r_val)
np.savez_compressed(f"{OUTPUT_DIR}/test.npz", X=X_test, y_move=y_m_test, y_result=y_r_test)

print(f"Train: {len(X_train):,} | Val: {len(X_val):,} | Test: {len(X_test):,}")
print("\n✅ All splits saved!")

In [None]:
print("Dataset generation complete")
print(f"Total positions: {len(X):,}")
print("Game types generated:")
for gt, count in GAME_CONFIG.items():
    print(f"  {gt}: {count:,} games")
print(f"Files saved to: {OUTPUT_DIR}")
get_memory_info()