# Tic-Tac-Toe Dataset Generation


In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import deque
from multiprocessing import Pool, Manager, cpu_count
import itertools
from functools import lru_cache

In [2]:
BOARD_SIZE = 5
EMPTY_TOKEN = "-"
TOKENS = ("X", "O", EMPTY_TOKEN)

In [None]:
@lru_cache(maxsize=10000)
def is_board_ended_cached(board_str):
    n = int(len(board_str) ** 0.5)
    board = [list(board_str[i*n:(i+1)*n]) for i in range(n)]
    return is_board_ended(board)

def is_board_ended(board):
    n = len(board)
    if n == 3:
        # Fast path for 3x3 - check all 8 possible lines
        lines = [
            # Rows
            (board[0][0], board[0][1], board[0][2]),
            (board[1][0], board[1][1], board[1][2]),
            (board[2][0], board[2][1], board[2][2]),
            # Columns
            (board[0][0], board[1][0], board[2][0]),
            (board[0][1], board[1][1], board[2][1]),
            (board[0][2], board[1][2], board[2][2]),
            # Diagonals
            (board[0][0], board[1][1], board[2][2]),
            (board[0][2], board[1][1], board[2][0]),
        ]
        for line in lines:
            if line[0] != EMPTY_TOKEN and line[0] == line[1] == line[2]:
                return True
        return False
    
    # Fallback for other sizes
    win_len = min(n, 5)
    
    def check_line(line):
        count = 1
        prev = None
        for token in line:
            if token == EMPTY_TOKEN:
                prev = None
                count = 0
                continue
            if token == prev:
                count += 1
                if count >= win_len:
                    return True
            else:
                prev = token
                count = 1
        return False

    for i in range(n):
        if check_line(board[i]):
            return True
    
    for i in range(n):
        if check_line([board[j][i] for j in range(n)]):
            return True
    
    for p in range(n * 2 - 1):
        diag1 = [board[i][p - i] for i in range(max(0, p - n + 1), min(p + 1, n)) if 0 <= p - i < n]
        if len(diag1) >= win_len and check_line(diag1):
            return True
    
    for p in range(-n + 1, n):
        diag2 = [board[i][i - p] for i in range(max(0, p), min(n, n + p)) if 0 <= i - p < n]
        if len(diag2) >= win_len and check_line(diag2):
            return True
    
    return False

def rotate_matrix_90(board):
    n = len(board)
    return [[board[n - 1 - j][i] for j in range(n)] for i in range(n)]



In [4]:
def print_board(board):
    n = len(board)
    for i, row in enumerate(board):
        print(' | '.join(cell if cell != EMPTY_TOKEN else ' ' for cell in row))
        if i < n - 1:
            print('-' * (n * 4 - 1))
    print()

In [None]:
def _hash_board(board):
    return ''.join(''.join(row) for row in board)

def _flip_horizontal(board):
    return [row[::-1] for row in board]

def _get_all_symmetries(board):
    variants = []
    current = board
    
    # 4 rotations
    for _ in range(4):
        variants.append(current)
        current = rotate_matrix_90(current)
    
    # 4 rotations of horizontal flip
    flipped = _flip_horizontal(board)
    current = flipped
    for _ in range(4):
        variants.append(current)
        current = rotate_matrix_90(current)
    
    return variants

def get_canonical_form(board):
    symmetries = _get_all_symmetries(board)
    hashes = [_hash_board(s) for s in symmetries]
    return min(hashes)

class BoardStateMap:
    def __init__(self, output_file="unique.out"):
        self.seen = set()
        self.output_file = output_file
        self.buffer = []
        self.buffer_size = 5000  # Larger buffer for better I/O performance
        with open(self.output_file, "w") as f:
            pass
    
    def check_and_add(self, board):
        best_key = get_canonical_form(board)
        
        if best_key in self.seen:
            return False
        
        self.seen.add(best_key)
        self.buffer.append(best_key)
        
        if len(self.buffer) >= self.buffer_size:
            self._flush()
        
        return True
    
    def _flush(self):
        if self.buffer:
            with open(self.output_file, "a") as f:
                f.write('\n'.join(self.buffer) + '\n')
            self.buffer.clear()
    
    def close(self):
        self._flush()

def count_tokens(board):
    board_str = _hash_board(board)
    x_count = board_str.count('X')
    o_count = board_str.count('O')
    return x_count, o_count    

In [None]:
def is_board_full(board):
    return all(cell != EMPTY_TOKEN for row in board for cell in row)

def process_batch_fast(args):
    boards_batch, board_size = args
    local_results = {
        'stats': {'win': 0, 'draw': 0, 'ongoing': 0},
        'new_children': []  # Only canonical forms
    }
    
    for board in boards_batch:
        has_winner = is_board_ended(board)
        is_full = is_board_full(board)
        game_ended = has_winner or is_full
        
        if game_ended:
            if has_winner:
                local_results['stats']['win'] += 1
            else:
                local_results['stats']['draw'] += 1
            continue
        
        # Ongoing game - generate children
        local_results['stats']['ongoing'] += 1
        x_count, o_count = count_tokens(board)
        current_token = 'X' if x_count == o_count else 'O'
        
        for i in range(board_size):
            for j in range(board_size):
                if board[i][j] == EMPTY_TOKEN:
                    new_board = [row[:] for row in board]
                    new_board[i][j] = current_token
                    
                    # Get canonical form immediately
                    canonical = get_canonical_form(new_board)
                    local_results['new_children'].append(canonical)
    
    return local_results

def canonical_to_board(canonical_str, board_size):
    return [list(canonical_str[i*board_size:(i+1)*board_size]) 
            for i in range(board_size)]

def generate_all_boards(board_size=BOARD_SIZE, n_workers=None, chunk_size=30000):
    if n_workers is None:
        n_workers = max(1, cpu_count() - 1)
    
    print(f"Using {n_workers} worker processes")
    print(f"Chunk size: {chunk_size:,} states per iteration")
    
    seen_canonical = set()
    
    board_map = BoardStateMap()
    empty_board = [[EMPTY_TOKEN for _ in range(board_size)] for _ in range(board_size)]
    
    empty_canonical = get_canonical_form(empty_board)
    current_queue = [empty_canonical]
    seen_canonical.add(empty_canonical)
    
    wins = 0
    draws = 0
    ongoing = 0
    
    with tqdm(desc="Generating boards", unit=" states", dynamic_ncols=True) as pbar:
        while current_queue:
            process_chunk = current_queue[:chunk_size]
            current_queue = current_queue[chunk_size:]
            
            boards_to_process = [canonical_to_board(c, board_size) for c in process_chunk]
            
            batch_size = max(50, len(boards_to_process) // (n_workers * 2))
            batches = [
                (boards_to_process[i:i + batch_size], board_size)
                for i in range(0, len(boards_to_process), batch_size)
            ]
            
            new_canonicals = []
            if len(batches) > 1 and n_workers > 1:
                with Pool(processes=n_workers) as pool:
                    batch_results = pool.map(process_batch_fast, batches)
                
                for result in batch_results:
                    wins += result['stats']['win']
                    draws += result['stats']['draw']
                    ongoing += result['stats']['ongoing']
                    new_canonicals.extend(result['new_children'])
            else:
                for batch in batches:
                    result = process_batch_fast(batch)
                    wins += result['stats']['win']
                    draws += result['stats']['draw']
                    ongoing += result['stats']['ongoing']
                    new_canonicals.extend(result['new_children'])
            unique_new = []
            for canonical in new_canonicals:
                if canonical not in seen_canonical:
                    seen_canonical.add(canonical)
                    unique_new.append(canonical)
                    board_map.buffer.append(canonical)
                    
                    if len(board_map.buffer) >= board_map.buffer_size:
                        board_map._flush()
            
            current_queue.extend(unique_new)
            
            pbar.update(len(process_chunk))
            pbar.set_postfix({
                "Wins": wins,
                "Draws": draws,
                "Ongoing": ongoing,
                "Unique": len(seen_canonical),
                "Queue": len(current_queue)
            })
    
    board_map.close()
    
    print(f"\n{'='*60}")
    print(f"Total unique boards: {len(seen_canonical):,}")
    print(f"  Wins: {wins:,}")
    print(f"  Draws: {draws:,}")
    print(f"  Ongoing: {ongoing:,}")
    print(f"{'='*60}")
    
    all_boards = [canonical_to_board(c, board_size) for c in seen_canonical]
    return all_boards

In [None]:
all_boards = generate_all_boards(
    board_size=BOARD_SIZE, 
    n_workers=7,
    chunk_size=2000000  # Adjust based on available RAM
)

print(f"\n✓ Successfully generated {len(all_boards):,} unique board states")

Using 7 worker processes
Chunk size: 2,000,000 states per iteration


Generating boards: 10521369 states [42:34, 4119.10 states/s, Wins=0, Draws=0, Ongoing=1.05e+7, Unique=4.11e+7, Queue=3.06e+7] 


KeyboardInterrupt: 