In [None]:
# main_cnn_2d_full_benchmark_CLEAN_UI.py -- VERSION FINALE AVEC ANALYSE DES CHAMPIONS

import torch
import torch.nn as nn
import torch.optim as optim
import itertools
from tqdm.auto import tqdm
import os
import numpy as np
import random
import pandas as pd
import warnings
import torch.nn.functional as F
from pathlib import Path

try:
    import pytorch_lightning as pl
    from pytorch_lightning.callbacks import ModelCheckpoint
except ImportError:
    pass

from torch.utils.data import DataLoader, IterableDataset

# --- Configurations ---
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.set_float32_matmul_precision('medium')
warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", ".*Checkpoint directory.*exists and is not empty.*")
warnings.filterwarnings("ignore", ".*Using a target size.*")

IMG_SIZE, IMG_CHANNELS = 8, 1
BATCH_SIZE, MAX_EPOCHS = 1048, 30
N_TARGET_PROBLEMS = 4

LEARNING_RATES = [1,5e-1,1e-1, 5e-2, 1e-2, 1e-3,1e-4]
N_HUTCHINSON_SAMPLES_LIST, MIXED_ALPHAS = [1, 5], [0.1]
SPSA_EPSILONS = [1e-1, 1e-3]
ACTIVATIONS = {"GELU": nn.GELU}
J_NOISE_CONFIGS = [{"name": "deterministic", "noise_std": 0.0}, {"name": "stochastic", "noise_std": "auto"}]

# --- Modèles ---
class ResBlock2D(nn.Module):
    def __init__(self, channels, activation_fn=nn.GELU):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding='same', bias=True),
            nn.SiLU(),
            nn.Conv2d(channels, int(channels/2), kernel_size=3, padding='same', bias=True))
        self.activation = activation_fn()
    def forward(self, x): 
        y = self.conv_block(x)
        y2 = x + self.activation(torch.cat([y,-y],dim=1))
        return y2

class TargetCNN_Classic_2D(nn.Module):
    def __init__(self, in_channels=1, noise_std=0.0):
        super().__init__()
        self.noise_std = noise_std
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding='same'), nn.GELU(),
            ResBlock2D(32), nn.Conv2d(32, 32, kernel_size=3, padding='same'),
            ResBlock2D(32), nn.Conv2d(32, 32, kernel_size=3, padding='same'),
            ResBlock2D(32), nn.Conv2d(32, 32, kernel_size=3, padding='same'),
            ResBlock2D(32), nn.Conv2d(32, in_channels, kernel_size=3, padding='same'), nn.GELU())
        
    def forward(self, x):
        features = self.net(x)
        scalar_energy = features.sum(dim=(1, 2, 3))
        if self.noise_std > 0 and self.training:
            return scalar_energy + torch.randn_like(scalar_energy) * self.noise_std
        return scalar_energy
    def gradient(self, x):
        x_req = x.detach().requires_grad_(True)
        with torch.enable_grad():
            j_sum = self.forward(x_req).sum()
            grad = torch.autograd.grad(j_sum, x_req, create_graph=True)[0]
        return grad
        
class Estimator_Potential_2D(nn.Module):
    def __init__(self, in_channels, image_size, activation_fn=nn.GELU):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(16, 1, kernel_size=1))
    def forward(self, x): return self.net(x).squeeze(-1).squeeze(-1)
    def gradient(self, x):
        x_req = x.detach().requires_grad_(True)
        potential_sum = self.forward(x_req).sum()
        grad = torch.autograd.grad(potential_sum, x_req, create_graph=True)[0]
        return grad

class Estimator_Direct_2D(nn.Module):
    def __init__(self, in_channels, out_channels, activation_fn=nn.GELU):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, out_channels, kernel_size=3, padding='same'))
    def forward(self, x): return self.net(x)

class SampleDatasetCNN_2D(IterableDataset):
    def __init__(self, channels, size, steps, batch_size, std=1.0):
        self.channels, self.size, self.steps, self.batch_size, self.std = channels, size, steps, batch_size, std
    def __iter__(self):
        for _ in range(self.steps):
            yield torch.randn(self.batch_size, self.channels, self.size, self.size, device=device) * self.std
            
class GradEstPLModule_2D(pl.LightningModule):
    def __init__(self, target_cnn, estimator_cnn, hparams):
        super().__init__(); self.save_hyperparameters(hparams)
        self.target_cnn, self.estimator_cnn = target_cnn, estimator_cnn
        for p in self.target_cnn.parameters(): p.requires_grad = False
    def training_step(self, batch, batch_idx):
        self.target_cnn.train()
        method = self.hparams.method
        if method == 'surrogate':
            j_hat = self.estimator_cnn(batch).squeeze()
            with torch.no_grad(): j_true = self.target_cnn(batch).squeeze()
            loss = F.mse_loss(j_hat, j_true)
        elif method == 'wgm': loss = self.wgm_loss(batch)
        elif method == 'mixed':
            j_hat = self.estimator_cnn(batch).squeeze()
            with torch.no_grad(): j_true = self.target_cnn(batch).squeeze()
            loss_surrogate = F.mse_loss(j_hat, j_true)
            loss_wgm = self.wgm_loss(batch)
            loss = loss_surrogate + self.hparams.mixed_alpha * loss_wgm
        elif method == 'spsa':
            x = batch[:batch.size(0) // 2]
            if x.shape[0] == 0: return None
            v_theta = self.estimator_cnn(x)
            with torch.no_grad():
                eps = self.hparams.spsa_epsilon
                v = torch.randn_like(x)
                v_norm = torch.norm(v.flatten(1), p=2, dim=1).view(-1, 1, 1, 1) + 1e-8
                v = v / v_norm
                j_plus = self.target_cnn(x + eps * v)
                j_minus = self.target_cnn(x - eps * v)
                target_directional_derivative = (j_plus - j_minus) / (2 * eps)
            predicted_directional_derivative = (v_theta * v).sum(dim=(1,2,3))
            loss = F.mse_loss(predicted_directional_derivative, target_directional_derivative)

        # --- FIN DE LA CORRECTION SPSA ---
        self.log('train_loss', loss, prog_bar=False, logger=True); return loss
    def wgm_loss(self, x):
        M = self.hparams.wgm_n_samples
        x_rep = x.detach().repeat_interleave(M, dim=0).requires_grad_(True)
        if self.hparams.is_conservative:
            potential = self.estimator_cnn(x_rep); s_theta = torch.autograd.grad(potential.sum(), x_rep, create_graph=True)[0]
        else:
            s_theta = self.estimator_cnn(x_rep)
        v = torch.randn_like(s_theta); s_dot_v_sum = (s_theta * v).sum()
        grad_s_dot_v = torch.autograd.grad(s_dot_v_sum, x_rep, create_graph=True)[0]
        div_terms = (v * grad_s_dot_v).sum(dim=(1, 2, 3))
        grad_log_p_terms = (s_theta * (-x_rep / self.hparams.wgm_noise_std**2)).sum(dim=(1,2,3))
        h_all = div_terms + grad_log_p_terms
        with torch.no_grad():
            j_val = self.target_cnn(x).squeeze(); j_val_rep = j_val.repeat_interleave(M, dim=0); c_opt = j_val_rep.mean()
        term1_all = s_theta.pow(2).sum(dim=(1, 2, 3))
        term2_all = 2 * (j_val_rep - c_opt) * h_all
        loss = (term1_all + term2_all).mean(); return loss
    def _compute_eval_metrics(self, batch):
        self.target_cnn.eval()
        batch.requires_grad_(True)
        with torch.enable_grad():
            grad_j_true = self.target_cnn.gradient(batch)
            if self.hparams.method == 'spsa' or not self.hparams.is_conservative:
                estimated_grad = self.estimator_cnn(batch)
            else: 
                estimated_grad = self.estimator_cnn.gradient(batch)
        batch.requires_grad_(False)
        mse = F.mse_loss(estimated_grad.detach(), grad_j_true.detach())
        cos_sim = F.cosine_similarity(estimated_grad.detach().flatten(1), grad_j_true.detach().flatten(1)).mean()
        return {'mse': mse, 'cos_sim': cos_sim}
    def validation_step(self, batch, batch_idx):
        metrics = self._compute_eval_metrics(batch)
        self.log('val_mse', metrics['mse'], prog_bar=True, logger=True)
        self.log('val_cos_sim', metrics['cos_sim'], prog_bar=True, logger=True)
        return metrics
    def test_step(self, batch, batch_idx, dataloader_idx=0):
        metrics = self._compute_eval_metrics(batch)
        self.log('test_mse', metrics['mse']); self.log('test_cos_sim', metrics['cos_sim'])
        return metrics
    def configure_optimizers(self): return optim.Adam(self.estimator_cnn.parameters(), lr=self.hparams.lr)

def run_single_training(target_cnn, hparams):
    torch.manual_seed(hparams['seed'])
    estimator = hparams["estimator_class"](
        IMG_CHANNELS, IMG_SIZE if hparams["estimator_class"] == Estimator_Potential_2D else IMG_CHANNELS, 
        activation_fn=ACTIVATIONS[hparams["activation_name"]]
    ).to(device)
    experiment = GradEstPLModule_2D(target_cnn, estimator, hparams)
    checkpoint_callback = ModelCheckpoint(monitor="val_mse", mode="min")
    trainer = pl.Trainer(
        max_epochs=MAX_EPOCHS, accelerator="auto", devices=1, logger=False, 
        callbacks=[checkpoint_callback], enable_model_summary=False, num_sanity_val_steps=0,
        enable_progress_bar=False, inference_mode=False
    )
    train_loader = DataLoader(SampleDatasetCNN_2D(IMG_CHANNELS, IMG_SIZE, 50, BATCH_SIZE), batch_size=None)
    val_loader = DataLoader(SampleDatasetCNN_2D(IMG_CHANNELS, IMG_SIZE, 1, BATCH_SIZE), batch_size=None)
    

    trainer.fit(model=experiment, train_dataloaders=train_loader, val_dataloaders=val_loader)
    test_results = trainer.test(ckpt_path="best", dataloaders=val_loader, verbose=False)
    return {'final_mse': test_results[0]['test_mse'], 'final_cos_sim': test_results[0]['test_cos_sim']} if test_results else {'final_mse': float('nan'), 'final_cos_sim': float('nan')}


def log_and_save_results(case_name, all_raw_results, all_hparams_info):
    txt_file = Path(f"security_log_{case_name}.txt"); csv_file = Path(f"all_runs_{case_name}.csv")
    df_raw_mse = pd.DataFrame(all_raw_results['mse'])
    df_raw_cos_sim = pd.DataFrame(all_raw_results['cos_sim'])
    
    # --- Analyse des Champions ---
    families = sorted(list(set([name.split(' ')[0] for name in all_hparams_info.keys()])))
    champion_mse_scores = {fam: [] for fam in families}
    champion_cos_sim_scores = {fam: [] for fam in families}

    for i in df_raw_mse.index: # Itérer sur chaque problème
        problem_mse = df_raw_mse.loc[i]
        problem_cos_sim = df_raw_cos_sim.loc[i]
        
        for family in families:
            family_cols = [col for col in df_raw_mse.columns if col.startswith(family)]
            if not family_cols: continue
            
            # Trouver le champion MSE pour cette famille et ce problème
            best_config_mse = problem_mse[family_cols].idxmin()
            champion_mse_scores[family].append(problem_mse[best_config_mse])
            
            # Trouver le champion Cosine Sim (le plus haut score)
            best_config_cos_sim = problem_cos_sim[family_cols].idxmax()
            champion_cos_sim_scores[family].append(problem_cos_sim[best_config_cos_sim])
            
    df_champions_mse = pd.DataFrame(champion_mse_scores)
    df_champions_cos_sim = pd.DataFrame(champion_cos_sim_scores)

    # Normalisation des champions MSE
    baseline = df_champions_mse['Surrogate']
    df_champions_norm = df_champions_mse.div(baseline, axis=0)

    champion_stats = {
        'norm_mse_mean': df_champions_norm.mean(),
        'norm_mse_std': df_champions_norm.std(),
        'cos_sim_mean': df_champions_cos_sim.mean(),
        'cos_sim_std': df_champions_cos_sim.std()
    }
    df_champion_stats = pd.DataFrame(champion_stats).sort_values('norm_mse_mean')

    # --- Affichage et Sauvegarde ---
    header = f"--- Results for Case: {case_name.upper()} ---\n"
    body = "--- Champion vs. Champion Analysis ---\n"
    body += (f"{'Family':<15} | {'Norm MSE Mean':<15} | {'Norm MSE Std':<15} | {'Cos Sim Mean':<15} | {'Cos Sim Std'}\n")
    body += ("-" * 85) + "\n"
    for family, stats in df_champion_stats.iterrows():
        body += (f"{family:<15} | {stats['norm_mse_mean']:.4f}           | {stats['norm_mse_std']:.4f}            | {stats['cos_sim_mean']:.4f}          | {stats['cos_sim_std']:.4f}\n")
    
    with open(txt_file, "w") as f: f.write(header + body)
    print("\n" + header + body)

    # Sauvegarder les résultats détaillés dans un CSV
    df_to_save = pd.DataFrame(all_raw_results['mse'])
    df_to_save = df_to_save.stack().reset_index()
    df_to_save.columns = ['problem_idx', 'method_name', 'final_mse']
    
    df_cos_sim_flat = pd.DataFrame(all_raw_results['cos_sim']).stack().reset_index()
    df_cos_sim_flat.columns = ['problem_idx', 'method_name', 'final_cos_sim']
    
    df_to_save = pd.merge(df_to_save, df_cos_sim_flat, on=['problem_idx', 'method_name'])
    df_to_save.to_csv(csv_file, index=False)
    print(f"Rapport CSV complet sauvegardé dans '{csv_file}'")

if __name__ == "__main__":
    base_methods = { "Surrogate": {"method": "surrogate", "is_conservative": True, "estimator_class": Estimator_Potential_2D} }
    for M in N_HUTCHINSON_SAMPLES_LIST:
        base_methods[f"WGM-NC (M={M})"] = {"method": "wgm", "is_conservative": False, "estimator_class": Estimator_Direct_2D, "wgm_n_samples": M}
        base_methods[f"WGM-C (M={M})"] = {"method": "wgm", "is_conservative": True, "estimator_class": Estimator_Potential_2D, "wgm_n_samples": M}
        for alpha in MIXED_ALPHAS:
            base_methods[f"Mixed (α={alpha}, M={M})"] = {"method": "mixed", "is_conservative": True, "estimator_class": Estimator_Potential_2D, "mixed_alpha": alpha, "wgm_n_samples": M}
    for eps in SPSA_EPSILONS:
        base_methods[f"SPSA (ε={eps})"] = {"method": "spsa", "is_conservative": False, "estimator_class": Estimator_Direct_2D, "spsa_epsilon": eps}
    all_hparams_info = {}
    for name, config in base_methods.items():
        for lr in LEARNING_RATES:
            config_name = f"{name} lr={lr}"
            effective_lr = lr / (IMG_SIZE * IMG_SIZE) if config['method'] in ['wgm', 'mixed'] else lr
            hparams = config.copy(); hparams.update({"base_lr": lr, "lr": effective_lr, "activation_name": "GELU", "wgm_noise_std": 1.0}); all_hparams_info[config_name] = hparams
    
    TargetClass = TargetCNN_Classic_2D
    for noise_config in J_NOISE_CONFIGS:
        noise_name = noise_config["name"]
        case_name = f"{noise_name}"
        print(f"\n{'='*30} DÉBUT DU BENCHMARK: {case_name.upper()} {'='*30}")
        
        all_raw_results = {
            'mse': {name: [] for name in all_hparams_info.keys()},
            'cos_sim': {name: [] for name in all_hparams_info.keys()}
        }
        
        pbar_problems = tqdm(range(N_TARGET_PROBLEMS), desc=f"Problèmes ({case_name})")
        for problem_idx in pbar_problems:
            torch.manual_seed(SEED + problem_idx)
            current_noise_std = noise_config["noise_std"]
            if current_noise_std == "auto":
                temp_target = TargetClass(IMG_CHANNELS, noise_std=0.0).to(device)
                with torch.no_grad(): j_values = temp_target(torch.randn(4096, IMG_CHANNELS, IMG_SIZE, IMG_SIZE, device=device))
                current_noise_std = 0.1 * j_values.std().item()
            target_cnn = TargetClass(IMG_CHANNELS, noise_std=current_noise_std).to(device)
            pbar_hparams = tqdm(all_hparams_info.items(), desc=f"Problem {problem_idx+1}", leave=False)
            for config_name, hparams in pbar_hparams:
                run_hparams = hparams.copy(); run_hparams['problem_idx'] = problem_idx
                run_hparams['seed'] = SEED * 1000 + problem_idx * 100 + int(hparams['base_lr'] * 1e5) + len(config_name)
                run_hparams['method_name'] = config_name
                final_metrics = run_single_training(target_cnn, run_hparams)
                all_raw_results['mse'][config_name].append(final_metrics['final_mse'])
                all_raw_results['cos_sim'][config_name].append(final_metrics['final_cos_sim'])
                
        log_and_save_results(case_name, all_raw_results, all_hparams_info)

    print("\nBenchmark complet terminé.")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda



Problèmes (deterministic):   0%|          | 0/4 [00:00<?, ?it/s]
Problem 1:   0%|          | 0/63 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=30` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\checkpoints\epoch=25-step=1300-v28.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\checkpoints\epoch=25-step=1300-v28.ckpt

Problem 1:   2%|▏         | 1/63 [00:07<07:19,  7.09s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit`


--- Results for Case: DETERMINISTIC ---
--- Champion vs. Champion Analysis ---
Family          | Norm MSE Mean   | Norm MSE Std    | Cos Sim Mean    | Cos Sim Std
-------------------------------------------------------------------------------------
WGM-C           | 0.7790           | 0.4716            | 0.9877          | 0.0040
Mixed           | 0.8027           | 0.0941            | 0.9844          | 0.0066
SPSA            | 0.8894           | 0.4723            | 0.9829          | 0.0135
Surrogate       | 1.0000           | 0.0000            | 0.9794          | 0.0100
WGM-NC          | 1.4576           | 0.6990            | 0.9739          | 0.0139

Rapport CSV complet sauvegardé dans 'all_runs_deterministic.csv'



Problèmes (stochastic):   0%|          | 0/4 [00:00<?, ?it/s]
Problem 1:   0%|          | 0/63 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=30` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\checkpoints\epoch=26-step=1350-v72.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\checkpoints\epoch=26-step=1350-v72.ckpt

Problem 1:   2%|▏         | 1/63 [00:06<06:29,  6.28s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` st


--- Results for Case: STOCHASTIC ---
--- Champion vs. Champion Analysis ---
Family          | Norm MSE Mean   | Norm MSE Std    | Cos Sim Mean    | Cos Sim Std
-------------------------------------------------------------------------------------
WGM-C           | 0.7698           | 0.2405            | 0.9862          | 0.0064
Mixed           | 0.8679           | 0.0912            | 0.9832          | 0.0081
Surrogate       | 1.0000           | 0.0000            | 0.9812          | 0.0072
WGM-NC          | 1.5486           | 0.5581            | 0.9703          | 0.0177
SPSA            | 4.7605           | 3.4976            | 0.8918          | 0.1195

Rapport CSV complet sauvegardé dans 'all_runs_stochastic.csv'

Benchmark complet terminé.



