# 03 - Comparaison des Shields

Ce notebook compare les différentes approches de safety shields:
- **Quantile Shield**: Régression quantile pessimiste
- **Flow Matching Shield**: Estimation de densité via CFM
- **Diffusion Shield**: Estimation de densité via DDPM

## Métriques comparées
- Retour moyen
- Taux de détection OOD
- Qualité des projections
- Temps d'inférence

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import torch
import matplotlib.pyplot as plt
import time
from pathlib import Path

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')

In [None]:
from scripts.models import (
    UDRLPolicy, QuantileShield, FlowMatchingShield, DiffusionShield,
    load_data, load_model
)

# Load data
data_files = list(Path('../data').glob('lunarlander_*.npz'))
if data_files:
    data_path = str(sorted(data_files)[-1])  # Latest
    print(f'Loading: {data_path}')
    data = load_data(data_path)
else:
    print('No data found! Run collect_data_standalone.py first.')

In [None]:
# Load models
state_dim = data['state_dim']
action_dim = data['action_dim']
env_name = 'lunarlander'

policy = UDRLPolicy(state_dim, action_dim).to(device)
quantile = QuantileShield(state_dim).to(device)
flow = FlowMatchingShield(state_dim).to(device)
diffusion = DiffusionShield(state_dim).to(device)

try:
    load_model(policy, env_name, 'policy')
    load_model(quantile, env_name, 'quantile_shield')
    load_model(flow, env_name, 'flow_shield')
    load_model(diffusion, env_name, 'diffusion_shield')
    print('All models loaded!')
except FileNotFoundError as e:
    print(f'Model not found: {e}')
    print('Train models first using the training scripts.')

## 1. Comparaison des temps d'inférence

In [None]:
# Benchmark inference time
states_t = torch.tensor(data['states'][:100], device=device)
commands_t = torch.tensor(data['commands'][:100], device=device)

def benchmark(shield, name, n_runs=10):
    times_ood = []
    times_proj = []
    
    for _ in range(n_runs):
        # OOD detection
        start = time.time()
        _ = shield.is_ood(states_t, commands_t)
        times_ood.append(time.time() - start)
        
        # Projection
        start = time.time()
        _ = shield.project(states_t, commands_t)
        times_proj.append(time.time() - start)
    
    print(f'{name}:')
    print(f'  OOD detection: {np.mean(times_ood)*1000:.2f} ± {np.std(times_ood)*1000:.2f} ms')
    print(f'  Projection:    {np.mean(times_proj)*1000:.2f} ± {np.std(times_proj)*1000:.2f} ms')
    return np.mean(times_ood), np.mean(times_proj)

shields = {
    'Quantile': quantile,
    'Flow Matching': flow,
    'Diffusion': diffusion,
}

benchmarks = {}
for name, shield in shields.items():
    benchmarks[name] = benchmark(shield, name)

In [None]:
# Plot benchmarks
names = list(benchmarks.keys())
ood_times = [benchmarks[n][0] * 1000 for n in names]
proj_times = [benchmarks[n][1] * 1000 for n in names]

x = np.arange(len(names))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 5))
bars1 = ax.bar(x - width/2, ood_times, width, label='OOD Detection', color='steelblue')
bars2 = ax.bar(x + width/2, proj_times, width, label='Projection', color='coral')

ax.set_ylabel('Time (ms)')
ax.set_title('Inference Time Comparison (100 samples)')
ax.set_xticks(x)
ax.set_xticklabels(names)
ax.legend()
ax.set_yscale('log')

plt.tight_layout()
plt.show()

## 2. Visualisation des frontières OOD

In [None]:
# Create grid
commands = data['commands']
h_range = np.linspace(commands[:, 0].min() - 30, commands[:, 0].max() + 30, 50)
r_range = np.linspace(commands[:, 1].min() - 50, commands[:, 1].max() + 50, 50)
H, R = np.meshgrid(h_range, r_range)

mean_state = torch.tensor(data['states'][:1000].mean(axis=0), device=device)
mean_state = mean_state.unsqueeze(0).repeat(len(h_range) * len(r_range), 1)
grid_commands = torch.tensor(np.stack([H.flatten(), R.flatten()], axis=1), 
                            dtype=torch.float32, device=device)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, (name, shield) in zip(axes, shields.items()):
    with torch.no_grad():
        ood = shield.is_ood(mean_state, grid_commands)
    
    ood_grid = ood.cpu().numpy().reshape(H.shape)
    
    ax.contourf(H, R, ood_grid, levels=1, colors=['lightgreen', 'lightcoral'], alpha=0.5)
    ax.contour(H, R, ood_grid, levels=[0.5], colors='red', linewidths=2)
    ax.scatter(commands[:300, 0], commands[:300, 1], c='blue', s=3, alpha=0.3)
    ax.set_xlabel('Horizon')
    ax.set_ylabel('Return-to-go')
    ax.set_title(f'{name} OOD Boundary')

plt.tight_layout()
plt.savefig('../results/lunarlander/figures/shield_comparison.png', dpi=150)
plt.show()

## 3. Qualité des projections

In [None]:
# Create OOD commands
n_test = 50
test_state = torch.tensor(data['states'][:n_test], device=device)

# OOD commands (unrealistic)
ood_commands = torch.tensor([
    [30, 200],   # Too optimistic
    [20, 300],   # Very optimistic
    [10, 400],   # Extremely optimistic
] * (n_test // 3 + 1), dtype=torch.float32, device=device)[:n_test]

# Project and compare
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, (name, shield) in zip(axes, shields.items()):
    with torch.no_grad():
        projected = shield.project(test_state, ood_commands)
    
    orig = ood_commands.cpu().numpy()
    proj = projected.cpu().numpy()
    
    ax.scatter(orig[:, 0], orig[:, 1], c='red', label='Original OOD', alpha=0.7, s=30)
    ax.scatter(proj[:, 0], proj[:, 1], c='green', label='Projected', alpha=0.7, s=30)
    
    # Draw arrows
    for i in range(min(20, len(orig))):
        ax.annotate('', xy=(proj[i, 0], proj[i, 1]), xytext=(orig[i, 0], orig[i, 1]),
                   arrowprops=dict(arrowstyle='->', color='gray', alpha=0.5))
    
    ax.scatter(commands[:200, 0], commands[:200, 1], c='blue', s=3, alpha=0.2, label='Training data')
    ax.set_xlabel('Horizon')
    ax.set_ylabel('Return-to-go')
    ax.set_title(f'{name} Projection')
    ax.legend(fontsize=8)

plt.tight_layout()
plt.savefig('../results/lunarlander/figures/projection_comparison.png', dpi=150)
plt.show()

## 4. Résumé

In [None]:
print('=' * 60)
print('SHIELD COMPARISON SUMMARY')
print('=' * 60)
print()
print(f'{"Shield":<20} {"OOD Time (ms)":<15} {"Proj Time (ms)":<15}')
print('-' * 50)
for name in shields:
    ood_t, proj_t = benchmarks[name]
    print(f'{name:<20} {ood_t*1000:<15.2f} {proj_t*1000:<15.2f}')
print('=' * 60)
print()
print('Key observations:')
print('- Quantile: Fastest, simple threshold-based detection')
print('- Flow Matching: Good density estimation, moderate speed')
print('- Diffusion: Best density model, slowest inference')