# Self-Play Reinforcement Learning for No-Press Diplomacy

**Project:** Improve Self-Play for Diplomacy  
**Authors:** Giacomo Colosio, Maciej Tasarz, Jakub Seliga, Luka Ivcevic  
**Course:** ISP - UPC Barcelona, Fall 2025/26

---

## Overview

This notebook implements **Self-Play Reinforcement Learning** for No-Press Diplomacy following:

1. **Silver et al. (2017)** - AlphaGo Zero: Tabula rasa self-play
2. **Bakhtin et al. (2021)** - DORA: Double Oracle RL for Diplomacy
3. **Bakhtin et al. (2022)** - Diplodocus: Human-regularized self-play

### Architecture
- **Policy Network**: Predicts action probabilities π(a|s)
- **Value Network**: Estimates state value V(s)
- **Training**: PPO (Proximal Policy Optimization) with self-play

### Research Questions Addressed
- **RQ1**: Quantify overfitting in pure self-play
- **RQ2**: Establish baseline for comparison with human-regularized RL

**Requirements:** GPU runtime recommended

## 1. Setup & Imports

In [None]:
# Install diplomacy environment
!pip install diplomacy torch numpy matplotlib tqdm --quiet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

import numpy as np
import json
import os
import random
import copy
from collections import deque, defaultdict, namedtuple
from typing import Dict, List, Tuple, Optional
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from datetime import datetime

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

## 2. Game Constants & Environment

In [None]:
# Diplomacy Constants
POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']
NUM_POWERS = len(POWERS)

# All 75 locations on standard map
LOCATIONS = [
    # Supply centers (34)
    'ANK', 'BEL', 'BER', 'BRE', 'BUD', 'BUL', 'CON', 'DEN', 'EDI', 'GRE',
    'HOL', 'KIE', 'LON', 'LVP', 'MAR', 'MOS', 'MUN', 'NAP', 'NWY', 'PAR',
    'POR', 'ROM', 'RUM', 'SER', 'SEV', 'SMY', 'SPA', 'STP', 'SWE', 'TRI',
    'TUN', 'VEN', 'VIE', 'WAR',
    # Non-supply center land (22)
    'ALB', 'APU', 'ARM', 'BOH', 'BUR', 'CLY', 'FIN', 'GAL', 'GAS', 'LVN',
    'NAF', 'PIC', 'PIE', 'PRU', 'RUH', 'SIL', 'SYR', 'TUS', 'TYR', 'UKR',
    'WAL', 'YOR',
    # Sea zones (19)
    'ADR', 'AEG', 'BAL', 'BAR', 'BLA', 'BOT', 'EAS', 'ENG', 'GOL', 'HEL',
    'ION', 'IRI', 'MAO', 'NAO', 'NTH', 'NWG', 'SKA', 'TYS', 'WES'
]
NUM_LOCATIONS = len(LOCATIONS)
SUPPLY_CENTERS = LOCATIONS[:34]
VICTORY_THRESHOLD = 18  # SCs needed to win

# Location indices
LOC_TO_IDX = {loc: i for i, loc in enumerate(LOCATIONS)}
IDX_TO_LOC = {i: loc for i, loc in enumerate(LOCATIONS)}

# Starting positions
STARTING_UNITS = {
    'AUSTRIA': ['A VIE', 'A BUD', 'F TRI'],
    'ENGLAND': ['F LON', 'F EDI', 'A LVP'],
    'FRANCE': ['F BRE', 'A PAR', 'A MAR'],
    'GERMANY': ['F KIE', 'A BER', 'A MUN'],
    'ITALY': ['F NAP', 'A ROM', 'A VEN'],
    'RUSSIA': ['F STP/SC', 'A MOS', 'A WAR', 'F SEV'],
    'TURKEY': ['F ANK', 'A CON', 'A SMY']
}

STARTING_CENTERS = {
    'AUSTRIA': ['VIE', 'BUD', 'TRI'],
    'ENGLAND': ['LON', 'EDI', 'LVP'],
    'FRANCE': ['BRE', 'PAR', 'MAR'],
    'GERMANY': ['KIE', 'BER', 'MUN'],
    'ITALY': ['NAP', 'ROM', 'VEN'],
    'RUSSIA': ['STP', 'MOS', 'WAR', 'SEV'],
    'TURKEY': ['ANK', 'CON', 'SMY']
}

print(f'Powers: {NUM_POWERS}')
print(f'Locations: {NUM_LOCATIONS}')
print(f'Supply Centers: {len(SUPPLY_CENTERS)}')
print(f'Victory requires: {VICTORY_THRESHOLD} SCs')

In [None]:
# Adjacency map (simplified - key connections)
ADJACENCIES = {
    'VIE': ['BOH', 'GAL', 'BUD', 'TRI', 'TYR'],
    'BUD': ['VIE', 'GAL', 'RUM', 'SER', 'TRI'],
    'TRI': ['VIE', 'BUD', 'SER', 'ALB', 'ADR', 'VEN', 'TYR'],
    'LON': ['WAL', 'YOR', 'NTH', 'ENG'],
    'EDI': ['YOR', 'NTH', 'NWG', 'CLY', 'LVP'],
    'LVP': ['EDI', 'CLY', 'NAO', 'IRI', 'WAL', 'YOR'],
    'BRE': ['PIC', 'PAR', 'GAS', 'MAO', 'ENG'],
    'PAR': ['PIC', 'BUR', 'GAS', 'BRE'],
    'MAR': ['SPA', 'GAS', 'BUR', 'PIE', 'GOL'],
    'KIE': ['HOL', 'RUH', 'MUN', 'BER', 'BAL', 'HEL', 'DEN'],
    'BER': ['KIE', 'MUN', 'SIL', 'PRU', 'BAL'],
    'MUN': ['KIE', 'BER', 'SIL', 'BOH', 'TYR', 'BUR', 'RUH'],
    'NAP': ['ROM', 'APU', 'ION', 'TYS'],
    'ROM': ['NAP', 'APU', 'VEN', 'TUS', 'TYS'],
    'VEN': ['ROM', 'APU', 'TRI', 'TYR', 'PIE', 'TUS', 'ADR'],
    'MOS': ['STP', 'LVN', 'WAR', 'UKR', 'SEV'],
    'WAR': ['MOS', 'LVN', 'PRU', 'SIL', 'GAL', 'UKR'],
    'SEV': ['MOS', 'UKR', 'RUM', 'BLA', 'ARM'],
    'STP': ['MOS', 'LVN', 'FIN', 'NWY', 'BAR', 'BOT'],
    'ANK': ['CON', 'SMY', 'ARM', 'BLA'],
    'CON': ['ANK', 'SMY', 'BUL', 'BLA', 'AEG'],
    'SMY': ['ANK', 'CON', 'ARM', 'SYR', 'AEG', 'EAS'],
    # Add more as needed...
}

# Fill missing adjacencies with empty lists
for loc in LOCATIONS:
    if loc not in ADJACENCIES:
        ADJACENCIES[loc] = []

print(f'Adjacencies defined for {len([k for k,v in ADJACENCIES.items() if v])} locations')

## 3. State Encoder (887-dimensional as per Bakhtin et al.)

In [None]:
class DiplomacyStateEncoder:
    """
    Encodes Diplomacy state into fixed-size vector.
    Based on Bakhtin et al. (2022) encoding scheme.
    
    Features per location (75 locations):
        - Unit presence: 7 powers × 2 unit types = 14
        - SC ownership: 7 powers + neutral = 8
        - Dislodged units: 7 powers × 2 = 14 (retreat phases)
    Subtotal: 36 features × 75 locations = 2700 (we simplify to ~1200)
    
    Global features:
        - SC counts per power: 7
        - Unit counts per power: 7
        - Phase info: 4 (year, season, phase type)
    
    We use simplified 1216-dim encoding for efficiency.
    """
    
    def __init__(self):
        self.num_locations = NUM_LOCATIONS
        self.num_powers = NUM_POWERS
        self.features_per_loc = 16  # Simplified
        self.global_features = 16
        self.state_size = self.num_locations * self.features_per_loc + self.global_features
        
    def encode(self, game_state: 'GameState', power: str) -> np.ndarray:
        """
        Encode state from perspective of given power.
        Power-specific encoding helps the model learn power-specific strategies.
        """
        features = np.zeros(self.state_size, dtype=np.float32)
        power_idx = POWERS.index(power)
        
        # Encode each location
        for loc_idx, loc in enumerate(LOCATIONS):
            offset = loc_idx * self.features_per_loc
            
            # Unit presence (relative to current power)
            for p_idx, p in enumerate(POWERS):
                # Reorder so current power is always index 0
                rel_idx = (p_idx - power_idx) % NUM_POWERS
                
                for unit in game_state.units.get(p, []):
                    unit_loc = self._get_unit_location(unit)
                    if unit_loc == loc:
                        features[offset + rel_idx] = 1.0
                        # Unit type: army=1, fleet=0
                        features[offset + 7] = 1.0 if unit.startswith('A') else 0.0
            
            # SC ownership (relative)
            if loc in SUPPLY_CENTERS:
                features[offset + 15] = 1.0  # Is SC
                for p_idx, p in enumerate(POWERS):
                    rel_idx = (p_idx - power_idx) % NUM_POWERS
                    if loc in game_state.centers.get(p, []):
                        features[offset + 8 + rel_idx] = 1.0
        
        # Global features
        g_offset = self.num_locations * self.features_per_loc
        
        # SC counts (normalized, relative ordering)
        for p_idx, p in enumerate(POWERS):
            rel_idx = (p_idx - power_idx) % NUM_POWERS
            sc_count = len(game_state.centers.get(p, []))
            features[g_offset + rel_idx] = sc_count / VICTORY_THRESHOLD
        
        # Unit counts (normalized)
        for p_idx, p in enumerate(POWERS):
            rel_idx = (p_idx - power_idx) % NUM_POWERS
            unit_count = len(game_state.units.get(p, []))
            features[g_offset + 7 + rel_idx] = unit_count / 17.0
        
        # Phase info
        features[g_offset + 14] = (game_state.year - 1901) / 20.0
        features[g_offset + 15] = {'S': 0.0, 'F': 0.5, 'W': 1.0}.get(game_state.season, 0.0)
        
        return features
    
    def _get_unit_location(self, unit: str) -> str:
        """Extract location from unit string like 'A PAR' or 'F STP/SC'."""
        parts = unit.split()
        if len(parts) >= 2:
            return parts[1].split('/')[0]
        return ''

state_encoder = DiplomacyStateEncoder()
print(f'State size: {state_encoder.state_size}')

## 4. Action Space

In [None]:
class ActionSpace:
    """
    Defines valid actions for Diplomacy.
    
    Action types:
        - HOLD: Unit stays in place
        - MOVE: Unit moves to adjacent location
        - SUPPORT: Unit supports another unit
        - CONVOY: Fleet convoys army
    
    We use a simplified action encoding:
        action = (unit_loc, order_type, target_loc)
    
    Total theoretical actions: 75 × 4 × 75 = 22,500
    But we mask invalid actions per state.
    """
    
    ORDER_TYPES = ['HOLD', 'MOVE', 'SUPPORT_HOLD', 'SUPPORT_MOVE', 'CONVOY']
    
    def __init__(self):
        self.num_order_types = len(self.ORDER_TYPES)
        # Simplified: encode as (source, target) pairs + order type
        # Action index = source * num_locations * num_types + target * num_types + order_type
        self.action_size = NUM_LOCATIONS * NUM_LOCATIONS * self.num_order_types
        
        # For practical use, we'll work with a vocabulary
        self.action_vocab = {}
        self.idx_to_action = {}
        self._build_vocab()
        
    def _build_vocab(self):
        """Build vocabulary of common actions."""
        idx = 0
        
        # Add HOLD for each location
        for loc in LOCATIONS:
            for unit_type in ['A', 'F']:
                action = f'{unit_type} {loc} H'
                self.action_vocab[action] = idx
                self.idx_to_action[idx] = action
                idx += 1
        
        # Add MOVE for each location pair
        for src in LOCATIONS:
            for dst in ADJACENCIES.get(src, []):
                for unit_type in ['A', 'F']:
                    action = f'{unit_type} {src} - {dst}'
                    self.action_vocab[action] = idx
                    self.idx_to_action[idx] = action
                    idx += 1
        
        self.vocab_size = len(self.action_vocab)
        print(f'Action vocabulary size: {self.vocab_size}')
    
    def get_valid_actions(self, game_state: 'GameState', power: str) -> List[int]:
        """
        Get valid action indices for a power in given state.
        Returns list of valid action indices.
        """
        valid = []
        units = game_state.units.get(power, [])
        
        for unit in units:
            unit_type = unit[0]  # 'A' or 'F'
            unit_loc = self._get_unit_location(unit)
            
            # HOLD is always valid
            hold_action = f'{unit_type} {unit_loc} H'
            if hold_action in self.action_vocab:
                valid.append(self.action_vocab[hold_action])
            
            # MOVE to adjacent locations
            for adj in ADJACENCIES.get(unit_loc, []):
                move_action = f'{unit_type} {unit_loc} - {adj}'
                if move_action in self.action_vocab:
                    valid.append(self.action_vocab[move_action])
        
        return valid if valid else [0]  # Return at least one action
    
    def _get_unit_location(self, unit: str) -> str:
        parts = unit.split()
        if len(parts) >= 2:
            return parts[1].split('/')[0]
        return ''
    
    def decode_action(self, idx: int) -> str:
        return self.idx_to_action.get(idx, 'UNKNOWN')
    
    def encode_action(self, action: str) -> int:
        return self.action_vocab.get(action, 0)

action_space = ActionSpace()

## 5. Game State & Simplified Environment

In [None]:
class GameState:
    """
    Represents Diplomacy game state.
    Simplified implementation for self-play training.
    """
    
    def __init__(self):
        self.units = copy.deepcopy(STARTING_UNITS)
        self.centers = copy.deepcopy(STARTING_CENTERS)
        self.year = 1901
        self.season = 'S'  # S=Spring, F=Fall, W=Winter
        self.phase = 'M'   # M=Movement, R=Retreat, A=Adjustment
        
    def clone(self) -> 'GameState':
        new_state = GameState.__new__(GameState)
        new_state.units = copy.deepcopy(self.units)
        new_state.centers = copy.deepcopy(self.centers)
        new_state.year = self.year
        new_state.season = self.season
        new_state.phase = self.phase
        return new_state
    
    def get_phase_name(self) -> str:
        return f'{self.season}{self.year}{self.phase}'
    
    def get_sc_count(self, power: str) -> int:
        return len(self.centers.get(power, []))
    
    def get_winner(self) -> Optional[str]:
        """Return winner if someone has 18+ SCs, else None."""
        for power in POWERS:
            if self.get_sc_count(power) >= VICTORY_THRESHOLD:
                return power
        return None
    
    def is_eliminated(self, power: str) -> bool:
        return len(self.units.get(power, [])) == 0
    
    def get_alive_powers(self) -> List[str]:
        return [p for p in POWERS if not self.is_eliminated(p)]

In [None]:
class DiplomacyEnv:
    """
    Simplified Diplomacy environment for self-play.
    
    Key simplifications:
    - Simultaneous move resolution (simplified)
    - No retreat phases (units just disappear on conflict)
    - No build phases (unit count = SC count, capped)
    
    This allows faster training iterations.
    For full rules, use the 'diplomacy' package.
    """
    
    MAX_YEARS = 20  # Game ends after 1920
    
    def __init__(self):
        self.state = None
        self.action_space = action_space
        self.state_encoder = state_encoder
        
    def reset(self) -> GameState:
        """Reset to initial game state."""
        self.state = GameState()
        return self.state
    
    def step(self, actions: Dict[str, List[str]]) -> Tuple[GameState, Dict[str, float], bool]:
        """
        Execute one game step with all powers' actions.
        
        Args:
            actions: Dict mapping power -> list of order strings
            
        Returns:
            (new_state, rewards, done)
        """
        # Resolve moves (simplified)
        self._resolve_moves(actions)
        
        # Update phase
        self._advance_phase()
        
        # Calculate rewards
        rewards = self._calculate_rewards()
        
        # Check if game is done
        done = self._is_game_over()
        
        return self.state, rewards, done
    
    def _resolve_moves(self, actions: Dict[str, List[str]]):
        """
        Simplified move resolution.
        
        Rules:
        - Moves to empty spaces succeed
        - Conflicting moves: both bounce back
        - Supported moves beat unsupported
        """
        # Track destinations
        destinations = defaultdict(list)  # loc -> [(power, unit, action)]
        holds = {}  # loc -> (power, unit)
        
        # Parse all orders
        for power, orders in actions.items():
            for order in orders:
                if ' H' in order:  # HOLD
                    loc = self._parse_location(order)
                    holds[loc] = (power, order.split()[0] + ' ' + loc)
                elif ' - ' in order:  # MOVE
                    parts = order.split(' - ')
                    if len(parts) == 2:
                        src = self._parse_location(parts[0])
                        dst = parts[1].split()[0].split('/')[0]
                        unit = order.split()[0] + ' ' + src
                        destinations[dst].append((power, unit, src))
        
        # Resolve conflicts
        new_units = {p: [] for p in POWERS}
        moved_from = set()
        
        for dst, movers in destinations.items():
            if len(movers) == 1:
                # Successful move
                power, unit, src = movers[0]
                unit_type = unit[0]
                new_units[power].append(f'{unit_type} {dst}')
                moved_from.add(src)
                
                # Capture SC if applicable
                if dst in SUPPLY_CENTERS:
                    # Remove from other powers
                    for p in POWERS:
                        if dst in self.state.centers.get(p, []):
                            self.state.centers[p].remove(dst)
                    # Add to moving power
                    if dst not in self.state.centers.get(power, []):
                        self.state.centers[power].append(dst)
            else:
                # Bounce - units stay in place
                for power, unit, src in movers:
                    new_units[power].append(unit)
        
        # Keep holding units
        for loc, (power, unit) in holds.items():
            if loc not in moved_from:  # Wasn't dislodged
                new_units[power].append(unit)
        
        # Update state
        self.state.units = new_units
    
    def _parse_location(self, s: str) -> str:
        """Extract location from string."""
        parts = s.split()
        for p in parts:
            loc = p.split('/')[0]
            if loc in LOCATIONS:
                return loc
        return ''
    
    def _advance_phase(self):
        """Advance to next phase."""
        if self.state.season == 'S':
            self.state.season = 'F'
        elif self.state.season == 'F':
            self.state.season = 'W'
        else:  # Winter
            self.state.season = 'S'
            self.state.year += 1
    
    def _calculate_rewards(self) -> Dict[str, float]:
        """
        Calculate rewards for each power.
        
        Reward shaping (from literature):
        - +1.0 for winning (18+ SCs)
        - -1.0 for elimination
        - +0.01 per SC owned (intermediate reward)
        - +0.1 for gaining an SC
        - -0.1 for losing an SC
        """
        rewards = {}
        winner = self.state.get_winner()
        
        for power in POWERS:
            if winner == power:
                rewards[power] = 1.0
            elif winner is not None:
                rewards[power] = -1.0 / (NUM_POWERS - 1)
            elif self.state.is_eliminated(power):
                rewards[power] = -1.0
            else:
                # Intermediate reward based on SC count
                sc_count = self.state.get_sc_count(power)
                rewards[power] = 0.01 * sc_count
        
        return rewards
    
    def _is_game_over(self) -> bool:
        """Check if game has ended."""
        # Someone won
        if self.state.get_winner() is not None:
            return True
        
        # Time limit
        if self.state.year > 1901 + self.MAX_YEARS:
            return True
        
        # Only one power left
        alive = self.state.get_alive_powers()
        if len(alive) <= 1:
            return True
        
        return False
    
    def get_observation(self, power: str) -> np.ndarray:
        """Get encoded state from power's perspective."""
        return self.state_encoder.encode(self.state, power)

# Test environment
env = DiplomacyEnv()
state = env.reset()
print(f'Initial state: {state.get_phase_name()}')
print(f'Units per power: {[(p, len(u)) for p, u in state.units.items()]}')

## 6. Actor-Critic Network (PPO)

In [None]:
class ActorCritic(nn.Module):
    """
    Actor-Critic network for PPO.
    
    Architecture follows AlphaGo Zero / Diplodocus:
    - Shared feature extractor
    - Policy head (actor): outputs action probabilities
    - Value head (critic): outputs state value
    """
    
    def __init__(self, state_size: int, action_size: int, hidden_size: int = 512):
        super().__init__()
        
        self.state_size = state_size
        self.action_size = action_size
        
        # Shared feature extractor
        self.shared = nn.Sequential(
            nn.Linear(state_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
        )
        
        # Policy head (actor)
        self.policy = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, action_size)
        )
        
        # Value head (critic)
        self.value = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 1)
        )
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass.
        
        Returns:
            (policy_logits, value)
        """
        features = self.shared(x)
        policy_logits = self.policy(features)
        value = self.value(features)
        return policy_logits, value
    
    def get_action(self, state: torch.Tensor, valid_actions: List[int] = None,
                   deterministic: bool = False) -> Tuple[int, torch.Tensor, torch.Tensor]:
        """
        Sample action from policy.
        
        Args:
            state: Encoded state tensor
            valid_actions: List of valid action indices (for masking)
            deterministic: If True, return argmax action
            
        Returns:
            (action_idx, log_prob, value)
        """
        policy_logits, value = self.forward(state)
        
        # Mask invalid actions
        if valid_actions is not None and len(valid_actions) > 0:
            mask = torch.ones_like(policy_logits) * float('-inf')
            mask[0, valid_actions] = 0
            policy_logits = policy_logits + mask
        
        probs = F.softmax(policy_logits, dim=-1)
        dist = Categorical(probs)
        
        if deterministic:
            action = probs.argmax(dim=-1)
        else:
            action = dist.sample()
        
        log_prob = dist.log_prob(action)
        
        return action.item(), log_prob, value.squeeze(-1)
    
    def evaluate_action(self, state: torch.Tensor, action: torch.Tensor,
                        valid_actions_batch: List[List[int]] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Evaluate actions for PPO update.
        
        Returns:
            (log_probs, values, entropy)
        """
        policy_logits, value = self.forward(state)
        
        # Apply masking if provided
        if valid_actions_batch is not None:
            for i, valid_actions in enumerate(valid_actions_batch):
                if valid_actions:
                    mask = torch.ones(policy_logits.size(-1), device=policy_logits.device) * float('-inf')
                    mask[valid_actions] = 0
                    policy_logits[i] = policy_logits[i] + mask
        
        probs = F.softmax(policy_logits, dim=-1)
        dist = Categorical(probs)
        
        log_probs = dist.log_prob(action)
        entropy = dist.entropy()
        
        return log_probs, value.squeeze(-1), entropy

# Test model
model = ActorCritic(state_encoder.state_size, action_space.vocab_size).to(device)
print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')

## 7. PPO Agent

In [None]:
# Experience tuple
Experience = namedtuple('Experience', 
    ['state', 'action', 'reward', 'next_state', 'done', 'log_prob', 'value', 'valid_actions'])


class PPOAgent:
    """
    PPO Agent for Diplomacy self-play.
    
    Implements Proximal Policy Optimization (Schulman et al., 2017)
    with clipped objective for stable training.
    """
    
    def __init__(self, state_size: int, action_size: int,
                 lr: float = 3e-4,
                 gamma: float = 0.99,
                 gae_lambda: float = 0.95,
                 clip_epsilon: float = 0.2,
                 value_coef: float = 0.5,
                 entropy_coef: float = 0.01,
                 max_grad_norm: float = 0.5):
        
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.max_grad_norm = max_grad_norm
        
        self.network = ActorCritic(state_size, action_size).to(device)
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr)
        
        # Experience buffer
        self.buffer = []
        
    def select_action(self, state: np.ndarray, valid_actions: List[int] = None,
                      deterministic: bool = False) -> Tuple[int, float, float]:
        """
        Select action given state.
        
        Returns:
            (action, log_prob, value)
        """
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        
        with torch.no_grad():
            action, log_prob, value = self.network.get_action(
                state_tensor, valid_actions, deterministic
            )
        
        return action, log_prob.item(), value.item()
    
    def store_experience(self, exp: Experience):
        """Store experience in buffer."""
        self.buffer.append(exp)
    
    def compute_gae(self, rewards: List[float], values: List[float], 
                    dones: List[bool], next_value: float) -> Tuple[List[float], List[float]]:
        """
        Compute Generalized Advantage Estimation.
        
        Returns:
            (advantages, returns)
        """
        advantages = []
        returns = []
        gae = 0
        
        # Process in reverse order
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_val = next_value
                next_done = 1.0
            else:
                next_val = values[t + 1]
                next_done = 1.0 - float(dones[t + 1])
            
            delta = rewards[t] + self.gamma * next_val * next_done - values[t]
            gae = delta + self.gamma * self.gae_lambda * next_done * gae
            
            advantages.insert(0, gae)
            returns.insert(0, gae + values[t])
        
        return advantages, returns
    
    def update(self, num_epochs: int = 4, batch_size: int = 64) -> Dict[str, float]:
        """
        Update policy using PPO.
        
        Returns:
            Dict with loss metrics
        """
        if len(self.buffer) < batch_size:
            return {}
        
        # Extract data from buffer
        states = np.array([e.state for e in self.buffer])
        actions = np.array([e.action for e in self.buffer])
        rewards = [e.reward for e in self.buffer]
        dones = [e.done for e in self.buffer]
        old_log_probs = np.array([e.log_prob for e in self.buffer])
        values = [e.value for e in self.buffer]
        valid_actions_list = [e.valid_actions for e in self.buffer]
        
        # Compute advantages
        advantages, returns = self.compute_gae(rewards, values, dones, 0.0)
        
        # Convert to tensors
        states_t = torch.FloatTensor(states).to(device)
        actions_t = torch.LongTensor(actions).to(device)
        old_log_probs_t = torch.FloatTensor(old_log_probs).to(device)
        advantages_t = torch.FloatTensor(advantages).to(device)
        returns_t = torch.FloatTensor(returns).to(device)
        
        # Normalize advantages
        advantages_t = (advantages_t - advantages_t.mean()) / (advantages_t.std() + 1e-8)
        
        # PPO update
        total_loss = 0
        policy_loss_sum = 0
        value_loss_sum = 0
        entropy_sum = 0
        
        dataset_size = len(self.buffer)
        indices = np.arange(dataset_size)
        
        for _ in range(num_epochs):
            np.random.shuffle(indices)
            
            for start in range(0, dataset_size, batch_size):
                end = min(start + batch_size, dataset_size)
                batch_indices = indices[start:end]
                
                batch_states = states_t[batch_indices]
                batch_actions = actions_t[batch_indices]
                batch_old_log_probs = old_log_probs_t[batch_indices]
                batch_advantages = advantages_t[batch_indices]
                batch_returns = returns_t[batch_indices]
                batch_valid_actions = [valid_actions_list[i] for i in batch_indices]
                
                # Evaluate actions
                new_log_probs, new_values, entropy = self.network.evaluate_action(
                    batch_states, batch_actions, batch_valid_actions
                )
                
                # Policy loss (clipped PPO objective)
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                
                # Value loss
                value_loss = F.mse_loss(new_values, batch_returns)
                
                # Entropy bonus (encourages exploration)
                entropy_loss = -entropy.mean()
                
                # Total loss
                loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
                
                # Update
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
                self.optimizer.step()
                
                total_loss += loss.item()
                policy_loss_sum += policy_loss.item()
                value_loss_sum += value_loss.item()
                entropy_sum += (-entropy_loss.item())
        
        # Clear buffer
        self.buffer = []
        
        num_batches = (dataset_size // batch_size) * num_epochs
        return {
            'total_loss': total_loss / max(num_batches, 1),
            'policy_loss': policy_loss_sum / max(num_batches, 1),
            'value_loss': value_loss_sum / max(num_batches, 1),
            'entropy': entropy_sum / max(num_batches, 1)
        }
    
    def save(self, path: str):
        """Save model checkpoint."""
        torch.save({
            'network_state_dict': self.network.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)
    
    def load(self, path: str):
        """Load model checkpoint."""
        checkpoint = torch.load(path, map_location=device)
        self.network.load_state_dict(checkpoint['network_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

print('PPO Agent defined')

## 8. Self-Play Training Loop

In [None]:
class SelfPlayTrainer:
    """
    Self-Play trainer following AlphaGo Zero paradigm.
    
    Key features:
    - All 7 powers controlled by the same network
    - Experience collected from all powers' perspectives
    - Periodic evaluation against fixed checkpoints
    """
    
    def __init__(self, config: Dict):
        self.config = config
        
        # Environment
        self.env = DiplomacyEnv()
        
        # Agent
        self.agent = PPOAgent(
            state_size=state_encoder.state_size,
            action_size=action_space.vocab_size,
            lr=config.get('lr', 3e-4),
            gamma=config.get('gamma', 0.99),
            gae_lambda=config.get('gae_lambda', 0.95),
            clip_epsilon=config.get('clip_epsilon', 0.2),
            entropy_coef=config.get('entropy_coef', 0.01)
        )
        
        # Training history
        self.history = {
            'episode_rewards': [],
            'episode_lengths': [],
            'wins': defaultdict(int),
            'losses': [],
            'entropies': []
        }
        
        # Checkpoints for evaluation
        self.checkpoints = []
        
    def play_episode(self) -> Dict:
        """
        Play one full game with self-play.
        All powers use the same policy network.
        
        Returns:
            Episode statistics
        """
        state = self.env.reset()
        
        episode_rewards = {p: 0.0 for p in POWERS}
        episode_length = 0
        
        # Store experiences per power
        power_experiences = {p: [] for p in POWERS}
        
        done = False
        max_steps = self.config.get('max_steps_per_episode', 100)
        
        while not done and episode_length < max_steps:
            # Collect actions from all powers
            actions = {}
            power_data = {}  # Store state, action, log_prob, value per power
            
            for power in POWERS:
                if state.is_eliminated(power):
                    actions[power] = []
                    continue
                
                # Get observation from this power's perspective
                obs = self.env.get_observation(power)
                
                # Get valid actions
                valid_actions = action_space.get_valid_actions(state, power)
                
                # Select action for each unit
                power_orders = []
                for unit in state.units.get(power, []):
                    action_idx, log_prob, value = self.agent.select_action(obs, valid_actions)
                    action_str = action_space.decode_action(action_idx)
                    power_orders.append(action_str)
                    
                    power_data[power] = {
                        'obs': obs.copy(),
                        'action': action_idx,
                        'log_prob': log_prob,
                        'value': value,
                        'valid_actions': valid_actions
                    }
                
                actions[power] = power_orders
            
            # Execute step
            next_state, rewards, done = self.env.step(actions)
            
            # Store experiences
            for power in POWERS:
                if power in power_data:
                    data = power_data[power]
                    next_obs = self.env.get_observation(power)
                    
                    exp = Experience(
                        state=data['obs'],
                        action=data['action'],
                        reward=rewards.get(power, 0.0),
                        next_state=next_obs,
                        done=done,
                        log_prob=data['log_prob'],
                        value=data['value'],
                        valid_actions=data['valid_actions']
                    )
                    self.agent.store_experience(exp)
                    episode_rewards[power] += rewards.get(power, 0.0)
            
            state = next_state
            episode_length += 1
        
        # Determine winner
        winner = state.get_winner()
        
        return {
            'rewards': episode_rewards,
            'length': episode_length,
            'winner': winner,
            'final_scs': {p: state.get_sc_count(p) for p in POWERS}
        }
    
    def train(self, num_episodes: int, update_freq: int = 10,
              eval_freq: int = 50, save_freq: int = 100):
        """
        Main training loop.
        
        Args:
            num_episodes: Total episodes to train
            update_freq: Episodes between PPO updates
            eval_freq: Episodes between evaluations
            save_freq: Episodes between checkpoint saves
        """
        print('='*60)
        print('SELF-PLAY TRAINING')
        print('='*60)
        print(f'Episodes: {num_episodes}')
        print(f'Update frequency: {update_freq}')
        print(f'Device: {device}')
        print('='*60)
        
        pbar = tqdm(range(num_episodes), desc='Training')
        
        for episode in pbar:
            # Play episode
            stats = self.play_episode()
            
            # Record stats
            total_reward = sum(stats['rewards'].values())
            self.history['episode_rewards'].append(total_reward)
            self.history['episode_lengths'].append(stats['length'])
            
            if stats['winner']:
                self.history['wins'][stats['winner']] += 1
            
            # Update policy
            if (episode + 1) % update_freq == 0:
                loss_info = self.agent.update(
                    num_epochs=self.config.get('ppo_epochs', 4),
                    batch_size=self.config.get('batch_size', 64)
                )
                if loss_info:
                    self.history['losses'].append(loss_info['total_loss'])
                    self.history['entropies'].append(loss_info['entropy'])
            
            # Update progress bar
            avg_reward = np.mean(self.history['episode_rewards'][-100:])
            avg_length = np.mean(self.history['episode_lengths'][-100:])
            pbar.set_postfix({
                'reward': f'{avg_reward:.2f}',
                'length': f'{avg_length:.1f}',
                'wins': sum(self.history['wins'].values())
            })
            
            # Save checkpoint
            if (episode + 1) % save_freq == 0:
                self.save_checkpoint(episode + 1)
            
            # Evaluation
            if (episode + 1) % eval_freq == 0:
                self.evaluate()
        
        print('\nTraining complete!')
        return self.history
    
    def evaluate(self, num_games: int = 10):
        """Evaluate current policy."""
        wins = defaultdict(int)
        
        for _ in range(num_games):
            stats = self.play_episode()
            if stats['winner']:
                wins[stats['winner']] += 1
            # Clear buffer (don't train on eval games)
            self.agent.buffer = []
        
        print(f"\nEvaluation ({num_games} games):")
        for power in POWERS:
            print(f"  {power}: {wins[power]} wins ({wins[power]/num_games*100:.1f}%)")
    
    def save_checkpoint(self, episode: int):
        """Save training checkpoint."""
        path = f'checkpoint_ep{episode}.pt'
        self.agent.save(path)
        self.checkpoints.append(path)
        print(f'\nSaved checkpoint: {path}')

print('SelfPlayTrainer defined')

## 9. Training Configuration

In [None]:
# Training configuration
CONFIG = {
    # Training
    'num_episodes': 1000,
    'max_steps_per_episode': 60,  # ~20 years × 3 phases
    'update_freq': 10,
    'eval_freq': 100,
    'save_freq': 200,
    
    # PPO
    'lr': 3e-4,
    'gamma': 0.99,
    'gae_lambda': 0.95,
    'clip_epsilon': 0.2,
    'entropy_coef': 0.01,
    'ppo_epochs': 4,
    'batch_size': 64,
}

print('Configuration:')
for k, v in CONFIG.items():
    print(f'  {k}: {v}')

## 10. Run Training!

In [None]:
# Create trainer
trainer = SelfPlayTrainer(CONFIG)

# Train!
history = trainer.train(
    num_episodes=CONFIG['num_episodes'],
    update_freq=CONFIG['update_freq'],
    eval_freq=CONFIG['eval_freq'],
    save_freq=CONFIG['save_freq']
)

## 11. Training Results

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Episode rewards
ax = axes[0, 0]
rewards = history['episode_rewards']
ax.plot(rewards, alpha=0.3, color='blue')
# Moving average
window = 50
if len(rewards) >= window:
    ma = np.convolve(rewards, np.ones(window)/window, mode='valid')
    ax.plot(range(window-1, len(rewards)), ma, color='red', linewidth=2, label=f'MA-{window}')
ax.set_xlabel('Episode')
ax.set_ylabel('Total Reward')
ax.set_title('Episode Rewards')
ax.legend()
ax.grid(True, alpha=0.3)

# Episode lengths
ax = axes[0, 1]
lengths = history['episode_lengths']
ax.plot(lengths, alpha=0.3, color='green')
if len(lengths) >= window:
    ma = np.convolve(lengths, np.ones(window)/window, mode='valid')
    ax.plot(range(window-1, len(lengths)), ma, color='red', linewidth=2)
ax.set_xlabel('Episode')
ax.set_ylabel('Steps')
ax.set_title('Episode Length')
ax.grid(True, alpha=0.3)

# Win distribution
ax = axes[1, 0]
wins = history['wins']
powers = POWERS + ['Draw']
win_counts = [wins.get(p, 0) for p in POWERS] + [CONFIG['num_episodes'] - sum(wins.values())]
colors = plt.cm.tab10(range(len(powers)))
ax.bar(powers, win_counts, color=colors)
ax.set_xlabel('Power')
ax.set_ylabel('Wins')
ax.set_title('Win Distribution')
ax.tick_params(axis='x', rotation=45)

# Loss curve
ax = axes[1, 1]
if history['losses']:
    ax.plot(history['losses'], label='Loss', color='purple')
    ax.set_xlabel('Update')
    ax.set_ylabel('Loss')
    ax.set_title('Training Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('self_play_training.png', dpi=150)
plt.show()

In [None]:
# Print final statistics
print('='*60)
print('TRAINING SUMMARY')
print('='*60)
print(f"\nTotal episodes: {CONFIG['num_episodes']}")
print(f"Average reward (last 100): {np.mean(history['episode_rewards'][-100:]):.3f}")
print(f"Average length (last 100): {np.mean(history['episode_lengths'][-100:]):.1f}")

print(f"\nWin distribution:")
total_wins = sum(history['wins'].values())
for power in POWERS:
    wins = history['wins'].get(power, 0)
    pct = wins / CONFIG['num_episodes'] * 100
    print(f"  {power}: {wins} ({pct:.1f}%)")

draws = CONFIG['num_episodes'] - total_wins
print(f"  Draws: {draws} ({draws/CONFIG['num_episodes']*100:.1f}%)")
print('='*60)

## 12. Save Final Model

In [None]:
# Save final model
trainer.agent.save('self_play_final.pt')

# Save training history
import json
with open('training_history.json', 'w') as f:
    # Convert defaultdict to regular dict for JSON
    history_save = {
        'episode_rewards': history['episode_rewards'],
        'episode_lengths': history['episode_lengths'],
        'wins': dict(history['wins']),
        'losses': history['losses'],
        'config': CONFIG
    }
    json.dump(history_save, f)

print('Saved:')
print('  - self_play_final.pt')
print('  - training_history.json')
print('  - self_play_training.png')

## 13. Download Files

In [None]:
from google.colab import files

files.download('self_play_final.pt')
files.download('training_history.json')
files.download('self_play_training.png')

print('Files downloaded!')

## 14. Analysis: Measuring Overfitting (RQ1)

To quantify overfitting in pure self-play, we need to:
1. Evaluate against **fixed checkpoints** from earlier in training
2. Evaluate against **random policy**
3. Compare win rates

Signs of overfitting:
- High win rate vs recent checkpoints
- Low win rate vs old checkpoints (strategy cycling)
- Narrow strategy distribution

In [None]:
def evaluate_vs_random(agent, num_games: int = 20):
    """
    Evaluate trained agent vs random policy.
    Agent plays as one power, others play randomly.
    """
    env = DiplomacyEnv()
    wins = 0
    agent_power = 'FRANCE'  # Agent plays as France
    
    for _ in tqdm(range(num_games), desc='Evaluating vs Random'):
        state = env.reset()
        done = False
        steps = 0
        
        while not done and steps < 60:
            actions = {}
            
            for power in POWERS:
                if state.is_eliminated(power):
                    actions[power] = []
                    continue
                
                valid = action_space.get_valid_actions(state, power)
                
                if power == agent_power:
                    # Agent action
                    obs = env.get_observation(power)
                    action_idx, _, _ = agent.select_action(obs, valid, deterministic=True)
                    actions[power] = [action_space.decode_action(action_idx)]
                else:
                    # Random action
                    if valid:
                        action_idx = random.choice(valid)
                        actions[power] = [action_space.decode_action(action_idx)]
                    else:
                        actions[power] = []
            
            state, _, done = env.step(actions)
            steps += 1
        
        winner = state.get_winner()
        if winner == agent_power:
            wins += 1
    
    win_rate = wins / num_games
    print(f"\nWin rate vs Random: {win_rate*100:.1f}% ({wins}/{num_games})")
    return win_rate

# Evaluate
win_rate = evaluate_vs_random(trainer.agent, num_games=20)

## 15. Summary & Next Steps

### What We Built
- Full self-play RL pipeline for No-Press Diplomacy
- PPO-based policy optimization
- Actor-Critic network with action masking
- Experience collection from all 7 powers' perspectives

### Results
- See training curves above
- Win distribution shows which powers the agent favors
- Comparison vs random baseline measures basic competence

### Limitations of Pure Self-Play (RQ1)
As predicted by literature:
- May overfit to self's strategies
- Limited strategy diversity
- Performance may degrade vs novel opponents

### Next Steps
1. **Human-Regularized RL**: Add KL penalty toward human policy (BC model)
2. **Population-Based Training**: Maintain diverse opponent pool
3. **Full Game Rules**: Integrate with `diplomacy` package for accurate simulation