In [None]:
import os
import glob
import torch
import chess
import numpy as np
from tqdm import tqdm

def compute_legal_masks_for_position(pos_tokens):
    """
    Compute legal masks for a single position.
    Returns: (mask_from, mask_dest) as numpy arrays (8,8) bool
    """
    mask_from = np.zeros((8, 8), dtype=bool)
    mask_dest = np.zeros((8, 8), dtype=bool)
    
    # Rebuild board from tokens
    board = chess.Board.empty()
    stm = None
    for tok in pos_tokens:
        pid, r, f, stm_bit, wK, wQ, bK, bQ, ep, half = tok
        if pid == 0 and r == 0 and f == 0:  # padding
            continue
        piece_symbol = "PNBRQKpnbrqk"[pid]
        sq = chess.square(f, r)
        board.set_piece_at(sq, chess.Piece.from_symbol(piece_symbol))
        stm = stm_bit
    
    board.turn = chess.WHITE if stm == 0 else chess.BLACK
    board.castling_rights = (
        (chess.BB_H1 if pos_tokens[0][4] else 0) |  # wK
        (chess.BB_A1 if pos_tokens[0][5] else 0) |  # wQ
        (chess.BB_H8 if pos_tokens[0][6] else 0) |  # bK
        (chess.BB_A8 if pos_tokens[0][7] else 0)    # bQ
    )
    
    # En passant
    if pos_tokens[0][8] != 8:
        board.ep_square = chess.square(pos_tokens[0][8], 5 if stm else 2)
    
    # Collect legal moves
    for mv in board.legal_moves:
        r_from, f_from = divmod(mv.from_square, 8)
        r_to, f_to = divmod(mv.to_square, 8)
        mask_from[r_from, f_from] = True
        mask_dest[r_to, f_to] = True
    
    return mask_from, mask_dest

def preprocess_dataset(input_dir, output_dir, allowed_pt_files):
    """
    Load existing .pt files, add legal masks, and save new files.
    """
    os.makedirs(output_dir, exist_ok=True)
    pt_files = glob.glob(os.path.join(input_dir, "*.pt"))
    
    for pt_file in tqdm(pt_files, desc="Processing files"):
        if os.path.basename(pt_file) not in allowed_pt_files:
            continue
        # Load existing data
        records = torch.load(pt_file)
        
        # Add legal masks to each record
        for record in tqdm(records, desc=f"Processing {os.path.basename(pt_file)}", leave=False):
            pos_tokens = record["position"]
            mask_from, mask_dest = compute_legal_masks_for_position(pos_tokens)
            
            # Add masks to record
            record["legal_mask_from"] = mask_from.astype(np.uint8)  # Save space
            record["legal_mask_dest"] = mask_dest.astype(np.uint8)
        
        # Save updated file
        output_file = os.path.join(output_dir, os.path.basename(pt_file))
        torch.save(records, output_file)
        print(f"Saved {output_file}")

if __name__ == "__main__":
    input_dir = "/kaggle/input/the-big-and-the-beautiful-chess-dataset"
    output_dir = "/kaggle/working/set_"
    preprocess_dataset(input_dir, output_dir, ['batch_000000.pt', 'batch_000001.pt', 'batch_000002.pt', 'batch_000003.pt', 'batch_000004.pt', 'batch_000005.pt'])