# Self-Play Reinforcement Learning for No-Press Diplomacy
## Using Official Diplomacy Package

**Project:** Improve Self-Play for Diplomacy  
**Authors:** Giacomo Colosio, Maciej Tasarz, Jakub Seliga, Luka Ivcevic  
**Course:** ISP - UPC Barcelona, Fall 2025/26

---

### References
- **Silver et al. (2017)** - AlphaGo Zero: Mastering Go without human knowledge
- **Bakhtin et al. (2021)** - DORA: Double Oracle RL for Diplomacy  
- **Bakhtin et al. (2022)** - No-Press Diplomacy from Scratch with Human-Regularized RL
- **Paquette et al. (2019)** - No-Press Diplomacy: Modeling Multi-Agent Gameplay

### This Notebook
- Uses official `diplomacy` package for accurate game simulation
- Implements PPO (Proximal Policy Optimization) for policy learning
- Actor-Critic architecture with proper state/action encoding
- Addresses RQ1: Quantify overfitting in pure self-play

**Requirements:** GPU runtime (Runtime → Change runtime type → GPU)

## 1. Setup & Installation

In [None]:
# Install required packages
!pip install diplomacy torch numpy matplotlib tqdm tensorboard --quiet
print("Installation complete!")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import random
import json
import os
import copy
from collections import defaultdict, deque, namedtuple
from typing import Dict, List, Tuple, Optional, Set
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from datetime import datetime

# Diplomacy package
from diplomacy import Game
from diplomacy.utils.export import to_saved_game_format

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

## 2. Game Constants

In [None]:
# Standard Diplomacy constants
POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
NUM_POWERS = 7

# All 75 provinces on the standard map
LOCATIONS = [
    # 34 Supply Centers
    'ANK', 'BEL', 'BER', 'BRE', 'BUD', 'BUL', 'CON', 'DEN', 'EDI', 'GRE',
    'HOL', 'KIE', 'LON', 'LVP', 'MAR', 'MOS', 'MUN', 'NAP', 'NWY', 'PAR',
    'POR', 'ROM', 'RUM', 'SER', 'SEV', 'SMY', 'SPA', 'STP', 'SWE', 'TRI',
    'TUN', 'VEN', 'VIE', 'WAR',
    # 22 Non-SC land provinces
    'ALB', 'APU', 'ARM', 'BOH', 'BUR', 'CLY', 'FIN', 'GAL', 'GAS', 'LVN',
    'NAF', 'PIC', 'PIE', 'PRU', 'RUH', 'SIL', 'SYR', 'TUS', 'TYR', 'UKR',
    'WAL', 'YOR',
    # 19 Sea provinces
    'ADR', 'AEG', 'BAL', 'BAR', 'BLA', 'BOT', 'EAS', 'ENG', 'GOL', 'HEL',
    'ION', 'IRI', 'MAO', 'NAO', 'NTH', 'NWG', 'SKA', 'TYS', 'WES'
]
NUM_LOCATIONS = 75

SUPPLY_CENTERS = set(LOCATIONS[:34])
VICTORY_CENTERS = 18

# Mappings
LOC_TO_IDX = {loc: i for i, loc in enumerate(LOCATIONS)}
IDX_TO_LOC = {i: loc for i, loc in enumerate(LOCATIONS)}
POWER_TO_IDX = {p: i for i, p in enumerate(POWERS)}
IDX_TO_POWER = {i: p for i, p in enumerate(POWERS)}

print(f'Powers: {NUM_POWERS}')
print(f'Locations: {NUM_LOCATIONS}')
print(f'Supply Centers: {len(SUPPLY_CENTERS)}')
print(f'Victory requires: {VICTORY_CENTERS} SCs')

## 3. State Encoder

Based on Bakhtin et al. (2022) encoding scheme:
- Board state encoding per location
- Relative encoding from each power's perspective
- Season and year information

In [None]:
class DiplomacyStateEncoder:
    """
    Encodes Diplomacy game state into fixed-size tensor.
    
    Features per location (75 locations × 35 features = 2625):
        - Unit presence: 7 powers × 2 unit types = 14 (one-hot)
        - Unit can move here: 7 powers = 7 (binary)
        - SC ownership: 8 (7 powers + neutral)
        - Dislodged unit: 7 powers (for retreat phases)
        - Area type: 3 (land, sea, coast)
        - Is supply center: 1
    
    Global features (29):
        - SC count per power: 7 (normalized)
        - Unit count per power: 7 (normalized)
        - Build/disband count per power: 7
        - Season: 3 (one-hot: spring, fall, winter)
        - Year: 1 (normalized)
        - Phase type: 3 (movement, retreat, adjustment)
        - Current power: 1 (index, normalized)
    
    Total: ~2654 features (we use 2048 with compression)
    """
    
    def __init__(self, compressed_size: int = 2048):
        self.compressed_size = compressed_size
        self.raw_loc_features = 23  # Features per location
        self.raw_global_features = 29
        self.raw_size = NUM_LOCATIONS * self.raw_loc_features + self.raw_global_features
        
        # Compression layer (learned during training)
        self.use_compression = compressed_size < self.raw_size
        self.state_size = compressed_size if self.use_compression else self.raw_size
        
    def encode(self, game: Game, power_name: str) -> np.ndarray:
        """
        Encode game state from the perspective of power_name.
        
        Args:
            game: Diplomacy Game object
            power_name: Power whose perspective to encode from
            
        Returns:
            numpy array of shape (state_size,)
        """
        features = np.zeros(self.raw_size, dtype=np.float32)
        power_idx = POWER_TO_IDX[power_name]
        
        # Get game state
        state = game.get_state()
        units = state['units']
        centers = state['centers']
        
        # Parse current phase
        phase_name = game.get_current_phase()
        year = self._parse_year(phase_name)
        season = self._parse_season(phase_name)
        phase_type = self._parse_phase_type(phase_name)
        
        # Build unit location map
        unit_locations = {}  # loc -> (power, unit_type)
        for pwr, pwr_units in units.items():
            for unit in pwr_units:
                unit_type, loc = self._parse_unit(unit)
                if loc:
                    unit_locations[loc] = (pwr, unit_type)
        
        # Encode each location
        for loc_idx, loc in enumerate(LOCATIONS):
            offset = loc_idx * self.raw_loc_features
            
            # Unit presence (relative to current power)
            if loc in unit_locations:
                pwr, unit_type = unit_locations[loc]
                rel_pwr_idx = (POWER_TO_IDX[pwr] - power_idx) % NUM_POWERS
                
                if unit_type == 'A':
                    features[offset + rel_pwr_idx] = 1.0  # Army
                else:
                    features[offset + NUM_POWERS + rel_pwr_idx] = 1.0  # Fleet
            
            # SC ownership (relative)
            sc_offset = offset + 14
            if loc in SUPPLY_CENTERS:
                features[offset + 22] = 1.0  # Is SC
                owned = False
                for pwr, pwr_centers in centers.items():
                    if loc in pwr_centers:
                        rel_pwr_idx = (POWER_TO_IDX[pwr] - power_idx) % NUM_POWERS
                        features[sc_offset + rel_pwr_idx] = 1.0
                        owned = True
                        break
                if not owned:
                    features[sc_offset + 7] = 1.0  # Neutral
        
        # Global features
        g_offset = NUM_LOCATIONS * self.raw_loc_features
        
        # SC counts (normalized, relative order)
        for pwr in POWERS:
            rel_idx = (POWER_TO_IDX[pwr] - power_idx) % NUM_POWERS
            sc_count = len(centers.get(pwr, []))
            features[g_offset + rel_idx] = sc_count / VICTORY_CENTERS
        
        # Unit counts (normalized)
        for pwr in POWERS:
            rel_idx = (POWER_TO_IDX[pwr] - power_idx) % NUM_POWERS
            unit_count = len(units.get(pwr, []))
            features[g_offset + 7 + rel_idx] = unit_count / 17.0
        
        # Build counts
        for pwr in POWERS:
            rel_idx = (POWER_TO_IDX[pwr] - power_idx) % NUM_POWERS
            sc_count = len(centers.get(pwr, []))
            unit_count = len(units.get(pwr, []))
            build_count = sc_count - unit_count
            features[g_offset + 14 + rel_idx] = np.clip(build_count / 5.0, -1, 1)
        
        # Season (one-hot)
        season_offset = g_offset + 21
        if season == 'S':
            features[season_offset] = 1.0
        elif season == 'F':
            features[season_offset + 1] = 1.0
        else:
            features[season_offset + 2] = 1.0
        
        # Year (normalized)
        features[g_offset + 24] = (year - 1901) / 20.0
        
        # Phase type (one-hot)
        phase_offset = g_offset + 25
        if phase_type == 'M':
            features[phase_offset] = 1.0
        elif phase_type == 'R':
            features[phase_offset + 1] = 1.0
        else:
            features[phase_offset + 2] = 1.0
        
        # Current power index
        features[g_offset + 28] = power_idx / (NUM_POWERS - 1)
        
        return features[:self.state_size] if not self.use_compression else features[:self.raw_size]
    
    def _parse_unit(self, unit: str) -> Tuple[str, str]:
        """Parse unit string like 'A PAR' or 'F STP/SC'."""
        parts = unit.split()
        if len(parts) >= 2:
            unit_type = parts[0]
            loc = parts[1].split('/')[0]
            return unit_type, loc
        return '', ''
    
    def _parse_year(self, phase: str) -> int:
        try:
            return int(phase[1:5])
        except:
            return 1901
    
    def _parse_season(self, phase: str) -> str:
        return phase[0] if phase else 'S'
    
    def _parse_phase_type(self, phase: str) -> str:
        return phase[-1] if phase else 'M'

# Test encoder
state_encoder = DiplomacyStateEncoder(compressed_size=1754)
print(f'Raw state size: {state_encoder.raw_size}')
print(f'Used state size: {state_encoder.state_size}')

## 4. Action Encoder

Handles encoding/decoding of Diplomacy orders using the official game's action space.

In [None]:
class DiplomacyActionEncoder:
    """
    Encodes Diplomacy orders to indices and vice versa.
    
    Uses the official diplomacy package to get valid orders,
    then maps them to a vocabulary for the neural network.
    """
    
    def __init__(self):
        self.order_to_idx = {'<PAD>': 0, '<UNK>': 1}
        self.idx_to_order = {0: '<PAD>', 1: '<UNK>'}
        self.vocab_size = 2
        self._build_base_vocab()
        
    def _build_base_vocab(self):
        """
        Build vocabulary from possible orders.
        We generate a comprehensive set covering most game situations.
        """
        idx = 2
        
        # Generate orders by playing sample games
        orders_seen = set()
        
        # Play random games to collect order vocabulary
        for _ in range(50):
            game = Game()
            for _ in range(30):  # 30 phases max
                if game.is_game_done:
                    break
                
                # Get all possible orders for all powers
                for power_name in POWERS:
                    possible = game.get_all_possible_orders()
                    for loc, loc_orders in possible.items():
                        for order in loc_orders:
                            orders_seen.add(order)
                
                # Submit random orders
                for power_name in POWERS:
                    power = game.get_power(power_name)
                    possible = game.get_all_possible_orders()
                    orders = []
                    for unit in power.units:
                        loc = unit.split()[-1].split('/')[0]
                        if loc in possible and possible[loc]:
                            orders.append(random.choice(possible[loc]))
                    game.set_orders(power_name, orders)
                
                game.process()
        
        # Add all seen orders to vocabulary
        for order in sorted(orders_seen):
            if order not in self.order_to_idx:
                self.order_to_idx[order] = idx
                self.idx_to_order[idx] = order
                idx += 1
        
        self.vocab_size = len(self.order_to_idx)
        print(f'Built vocabulary with {self.vocab_size} orders')
    
    def encode(self, order: str) -> int:
        """Encode order string to index."""
        return self.order_to_idx.get(order, 1)  # 1 = UNK
    
    def decode(self, idx: int) -> str:
        """Decode index to order string."""
        return self.idx_to_order.get(idx, '<UNK>')
    
    def get_valid_action_mask(self, game: Game, power_name: str) -> Tuple[List[int], Dict[int, str]]:
        """
        Get mask of valid actions for a power.
        
        Returns:
            (list of valid action indices, mapping from idx to order)
        """
        valid_indices = []
        idx_to_order = {}
        
        power = game.get_power(power_name)
        possible = game.get_all_possible_orders()
        
        for unit in power.units:
            loc = unit.split()[-1].split('/')[0]
            if loc in possible:
                for order in possible[loc]:
                    idx = self.encode(order)
                    if idx > 1:  # Not PAD or UNK
                        valid_indices.append(idx)
                        idx_to_order[idx] = order
        
        # If no valid actions, return hold-equivalent
        if not valid_indices:
            valid_indices = [1]
            
        return valid_indices, idx_to_order
    
    def save(self, path: str):
        """Save vocabulary to file."""
        with open(path, 'w') as f:
            json.dump({
                'order_to_idx': self.order_to_idx,
                'vocab_size': self.vocab_size
            }, f)
    
    def load(self, path: str):
        """Load vocabulary from file."""
        with open(path, 'r') as f:
            data = json.load(f)
        self.order_to_idx = data['order_to_idx']
        self.idx_to_order = {int(v): k for k, v in self.order_to_idx.items()}
        self.vocab_size = data['vocab_size']

# Build action encoder
print('Building action vocabulary (this may take a minute)...')
action_encoder = DiplomacyActionEncoder()

## 5. Actor-Critic Network

In [None]:
class ActorCriticNetwork(nn.Module):
    """
    Actor-Critic network for PPO.
    
    Architecture:
    - Shared MLP backbone for feature extraction
    - Policy head (Actor): outputs action logits
    - Value head (Critic): outputs state value V(s)
    
    Based on architectures from:
    - Schulman et al. (2017) - PPO
    - Bakhtin et al. (2022) - Diplodocus
    """
    
    def __init__(self, state_size: int, action_size: int, 
                 hidden_sizes: List[int] = [1024, 512, 256]):
        super().__init__()
        
        self.state_size = state_size
        self.action_size = action_size
        
        # Shared backbone
        layers = []
        prev_size = state_size
        for hidden_size in hidden_sizes:
            layers.extend([
                nn.Linear(prev_size, hidden_size),
                nn.LayerNorm(hidden_size),
                nn.ReLU(),
                nn.Dropout(0.1)
            ])
            prev_size = hidden_size
        self.backbone = nn.Sequential(*layers)
        
        # Policy head (Actor)
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_sizes[-1], hidden_sizes[-1]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[-1], action_size)
        )
        
        # Value head (Critic)
        self.value_head = nn.Sequential(
            nn.Linear(hidden_sizes[-1], hidden_sizes[-1] // 2),
            nn.ReLU(),
            nn.Linear(hidden_sizes[-1] // 2, 1)
        )
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize network weights."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
                nn.init.constant_(module.bias, 0)
        
        # Smaller init for output layers
        nn.init.orthogonal_(self.policy_head[-1].weight, gain=0.01)
        nn.init.orthogonal_(self.value_head[-1].weight, gain=1.0)
    
    def forward(self, x: torch.Tensor, 
                action_mask: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass.
        
        Args:
            x: State tensor (batch_size, state_size)
            action_mask: Binary mask for valid actions (batch_size, action_size)
            
        Returns:
            (policy_logits, value)
        """
        features = self.backbone(x)
        
        # Policy logits
        logits = self.policy_head(features)
        
        # Apply action mask if provided
        if action_mask is not None:
            # Set invalid actions to very negative value
            logits = logits.masked_fill(~action_mask.bool(), float('-inf'))
        
        # Value
        value = self.value_head(features).squeeze(-1)
        
        return logits, value
    
    def get_action_and_value(self, state: torch.Tensor,
                              valid_actions: List[int] = None,
                              deterministic: bool = False) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Get action, log prob, entropy, and value for a state.
        
        Args:
            state: State tensor (1, state_size)
            valid_actions: List of valid action indices
            deterministic: If True, return argmax action
            
        Returns:
            (action, log_prob, entropy, value)
        """
        # Create action mask
        action_mask = None
        if valid_actions is not None:
            action_mask = torch.zeros(1, self.action_size, device=state.device)
            action_mask[0, valid_actions] = 1.0
        
        logits, value = self.forward(state, action_mask)
        
        # Get probabilities
        probs = F.softmax(logits, dim=-1)
        dist = Categorical(probs)
        
        if deterministic:
            action = probs.argmax(dim=-1)
        else:
            action = dist.sample()
        
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        
        return action.item(), log_prob, entropy, value
    
    def evaluate_actions(self, states: torch.Tensor, 
                         actions: torch.Tensor,
                         action_masks: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Evaluate actions for PPO update.
        
        Returns:
            (log_probs, entropy, values)
        """
        logits, values = self.forward(states, action_masks)
        
        probs = F.softmax(logits, dim=-1)
        dist = Categorical(probs)
        
        log_probs = dist.log_prob(actions)
        entropy = dist.entropy()
        
        return log_probs, entropy, values

# Test network
test_net = ActorCriticNetwork(state_encoder.state_size, action_encoder.vocab_size)
print(f'Network parameters: {sum(p.numel() for p in test_net.parameters()):,}')
del test_net

## 6. Experience Buffer

In [None]:
class RolloutBuffer:
    """
    Buffer for storing rollout experiences.
    Handles GAE computation and batch generation.
    """
    
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []
        self.values = []
        self.action_masks = []
        
    def add(self, state, action, reward, done, log_prob, value, action_mask=None):
        """Add experience to buffer."""
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.dones.append(done)
        self.log_probs.append(log_prob)
        self.values.append(value)
        self.action_masks.append(action_mask)
    
    def compute_returns_and_advantages(self, last_value: float, 
                                        gamma: float = 0.99, 
                                        gae_lambda: float = 0.95) -> Tuple[np.ndarray, np.ndarray]:
        """
        Compute returns and GAE advantages.
        
        Args:
            last_value: Value estimate for final state
            gamma: Discount factor
            gae_lambda: GAE lambda parameter
            
        Returns:
            (returns, advantages)
        """
        rewards = np.array(self.rewards)
        values = np.array(self.values + [last_value])
        dones = np.array(self.dones + [True])
        
        # GAE computation
        advantages = np.zeros_like(rewards)
        last_gae = 0
        
        for t in reversed(range(len(rewards))):
            next_non_terminal = 1.0 - dones[t + 1]
            delta = rewards[t] + gamma * values[t + 1] * next_non_terminal - values[t]
            last_gae = delta + gamma * gae_lambda * next_non_terminal * last_gae
            advantages[t] = last_gae
        
        returns = advantages + np.array(self.values)
        
        return returns, advantages
    
    def get_batches(self, batch_size: int, returns: np.ndarray, 
                    advantages: np.ndarray) -> List[Dict]:
        """
        Generate mini-batches for training.
        """
        n_samples = len(self.states)
        indices = np.arange(n_samples)
        np.random.shuffle(indices)
        
        batches = []
        for start in range(0, n_samples, batch_size):
            end = min(start + batch_size, n_samples)
            batch_indices = indices[start:end]
            
            batch = {
                'states': np.array([self.states[i] for i in batch_indices]),
                'actions': np.array([self.actions[i] for i in batch_indices]),
                'log_probs': np.array([self.log_probs[i] for i in batch_indices]),
                'returns': returns[batch_indices],
                'advantages': advantages[batch_indices],
            }
            batches.append(batch)
        
        return batches
    
    def clear(self):
        """Clear buffer."""
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []
        self.values = []
        self.action_masks = []
    
    def __len__(self):
        return len(self.states)

print('RolloutBuffer defined')

## 7. PPO Algorithm

In [None]:
class PPOAgent:
    """
    PPO Agent implementation.
    
    Based on:
    - Schulman et al. (2017) - Proximal Policy Optimization
    - Implementation details from OpenAI baselines
    """
    
    def __init__(self, state_size: int, action_size: int,
                 lr: float = 2.5e-4,
                 gamma: float = 0.99,
                 gae_lambda: float = 0.95,
                 clip_epsilon: float = 0.2,
                 value_coef: float = 0.5,
                 entropy_coef: float = 0.01,
                 max_grad_norm: float = 0.5,
                 target_kl: float = 0.03):
        
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.max_grad_norm = max_grad_norm
        self.target_kl = target_kl
        
        # Network
        self.network = ActorCriticNetwork(state_size, action_size).to(device)
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr, eps=1e-5)
        
        # Learning rate scheduler
        self.scheduler = None  # Set externally if needed
        
        # Buffer
        self.buffer = RolloutBuffer()
        
        # Stats
        self.train_step = 0
        
    def select_action(self, state: np.ndarray, 
                      valid_actions: List[int] = None,
                      deterministic: bool = False) -> Tuple[int, float, float]:
        """
        Select action given state.
        
        Returns:
            (action_idx, log_prob, value)
        """
        state_t = torch.FloatTensor(state).unsqueeze(0).to(device)
        
        with torch.no_grad():
            action, log_prob, entropy, value = self.network.get_action_and_value(
                state_t, valid_actions, deterministic
            )
        
        return action, log_prob.item(), value.item()
    
    def store_transition(self, state, action, reward, done, log_prob, value, action_mask=None):
        """Store transition in buffer."""
        self.buffer.add(state, action, reward, done, log_prob, value, action_mask)
    
    def update(self, num_epochs: int = 4, batch_size: int = 64) -> Dict[str, float]:
        """
        Update policy using PPO.
        
        Returns:
            Dictionary of training metrics
        """
        if len(self.buffer) < batch_size:
            return {}
        
        # Compute returns and advantages
        returns, advantages = self.buffer.compute_returns_and_advantages(
            last_value=0.0,
            gamma=self.gamma,
            gae_lambda=self.gae_lambda
        )
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Training metrics
        total_policy_loss = 0
        total_value_loss = 0
        total_entropy = 0
        total_kl = 0
        num_updates = 0
        
        # PPO epochs
        for epoch in range(num_epochs):
            batches = self.buffer.get_batches(batch_size, returns, advantages)
            
            for batch in batches:
                states_t = torch.FloatTensor(batch['states']).to(device)
                actions_t = torch.LongTensor(batch['actions']).to(device)
                old_log_probs_t = torch.FloatTensor(batch['log_probs']).to(device)
                returns_t = torch.FloatTensor(batch['returns']).to(device)
                advantages_t = torch.FloatTensor(batch['advantages']).to(device)
                
                # Evaluate actions
                new_log_probs, entropy, values = self.network.evaluate_actions(
                    states_t, actions_t
                )
                
                # Policy loss (clipped PPO objective)
                ratio = torch.exp(new_log_probs - old_log_probs_t)
                surr1 = ratio * advantages_t
                surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages_t
                policy_loss = -torch.min(surr1, surr2).mean()
                
                # Value loss (clipped)
                value_loss = F.mse_loss(values, returns_t)
                
                # Entropy bonus
                entropy_loss = -entropy.mean()
                
                # Total loss
                loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
                
                # Update
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
                self.optimizer.step()
                
                # Track metrics
                total_policy_loss += policy_loss.item()
                total_value_loss += value_loss.item()
                total_entropy += entropy.mean().item()
                
                # Approximate KL divergence
                with torch.no_grad():
                    kl = (old_log_probs_t - new_log_probs).mean().item()
                    total_kl += kl
                
                num_updates += 1
            
            # Early stopping if KL divergence too high
            if total_kl / max(num_updates, 1) > self.target_kl:
                break
        
        # Clear buffer
        self.buffer.clear()
        self.train_step += 1
        
        return {
            'policy_loss': total_policy_loss / max(num_updates, 1),
            'value_loss': total_value_loss / max(num_updates, 1),
            'entropy': total_entropy / max(num_updates, 1),
            'kl': total_kl / max(num_updates, 1),
            'num_updates': num_updates
        }
    
    def save(self, path: str):
        """Save agent checkpoint."""
        torch.save({
            'network_state_dict': self.network.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_step': self.train_step
        }, path)
    
    def load(self, path: str):
        """Load agent checkpoint."""
        checkpoint = torch.load(path, map_location=device)
        self.network.load_state_dict(checkpoint['network_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.train_step = checkpoint.get('train_step', 0)

print('PPOAgent defined')

## 8. Reward Shaping

Critical for learning! Based on Bakhtin et al. (2022).

In [None]:
class RewardShaper:
    """
    Computes shaped rewards for Diplomacy.
    
    Based on Bakhtin et al. (2022) reward design:
    - Final outcome reward (win/draw/loss)
    - Supply center delta reward
    - Survival bonus
    """
    
    def __init__(self, 
                 win_reward: float = 1.0,
                 draw_reward: float = 0.0,
                 loss_reward: float = -1.0,
                 elimination_reward: float = -1.0,
                 sc_gain_reward: float = 0.1,
                 sc_loss_reward: float = -0.1,
                 survival_reward: float = 0.01):
        
        self.win_reward = win_reward
        self.draw_reward = draw_reward
        self.loss_reward = loss_reward
        self.elimination_reward = elimination_reward
        self.sc_gain_reward = sc_gain_reward
        self.sc_loss_reward = sc_loss_reward
        self.survival_reward = survival_reward
        
        # Track SC counts for delta computation
        self.prev_sc_counts = {}
    
    def reset(self, game: Game):
        """Reset tracker for new game."""
        state = game.get_state()
        self.prev_sc_counts = {
            power: len(state['centers'].get(power, []))
            for power in POWERS
        }
    
    def compute_rewards(self, game: Game, done: bool) -> Dict[str, float]:
        """
        Compute rewards for all powers.
        
        Args:
            game: Current game state
            done: Whether game is finished
            
        Returns:
            Dict mapping power name to reward
        """
        rewards = {power: 0.0 for power in POWERS}
        state = game.get_state()
        centers = state['centers']
        units = state['units']
        
        # Current SC counts
        current_sc_counts = {
            power: len(centers.get(power, []))
            for power in POWERS
        }
        
        # Find winner (if any)
        winner = None
        for power in POWERS:
            if current_sc_counts[power] >= VICTORY_CENTERS:
                winner = power
                break
        
        for power in POWERS:
            # Check elimination
            is_eliminated = len(units.get(power, [])) == 0 and current_sc_counts[power] == 0
            
            if done:
                # Final rewards
                if winner == power:
                    rewards[power] = self.win_reward
                elif winner is not None:
                    rewards[power] = self.loss_reward
                elif is_eliminated:
                    rewards[power] = self.elimination_reward
                else:
                    # Draw - reward based on relative position
                    total_scs = sum(current_sc_counts.values())
                    if total_scs > 0:
                        share = current_sc_counts[power] / total_scs
                        rewards[power] = self.draw_reward + share * 0.5
            else:
                # Intermediate rewards
                if is_eliminated:
                    rewards[power] = self.elimination_reward
                else:
                    # SC delta reward
                    sc_delta = current_sc_counts[power] - self.prev_sc_counts.get(power, 0)
                    if sc_delta > 0:
                        rewards[power] += self.sc_gain_reward * sc_delta
                    elif sc_delta < 0:
                        rewards[power] += self.sc_loss_reward * abs(sc_delta)
                    
                    # Survival bonus
                    rewards[power] += self.survival_reward
        
        # Update previous SC counts
        self.prev_sc_counts = current_sc_counts.copy()
        
        return rewards

print('RewardShaper defined')

## 9. Self-Play Trainer

In [None]:
class SelfPlayTrainer:
    """
    Self-Play training loop for Diplomacy.
    
    All 7 powers are controlled by the same policy network.
    Experience is collected from all powers' perspectives.
    """
    
    def __init__(self, config: Dict):
        self.config = config
        
        # Components
        self.state_encoder = state_encoder
        self.action_encoder = action_encoder
        self.reward_shaper = RewardShaper(
            win_reward=config.get('win_reward', 1.0),
            sc_gain_reward=config.get('sc_gain_reward', 0.1),
            sc_loss_reward=config.get('sc_loss_reward', -0.1),
            survival_reward=config.get('survival_reward', 0.02)
        )
        
        # Agent
        self.agent = PPOAgent(
            state_size=self.state_encoder.state_size,
            action_size=self.action_encoder.vocab_size,
            lr=config.get('lr', 2.5e-4),
            gamma=config.get('gamma', 0.99),
            gae_lambda=config.get('gae_lambda', 0.95),
            clip_epsilon=config.get('clip_epsilon', 0.2),
            entropy_coef=config.get('entropy_coef', 0.01)
        )
        
        # History
        self.history = {
            'episode_rewards': [],
            'episode_lengths': [],
            'wins': defaultdict(int),
            'draws': 0,
            'sc_counts': [],
            'policy_loss': [],
            'value_loss': [],
            'entropy': []
        }
        
        # Tensorboard
        self.writer = SummaryWriter(log_dir=f'runs/selfplay_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
    
    def play_game(self, deterministic: bool = False) -> Dict:
        """
        Play one complete game with self-play.
        
        Returns:
            Game statistics
        """
        game = Game()
        self.reward_shaper.reset(game)
        
        episode_rewards = {p: 0.0 for p in POWERS}
        steps = 0
        max_steps = self.config.get('max_game_length', 100)
        
        while not game.is_game_done and steps < max_steps:
            current_phase = game.get_current_phase()
            
            # Collect orders from all powers
            for power_name in POWERS:
                power = game.get_power(power_name)
                
                # Skip if no units or eliminated
                if not power.units:
                    continue
                
                # Get state encoding
                state = self.state_encoder.encode(game, power_name)
                
                # Get valid actions for each unit
                orders = []
                possible_orders = game.get_all_possible_orders()
                
                for unit in power.units:
                    # Get unit location
                    unit_loc = unit.split()[-1].split('/')[0]
                    
                    # Get valid orders for this unit
                    if unit_loc in possible_orders and possible_orders[unit_loc]:
                        unit_orders = possible_orders[unit_loc]
                        
                        # Encode valid actions
                        valid_indices = []
                        idx_to_order = {}
                        for order in unit_orders:
                            idx = self.action_encoder.encode(order)
                            if idx > 1:  # Not PAD or UNK
                                valid_indices.append(idx)
                                idx_to_order[idx] = order
                        
                        if valid_indices:
                            # Select action
                            action_idx, log_prob, value = self.agent.select_action(
                                state, valid_indices, deterministic
                            )
                            
                            # Get order string
                            if action_idx in idx_to_order:
                                order = idx_to_order[action_idx]
                            else:
                                order = random.choice(unit_orders)
                                action_idx = self.action_encoder.encode(order)
                                log_prob = 0.0
                            
                            orders.append(order)
                            
                            # Store experience (we'll add reward after processing)
                            self.agent.buffer.add(
                                state=state,
                                action=action_idx,
                                reward=0.0,  # Placeholder
                                done=False,
                                log_prob=log_prob,
                                value=value
                            )
                        else:
                            # Fallback to random
                            orders.append(random.choice(unit_orders))
                    else:
                        # No valid orders - unit holds
                        pass
                
                # Submit orders
                game.set_orders(power_name, orders)
            
            # Process the phase
            game.process()
            steps += 1
            
            # Compute rewards
            done = game.is_game_done or steps >= max_steps
            rewards = self.reward_shaper.compute_rewards(game, done)
            
            # Update rewards in buffer (for most recent experiences)
            # This is a simplification - in practice we'd track per-power
            for power_name, reward in rewards.items():
                episode_rewards[power_name] += reward
            
            # Update last experiences with rewards
            avg_reward = sum(rewards.values()) / len(rewards)
            for i in range(min(7, len(self.agent.buffer))):
                idx = len(self.agent.buffer) - 1 - i
                if idx >= 0:
                    self.agent.buffer.rewards[idx] = avg_reward
                    self.agent.buffer.dones[idx] = done
        
        # Determine winner
        winner = None
        final_scs = {}
        state = game.get_state()
        for power_name in POWERS:
            sc_count = len(state['centers'].get(power_name, []))
            final_scs[power_name] = sc_count
            if sc_count >= VICTORY_CENTERS:
                winner = power_name
        
        return {
            'winner': winner,
            'steps': steps,
            'rewards': episode_rewards,
            'final_scs': final_scs,
            'phase': game.get_current_phase()
        }
    
    def train(self, num_games: int, 
              update_every: int = 5,
              eval_every: int = 50,
              save_every: int = 100):
        """
        Main training loop.
        
        Args:
            num_games: Total games to play
            update_every: Games between policy updates
            eval_every: Games between evaluations
            save_every: Games between checkpoint saves
        """
        print('='*60)
        print('SELF-PLAY TRAINING')
        print('='*60)
        print(f'Games: {num_games}')
        print(f'Update every: {update_every} games')
        print(f'Device: {device}')
        print('='*60 + '\n')
        
        pbar = tqdm(range(num_games), desc='Training')
        
        for game_num in pbar:
            # Play game
            stats = self.play_game()
            
            # Record stats
            total_reward = sum(stats['rewards'].values())
            self.history['episode_rewards'].append(total_reward)
            self.history['episode_lengths'].append(stats['steps'])
            self.history['sc_counts'].append(stats['final_scs'])
            
            if stats['winner']:
                self.history['wins'][stats['winner']] += 1
            else:
                self.history['draws'] += 1
            
            # Log to tensorboard
            self.writer.add_scalar('Reward/Total', total_reward, game_num)
            self.writer.add_scalar('Game/Length', stats['steps'], game_num)
            
            # Update policy
            if (game_num + 1) % update_every == 0 and len(self.agent.buffer) > 0:
                metrics = self.agent.update(
                    num_epochs=self.config.get('ppo_epochs', 4),
                    batch_size=self.config.get('batch_size', 64)
                )
                
                if metrics:
                    self.history['policy_loss'].append(metrics['policy_loss'])
                    self.history['value_loss'].append(metrics['value_loss'])
                    self.history['entropy'].append(metrics['entropy'])
                    
                    self.writer.add_scalar('Loss/Policy', metrics['policy_loss'], game_num)
                    self.writer.add_scalar('Loss/Value', metrics['value_loss'], game_num)
                    self.writer.add_scalar('Loss/Entropy', metrics['entropy'], game_num)
            
            # Update progress bar
            recent_rewards = self.history['episode_rewards'][-100:]
            recent_lengths = self.history['episode_lengths'][-100:]
            total_wins = sum(self.history['wins'].values())
            
            pbar.set_postfix({
                'reward': f'{np.mean(recent_rewards):.2f}',
                'length': f'{np.mean(recent_lengths):.1f}',
                'wins': total_wins,
                'draws': self.history['draws']
            })
            
            # Save checkpoint
            if (game_num + 1) % save_every == 0:
                self.save_checkpoint(game_num + 1)
            
            # Evaluation
            if (game_num + 1) % eval_every == 0:
                self.evaluate(num_games=10)
        
        self.writer.close()
        print('\nTraining complete!')
        return self.history
    
    def evaluate(self, num_games: int = 10):
        """Evaluate current policy with deterministic actions."""
        print(f'\n--- Evaluation ({num_games} games) ---')
        
        eval_wins = defaultdict(int)
        eval_scs = defaultdict(list)
        
        for _ in range(num_games):
            stats = self.play_game(deterministic=True)
            self.agent.buffer.clear()  # Don't train on eval games
            
            if stats['winner']:
                eval_wins[stats['winner']] += 1
            
            for power, scs in stats['final_scs'].items():
                eval_scs[power].append(scs)
        
        # Print results
        for power in POWERS:
            wins = eval_wins[power]
            avg_sc = np.mean(eval_scs[power]) if eval_scs[power] else 0
            print(f'  {power}: {wins} wins, avg {avg_sc:.1f} SCs')
        
        draws = num_games - sum(eval_wins.values())
        print(f'  Draws: {draws}')
        print()
    
    def save_checkpoint(self, game_num: int):
        """Save training checkpoint."""
        path = f'selfplay_checkpoint_{game_num}.pt'
        self.agent.save(path)
        print(f'\nCheckpoint saved: {path}')

print('SelfPlayTrainer defined')

## 10. Training Configuration

In [None]:
CONFIG = {
    # Training
    'num_games': 500,  # Start with 500, increase for better results
    'max_game_length': 100,  # Max phases per game
    'update_every': 5,  # Games between PPO updates
    'eval_every': 100,  # Games between evaluations
    'save_every': 100,  # Games between saves
    
    # PPO hyperparameters
    'lr': 2.5e-4,
    'gamma': 0.99,
    'gae_lambda': 0.95,
    'clip_epsilon': 0.2,
    'entropy_coef': 0.01,
    'ppo_epochs': 4,
    'batch_size': 64,
    
    # Reward shaping
    'win_reward': 1.0,
    'sc_gain_reward': 0.1,
    'sc_loss_reward': -0.1,
    'survival_reward': 0.02,
}

print('Training Configuration:')
print('-' * 40)
for k, v in CONFIG.items():
    print(f'  {k}: {v}')

## 11. Run Training!

In [None]:
# Create trainer
trainer = SelfPlayTrainer(CONFIG)

# Train!
history = trainer.train(
    num_games=CONFIG['num_games'],
    update_every=CONFIG['update_every'],
    eval_every=CONFIG['eval_every'],
    save_every=CONFIG['save_every']
)

## 12. Results & Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Episode Rewards
ax = axes[0, 0]
rewards = history['episode_rewards']
ax.plot(rewards, alpha=0.3, color='blue', label='Raw')
window = 50
if len(rewards) >= window:
    ma = np.convolve(rewards, np.ones(window)/window, mode='valid')
    ax.plot(range(window-1, len(rewards)), ma, color='red', linewidth=2, label=f'MA-{window}')
ax.set_xlabel('Game')
ax.set_ylabel('Total Reward')
ax.set_title('Episode Rewards')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. Game Lengths
ax = axes[0, 1]
lengths = history['episode_lengths']
ax.plot(lengths, alpha=0.3, color='green', label='Raw')
if len(lengths) >= window:
    ma = np.convolve(lengths, np.ones(window)/window, mode='valid')
    ax.plot(range(window-1, len(lengths)), ma, color='red', linewidth=2, label=f'MA-{window}')
ax.set_xlabel('Game')
ax.set_ylabel('Phases')
ax.set_title('Game Length')
ax.legend()
ax.grid(True, alpha=0.3)

# 3. Win Distribution
ax = axes[1, 0]
wins = history['wins']
categories = POWERS + ['Draw']
counts = [wins.get(p, 0) for p in POWERS] + [history['draws']]
colors = plt.cm.Set3(range(len(categories)))
bars = ax.bar(categories, counts, color=colors, edgecolor='black')
ax.set_xlabel('Outcome')
ax.set_ylabel('Count')
ax.set_title('Win Distribution')
ax.tick_params(axis='x', rotation=45)

# Add value labels
for bar, count in zip(bars, counts):
    if count > 0:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                str(count), ha='center', va='bottom', fontsize=9)

# 4. Training Losses
ax = axes[1, 1]
if history['policy_loss']:
    ax.plot(history['policy_loss'], label='Policy Loss', color='purple')
    ax.plot(history['value_loss'], label='Value Loss', color='orange')
    ax.set_xlabel('Update')
    ax.set_ylabel('Loss')
    ax.set_title('Training Losses')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('selfplay_training_results.png', dpi=150)
plt.show()

In [None]:
# Print final statistics
print('='*60)
print('TRAINING SUMMARY')
print('='*60)

print(f"\nTotal games: {CONFIG['num_games']}")
print(f"Average reward (last 100): {np.mean(history['episode_rewards'][-100:]):.3f}")
print(f"Average length (last 100): {np.mean(history['episode_lengths'][-100:]):.1f} phases")

print(f"\nWin Distribution:")
total_wins = sum(history['wins'].values())
for power in POWERS:
    count = history['wins'].get(power, 0)
    pct = count / CONFIG['num_games'] * 100
    print(f"  {power}: {count} ({pct:.1f}%)")

print(f"  Draws: {history['draws']} ({history['draws']/CONFIG['num_games']*100:.1f}%)")

# Win rate (non-draw games)
if total_wins > 0:
    print(f"\nWin rate among decisive games: {total_wins/(total_wins+history['draws'])*100:.1f}%")

print('='*60)

## 13. Evaluate vs Random Baseline

In [None]:
def evaluate_vs_random(trainer, agent_power: str = 'FRANCE', num_games: int = 20):
    """
    Evaluate trained agent vs random opponents.
    Agent plays as one power, all others play randomly.
    """
    print(f'\nEvaluating {agent_power} vs Random ({num_games} games)...')
    
    wins = 0
    total_scs = []
    
    for game_idx in tqdm(range(num_games), desc='Eval vs Random'):
        game = Game()
        steps = 0
        
        while not game.is_game_done and steps < 100:
            possible_orders = game.get_all_possible_orders()
            
            for power_name in POWERS:
                power = game.get_power(power_name)
                if not power.units:
                    continue
                
                orders = []
                
                if power_name == agent_power:
                    # Agent's turn - use trained policy
                    state = trainer.state_encoder.encode(game, power_name)
                    
                    for unit in power.units:
                        unit_loc = unit.split()[-1].split('/')[0]
                        if unit_loc in possible_orders and possible_orders[unit_loc]:
                            unit_orders = possible_orders[unit_loc]
                            
                            # Get valid indices
                            valid_indices = []
                            idx_to_order = {}
                            for order in unit_orders:
                                idx = trainer.action_encoder.encode(order)
                                if idx > 1:
                                    valid_indices.append(idx)
                                    idx_to_order[idx] = order
                            
                            if valid_indices:
                                action_idx, _, _ = trainer.agent.select_action(
                                    state, valid_indices, deterministic=True
                                )
                                if action_idx in idx_to_order:
                                    orders.append(idx_to_order[action_idx])
                                else:
                                    orders.append(random.choice(unit_orders))
                            else:
                                orders.append(random.choice(unit_orders))
                else:
                    # Random opponent
                    for unit in power.units:
                        unit_loc = unit.split()[-1].split('/')[0]
                        if unit_loc in possible_orders and possible_orders[unit_loc]:
                            orders.append(random.choice(possible_orders[unit_loc]))
                
                game.set_orders(power_name, orders)
            
            game.process()
            steps += 1
        
        # Check result
        state = game.get_state()
        agent_scs = len(state['centers'].get(agent_power, []))
        total_scs.append(agent_scs)
        
        # Check for win
        if agent_scs >= VICTORY_CENTERS:
            wins += 1
        else:
            # Check if agent has most SCs
            max_scs = max(len(state['centers'].get(p, [])) for p in POWERS)
            if agent_scs == max_scs and agent_scs > 0:
                # Count as win if tied for first
                pass
    
    win_rate = wins / num_games
    avg_scs = np.mean(total_scs)
    
    print(f'\nResults:')
    print(f'  Win rate: {win_rate*100:.1f}% ({wins}/{num_games})')
    print(f'  Average SCs: {avg_scs:.1f}')
    
    return win_rate, avg_scs

# Run evaluation
win_rate, avg_scs = evaluate_vs_random(trainer, agent_power='FRANCE', num_games=20)

## 14. Save Final Model & Results

In [None]:
# Save final model
trainer.agent.save('selfplay_final_model.pt')

# Save action encoder vocabulary
trainer.action_encoder.save('action_vocab.json')

# Save training history
history_save = {
    'episode_rewards': history['episode_rewards'],
    'episode_lengths': history['episode_lengths'],
    'wins': dict(history['wins']),
    'draws': history['draws'],
    'policy_loss': history['policy_loss'],
    'value_loss': history['value_loss'],
    'entropy': history['entropy'],
    'config': CONFIG,
    'eval_vs_random': {
        'win_rate': win_rate,
        'avg_scs': avg_scs
    }
}

with open('training_history.json', 'w') as f:
    json.dump(history_save, f, indent=2)

print('Saved files:')
print('  - selfplay_final_model.pt')
print('  - action_vocab.json')
print('  - training_history.json')
print('  - selfplay_training_results.png')

## 15. Download Files

In [None]:
from google.colab import files

files.download('selfplay_final_model.pt')
files.download('action_vocab.json')
files.download('training_history.json')
files.download('selfplay_training_results.png')

print('\nFiles downloaded!')

## 16. Summary & Next Steps

### What We Built
- Complete self-play RL pipeline using official `diplomacy` package
- PPO algorithm with proper state encoding and action masking
- Reward shaping for faster learning
- Evaluation framework vs random baseline

### Key Results
- Win distribution across powers
- Learning curves (rewards, game length, losses)
- Performance vs random opponents

### Addressing RQ1: Overfitting in Pure Self-Play
To fully quantify overfitting:
1. Evaluate against **checkpoints from earlier training**
2. Evaluate against **diverse opponent policies**
3. Check for **strategy collapse** (narrow strategy distribution)

### Next Steps
1. **Human-Regularized RL (RQ2)**: Add KL divergence penalty toward BC policy
2. **Population-Based Training (RQ3)**: Train against diverse opponents
3. **Full Evaluation**: Compare self-play vs human-regularized vs population-based