# GENERAL: Strategic Game AI Training

This notebook provides GPU-accelerated training for the GENERAL project using AlphaZero-style reinforcement learning.

## Requirements
- Google Colab with GPU runtime (T4, V100, or A100)
- Python 3.8+
- PyTorch with CUDA support

## 1. Environment Setup

In [None]:
!git clone https://github.com/YOUR_USERNAME/Generals-AI.git
%cd Generals-AI/Generals

In [None]:
!pip install torch numpy --quiet

## 2. GPU Verification

In [None]:
import torch

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Device Name: {torch.cuda.get_device_name(0)}")
    print(f"Device Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("WARNING: No GPU detected. Training will be slow.")
    print("Go to Runtime -> Change runtime type -> GPU")

## 3. Configuration

In [None]:
TRAINING_CONFIG = {
    "max_iterations": 100,
    "games_per_iter": 64,
    "mcts_simulations": 100,
    "train_epochs": 5,
    "batch_size": 64,
    "learning_rate": 5e-4,
    "eval_games": 20,
    "acceptance_threshold": 0.55
}

print("Training Configuration:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")

## 4. Import and Validate Modules

In [None]:
import sys
sys.path.insert(0, '.')

from env.generals_env import GeneralsEnv
from model.network import GeneralsNet
from mcts.mcts import AsyncMCTS
from training.train import Trainer
from training.replay_buffer import ReplayBuffer
from selfplay.selfplay import SelfPlay
from evaluate.evaluate import Arena
from utils.batched_inference import InferenceServer

print("All modules imported successfully.")

In [None]:
env = GeneralsEnv()
state = env.reset()
print(f"Environment initialized. State shape: {state.shape}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = GeneralsNet().to(device)
param_count = sum(p.numel() for p in net.parameters())
print(f"Network initialized on {device}. Parameters: {param_count:,}")

## 5. Training Execution

In [None]:
import asyncio
import os
import shutil
import time
from pathlib import Path
import numpy as np

PROJECT_ROOT = Path(".").resolve()
CHECKPOINT_DIR = PROJECT_ROOT / "data" / "checkpoints"
REPLAY_DIR = PROJECT_ROOT / "data" / "replay"

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
REPLAY_DIR.mkdir(parents=True, exist_ok=True)

print(f"Checkpoint directory: {CHECKPOINT_DIR}")
print(f"Replay directory: {REPLAY_DIR}")

In [None]:
async def training_loop(config):
    latest_path = CHECKPOINT_DIR / "model_latest.pth"
    old_path = CHECKPOINT_DIR / "model_old.pth"
    
    trainer = Trainer(
        lr=config["learning_rate"],
        weight_decay=1e-4,
        batch_size=config["batch_size"],
        epochs=config["train_epochs"],
        checkpoint_dir=str(CHECKPOINT_DIR)
    )
    
    if not old_path.exists():
        if latest_path.exists():
            shutil.copyfile(latest_path, old_path)
            print("Copied model_latest to model_old")
        else:
            print("Creating initial model...")
            states = np.zeros((8, 17, 10, 10), dtype=np.float32)
            policies = np.zeros((8, 10003), dtype=np.float32)
            policies[:, 0] = 1.0
            values = np.zeros((8,), dtype=np.float32)
            trainer.train(states, policies, values, save_name="model_latest.pth")
            shutil.copyfile(latest_path, old_path)
            print("Initial model created.")
    
    replay = ReplayBuffer(save_dir=str(REPLAY_DIR), max_batches=20)
    inference_server = InferenceServer(trainer.net, batch_size=32)
    await inference_server.start()
    
    try:
        for iteration in range(1, config["max_iterations"] + 1):
            print(f"\n{'='*60}")
            print(f"ITERATION {iteration}/{config['max_iterations']}")
            print(f"{'='*60}")
            
            sp = SelfPlay(
                GeneralsEnv,
                inference_server,
                games_per_iteration=config["games_per_iter"],
                mcts_simulations=config["mcts_simulations"],
                temperature_threshold=10
            )
            
            print(f"Generating {config['games_per_iter']} self-play games...")
            start_time = time.time()
            states, policies, values = await sp.play_iteration()
            elapsed = time.time() - start_time
            print(f"Generated {len(states)} positions in {elapsed:.1f}s")
            
            replay.add_game(states, policies, values)
            states_all, policies_all, values_all = replay.load_all()
            
            print(f"Training on {len(states_all)} total positions...")
            trainer.train(states_all, policies_all, values_all, save_name="model_latest.pth")
            
            inference_server.reload_model(str(latest_path))
            
            print(f"Evaluating new model ({config['eval_games']} games)...")
            arena = Arena(
                model_A_path=str(latest_path),
                model_B_path=str(old_path),
                games=config["eval_games"],
                mcts_simulations=config["mcts_simulations"]
            )
            win_rate = await arena.run()
            
            if win_rate > config["acceptance_threshold"]:
                print(f"Model ACCEPTED (win rate: {win_rate:.2%})")
                shutil.copyfile(latest_path, old_path)
            else:
                print(f"Model REJECTED (win rate: {win_rate:.2%})")
            
            time.sleep(0.5)
    
    except KeyboardInterrupt:
        print("\nTraining interrupted by user.")
    finally:
        await inference_server.stop()
        print("Training loop completed.")

In [None]:
await training_loop(TRAINING_CONFIG)

## 6. Download Trained Model

In [None]:
from google.colab import files

model_path = CHECKPOINT_DIR / "model_latest.pth"
if model_path.exists():
    files.download(str(model_path))
    print(f"Downloaded: {model_path}")
else:
    print("No model found. Run training first.")

## 7. Training Monitoring

In [None]:
import os

def show_training_status():
    print("=" * 50)
    print("TRAINING STATUS")
    print("=" * 50)
    
    if CHECKPOINT_DIR.exists():
        checkpoints = list(CHECKPOINT_DIR.glob("*.pth"))
        print(f"\nCheckpoints: {len(checkpoints)}")
        for cp in checkpoints:
            size_mb = cp.stat().st_size / (1024 * 1024)
            print(f"  - {cp.name}: {size_mb:.2f} MB")
    
    if REPLAY_DIR.exists():
        replays = list(REPLAY_DIR.glob("*.npz"))
        print(f"\nReplay batches: {len(replays)}")
        if replays:
            total_size = sum(r.stat().st_size for r in replays) / (1024 * 1024)
            print(f"  Total size: {total_size:.2f} MB")
    
    if torch.cuda.is_available():
        print(f"\nGPU Memory:")
        print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
        print(f"  Cached: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

show_training_status()