# Behavioral Cloning for No-Press Diplomacy

**Project:** Improve Self-Play for Diplomacy  
**Authors:** Giacomo Colosio, Maciej Tasarz, Jakub Seliga, Luka Ivcevic  
**Course:** ISP - UPC Barcelona, Fall 2025/26

---

This notebook trains a neural network to imitate human Diplomacy players using behavioral cloning.

**Requirements:** GPU runtime (Runtime -> Change runtime type -> GPU)

## 1. Setup & GPU Check

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
import json
import os
import re
from collections import Counter, defaultdict
from typing import Dict, List, Tuple, Optional
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

## 2. Upload Data

Upload `standard_no_press.jsonl` from your local machine.

In [None]:
from google.colab import files

# Option 1: Upload file directly
print("Upload 'standard_no_press.jsonl' file:")
uploaded = files.upload()

DATA_PATH = 'standard_no_press.jsonl'
print(f"\nFile uploaded: {DATA_PATH}")

In [None]:
# Option 2: If using Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# DATA_PATH = '/content/drive/MyDrive/diplomacy/standard_no_press.jsonl'

## 3. Constants

In [None]:
POWERS = ['AUSTRIA', 'ENGLAND', 'FRANCE', 'GERMANY', 'ITALY', 'RUSSIA', 'TURKEY']

LOCATIONS = [
    # Supply centers (34)
    'ANK', 'BEL', 'BER', 'BRE', 'BUD', 'BUL', 'CON', 'DEN', 'EDI', 'GRE',
    'HOL', 'KIE', 'LON', 'LVP', 'MAR', 'MOS', 'MUN', 'NAP', 'NWY', 'PAR',
    'POR', 'ROM', 'RUM', 'SER', 'SEV', 'SMY', 'SPA', 'STP', 'SWE', 'TRI',
    'TUN', 'VEN', 'VIE', 'WAR',
    # Non-supply center land (22)
    'ALB', 'APU', 'ARM', 'BOH', 'BUR', 'CLY', 'FIN', 'GAL', 'GAS', 'LVN',
    'NAF', 'PIC', 'PIE', 'PRU', 'RUH', 'SIL', 'SYR', 'TUS', 'TYR', 'UKR',
    'WAL', 'YOR',
    # Sea zones (19)
    'ADR', 'AEG', 'BAL', 'BAR', 'BLA', 'BOT', 'EAS', 'ENG', 'GOL', 'HEL',
    'ION', 'IRI', 'MAO', 'NAO', 'NTH', 'NWG', 'SKA', 'TYS', 'WES'
]

SUPPLY_CENTERS = LOCATIONS[:34]
LOC_TO_IDX = {loc: i for i, loc in enumerate(LOCATIONS)}
POWER_TO_IDX = {p: i for i, p in enumerate(POWERS)}

print(f'Powers: {len(POWERS)}')
print(f'Locations: {len(LOCATIONS)}')
print(f'Supply Centers: {len(SUPPLY_CENTERS)}')

## 4. State Encoder

In [None]:
class StateEncoder:
    """
    Encodes Diplomacy game state into a fixed-size vector.
    
    Per location (75 locations):
        - 7 bits: which power has a unit (one-hot)
        - 1 bit: army (1) or fleet (0)
        - 7 bits: which power owns SC (one-hot)
        - 1 bit: is supply center
    Total per location: 16 features
    
    Global features: 16
    Total: 75 * 16 + 16 = 1216 features
    """
    
    def __init__(self):
        self.num_locations = len(LOCATIONS)
        self.num_powers = len(POWERS)
        self.features_per_loc = 16
        self.global_features = 16
        self.state_size = self.num_locations * self.features_per_loc + self.global_features
        
    def encode(self, state: Dict, phase_name: str = '') -> np.ndarray:
        features = np.zeros(self.state_size, dtype=np.float32)
        
        units = state.get('units', {})
        centers = state.get('centers', {})
        
        # Encode each location
        for loc_idx, loc in enumerate(LOCATIONS):
            offset = loc_idx * self.features_per_loc
            
            # Check for units
            for power_idx, power in enumerate(POWERS):
                power_units = units.get(power, [])
                for unit in power_units:
                    unit_loc = self._parse_unit_location(unit)
                    if unit_loc == loc:
                        features[offset + power_idx] = 1.0
                        features[offset + 7] = 1.0 if unit.startswith('A ') else 0.0
                        break
            
            # Check SC ownership
            if loc in SUPPLY_CENTERS:
                features[offset + 15] = 1.0
                for power_idx, power in enumerate(POWERS):
                    if loc in centers.get(power, []):
                        features[offset + 8 + power_idx] = 1.0
                        break
        
        # Global features
        global_offset = self.num_locations * self.features_per_loc
        
        for power_idx, power in enumerate(POWERS):
            features[global_offset + power_idx] = len(centers.get(power, [])) / 18.0
            features[global_offset + 7 + power_idx] = len(units.get(power, [])) / 17.0
        
        if phase_name:
            try:
                year = int(phase_name[1:5])
                features[global_offset + 14] = (year - 1901) / 20.0
            except:
                pass
            features[global_offset + 15] = {'S': 0.0, 'F': 0.5, 'W': 1.0}.get(phase_name[0], 0.0)
        
        return features
    
    def _parse_unit_location(self, unit: str) -> str:
        parts = unit.split()
        if len(parts) >= 2:
            return parts[1].split('/')[0]
        return ''

# Test
encoder = StateEncoder()
print(f'State size: {encoder.state_size}')

## 5. Action Encoder

In [None]:
class ActionEncoder:
    """
    Encodes Diplomacy orders into numerical indices.
    Builds vocabulary from training data.
    """
    
    def __init__(self):
        self.order_to_idx = {}
        self.idx_to_order = {}
        self.vocab_size = 0
        
    def build_vocab(self, games: List[Dict], max_vocab: int = 10000):
        order_counts = Counter()
        
        for game in games:
            for phase in game.get('phases', []):
                for power, orders in phase.get('orders', {}).items():
                    if orders is None:
                        continue
                    for order in orders:
                        norm = self._normalize(order)
                        if norm:
                            order_counts[norm] += 1
        
        most_common = order_counts.most_common(max_vocab - 2)
        
        self.order_to_idx = {'<PAD>': 0, '<UNK>': 1}
        self.idx_to_order = {0: '<PAD>', 1: '<UNK>'}
        
        for idx, (order, _) in enumerate(most_common, start=2):
            self.order_to_idx[order] = idx
            self.idx_to_order[idx] = order
        
        self.vocab_size = len(self.order_to_idx)
        print(f'Vocabulary size: {self.vocab_size}')
        
    def _normalize(self, order: str) -> Optional[str]:
        order = order.strip().upper()
        order = re.sub(r'/[A-Z]{2}', '', order)
        return order if len(order) >= 3 else None
    
    def encode(self, order: str) -> int:
        norm = self._normalize(order)
        return self.order_to_idx.get(norm, 1)
    
    def decode(self, idx: int) -> str:
        return self.idx_to_order.get(idx, '<UNK>')

action_encoder = ActionEncoder()
print('Action encoder ready')

## 6. Dataset

In [None]:
class DiplomacyDataset(Dataset):
    """
    PyTorch Dataset for behavioral cloning.
    Each sample: (state, power, action)
    """
    
    def __init__(self, games: List[Dict], state_encoder: StateEncoder, 
                 action_encoder: ActionEncoder):
        self.state_encoder = state_encoder
        self.action_encoder = action_encoder
        self.samples = []
        self._process(games)
        
    def _process(self, games: List[Dict]):
        for game in tqdm(games, desc='Processing games'):
            for phase in game.get('phases', []):
                phase_name = phase.get('name', '')
                state = phase.get('state', {})
                orders = phase.get('orders', {})
                
                if not phase_name.endswith('M'):
                    continue
                
                encoded_state = self.state_encoder.encode(state, phase_name)
                
                for power_idx, power in enumerate(POWERS):
                    power_orders = orders.get(power, [])
                    if power_orders is None:
                        continue
                    
                    for order in power_orders:
                        action_idx = self.action_encoder.encode(order)
                        if action_idx <= 1:
                            continue
                        
                        self.samples.append({
                            'state': encoded_state,
                            'power': power_idx,
                            'action': action_idx
                        })
        
        print(f'Total samples: {len(self.samples):,}')
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        s = self.samples[idx]
        return (
            torch.FloatTensor(s['state']),
            torch.LongTensor([s['power']]),
            torch.LongTensor([s['action']])
        )

## 7. Model

In [None]:
class BCModel(nn.Module):
    """
    MLP model for behavioral cloning.
    """
    
    def __init__(self, state_size: int, vocab_size: int, hidden_size: int = 512):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(state_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(hidden_size, hidden_size // 2),
            nn.LayerNorm(hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(hidden_size // 2, vocab_size)
        )
        
    def forward(self, x):
        return self.net(x)
    
    def predict(self, x, temperature=1.0):
        logits = self.forward(x)
        return F.softmax(logits / temperature, dim=-1)


class TransformerBCModel(nn.Module):
    """
    Transformer model for behavioral cloning.
    """
    
    def __init__(self, state_size: int, vocab_size: int, 
                 d_model: int = 256, nhead: int = 8, num_layers: int = 4):
        super().__init__()
        
        self.input_proj = nn.Linear(state_size, d_model)
        self.pos_emb = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        self.output = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model, vocab_size)
        )
        
    def forward(self, x):
        x = self.input_proj(x).unsqueeze(1)
        x = x + self.pos_emb
        x = self.transformer(x).squeeze(1)
        return self.output(x)

print('Models defined')

## 8. Load Data

In [None]:
# Configuration
MAX_GAMES = 10000  # Increase for better results
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
EPOCHS = 30
HIDDEN_SIZE = 512
MODEL_TYPE = 'mlp'  # 'mlp' or 'transformer'

print(f'Config:')
print(f'  Max games: {MAX_GAMES}')
print(f'  Batch size: {BATCH_SIZE}')
print(f'  Epochs: {EPOCHS}')
print(f'  Model: {MODEL_TYPE}')

In [None]:
# Load games
print('Loading games...')
games = []
with open(DATA_PATH, 'r') as f:
    for i, line in enumerate(f):
        if i >= MAX_GAMES:
            break
        games.append(json.loads(line))
        if (i + 1) % 2000 == 0:
            print(f'  Loaded {i + 1} games...')

print(f'Total games: {len(games)}')

In [None]:
# Build vocabulary
print('\nBuilding vocabulary...')
action_encoder.build_vocab(games)

In [None]:
# Split data
split_idx = int(0.9 * len(games))
train_games = games[:split_idx]
val_games = games[split_idx:]

print(f'Train: {len(train_games)} games')
print(f'Val: {len(val_games)} games')

In [None]:
# Create datasets
state_encoder = StateEncoder()

print('\nCreating train dataset...')
train_dataset = DiplomacyDataset(train_games, state_encoder, action_encoder)

print('\nCreating val dataset...')
val_dataset = DiplomacyDataset(val_games, state_encoder, action_encoder)

In [None]:
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')

## 9. Create Model

In [None]:
# Create model
if MODEL_TYPE == 'transformer':
    model = TransformerBCModel(state_encoder.state_size, action_encoder.vocab_size)
else:
    model = BCModel(state_encoder.state_size, action_encoder.vocab_size, HIDDEN_SIZE)

model = model.to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f'Model: {MODEL_TYPE.upper()}')
print(f'Parameters: {num_params:,}')
print(f'State size: {state_encoder.state_size}')
print(f'Vocab size: {action_encoder.vocab_size}')

In [None]:
# Optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
criterion = nn.CrossEntropyLoss()

## 10. Training Functions

In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for states, powers, actions in pbar:
        states = states.to(device)
        actions = actions.squeeze(1).to(device)
        
        optimizer.zero_grad()
        logits = model(states)
        loss = criterion(logits, actions)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        correct += (preds == actions).sum().item()
        total += actions.size(0)
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{correct/total:.4f}'})
    
    return total_loss / len(loader), correct / total


def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    top5_correct = 0
    total = 0
    
    with torch.no_grad():
        for states, powers, actions in tqdm(loader, desc='Validating'):
            states = states.to(device)
            actions = actions.squeeze(1).to(device)
            
            logits = model(states)
            loss = criterion(logits, actions)
            total_loss += loss.item()
            
            preds = logits.argmax(dim=1)
            correct += (preds == actions).sum().item()
            
            _, top5 = logits.topk(5, dim=1)
            top5_correct += (top5 == actions.unsqueeze(1)).any(dim=1).sum().item()
            
            total += actions.size(0)
    
    return total_loss / len(loader), correct / total, top5_correct / total

## 11. Train!

In [None]:
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [], 'val_top5': []
}
best_val_acc = 0

print('='*60)
print('TRAINING')
print('='*60)

for epoch in range(EPOCHS):
    print(f'\nEpoch {epoch + 1}/{EPOCHS}')
    
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_top5 = validate(model, val_loader, criterion, device)
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_top5'].append(val_top5)
    
    print(f'  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}')
    print(f'  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, Top-5: {val_top5:.4f}')
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'model_state_dict': model.state_dict(),
            'vocab': action_encoder.order_to_idx,
            'config': {'state_size': state_encoder.state_size, 'vocab_size': action_encoder.vocab_size}
        }, 'best_bc_model.pt')
        print(f'  -> Saved best model!')

print('\n' + '='*60)
print(f'Best Val Accuracy: {best_val_acc:.4f}')
print('='*60)

## 12. Results

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train', linewidth=2)
axes[0].plot(history['val_loss'], label='Val', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train', linewidth=2)
axes[1].plot(history['val_acc'], label='Val', linewidth=2)
axes[1].plot(history['val_top5'], label='Val Top-5', linewidth=2, linestyle='--')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()

## 13. Test Predictions

In [None]:
# Load best model
checkpoint = torch.load('best_bc_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print('Model loaded!')

In [None]:
# Test on a random sample
test_game = val_games[0]
test_phase = test_game['phases'][0]

print(f"Phase: {test_phase['name']}")
print(f"\nActual orders:")
for power in POWERS:
    orders = test_phase['orders'].get(power, [])
    if orders:
        print(f"  {power}: {orders}")

# Get model predictions
state = test_phase['state']
encoded = state_encoder.encode(state, test_phase['name'])
x = torch.FloatTensor(encoded).unsqueeze(0).to(device)

with torch.no_grad():
    logits = model(x)
    probs = F.softmax(logits, dim=-1)
    top5_probs, top5_idx = probs.topk(5, dim=-1)

print(f"\nModel's top 5 predictions:")
for i in range(5):
    idx = top5_idx[0, i].item()
    prob = top5_probs[0, i].item()
    order = action_encoder.decode(idx)
    print(f"  {i+1}. {order} ({prob:.2%})")

## 14. Download Model

In [None]:
# Download the trained model
from google.colab import files

files.download('best_bc_model.pt')
files.download('training_curves.png')

print('Files downloaded!')

## 15. Summary

### Results
- **Train Accuracy**: See plot above
- **Val Accuracy**: See plot above  
- **Top-5 Accuracy**: Model's correct order is in top 5 predictions

### Next Steps
1. **Self-Play**: Use BC model as starting point for RL
2. **Human-Regularized RL**: Add KL penalty to stay close to human policy
3. **Population Training**: Train diverse opponents