# Chess RL Agent - Training

**Strategy:** Split into safe 5-6 hour sessions (no browser babysitting needed)

**Workflow:**
1. Run cells 1-6 once → Auto-trains or resumes (~5-6 hours per session)
2. Training completes, checkpoints auto-backup to Drive
3. Close browser, take a break
4. Next session: Re-run cells 1-6 → Automatically resumes where you left off

**Total time:** 2 sessions × 5-6 hours = 10-12 hours for proof of concept

## 1. Verify GPU

In [None]:
import torch

print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
print(f"CUDA: {torch.cuda.is_available()}")

if not torch.cuda.is_available():
    print("\n⚠ Go to Runtime → Change runtime type → Select GPU (L4 recommended)")
else:
    print("\n✓ GPU ready")

## 2. Mount Google Drive

In [None]:
from google.colab import drive

drive.mount('/content/drive')
!mkdir -p /content/drive/MyDrive/chess_checkpoints

print("✓ Drive mounted")

## 3. Clone Repository

In [None]:
!rm -rf rl_chess_agent
!git clone https://github.com/capacap/rl_chess_agent.git
%cd rl_chess_agent

print("✓ Repository cloned")

## 4. Install Dependencies

In [None]:
# Install chess library (Colab has torch, numpy, etc.)
!pip install -q -r requirements-colab.txt

# Verify imports
import chess
from model.network import ChessNet

print(f"✓ Dependencies installed")
print(f"  chess: {chess.__version__}")
print(f"  torch: {torch.__version__}")

## 5. Configure Training

In [None]:
# === Project Configuration ===
# Set this once and keep it the same across all sessions
PROJECT_NAME = "proof_of_concept"  # Change this for different training runs

# === Session Configuration ===
SESSION_ITERATIONS = 5    # Iterations per session (5-6 hours each)
TOTAL_TARGET = 10         # Total iterations you want overall

# Training parameters
GAMES_PER_ITER = 50       # Games per iteration
SIMULATIONS = 20          # MCTS simulations per move
ARENA_GAMES = 20          # Arena evaluation games

# Advanced (rarely need to change)
BATCH_SIZE = 256
EPOCHS = 5
LEARNING_RATE = 1e-3

# Checkpoint directories
CHECKPOINT_DIR = f"checkpoints/{PROJECT_NAME}"
GDRIVE_BACKUP = "/content/drive/MyDrive/chess_checkpoints"

print("Training Configuration:")
print(f"  Project: {PROJECT_NAME}")
print(f"  Target: {TOTAL_TARGET} iterations total")
print(f"  Session size: {SESSION_ITERATIONS} iterations (~{SESSION_ITERATIONS * 1.2:.0f}-{SESSION_ITERATIONS * 1.5:.0f} hours)")
print(f"  {GAMES_PER_ITER} games/iter, {SIMULATIONS} MCTS sims")
print(f"\nCheckpoints: {CHECKPOINT_DIR}")
print(f"Drive backup: {GDRIVE_BACKUP}/{PROJECT_NAME}/")

## 6. Train (auto-resumes if checkpoint exists)

In [None]:
import os
import glob

# Check for existing checkpoints
existing_checkpoints = sorted(glob.glob(f"{CHECKPOINT_DIR}/iteration_*.pt"))

if existing_checkpoints:
    # Resume from latest checkpoint
    latest = existing_checkpoints[-1]
    completed = int(latest.split('_')[-1].split('.')[0])
    remaining = TOTAL_TARGET - completed
    
    print("="*60)
    print("RESUMING TRAINING")
    print("="*60)
    print(f"  Progress: {completed}/{TOTAL_TARGET} iterations complete")
    print(f"  Resuming from: {os.path.basename(latest)}")
    
    if remaining > 0:
        iterations_this_session = min(SESSION_ITERATIONS, remaining)
        print(f"  This session: {iterations_this_session} iterations (~{iterations_this_session * 1.2:.0f}-{iterations_this_session * 1.5:.0f} hours)")
        print(f"  Remaining after: {remaining - iterations_this_session} iterations\n")
        
        # Resume training
        !python train.py \
          --resume {latest} \
          --iterations {iterations_this_session} \
          --games-per-iter {GAMES_PER_ITER} \
          --simulations {SIMULATIONS} \
          --arena-games {ARENA_GAMES} \
          --batch-size {BATCH_SIZE} \
          --epochs {EPOCHS} \
          --checkpoint-dir {CHECKPOINT_DIR} \
          --gdrive-backup-dir {GDRIVE_BACKUP}
        
        new_completed = completed + iterations_this_session
        print(f"\n{'='*60}")
        print(f"SESSION COMPLETE")
        print(f"{'='*60}")
        print(f"  Progress: {new_completed}/{TOTAL_TARGET} iterations")
        print(f"  Checkpoints: {GDRIVE_BACKUP}/{PROJECT_NAME}/")
        
        if new_completed >= TOTAL_TARGET:
            print(f"\n🎉 TRAINING COMPLETE!")
            print(f"  Final model: {CHECKPOINT_DIR}/iteration_{TOTAL_TARGET}.pkl")
        else:
            print(f"\nNext steps:")
            print(f"  1. Close browser (checkpoints saved to Drive)")
            print(f"  2. When ready: Re-run cells 1-6 to continue")
            print(f"  3. Remaining: {TOTAL_TARGET - new_completed} iterations")
    else:
        print(f"\n🎉 TRAINING ALREADY COMPLETE!")
        print(f"  Final model: {CHECKPOINT_DIR}/iteration_{TOTAL_TARGET}.pkl")
        print(f"  Download from: {GDRIVE_BACKUP}/{PROJECT_NAME}/")

else:
    # Start fresh training
    print("="*60)
    print("STARTING NEW TRAINING")
    print("="*60)
    print(f"  Project: {PROJECT_NAME}")
    print(f"  This session: {SESSION_ITERATIONS} iterations (~{SESSION_ITERATIONS * 1.2:.0f}-{SESSION_ITERATIONS * 1.5:.0f} hours)")
    print(f"  Total target: {TOTAL_TARGET} iterations\n")
    
    # Start training
    !python train.py \
      --iterations {SESSION_ITERATIONS} \
      --games-per-iter {GAMES_PER_ITER} \
      --simulations {SIMULATIONS} \
      --arena-games {ARENA_GAMES} \
      --batch-size {BATCH_SIZE} \
      --epochs {EPOCHS} \
      --lr {LEARNING_RATE} \
      --checkpoint-dir {CHECKPOINT_DIR} \
      --gdrive-backup-dir {GDRIVE_BACKUP}
    
    print(f"\n{'='*60}")
    print(f"SESSION COMPLETE")
    print(f"{'='*60}")
    print(f"  Progress: {SESSION_ITERATIONS}/{TOTAL_TARGET} iterations")
    print(f"  Checkpoints: {GDRIVE_BACKUP}/{PROJECT_NAME}/")
    
    if SESSION_ITERATIONS >= TOTAL_TARGET:
        print(f"\n🎉 TRAINING COMPLETE!")
        print(f"  Final model: {CHECKPOINT_DIR}/iteration_{TOTAL_TARGET}.pkl")
    else:
        print(f"\nNext steps:")
        print(f"  1. Close browser (checkpoints saved to Drive)")
        print(f"  2. When ready: Re-run cells 1-6 to continue")
        print(f"  3. Remaining: {TOTAL_TARGET - SESSION_ITERATIONS} iterations")

## Troubleshooting

**Out of memory:** Reduce `BATCH_SIZE = 128` or `GAMES_PER_ITER = 25`

**Too slow:** Reduce `SIMULATIONS = 15` or `ARENA_GAMES = 10`

**Start over:** Change `PROJECT_NAME` in cell 5 to create a new training run

**Download checkpoints:** Already in Google Drive at `/MyDrive/chess_checkpoints/{PROJECT_NAME}/`