# KiyEngine V3: Modern PyTorch Training

**Disclaimer:** This Jupyter Notebook is designed for **single-GPU experimentation, visualization, and debugging**. \n\nFor scalable, multi-GPU training, please use the `train.py` script, which is optimized for `DistributedDataParallel` (DDP).

This notebook provides an advanced training pipeline for the KiyEngine V3 MoE-Mamba model. It incorporates several modern PyTorch features to accelerate training, improve stability, and provide better insights into the training process.\n\n**Key Features:**\n- **Configuration Driven**: All hyperparameters are loaded from `config.yaml`.\n- **JIT Compilation**: Uses `torch.compile` for significant speedups on compatible GPUs.\n- **Automatic Mixed Precision (AMP)**: Leverages `fp16` to reduce memory usage and accelerate training.\n- **Gradient Clipping**: Prevents exploding gradients for more stable training.\n- **Learning Rate Scheduling**: Implements a cosine annealing schedule for better convergence.\n- **Live Loss Plotting**: Visualizes the training loss in real-time.

## 1. Setup and Imports

In [None]:
import torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\nfrom torch.cuda.amp import GradScaler, autocast\nimport yaml\nimport os\nfrom tqdm import tqdm\nimport matplotlib.pyplot as plt\nfrom IPython.display import display, clear_output\n\n# Import from our local files\nfrom model import KiyEngineV3\nfrom dataset import ChessDataset\nfrom train import save_model_as_safetensors # Re-use the saving function

## 2. Configuration

In [None]:
def load_config(config_path='config.yaml'):\n    with open(config_path, 'r') as f:\n        config = yaml.safe_load(f)\n    return config\n\nconfig = load_config()\nprint("Configuration loaded successfully.")

## 3. The Modern Training Loop

In [None]:
def train_modern(config: dict):\n    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")\n    print(f"Using device: {device}")\n\n    # --- Model Initialization ---\n    model_config = config['model']\n    model_config['noise_sigma'] = config['training']['noise_sigma']\n    model = KiyEngineV3(model_config).to(device)\n\n    # --- JIT Compilation (PyTorch 2.0+) ---\n    if hasattr(torch, 'compile'):\n        print("Compiling the model with torch.compile for a speedup...")\n        model = torch.compile(model, mode="reduce-overhead")\n\n    # --- Data Loading ---\n    if not os.path.exists(config['paths']['train_data_path']):\n        print(f"Warning: Training data not found. Using a dummy dataset.")\n        with open("dummy_dataset.pgn", "w") as f:\n            f.write('[Event "Dummy"]\n[Result "1-0"]\n\n1. e4 e5 *')\n        dataset = ChessDataset(pgn_file_path="dummy_dataset.pgn")\n    else:\n        dataset = ChessDataset(pgn_file_path=config['paths']['train_data_path'])\n    dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True, num_workers=4, pin_memory=True)\n\n    # --- Optimizer, Scheduler, and Loss ---\n    optimizer = optim.Adam(model.parameters(), lr=config['training']['learning_rate'])\n    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(dataloader) * config['training']['epochs'])\n    policy_loss_fn = nn.CrossEntropyLoss()\n    value_loss_fn = nn.MSELoss()\n    scaler = GradScaler() # For Automatic Mixed Precision\n\n    # --- Training Loop ---\n    model.train()\n    losses = []\n    plt.ion() # Interactive mode for live plotting\n    fig, ax = plt.subplots()\n\n    for epoch in range(config['training']['epochs']):\n        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config['training']['epochs']}")\n        for batch in pbar:\n            input_seq, policy_target, value_target = [b.to(device) for b in batch]\n\n            optimizer.zero_grad()\n\n            # Automatic Mixed Precision Context\n            with autocast():\n                policy_logits, value_pred, aux_loss = model(input_seq)\n                policy_loss = policy_loss_fn(policy_logits, policy_target)\n                value_loss = value_loss_fn(value_pred.squeeze(), value_target.squeeze())\n                loss = (config['training']['policy_weight'] * policy_loss +\n                        config['training']['value_weight'] * value_loss +\n                        config['training']['aux_loss_lambda'] * aux_loss)\n\n            scaler.scale(loss).backward()\n            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient Clipping\n            scaler.step(optimizer)\n            scaler.update()\n            scheduler.step()\n\n            losses.append(loss.item())\n            pbar.set_postfix({"Loss": f"{loss.item():.4f}", "LR": f"{scheduler.get_last_lr()[0]:.6f}"})\n\n            # Live Plotting\n            if len(losses) % 20 == 0:\n                ax.clear()\n                ax.plot(losses)\n                ax.set_title("Training Loss")\n                ax.set_xlabel("Step")\n                ax.set_ylabel("Loss")\n                display(fig)\n                clear_output(wait=True)\n\n    print("\nTraining finished.")\n    plt.ioff()\n\n    # --- Save Final Model ---\n    save_dir = config['paths']['save_path']\n    if not os.path.exists(save_dir):\n        os.makedirs(save_dir)\n    save_path = os.path.join(save_dir, config['paths']['model_save_name'])\n    save_model_as_safetensors(model, save_path)\n

## 4. Run Training

In [None]:
if __name__ == '__main__':\n    train_modern(config)