In [17]:
import numpy as np
import chess        # pip install python-chess

def san_to_heatmaps(board: chess.Board, san: str, *, dtype=np.float32):
    """
    Turn a SAN string (e.g. 'Nf3', 'O-O', 'e8=Q') into two 8 × 8 one‑hot planes.

    Parameters
    ----------
    board : chess.Board
        The position *before* the move.
    san : str
        Standard Algebraic Notation of the move, relative to `board`.
    dtype : np.dtype, default np.float32
        Data type of the returned arrays.

    Returns
    -------
    origin_map, dest_map : np.ndarray, np.ndarray
        Each is shaped (8, 8).  Squares are indexed from the White side:
        [rank 1 is row 0, file 'a' is column 0].
        The origin plane has a single 1 at the moved‑from square;
        the destination plane has a 1 at the moved‑to square.
        All other entries are 0.
    """
    # ── 1. Parse SAN to a concrete Move ────────────────────────────────────
    move = board.parse_san(san)          # raises ValueError if SAN illegal

    # ── 2. Convert square indices (0–63) to rank/file coordinates ─────────
    def square_to_rc(square: int):
        """a1‑square = 0 ⇒ (rank 0, file 0)."""
        return divmod(square, 8)          # (rank, file)

    r_from, c_from = square_to_rc(move.from_square)
    r_to,   c_to   = square_to_rc(move.to_square)

    # ── 3. One‑hot planes ─────────────────────────────────────────────────
    origin = np.zeros((8, 8), dtype=dtype)
    dest   = np.zeros((8, 8), dtype=dtype)
    origin[r_from, c_from] = 1
    dest[r_to,   c_to]     = 1

    return origin, dest


In [23]:
import pandas as pd
import numpy as np
import chess
import os
from pathlib import Path
import json

# def add_heatmaps_to_csv(input_dir, output_dir=None):
#     """
#     Process all CSV files in the input directory, replacing the SAN move column with
#     origin and destination heatmaps.
    
#     Parameters
#     ----------
#     input_dir : str
#         Directory containing CSV files with FEN and MOVE (SAN) columns
#     output_dir : str, optional
#         Directory where processed files will be saved. If None, input files will be overwritten.
#     """
#     # Create output directory if specified
#     if output_dir is not None:
#         os.makedirs(output_dir, exist_ok=True)
    
#     # Get list of all CSV files in input directory
#     csv_files = list(Path(input_dir).glob("*.csv"))
#     print(f"Found {len(csv_files)} CSV files to process")
    
#     total_moves_processed = 0
#     files_processed = 0
#     errors = 0
    
#     for csv_file in csv_files:
#         print(f"Processing {csv_file.name}...")
        
#         # Determine output path
#         output_path = Path(output_dir) / csv_file.name if output_dir else csv_file
        
#         try:
#             # Read the CSV file
#             df = pd.read_csv(csv_file, header=0)
            
#             # Initialize new columns for origin and destination heatmaps
#             origin_maps = []
#             dest_maps = []
            
#             # Process each row
#             for i, row in df.iterrows():
#                 try:
#                     # Extract FEN and move
#                     fen = row['FEN']
#                     san_move = row['MOVE']
                    
#                     # Create board from FEN
#                     board = chess.Board(fen)
                    
#                     # Generate heatmaps
#                     origin, dest = san_to_heatmaps(board, san_move)
                    
#                     # Convert heatmaps to JSON strings
#                     origin_json = json.dumps(origin.tolist())
#                     dest_json = json.dumps(dest.tolist())
                    
#                     # Add to lists
#                     origin_maps.append(origin_json)
#                     dest_maps.append(dest_json)
                    
#                     total_moves_processed += 1
                    
#                     # Print progress for large files
#                     if i > 0 and i % 10000 == 0:
#                         print(f"  Processed {i} moves in {csv_file.name}")
                        
#                 except Exception as e:
#                     # Handle errors in individual moves
#                     errors += 1
#                     origin_maps.append(None)
#                     dest_maps.append(None)
#                     print(f"  Error processing move at row {i}: {e}")
            
#             # Create a new dataframe with FEN, ORIGIN, and DEST columns
#             new_df = pd.DataFrame({
#                 'FEN': df['FEN'],
#                 'MOVE': df['MOVE'],
#                 'ORIGIN': origin_maps,
#                 'DEST': dest_maps
#             })
            
#             # Save the updated dataframe
#             new_df.to_csv(output_path, index=False)
            
#             files_processed += 1
            
#         except Exception as e:
#             print(f"Error processing file {csv_file.name}: {e}")
    
#     print(f"\nProcessing complete!")
#     print(f"Files processed: {files_processed}/{len(csv_files)}")
#     print(f"Total moves processed: {total_moves_processed}")
#     print(f"Errors encountered: {errors}")

# ...existing code...
def add_heatmaps_to_csv(input_dir, output_dir=None):
    """
    Process all CSV files in the input directory, replacing the SAN move column with
    origin and destination heatmaps. Rows that fail to parse are DROPPED.
    """
    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)

    csv_files = list(Path(input_dir).glob("*.csv"))
    print(f"Found {len(csv_files)} CSV files to process")

    total_moves_processed = 0
    files_processed = 0
    errors = 0

    for csv_file in csv_files:
        print(f"Processing {csv_file.name}...")
        output_path = Path(output_dir) / csv_file.name if output_dir else csv_file

        try:
            df = pd.read_csv(csv_file, header=0)

            kept_fens = []
            kept_moves = []
            origin_maps = []
            dest_maps = []

            for i, row in df.iterrows():
                try:
                    fen = row['FEN']
                    san_move = row['MOVE']

                    # strip crazyhouse pocket if present (e.g. ...[Nn])
                    first_field = fen.split()[0]
                    if '[' in first_field:
                        parts = fen.split()
                        parts[0] = first_field.split('[', 1)[0]
                        fen = ' '.join(parts)

                    board = chess.Board(fen)
                    origin, dest = san_to_heatmaps(board, san_move)

                    origin_json = json.dumps(origin.tolist())
                    dest_json = json.dumps(dest.tolist())

                    kept_fens.append(fen)
                    kept_moves.append(san_move)
                    origin_maps.append(origin_json)
                    dest_maps.append(dest_json)

                    total_moves_processed += 1
                    if total_moves_processed % 10000 == 0:
                        print(f"  Processed {total_moves_processed} total moves so far...")
                except Exception as e:
                    errors += 1
                    # drop the row entirely (do not append placeholders)
                    if errors <= 10:
                        print(f"  Dropped row {i} (error: {e})")
                    continue

            new_df = pd.DataFrame({
                'FEN': kept_fens,
                'MOVE': kept_moves,
                'ORIGIN': origin_maps,
                'DEST': dest_maps
            })

            new_df.to_csv(output_path, index=False)
            print(f"  Wrote {len(new_df)} clean rows (dropped {errors} so far cumulative).")
            files_processed += 1

        except Exception as e:
            print(f"Error processing file {csv_file.name}: {e}")

    print(f"\nProcessing complete!")
    print(f"Files processed: {files_processed}/{len(csv_files)}")
    print(f"Total kept moves: {total_moves_processed}")
    print(f"Total rows dropped: {errors}")
# ...existing code...

In [24]:
add_heatmaps_to_csv("fen_moves_output", "fen_moves_output_heatmaps")

Found 18 CSV files to process
Processing all_GMWSO_black_2020-08-14_to_2025-08-13_fen_moves.csv...
  Dropped row 348 (error: invalid half-move clock in fen: 'rn1qkbnr/pbpppppp/1p6/8/2B1P3/2N5/PPPP1PPP/R1BQK1NR b KQkq - 3+3 3 3')
  Dropped row 349 (error: invalid half-move clock in fen: 'rn1qkbnr/pbpp1ppp/1p2p3/8/2B1P3/2N2N2/PPPP1PPP/R1BQK2R b KQkq - 3+3 1 4')
  Dropped row 350 (error: invalid half-move clock in fen: 'rn1qkb1r/pbpp1ppp/1p2pn2/8/2B1P3/2N2N2/PPPP1PPP/R1BQ1RK1 b kq - 3+3 3 5')
  Dropped row 351 (error: invalid half-move clock in fen: 'rn1qkb1r/pbpp1ppp/1p2p3/8/2B1N3/5N2/PPPP1PPP/R1BQ1RK1 b kq - 3+3 0 6')
  Dropped row 352 (error: invalid half-move clock in fen: 'rn1qkb1r/p1pp1ppp/1p2p3/8/2B1b3/5N2/PPPP1PPP/R1BQR1K1 b kq - 3+3 1 7')
  Dropped row 353 (error: invalid half-move clock in fen: 'rn1qkb1r/p1pp1ppp/1p2p1b1/8/2BP4/5N2/PPP2PPP/R1BQR1K1 b kq - 3+3 0 8')
  Dropped row 354 (error: invalid half-move clock in fen: 'rn1qk2r/p1ppbppp/1p2p1b1/3P4/2B5/5N2/PPP2PPP/R1BQR1K1 b 

In [25]:
import chess
import chess.pgn
import os
import csv

import numpy as np

def fen_to_piece_tokens(fen: str, *, tensor: bool = False, max_halfmove: int = 100):
    """
    Convert a FEN position to the *piece‑token patching* encoding.

    Parameters
    ----------
    fen : str
        Standard Forsyth–Edwards Notation string.
    tensor : bool, default False
        If True, return a `torch.LongTensor`; otherwise a `np.ndarray`.
    max_halfmove : int, default 100
        Clamp the half‑move clock to this upper bound (helps normalisation).

    Returns
    -------
    tokens : np.ndarray | torch.LongTensor
        Shape = (N_pieces, 10) with columns:

        0  piece_id              0–5 = white P,N,B,R,Q,K ; 6–11 = black P…K  
        1  rank                  0–7  (0 = rank 1 from White’s view)  
        2  file                  0–7  (0 = file ‘a’)  
        3  side_to_move          0 = white, 1 = black  
        4  white_castle_K        0/1  
        5  white_castle_Q        0/1  
        6  black_castle_K        0/1  
        7  black_castle_Q        0/1  
        8  en_passant_file       0–7 if ep target exists else **8**  
        9  halfmove_clock        0–`max_halfmove`
    """
    # --- split FEN into its 6 fields ---------------------------------------
    board_fen, stm, castles, ep_sq, half_clk, _ = fen.split()

    piece_to_id = {c: i for i, c in enumerate("PNBRQKpnbrqk")}

    # --- global‑state scalars (replicated into every token) -----------------
    stm_bit = 0 if stm == "w" else 1
    castle_bits = (
        int("K" in castles),  # white K‑side
        int("Q" in castles),  # white Q‑side
        int("k" in castles),  # black K‑side
        int("q" in castles),  # black Q‑side
    )
    ep_file = ord(ep_sq[0]) - ord("a") if ep_sq != "-" else 8
    half_clk = min(int(half_clk), max_halfmove)

    # --- walk through the 8×8 board ----------------------------------------
    tokens = []
    rank_idx = 7                              # FEN starts from rank 8
    for row in board_fen.split("/"):
        file_idx = 0
        for ch in row:
            if ch.isdigit():                  # empty squares
                file_idx += int(ch)
            else:                             # occupied square → token
                tokens.append([
                    piece_to_id[ch],
                    rank_idx,
                    file_idx,
                    stm_bit,
                    *castle_bits,
                    ep_file,
                    half_clk,
                ])
                file_idx += 1
        rank_idx -= 1

    arr = np.asarray(tokens, dtype=np.int16)
    if tensor:
        import torch
        return torch.as_tensor(arr, dtype=torch.long)
    return arr


fen_to_piece_tokens('rnbqkbnr/pp1ppp1p/6p1/2p5/2P5/2N3P1/PP1PPP1P/R1BQKBNR b KQkq - 0 3', tensor=False, max_halfmove=100)

array([[ 9,  7,  0,  1,  1,  1,  1,  1,  8,  0],
       [ 7,  7,  1,  1,  1,  1,  1,  1,  8,  0],
       [ 8,  7,  2,  1,  1,  1,  1,  1,  8,  0],
       [10,  7,  3,  1,  1,  1,  1,  1,  8,  0],
       [11,  7,  4,  1,  1,  1,  1,  1,  8,  0],
       [ 8,  7,  5,  1,  1,  1,  1,  1,  8,  0],
       [ 7,  7,  6,  1,  1,  1,  1,  1,  8,  0],
       [ 9,  7,  7,  1,  1,  1,  1,  1,  8,  0],
       [ 6,  6,  0,  1,  1,  1,  1,  1,  8,  0],
       [ 6,  6,  1,  1,  1,  1,  1,  1,  8,  0],
       [ 6,  6,  3,  1,  1,  1,  1,  1,  8,  0],
       [ 6,  6,  4,  1,  1,  1,  1,  1,  8,  0],
       [ 6,  6,  5,  1,  1,  1,  1,  1,  8,  0],
       [ 6,  6,  7,  1,  1,  1,  1,  1,  8,  0],
       [ 6,  5,  6,  1,  1,  1,  1,  1,  8,  0],
       [ 6,  4,  2,  1,  1,  1,  1,  1,  8,  0],
       [ 0,  3,  2,  1,  1,  1,  1,  1,  8,  0],
       [ 1,  2,  2,  1,  1,  1,  1,  1,  8,  0],
       [ 0,  2,  6,  1,  1,  1,  1,  1,  8,  0],
       [ 0,  1,  0,  1,  1,  1,  1,  1,  8,  0],
       [ 0,  1,  1, 

In [26]:
def compute_legal_masks_for_position(pos_tokens):
    """
    Compute legal masks for a single position.
    Returns: (mask_from, mask_dest) as numpy arrays (8,8) bool
    """
    mask_from = np.zeros((8, 8), dtype=bool)
    mask_dest = np.zeros((8, 8), dtype=bool)
    
    # Rebuild board from tokens
    board = chess.Board.empty()
    stm = None
    for tok in pos_tokens:
        pid, r, f, stm_bit, wK, wQ, bK, bQ, ep, half = tok
        if pid == 0 and r == 0 and f == 0:  # padding
            continue
        piece_symbol = "PNBRQKpnbrqk"[pid]
        sq = chess.square(f, r)
        board.set_piece_at(sq, chess.Piece.from_symbol(piece_symbol))
        stm = stm_bit
    
    board.turn = chess.WHITE if stm == 0 else chess.BLACK
    board.castling_rights = (
        (chess.BB_H1 if pos_tokens[0][4] else 0) |  # wK
        (chess.BB_A1 if pos_tokens[0][5] else 0) |  # wQ
        (chess.BB_H8 if pos_tokens[0][6] else 0) |  # bK
        (chess.BB_A8 if pos_tokens[0][7] else 0)    # bQ
    )
    
    # En passant
    if pos_tokens[0][8] != 8:
        board.ep_square = chess.square(pos_tokens[0][8], 5 if stm else 2)
    
    # Collect legal moves
    for mv in board.legal_moves:
        r_from, f_from = divmod(mv.from_square, 8)
        r_to, f_to = divmod(mv.to_square, 8)
        mask_from[r_from, f_from] = True
        mask_dest[r_to, f_to] = True
    
    return mask_from, mask_dest

In [27]:
import torch
import numpy as np
import pandas as pd
from pathlib import Path
import os
import gc
import json
import time

def preprocess_csv_to_tensors_robust(input_dir, output_dir, batch_size=50000, resume=True, max_positions=None):
    """
    Process all CSV files with FEN+move pairs into multiple batch files
    with more robust file handling to prevent corruption
    
    Parameters
    ----------
    input_dir : str
        Directory containing CSV files with FEN and MOVE columns
    output_dir : str
        Directory where the output tensor files will be saved
    batch_size : int
        Number of positions to process before saving to disk
    resume : bool
        Whether to resume processing from last saved batch
    max_positions : int or None
        Maximum number of positions to process (None for unlimited)
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Path for metadata and progress tracking
    metadata_file = Path(output_dir) / "dataset_metadata.json"
    progress_file = Path(output_dir) / "processing_progress.json"
    index_file = Path(output_dir) / "batch_index.json"
    
    # Initialize or load progress tracking
    progress = {
        'total_positions': 0,
        'files_processed': [],
        'current_file': None,
        'current_file_position': 0,
        'last_batch_num': -1
    }
    
    # Track batches
    batch_index = {
        'batch_files': [],
        'positions_per_batch': [],
        'total_positions': 0
    }
    
    # Check if we can resume from previous processing
    if resume and progress_file.exists():
        try:
            with open(progress_file, 'r') as f:
                progress = json.load(f)
            print(f"Resuming from previous processing. Already processed {progress['total_positions']} positions.")
            
            if index_file.exists():
                with open(index_file, 'r') as f:
                    batch_index = json.load(f)
                print(f"Found {len(batch_index['batch_files'])} existing batch files")
        except Exception as e:
            print(f"Could not load progress file: {e}. Starting from scratch.")
            progress = {
                'total_positions': 0,
                'files_processed': [],
                'current_file': None,
                'current_file_position': 0,
                'last_batch_num': -1
            }
            batch_index = {
                'batch_files': [],
                'positions_per_batch': [],
                'total_positions': 0
            }
    
    # Get the list of CSV files
    all_csv_files = sorted(list(Path(input_dir).glob("*.csv")))
    
    # Filter out already processed files
    csv_files = [f for f in all_csv_files if str(f) not in progress['files_processed']]
    
    print(f"Found {len(all_csv_files)} total CSV files, {len(csv_files)} remaining to process")
    
    if len(csv_files) == 0:
        print("All files already processed. Nothing to do.")
        return progress['total_positions']
    
    # Initialize batch tracking
    batch_data = []
    total_positions = progress['total_positions']
    batch_num = progress['last_batch_num'] + 1
    
    # Function to save a batch as a separate file
    def save_batch(batch_data):
        nonlocal batch_num
        
        if not batch_data:
            return
            
        # Create a unique filename for this batch
        batch_file = f"batch_{batch_num:06d}.pt"
        output_path = Path(output_dir) / batch_file
        temp_path = output_path.with_suffix('.tmp')
        
        # Save to temp file first for safety
        try:
            torch.save(batch_data, temp_path)
            
            # Make sure the file is fully written (properly manage file handle)
            with open(temp_path, 'rb') as f:
                os.fsync(f.fileno())
            
            # Rename the temp file to final filename (atomic operation)
            temp_path.rename(output_path)
            
            # Update the batch index
            batch_index['batch_files'].append(batch_file)
            batch_index['positions_per_batch'].append(len(batch_data))
            batch_index['total_positions'] += len(batch_data)
            
            # Move to next batch
            batch_num += 1
            progress['last_batch_num'] = batch_num - 1
            
            print(f"  Saved batch {batch_num-1} with {len(batch_data)} positions to {batch_file}")
            
        except Exception as e:
            print(f"  Error saving batch: {e}")
            if temp_path.exists():
                try:
                    temp_path.unlink()
                except:
                    pass
            
        # Clear batch data and collect garbage
        batch_data.clear()
        gc.collect()
    
    # Function to save progress
    def save_progress():
        with open(progress_file, 'w') as f:
            json.dump(progress, f, indent=2)
        with open(index_file, 'w') as f:
            json.dump(batch_index, f, indent=2)
    
    # Process each CSV file
    try:
        for csv_file in csv_files:
            # Skip if already fully processed
            if str(csv_file) in progress['files_processed']:
                continue
                
            print(f"Processing {csv_file.name}...")
            progress['current_file'] = str(csv_file)
            
            try:
                # Read the CSV file
                df = pd.read_csv(csv_file)
                
                # If resuming from middle of a file
                start_idx = progress['current_file_position'] if csv_file.name == Path(progress['current_file']).name else 0
                
                # Process each FEN string
                for idx, row in df.iloc[start_idx:].iterrows():
                    # Check if we've reached the maximum positions
                    if max_positions and total_positions >= max_positions:
                        print(f"Reached maximum of {max_positions} positions. Stopping.")
                        # Save any remaining data
                        if batch_data:
                            save_batch(batch_data)
                        save_progress()
                        return total_positions
                        
                    fen = row['FEN']
                    origin_data = json.loads(row['ORIGIN'])
                    dest_data = json.loads(row['DEST'])
                    origin_tensor = torch.as_tensor(origin_data, dtype=torch.float32)
                    dest_tensor = torch.as_tensor(dest_data, dtype=torch.float32)
                    move = torch.stack([origin_tensor, dest_tensor], dim=0)
                    
                    # Convert FEN to tensor
                    position_tensor = fen_to_piece_tokens(fen, tensor=True)
                    
                    mask_from, mask_dest = compute_legal_masks_for_position(position_tensor)
        
                    
                    # Store the tensor along with the move
                    batch_data.append({
                        'position': position_tensor,
                        'move': move,
                        'source_file': csv_file.stem,
                        'legal_mask_from': mask_from.astype(np.uint8),
                        'legal_mask_dest': mask_dest.astype(np.uint8),
                    })
                    
                    total_positions += 1
                    progress['total_positions'] = total_positions
                    progress['current_file_position'] = idx + 1
                    
                    # Print progress periodically
                    if total_positions % 10000 == 0:
                        print(f"  Processed {total_positions} positions so far...")
                    
                    # Save batch when reaching batch_size
                    if len(batch_data) >= batch_size:
                        save_batch(batch_data)
                        # Update and save progress
                        save_progress()
                
                # Mark file as fully processed
                progress['files_processed'].append(str(csv_file))
                progress['current_file_position'] = 0
                save_progress()
                
            except Exception as e:
                print(f"  Error processing {csv_file.name}: {e}")
                # Save progress before exiting this file
                save_progress()
                
                # Save any partial batch data
                if batch_data:
                    save_batch(batch_data)
        
        # Save any remaining data in the final batch
        if batch_data:
            print(f"  Saving final batch of {len(batch_data)} positions...")
            save_batch(batch_data)
        
    except KeyboardInterrupt:
        print("\nProcess interrupted by user.")
        # Save any partial batch data
        if batch_data:
            save_batch(batch_data)
        # Save progress
        save_progress()
        return total_positions
    
    print(f"\nProcessing complete. Processed a total of {total_positions} positions.")
    
    # Save the final metadata file
    metadata = {
        'total_positions': total_positions,
        'source_files': [f.stem for f in all_csv_files],
        'creation_date': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
        'batch_processing': True,
        'batch_size': batch_size,
        'total_batches': len(batch_index['batch_files'])
    }
    
    with open(metadata_file, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    # Keep the progress file for information
    return total_positions

In [None]:
preprocess_csv_to_tensors_robust("fen_moves_output_heatmaps", "fen_moves_output_tensors", batch_size=3000000, resume=True, max_positions=None)

Found 18 total CSV files, 18 remaining to process
Processing all_AnishGiri_black_2020-08-14_to_2025-08-13_fen_moves.csv...
  Processed 10000 positions so far...
  Processed 20000 positions so far...
Processing all_AnishGiri_white_2020-08-14_to_2025-08-13_fen_moves.csv...
  Processed 30000 positions so far...


In [22]:
# ...existing code...

import os, glob
import torch
from collections import defaultdict
from tqdm import tqdm

def regroup_batches_by_source_prefix(
    src_dir: str,
    dst_dir: str,
    prefix_parts: int = 2,
    flush_size: int = 50_000
):
    """
    Group records from batch_*.pt files by the first `prefix_parts` underscore-
    separated tokens of `record['source_file']` and write one .pt per group.

    Example:
      source_file = "all_FabianoCaruana_black_2020-08-14_to_2025-08-13_fen_moves"
      prefix_parts = 2  -> key = "all_FabianoCaruana" -> all_FabianoCaruana.pt

    Parameters
    ----------
    src_dir : directory containing original batch_*.pt files
    dst_dir : output directory for grouped .pt files
    prefix_parts : number of leading underscore tokens to form the group key
    flush_size : flush buffer to disk after this many accumulated records per key
    """
    os.makedirs(dst_dir, exist_ok=True)

    batch_files = sorted(glob.glob(os.path.join(src_dir, "batch_*.pt")))
    if not batch_files:
        print("No batch_*.pt files found.")
        return

    buffers = defaultdict(list)
    counts = defaultdict(int)

    def flush_key(key):
        """Append buffer for key to its .pt file and clear buffer."""
        buf = buffers[key]
        if not buf:
            return
        out_path = os.path.join(dst_dir, f"{key}.pt")
        if os.path.exists(out_path):
            # Load, extend, save (simple, may be I/O heavy for huge files)
            existing = torch.load(out_path, map_location='cpu', weights_only=False)
            existing.extend(buf)
            torch.save(existing, out_path)
        else:
            torch.save(buf, out_path)
        buffers[key].clear()

    print(f"Processing {len(batch_files)} source batch files...")
    for bf in tqdm(batch_files, desc="Regrouping"):
        try:
            data = torch.load(bf, map_location='cpu', weights_only=False)
        except Exception as e:
            print(f"  Skipping {bf} (load error: {e})")
            continue

        for rec in data:
            sf = rec.get("source_file", "")
            # Drop trailing extension if present
            sf_core = sf.rsplit('.', 1)[0]
            parts = sf_core.split('_')
            if len(parts) < prefix_parts:
                key = sf_core  # fallback
            else:
                key = "_".join(parts[:prefix_parts])
            buffers[key].append(rec)
            counts[key] += 1

            if flush_size and len(buffers[key]) >= flush_size:
                flush_key(key)

        # Free memory from loaded batch list
        del data

    # Flush remaining buffers
    for key in list(buffers.keys()):
        flush_key(key)

    # Summary
    print("\nGrouping complete. Created:")
    for k in sorted(counts.keys()):
        out_path = os.path.join(dst_dir, f"{k}.pt")
        size_mb = os.path.getsize(out_path) / (1024*1024) if os.path.exists(out_path) else 0
        print(f"  {k}.pt  {counts[k]} records  ({size_mb:.1f} MB)")

    total = sum(counts.values())
    print(f"\nTotal records regrouped: {total}")

# Run grouping
regroup_batches_by_source_prefix(
    src_dir="fen_moves_output_tensors",
    dst_dir="fen_moves_output_tensors_grouped",
    prefix_parts=2,
    flush_size=50_000
)
# ...existing code...

Processing 6 source batch files...


Regrouping: 100%|██████████| 6/6 [01:47<00:00, 17.96s/it]



Grouping complete. Created:
  all_AnishGiri.pt  43839 records  (262.1 MB)
  all_Chefshouse.pt  11212 records  (68.4 MB)
  all_FabianoCaruana.pt  26609 records  (161.0 MB)
  all_GukeshDommaraju.pt  25965 records  (156.0 MB)
  all_LevonAronian.pt  32563 records  (194.1 MB)
  all_LyonBeast.pt  10934 records  (65.8 MB)
  all_MagnusCarlsen.pt  128777 records  (765.8 MB)
  all_WesleySo.pt  39377 records  (118.6 MB)
  all_rpragchess.pt  62109 records  (221.5 MB)

Total records regrouped: 381385
