# 01 - Quick Start: FlowShield-UDRL

Ce notebook montre comment utiliser FlowShield-UDRL de manière interactive.

## Contenu
1. Collecte de données
2. Entraînement des modèles
3. Évaluation et visualisation

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import torch
import matplotlib.pyplot as plt
import gymnasium as gym
from tqdm.notebook import tqdm

# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## 1. Collecte de données

Collectons des trajectoires avec une politique aléatoire.

In [None]:
def collect_trajectories(env, n_episodes=100):
    """Collect trajectories with random policy."""
    trajectories = []
    episode_returns = []
    
    for _ in tqdm(range(n_episodes), desc='Collecting'):
        obs, _ = env.reset()
        trajectory = []
        episode_return = 0
        done = False
        
        while not done:
            action = env.action_space.sample()
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            trajectory.append({
                'state': obs, 'action': action, 'reward': reward,
                'next_state': next_obs, 'done': done
            })
            episode_return += reward
            obs = next_obs
        
        # Compute hindsight commands
        for i, trans in enumerate(trajectory):
            horizon = len(trajectory) - i
            return_to_go = sum(t['reward'] for t in trajectory[i:])
            trans['command'] = np.array([horizon, return_to_go], dtype=np.float32)
        
        trajectories.extend(trajectory)
        episode_returns.append(episode_return)
    
    return trajectories, episode_returns

In [None]:
# Create environment
env = gym.make('LunarLander-v3', continuous=True)
print(f'State dim: {env.observation_space.shape[0]}')
print(f'Action dim: {env.action_space.shape[0]}')

# Collect data
trajectories, returns = collect_trajectories(env, n_episodes=200)
print(f'\nCollected {len(trajectories)} transitions')
print(f'Mean return: {np.mean(returns):.2f} ± {np.std(returns):.2f}')

In [None]:
# Prepare data
states = np.array([t['state'] for t in trajectories], dtype=np.float32)
actions = np.array([t['action'] for t in trajectories], dtype=np.float32)
commands = np.array([t['command'] for t in trajectories], dtype=np.float32)

# Visualize command distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(commands[:, 0], bins=50, color='blue', alpha=0.7)
axes[0].set_xlabel('Horizon')
axes[0].set_ylabel('Count')
axes[0].set_title('Horizon Distribution')

axes[1].hist(commands[:, 1], bins=50, color='green', alpha=0.7)
axes[1].set_xlabel('Return-to-go')
axes[1].set_ylabel('Count')
axes[1].set_title('Return-to-go Distribution')

plt.tight_layout()
plt.show()

## 2. Entraînement des modèles

### 2.1 UDRL Policy

In [None]:
from scripts.models import UDRLPolicy, FlowMatchingShield, QuantileShield

# Create policy
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

policy = UDRLPolicy(state_dim, action_dim, hidden_dim=128).to(device)
print(f'Policy parameters: {sum(p.numel() for p in policy.parameters()):,}')

In [None]:
from torch.optim import Adam

# Prepare tensors
states_t = torch.tensor(states, device=device)
actions_t = torch.tensor(actions, device=device)
commands_t = torch.tensor(commands, device=device)

# Training
optimizer = Adam(policy.parameters(), lr=1e-3)
batch_size = 256
n_epochs = 50
losses = []

for epoch in tqdm(range(n_epochs), desc='Training Policy'):
    indices = np.random.permutation(len(states))
    epoch_loss = 0
    n_batches = 0
    
    for i in range(0, len(states), batch_size):
        idx = indices[i:i+batch_size]
        
        optimizer.zero_grad()
        log_prob = policy.log_prob(states_t[idx], commands_t[idx], actions_t[idx])
        loss = -log_prob.mean()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        n_batches += 1
    
    losses.append(epoch_loss / n_batches)

# Plot
plt.figure(figsize=(8, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('UDRL Policy Training')
plt.grid(True, alpha=0.3)
plt.show()

### 2.2 Flow Matching Shield

In [None]:
# Create Flow Shield
flow_shield = FlowMatchingShield(state_dim, hidden_dim=128).to(device)

# Training
optimizer = Adam(flow_shield.parameters(), lr=1e-3)
flow_losses = []

for epoch in tqdm(range(n_epochs), desc='Training Flow Shield'):
    indices = np.random.permutation(len(states))
    epoch_loss = 0
    n_batches = 0
    
    for i in range(0, len(states), batch_size):
        idx = indices[i:i+batch_size]
        
        optimizer.zero_grad()
        loss = flow_shield.loss(states_t[idx], commands_t[idx])
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        n_batches += 1
    
    flow_losses.append(epoch_loss / n_batches)

plt.figure(figsize=(8, 4))
plt.plot(flow_losses)
plt.xlabel('Epoch')
plt.ylabel('CFM Loss')
plt.title('Flow Matching Shield Training')
plt.grid(True, alpha=0.3)
plt.show()

## 3. Visualisation OOD

In [None]:
# Visualize OOD detection
h_range = np.linspace(commands[:, 0].min() - 50, commands[:, 0].max() + 50, 40)
r_range = np.linspace(commands[:, 1].min() - 100, commands[:, 1].max() + 100, 40)
H, R = np.meshgrid(h_range, r_range)

# Use mean state
mean_state = states_t.mean(dim=0, keepdim=True).repeat(len(h_range) * len(r_range), 1)
grid_commands = torch.tensor(np.stack([H.flatten(), R.flatten()], axis=1), 
                             dtype=torch.float32, device=device)

# Compute OOD
with torch.no_grad():
    ood = flow_shield.is_ood(mean_state, grid_commands)

ood_grid = ood.cpu().numpy().reshape(H.shape)

plt.figure(figsize=(10, 8))
plt.contourf(H, R, ood_grid, levels=1, colors=['lightgreen', 'lightcoral'], alpha=0.5)
plt.contour(H, R, ood_grid, levels=[0.5], colors='red', linewidths=2)
plt.scatter(commands[:500, 0], commands[:500, 1], c='blue', s=5, alpha=0.3, label='Training data')
plt.xlabel('Horizon (H)')
plt.ylabel('Return-to-go (R)')
plt.title('Flow Shield OOD Detection')
plt.colorbar(label='OOD')
plt.legend()
plt.show()

## 4. Test du Shield

In [None]:
# Test with and without shield
def run_episode(policy, env, command, shield=None, device='cpu'):
    """Run a single episode."""
    state, _ = env.reset()
    episode_return = 0
    target_h, target_r = command
    steps = 0
    projections = 0
    
    while steps < 500:
        state_t = torch.tensor([state], dtype=torch.float32, device=device)
        cmd_t = torch.tensor([[target_h - steps, target_r - episode_return]], 
                            dtype=torch.float32, device=device)
        
        if shield is not None:
            if shield.is_ood(state_t, cmd_t).any():
                cmd_t = shield.project(state_t, cmd_t)
                projections += 1
        
        with torch.no_grad():
            action = policy.sample(state_t, cmd_t, deterministic=True)
        action = np.clip(action.cpu().numpy()[0], -1, 1)
        
        state, reward, term, trunc, _ = env.step(action)
        episode_return += reward
        steps += 1
        
        if term or trunc:
            break
    
    return episode_return, steps, projections

# Test commands (some OOD)
test_commands = [
    (100, -200),  # Realistic
    (50, 100),    # Ambitious
    (30, 250),    # Very OOD
]

print('Testing without shield:')
for cmd in test_commands:
    ret, steps, _ = run_episode(policy, env, cmd, shield=None, device=device)
    print(f'  Command H={cmd[0]}, R={cmd[1]}: Return={ret:.1f}, Steps={steps}')

print('\nTesting with Flow Shield:')
for cmd in test_commands:
    ret, steps, proj = run_episode(policy, env, cmd, shield=flow_shield, device=device)
    print(f'  Command H={cmd[0]}, R={cmd[1]}: Return={ret:.1f}, Steps={steps}, Projections={proj}')

In [None]:
env.close()
print('Done!')