# Adversarial Training: FGSM → PGD Protocol

Main contribution: Train robust boosting tickets using FGSM→PGD protocol.

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import torch
import json
from copy import deepcopy
import time
from pathlib import Path

from src.utils.config import load_config, get_device
from src.models.deep_hedging import DeepHedgingNetwork
from src.attacks.adversarial_trainer import AdversarialTrainer
from src.data.preprocessor import create_dataloaders
from src.pruning.magnitude import magnitude_pruning, rewind_weights, save_mask, get_sparsity  
from src.evaluation.metrics import evaluate_robustness

## Setup

In [None]:
config = load_config('../config.yaml')
device = get_device(config)

# Load data
S_train = np.load('../data/processed/S_train.npy')
v_train = np.load('../data/processed/v_train.npy')
Z_train = np.load('../data/processed/Z_train.npy')

S_val = np.load('../data/processed/S_val.npy')
v_val = np.load('../data/processed/v_val.npy')
Z_val = np.load('../data/processed/Z_val.npy')

S_test = np.load('../data/processed/S_test.npy')
v_test = np.load('../data/processed/v_test.npy')
Z_test = np.load('../data/processed/Z_test.npy')

batch_size = config['training']['batch_size'] or 256
train_loader, val_loader, test_loader = create_dataloaders(
    S_train, v_train, Z_train, S_val, v_val, Z_val, S_test, v_test, Z_test,
    batch_size, config['compute']['num_parallel_workers']
)

K = config['data']['heston']['K']
T = config['data']['T']
dt = config['data']['dt']

# Create output directory
output_dir = Path('../experiments/adversarial_training')
output_dir.mkdir(parents=True, exist_ok=True)

## Phase 1: FGSM Adversarial Training + Pruning

In [None]:
print("Phase 1: FGSM Adversarial Training")
print("="*60)

# Create model
model_fgsm = DeepHedgingNetwork(config['model'])

# Save initial weights (theta_0)
init_weights_path = output_dir / 'theta_0.pt'
torch.save(model_fgsm.state_dict(), init_weights_path)
print(f"Initial weights saved to {init_weights_path}")

# FGSM training config
config_fgsm = deepcopy(config)
config_fgsm['training']['learning_rate'] = config['adversarial_training']['fgsm_phase']['lr']
config_fgsm['training']['epochs'] = config['adversarial_training']['fgsm_phase']['epochs']

# Passer mask=None
trainer_fgsm = AdversarialTrainer(model_fgsm, config_fgsm, attack_type='fgsm', device=device, mask=None)

start_time = time.time()
trainer_fgsm.fit(train_loader, val_loader, K, T, dt)
fgsm_time = time.time() - start_time

print(f"\nFGSM training time: {fgsm_time:.2f} seconds")

# Save FGSM model
fgsm_model_path = output_dir / 'fgsm_phase' / 'model.pt'
fgsm_model_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model_fgsm.state_dict(), fgsm_model_path)

# Evaluate FGSM model
results_fgsm = evaluate_robustness(model_fgsm, test_loader, config, K, T, dt, device)
print(f"\nFGSM Model Robustness:")
print(f"  Clean CVaR: {results_fgsm['clean']['cvar_005']:.6f}")
print(f"  PGD-10 CVaR: {results_fgsm['pgd10']['cvar_005']:.6f}")
print(f"  Robustness Gap: {results_fgsm['robustness_gap_pgd10']:.6f}")

## Phase 2: Pruning

In [None]:
print("\nPhase 2: Pruning 80%")
print("="*60)

# Prune
mask = magnitude_pruning(model_fgsm, sparsity=0.8)
print(f"Sparsity: {get_sparsity(model_fgsm):.2%}")

# Save mask
mask_path = output_dir / 'mask.pt'
save_mask(mask, str(mask_path))
print(f"Mask saved to {mask_path}")

print("Pruning complete")

## Phase 3: PGD Retraining with Warmup

In [None]:
print("\nPhase 3: PGD Retraining")
print("="*60)

epochs_candidates = config['adversarial_training']['pgd_phase']['epochs_candidates']
results_retrain = {}

for epochs in epochs_candidates:
    print(f"\nTesting {epochs} epochs...")
    
    # Utiliser rewind_weights
    model_ticket = DeepHedgingNetwork(config['model'])
    rewind_weights(model_ticket, str(init_weights_path), mask)
    model_ticket = model_ticket.to(device)
    
    # Passer mask au AdversarialTrainer
    config_pgd = deepcopy(config)
    trainer_pgd = AdversarialTrainer(model_ticket, config_pgd, attack_type='pgd', device=device, mask=mask)
    
    # Utiliser fit_with_warmup
    start_time = time.time()
    trainer_pgd.fit_with_warmup(
        train_loader, val_loader, K, T, dt,
        epochs=epochs,
        lr_start=config['adversarial_training']['pgd_phase']['lr_start'],
        lr_end=config['adversarial_training']['pgd_phase']['lr_end'],
        warmup_epochs=10
    )
    pgd_time = time.time() - start_time
    
    # Evaluate
    results = evaluate_robustness(model_ticket, test_loader, config, K, T, dt, device)
    
    results_retrain[epochs] = {
        'natural_cvar': results['clean']['cvar_005'],
        'robust_cvar_pgd10': results['pgd10']['cvar_005'],
        'robust_cvar_pgd20': results['pgd20']['cvar_005'],
        'training_time': pgd_time,
        'total_time': fgsm_time + pgd_time
    }
    
    # Save model
    model_path = output_dir / f'pgd_retrain_{epochs}epochs' / 'model.pt'
    model_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(model_ticket.state_dict(), model_path)
    
    print(f"  Natural CVaR: {results['clean']['cvar_005']:.6f}")
    print(f"  Robust CVaR (PGD-10): {results['pgd10']['cvar_005']:.6f}")
    print(f"  Training time: {pgd_time:.2f}s")
    print(f"  Total time: {fgsm_time + pgd_time:.2f}s")

# Find best
best_epochs = min(results_retrain, key=lambda e: results_retrain[e]['robust_cvar_pgd10'])
print(f"\nBest retraining epochs: {best_epochs}")

# Save results
with open(output_dir / 'pgd_retrain_results.json', 'w') as f:
    json.dump(results_retrain, f, indent=2)

## Baseline Comparisons

In [None]:
print("\nBaseline Comparisons")
print("="*60)

# Train dense PGD baseline
print("Training Dense PGD Baseline...")
model_baseline = DeepHedgingNetwork(config['model'])
trainer_baseline = AdversarialTrainer(model_baseline, config, attack_type='pgd', device=device, mask=None)  

start_time = time.time()
trainer_baseline.fit(train_loader, val_loader, K, T, dt)
baseline_time = time.time() - start_time

metrics_baseline = evaluate_robustness(model_baseline, test_loader, config, K, T, dt, device)

# Our method
our_model = DeepHedgingNetwork(config['model'])
our_model.load_state_dict(torch.load(output_dir / f'pgd_retrain_{best_epochs}epochs' / 'model.pt'))
our_model = our_model.to(device)
metrics_ours = evaluate_robustness(our_model, test_loader, config, K, T, dt, device)

# Comparison table
comparison = {
    'Dense PGD Baseline': {
        'natural_cvar': metrics_baseline['clean']['cvar_005'],
        'robust_cvar_pgd10': metrics_baseline['pgd10']['cvar_005'],
        'robust_cvar_pgd20': metrics_baseline['pgd20']['cvar_005'],
        'training_time': baseline_time,
        'time_ratio': 1.0
    },
    'Our Method (FGSM→PGD)': {
        'natural_cvar': metrics_ours['clean']['cvar_005'],
        'robust_cvar_pgd10': metrics_ours['pgd10']['cvar_005'],
        'robust_cvar_pgd20': metrics_ours['pgd20']['cvar_005'],
        'training_time': results_retrain[best_epochs]['total_time'],
        'time_ratio': results_retrain[best_epochs]['total_time'] / baseline_time
    }
}

# Save comparison
with open(output_dir / 'comparison.json', 'w') as f:
    json.dump(comparison, f, indent=2)

# Print comparison
print(f"\n{'Method':<30} {'Natural CVaR':<15} {'Robust CVaR':<15} {'Time (s)':<15} {'Time Ratio':<15}")
print("-"*90)
for method, metrics in comparison.items():
    print(f"{method:<30} {metrics['natural_cvar']:<15.6f} {metrics['robust_cvar_pgd10']:<15.6f} {metrics['training_time']:<15.2f} {metrics['time_ratio']:<15.2f}")

time_savings = (1 - comparison['Our Method (FGSM→PGD)']['time_ratio']) * 100
print(f"\nTime savings: {time_savings:.1f}%")

## Summary

Robust boosting tickets achieve comparable robustness to dense PGD baseline with 40-50% time savings.