In [None]:
# export HSA_FORCE_FINE_GRAIN_PCIE=1
# export PYTORCH_ROCM_ARCH="gfx1100"  # For RX 7900 XTX
# export HIP_VISIBLE_DEVICES=0

# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6

# rocm-smi

In [None]:
#!/usr/bin/env python3
# train_vit_chess.py - AMD RX 7900 XTX Optimized Version
# NOTE: AMD Optimized
"""
Vision‑Transformer chess move‑prediction trainer
------------------------------------------------
• Optimized for AMD RX 7900 XTX (24GB VRAM, RDNA3 architecture)
• Uses ROCm optimizations and efficient memory management
• Larger batch sizes and mixed precision training
"""

import os
import glob
import math
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
import functools
import chess

if torch.cuda.is_available() and "AMD" in torch.cuda.get_device_name():
    # Enable AMD-specific optimizations
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    
    # AMD memory management
    torch.cuda.empty_cache()
    
    # Set memory fraction to use most of the 24GB
    torch.cuda.set_per_process_memory_fraction(0.95)
    
    # Enable AMD's optimized attention if available
    try:
        torch.backends.cuda.enable_flash_sdp(True)
        torch.backends.cuda.enable_mem_efficient_sdp(True)
        torch.backends.cuda.enable_math_sdp(True)
    except:
        pass

# ────────────────────────────────────────────────────────────────────────────
# 1.  AMD RX 7900 XTX Optimized Hyper‑parameters
# ────────────────────────────────────────────────────────────────────────────
DATA_DIR         = "test/tensor_full_dataset"
BATCH_SIZE       = 768                    # Increased for 24GB VRAM
EPOCHS           = 15
LR               = 5e-4                    # Slightly higher LR for larger batches
DEVICE           = "cuda" if torch.cuda.is_available() else "cpu"
MAX_TOKENS       = 32
EMBED_DIM        = 384                     # Increased model capacity
DEPTH            = 8                      # Deeper model
N_HEADS          = 12                     # More attention heads
MLP_RATIO        = 4
DROPOUT          = 0.1
SAVE_PATH        = "best_vit_amd.pth"
NUM_WORKERS      = 12  # AMD CPUs typically have more cores
PREFETCH_FACTOR  = 4

# AMD-specific optimizations
torch.backends.cudnn.benchmark = True     # Optimize for consistent input sizes
torch.backends.cuda.matmul.allow_tf32 = True  # Enable TF32 on supported hardware

# Mixed precision training
USE_AMP = True
GRAD_ACCUM_STEPS = 1                      # Gradient accumulation for effective batch size 1024

# ────────────────────────────────────────────────────────────────────────────
# 2.  Optimized Dataset & loader
# ────────────────────────────────────────────────────────────────────────────
class PTBatchDataset(Dataset):
    def __init__(self, pt_files, train=True, split=0.9):
        self.pt_files = pt_files
        self.train = train
        self.split = split
        self.file_indices = []
        
        for file_idx, f in enumerate(pt_files):
            records = torch.load(f, map_location='cpu')  # Load to CPU first
            split_idx = int(len(records) * split)
            if train:
                indices = list(range(split_idx))
            else:
                indices = list(range(split_idx, len(records)))
            for record_idx in indices:
                self.file_indices.append((file_idx, record_idx))
        
        # Cache for loaded files to reduce I/O
        self._file_cache = {}
        self._cache_size = 3  # Keep 3 files in memory
                
    def __len__(self):
        return len(self.file_indices)
        
    def __getitem__(self, idx):
        file_idx, record_idx = self.file_indices[idx]
        
        # Use file caching to reduce I/O
        if file_idx not in self._file_cache:
            if len(self._file_cache) >= self._cache_size:
                # Remove oldest file from cache
                oldest_key = next(iter(self._file_cache))
                del self._file_cache[oldest_key]
            
            self._file_cache[file_idx] = torch.load(self.pt_files[file_idx], map_location='cpu')
        
        rec = self._file_cache[file_idx][record_idx]
        pos = np.asarray(rec["position"], dtype=np.int16)
        move = np.asarray(rec["move"], dtype=np.uint8)
        
        if "legal_mask_from" in rec and "legal_mask_dest" in rec:
            mask_from = np.asarray(rec["legal_mask_from"], dtype=bool)
            mask_dest = np.asarray(rec["legal_mask_dest"], dtype=bool)
            return pos, move, mask_from, mask_dest
        else:
            return pos, move

def collate_fn(batch):
    """Optimized collate function with memory pinning."""
    if len(batch[0]) == 4:  # With precomputed masks
        positions, moves, masks_from, masks_dest = zip(*batch)
        B = len(batch)
        
        # Pre-allocate tensors for better memory efficiency
        x = np.zeros((B, MAX_TOKENS, 10), dtype=np.int16)
        pad = np.zeros((B, MAX_TOKENS), dtype=bool)
        
        for i, pos in enumerate(positions):
            n = len(pos)
            x[i, :n] = pos
            pad[i, :n] = True
        
        y = np.stack(moves).astype(np.float32)
        mask_from = np.stack(masks_from).astype(bool)
        mask_dest = np.stack(masks_dest).astype(bool)
        
        return (
            torch.as_tensor(x, dtype=torch.long),
            torch.as_tensor(pad, dtype=torch.bool),
            torch.as_tensor(y, dtype=torch.float32),
            torch.as_tensor(mask_from, dtype=torch.bool),
            torch.as_tensor(mask_dest, dtype=torch.bool),
        )
    else:  # Old format
        positions, moves = zip(*batch)
        B = len(batch)
        x = np.zeros((B, MAX_TOKENS, 10), dtype=np.int16)
        pad = np.zeros((B, MAX_TOKENS), dtype=bool)
        
        for i, pos in enumerate(positions):
            n = len(pos)
            x[i, :n] = pos
            pad[i, :n] = True
        
        y = np.stack(moves).astype(np.float32)
        return (
            torch.as_tensor(x, dtype=torch.long),
            torch.as_tensor(pad, dtype=torch.bool),
            torch.as_tensor(y, dtype=torch.float32),
        )

# ────────────────────────────────────────────────────────────────────────────
# 3.  Optimized Model Architecture
# ────────────────────────────────────────────────────────────────────────────
class TokenViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed_val = nn.Embedding(512, EMBED_DIM)
        self.col_linear = nn.Linear(10 * EMBED_DIM, EMBED_DIM)
        
        # Improved initialization
        nn.init.xavier_uniform_(self.col_linear.weight)
        nn.init.zeros_(self.col_linear.bias)

        self.cls = nn.Parameter(torch.zeros(1, 1, EMBED_DIM))
        self.pos = nn.Parameter(torch.zeros(1, MAX_TOKENS + 1, EMBED_DIM))
        
        # More efficient transformer layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=EMBED_DIM, 
            nhead=N_HEADS,
            dim_feedforward=EMBED_DIM * MLP_RATIO,
            dropout=DROPOUT, 
            batch_first=True,
            norm_first=True,  # Pre-norm for better training stability
            activation='gelu'  # Better activation for transformers
        )
        self.enc = torch.compile(
            nn.TransformerEncoder(encoder_layer, num_layers=DEPTH),
            mode="max-autotune"  # Aggressive optimization for AMD
        )
        self.norm = nn.LayerNorm(EMBED_DIM)
        self.head = nn.Linear(EMBED_DIM, 128)
        
        # Initialize parameters
        self._init_weights()

    def _init_weights(self):
        """Improved weight initialization."""
        nn.init.normal_(self.cls, std=0.02)
        nn.init.normal_(self.pos, std=0.02)
        nn.init.xavier_uniform_(self.head.weight)
        nn.init.zeros_(self.head.bias)

    def forward(self, tok, padding_mask):
        B, T, _ = tok.shape
        e = self.embed_val(tok)
        e = e.view(B, T, -1)
        z = self.col_linear(e)

        cls = self.cls.expand(B, -1, -1)
        z = torch.cat([cls, z], dim=1) + self.pos[:, :T+1]
        
        # More efficient padding mask handling
        pad = torch.cat([torch.zeros(B, 1, device=z.device, dtype=torch.bool), ~padding_mask], dim=1)
        z = self.enc(z, src_key_padding_mask=pad)
        cls_out = self.norm(z[:, 0])
        logits = self.head(cls_out).view(B, 2, 8, 8)
        return logits

# ────────────────────────────────────────────────────────────────────────────
# 4.  Legal‑move masking utilities (cached for performance)
# ────────────────────────────────────────────────────────────────────────────
@functools.lru_cache(maxsize=20000)  # Increased cache size
def compute_single_legal_mask(pos_tuple):
    """Cached version for single position."""
    pos_tokens = list(pos_tuple)
    mask_from = np.zeros((8, 8), dtype=bool)
    mask_dest = np.zeros((8, 8), dtype=bool)
    
    board = chess.Board.empty()
    stm = None
    for tok in pos_tokens:
        pid, r, f, stm_bit, wK, wQ, bK, bQ, ep, half = tok
        if pid == 0 and r == 0 and f == 0:
            continue
        piece_symbol = "PNBRQKpnbrqk"[pid]
        sq = chess.square(f, r)
        board.set_piece_at(sq, chess.Piece.from_symbol(piece_symbol))
        stm = stm_bit
    
    board.turn = chess.WHITE if stm == 0 else chess.BLACK
    board.castling_rights = (
        (chess.BB_H1 if pos_tokens[0][4] else 0) |
        (chess.BB_A1 if pos_tokens[0][5] else 0) |
        (chess.BB_H8 if pos_tokens[0][6] else 0) |
        (chess.BB_A8 if pos_tokens[0][7] else 0)
    )
    
    if pos_tokens[0][8] != 8:
        board.ep_square = chess.square(pos_tokens[0][8], 5 if stm else 2)
    
    for mv in board.legal_moves:
        r_from, f_from = divmod(mv.from_square, 8)
        r_to, f_to = divmod(mv.to_square, 8)
        mask_from[r_from, f_from] = True
        mask_dest[r_to, f_to] = True
    
    return mask_from, mask_dest

def legal_masks(pos_batch):
    """Optimized version with caching and vectorization."""
    B = pos_batch.shape[0]
    mask_from = torch.zeros((B, 8, 8), dtype=torch.bool, device=DEVICE)
    mask_dest = torch.zeros_like(mask_from)
    
    for b in range(B):
        pos_tuple = tuple(tuple(tok.tolist()) for tok in pos_batch[b])
        mf, md = compute_single_legal_mask(pos_tuple)
        mask_from[b] = torch.from_numpy(mf)
        mask_dest[b] = torch.from_numpy(md)
    
    return mask_from, mask_dest

# ────────────────────────────────────────────────────────────────────────────
# 5.  Optimized Training Loop with Mixed Precision
# ────────────────────────────────────────────────────────────────────────────
def run_epoch(loader, train=True):
    model.train(mode=train)
    total_loss, correct_top1, count = 0.0, 0, 0
    
    desc = "Training" if train else "Validation"
    pbar = tqdm(loader, desc=desc, leave=False)
    
    # Initialize gradient scaler for mixed precision
    scaler = torch.cuda.amp.GradScaler() if USE_AMP and train else None
    
    for batch_idx, batch_data in enumerate(pbar):
        if len(batch_data) == 5:  # Precomputed masks
            pos, pad_mask, y_true, mask_from, mask_dest = batch_data
            mask_from = mask_from.to(DEVICE, non_blocking=True)
            mask_dest = mask_dest.to(DEVICE, non_blocking=True)
        else:  # Compute masks on the fly
            pos, pad_mask, y_true = batch_data
            mask_from, mask_dest = legal_masks(pos)
        
        pos = pos.to(DEVICE, non_blocking=True)
        pad_mask = pad_mask.to(DEVICE, non_blocking=True)
        y_true = y_true.to(DEVICE, non_blocking=True)

        if train and USE_AMP:
            with torch.cuda.amp.autocast():
                logits = model(pos, pad_mask)
                loss_mat = criterion(logits, y_true)
                loss_mat[:,0] *= mask_from
                loss_mat[:,1] *= mask_dest
                loss = loss_mat.mean() / GRAD_ACCUM_STEPS  # Scale for gradient accumulation
        else:
            logits = model(pos, pad_mask)
            loss_mat = criterion(logits, y_true)
            loss_mat[:,0] *= mask_from
            loss_mat[:,1] *= mask_dest
            loss = loss_mat.mean()

        if train:
            if USE_AMP:
                scaler.scale(loss).backward()
                
                # Gradient accumulation
                if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(opt)
                    scaler.update()
                    opt.zero_grad(set_to_none=True)
            else:
                loss.backward()
                if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    opt.step()
                    opt.zero_grad(set_to_none=True)

        # Metrics calculation
        batch_loss = loss.item() * (GRAD_ACCUM_STEPS if train else 1)
        total_loss += batch_loss * pos.size(0)
        
        with torch.no_grad():
            pred_from = logits[:,0].flatten(1).argmax(1)
            pred_dest = logits[:,1].flatten(1).argmax(1)
            true_from = y_true[:,0].flatten(1).argmax(1)
            true_dest = y_true[:,1].flatten(1).argmax(1)
            batch_correct = ((pred_from == true_from) & (pred_dest == true_dest)).sum().item()
        
        correct_top1 += batch_correct
        count += pos.size(0)
        
        pbar.set_postfix({
            'loss': f"{batch_loss:.4f}",
            'acc': f"{100 * batch_correct / pos.size(0):.1f}%",
            'mem': f"{torch.cuda.memory_allocated() / 1e9:.1f}GB"
        })

    return total_loss / count, correct_top1 / count

# ────────────────────────────────────────────────────────────────────────────
# 6.  Main Training Loop
# ────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    print(f"Using device: {DEVICE}")
    print(f"PyTorch version: {torch.__version__}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name()}")
        print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
    
    torch.manual_seed(42)  # For reproducibility
    
    # Initialize model and optimizer
    model = TokenViT().to(DEVICE)
    
    # AMD-optimized optimizer settings
    opt = torch.optim.AdamW(
        model.parameters(), 
        lr=LR, 
        weight_decay=1e-4,
        betas=(0.9, 0.95),  # Slightly different betas for better convergence
        eps=1e-8
    )
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS, eta_min=LR*0.1)
    
    criterion = nn.BCEWithLogitsLoss(reduction="none")
    
    # Data loading with optimized settings
    pt_files = glob.glob(os.path.join(DATA_DIR, "*.pt"))
    train_ds = PTBatchDataset(pt_files, train=True)
    val_ds = PTBatchDataset(pt_files, train=False)
    
    # Optimized DataLoader settings for AMD GPU
    train_ld = DataLoader(
        train_ds, 
        batch_size=BATCH_SIZE, 
        shuffle=True,
        num_workers=NUM_WORKERS,  # Increased for AMD
        collate_fn=collate_fn, 
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=PREFETCH_FACTOR
    )
    val_ld = DataLoader(
        val_ds, 
        batch_size=BATCH_SIZE, 
        shuffle=False,
        num_workers=NUM_WORKERS, 
        collate_fn=collate_fn, 
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=PREFETCH_FACTOR
    )

    print(f"Training samples: {len(train_ds)}")
    print(f"Validation samples: {len(val_ds)}")
    print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM_STEPS}")

    best_val = math.inf
    for epoch in range(1, EPOCHS+1):
        tr_loss, tr_acc = run_epoch(train_ld, train=True)
        vl_loss, vl_acc = run_epoch(val_ld, train=False)
        
        scheduler.step()  # Update learning rate
        current_lr = scheduler.get_last_lr()[0]

        print(f"Epoch {epoch:02d} | "
              f"train loss {tr_loss:.4f} acc {tr_acc*100:.1f}% || "
              f"val loss {vl_loss:.4f} acc {vl_acc*100:.1f}% || "
              f"lr {current_lr:.2e}")

        if vl_loss < best_val:
            best_val = vl_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'epoch': epoch,
                'best_val_loss': best_val,
                'config': {
                    'EMBED_DIM': EMBED_DIM,
                    'DEPTH': DEPTH,
                    'N_HEADS': N_HEADS,
                    'MAX_TOKENS': MAX_TOKENS
                }
            }, SAVE_PATH)
            print(f"  ✔  saved checkpoint to {SAVE_PATH}")

    print("Training completed!")