# GT-DQN Training on Google Colab

This notebook runs the GT-DQN poker training on Google Colab's GPU.

## Setup Steps:
1. Upload this notebook to Google Colab
2. Enable GPU: Runtime → Change runtime type → GPU
3. Run all cells in order

In [ ]:
# Verify GPU is available
import torch
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")

In [ ]:
# Clone repository and install dependencies
!git clone https://github.com/antimaf/GTDQN
%cd GTDQN
!pip install -r requirements.txt

In [ ]:
# Mount Google Drive for checkpoint saving
from google.colab import drive
drive.mount('/content/drive')

# Create checkpoint directory
CHECKPOINT_DIR = '/content/drive/MyDrive/GTDQN_checkpoints'
!mkdir -p {CHECKPOINT_DIR}

In [ ]:
import os
import time
from datetime import datetime

def save_checkpoint(trainer, episode, metrics, checkpoint_dir):
    """Save training checkpoint with metadata"""
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    checkpoint = {
        'episode': episode,
        'model_state_dict': trainer.policy_net.state_dict(),
        'optimizer_state_dict': trainer.optimizer.state_dict(),
        'metrics': metrics
    }
    path = os.path.join(checkpoint_dir, f'checkpoint_ep{episode}_{timestamp}.pt')
    torch.save(checkpoint, path)
    print(f"Saved checkpoint at episode {episode} to {path}")

def load_latest_checkpoint(trainer, checkpoint_dir):
    """Load most recent checkpoint if exists"""
    checkpoints = sorted([
        f for f in os.listdir(checkpoint_dir) 
        if f.startswith('checkpoint_ep')
    ])
    if not checkpoints:
        return 0, {}
    
    latest = os.path.join(checkpoint_dir, checkpoints[-1])
    print(f"Loading checkpoint: {latest}")
    checkpoint = torch.load(latest)
    
    trainer.policy_net.load_state_dict(checkpoint['model_state_dict'])
    trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['episode'], checkpoint['metrics']

In [ ]:
from train import PokerTrainer
import numpy as np

# Initialize trainer
trainer = PokerTrainer(device="cuda")

# Training parameters
TOTAL_EPISODES = 1000000
CHECKPOINT_FREQ = 1000  # Save every 1000 episodes
BATCH_SIZE = 256  # Increased for GPU
GAMMA = 0.99

# Load previous checkpoint if exists
start_episode, metrics = load_latest_checkpoint(trainer, CHECKPOINT_DIR)

# Training loop with checkpointing
try:
    for episode in range(start_episode, TOTAL_EPISODES):
        # Run episode
        episode_reward = trainer.run_episode()
        
        # Optimize model
        loss = trainer.optimize_model(BATCH_SIZE, GAMMA)
        
        # Update metrics
        trainer.episode_rewards.append(episode_reward)
        
        # Print progress
        if episode % 100 == 0:
            avg_reward = np.mean(trainer.episode_rewards[-100:])
            print(f"Episode {episode}/{TOTAL_EPISODES} | Avg Reward: {avg_reward:.2f} | Loss: {loss:.4f}")
        
        # Save checkpoint
        if episode % CHECKPOINT_FREQ == 0:
            metrics = {
                'episode_rewards': trainer.episode_rewards,
                'win_rates': trainer.win_rates,
                'nash_distances': trainer.nash_distances
            }
            save_checkpoint(trainer, episode, metrics, CHECKPOINT_DIR)

except KeyboardInterrupt:
    print("\nTraining interrupted! Saving checkpoint...")
    save_checkpoint(trainer, episode, metrics, CHECKPOINT_DIR)

print("Training completed!")

## Training Progress Visualization

Run this cell to visualize training metrics:

In [ ]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
plt.style.use('seaborn')
sns.set_palette("husl")

# Create figure with subplots
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 15))

# Plot episode rewards
ax1.plot(trainer.episode_rewards)
ax1.set_title('Episode Rewards')
ax1.set_xlabel('Episode')
ax1.set_ylabel('Reward')

# Plot win rates
ax2.plot(trainer.win_rates)
ax2.set_title('Win Rates')
ax2.set_xlabel('Episode')
ax2.set_ylabel('Win Rate')

# Plot Nash distances
ax3.plot(trainer.nash_distances)
ax3.set_title('Nash Distances')
ax3.set_xlabel('Episode')
ax3.set_ylabel('Distance')

plt.tight_layout()
plt.show()