In [1]:
# main_physics_benchmark_final_normalized.py

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm
import numpy as np
import random
import pandas as pd
import warnings
import torch.nn.functional as F
from pathlib import Path
from torch.utils.data import DataLoader, IterableDataset, TensorDataset

# Tenter d'importer PyTorch Lightning
try:
    import pytorch_lightning as pl
    from pytorch_lightning.callbacks import ModelCheckpoint
except ImportError:
    print("Pytorch Lightning n'est pas installé. Veuillez l'installer avec : pip install pytorch-lightning")
    pl = None

# --- Configurations Générales ---
if pl is None:
    raise ImportError("Pytorch Lightning est requis pour exécuter ce script.")

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration pour les performances sur les GPU récents
if torch.cuda.is_available():
    torch.set_float32_matmul_precision('medium')

# Ignorer les avertissements non critiques
warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", ".*Checkpoint directory.*exists and is not empty.*")
warnings.filterwarnings("ignore", ".*Using a target size.*")

# --- Hyperparamètres du Benchmark ---
IMG_SIZE, IMG_CHANNELS = 8, 1
BATCH_SIZE, MAX_EPOCHS = 128, 20

# Grille de recherche des hyperparamètres
LEARNING_RATES = [10,5,1,5e-1,1e-1,5e-2, 1e-2, 5e-3, 1e-3, 1e-4]
N_HUTCHINSON_SAMPLES_LIST, MIXED_ALPHAS = [1, 5, 10], [0.1, 0.01]
SPSA_EPSILONS = [1e-1, 1e-3]
ACTIVATIONS = {"GELU": nn.GELU}
J_NOISE_CONFIGS = [{"name": "deterministic", "noise_std": 0.0}, {"name": "stochastic", "noise_std": "auto"}]
N_SAMPLES_FOR_STATS = 4096 # Nombre d'échantillons pour calculer la moyenne et l'std

# --- Fonctions Cibles Basées sur la Physique ---

class PhysicsTarget(nn.Module):
    def __init__(self, img_size=16, noise_std=0.0):
        super().__init__()
        self.img_size = img_size
        self.noise_std = noise_std
        for p in self.parameters(): p.requires_grad = False

    def forward(self, x): raise NotImplementedError

    def gradient(self, x):
        x_req = x.detach().requires_grad_(True)
        with torch.enable_grad():
            original_training_state = self.training
            self.training = False # Toujours calculer le gradient sans bruit
            j_sum = self.forward(x_req).sum()
            grad = torch.autograd.grad(j_sum, x_req)[0]
            self.training = original_training_state
        return grad

class NormalizedTarget(PhysicsTarget):
    """
    Wrapper qui centre et réduit une fonction cible J(x) en utilisant
    des statistiques globales calculées une seule fois.
    """
    def __init__(self, base_target: PhysicsTarget, n_samples_for_stats=N_SAMPLES_FOR_STATS):
        super().__init__(img_size=base_target.img_size, noise_std=base_target.noise_std)
        self.base_target = base_target
        
        print(f"Calculating normalization stats for {base_target.__class__.__name__}...")
        with torch.no_grad():
            original_training_state = self.base_target.training
            self.base_target.train(False)
            
            dummy_dataset = TensorDataset(torch.randn(n_samples_for_stats, IMG_CHANNELS, self.img_size, self.img_size))
            dummy_loader = DataLoader(dummy_dataset, batch_size=BATCH_SIZE*2)
            
            j_values = []
            for batch in tqdm(dummy_loader, desc="Stat Computation", leave=False):
                j_values.append(self.base_target(batch[0].to(device)))
            
            j_values = torch.cat(j_values)
            mean, std = j_values.mean().item(), j_values.std().item()
            self.base_target.train(original_training_state)

        self.register_buffer('mean', torch.tensor(mean, device=device))
        self.register_buffer('std', torch.tensor(std, device=device) + 1e-8)

        print(f"Stats: Mean={self.mean.item():.4f}, Std={self.std.item():.4f}")

    def forward(self, x):
        base_energy_no_noise = self.base_target(x)
        normalized_energy = (base_energy_no_noise - self.mean) / self.std
        
        if self.noise_std > 0 and self.training:
            noise = torch.randn_like(normalized_energy) * self.noise_std 
            return normalized_energy + noise
            
        return normalized_energy

class IsingTarget(PhysicsTarget):
    def __init__(self, img_size=16, J=1.0, h=0.0, noise_std=0.0):
        super().__init__(img_size, noise_std)
        self.J, self.h = J, h
        kernel = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.float32).view(1, 1, 3, 3)
        self.register_buffer('neighbor_kernel', kernel)

    def forward(self, x):
        spins = torch.tanh(x)
        neighbor_sum = F.conv2d(spins, self.neighbor_kernel.to(x.device), padding='same')
        interaction_energy = -self.J * (spins * neighbor_sum)
        field_energy = -self.h * spins
        total_energy = (interaction_energy.sum(dim=(1,2,3)) / 2) + field_energy.sum(dim=(1,2,3))
        # Note: Le bruit est géré par le wrapper NormalizedTarget (si utilisé)
        return total_energy

class XYTarget(PhysicsTarget):
    def __init__(self, img_size=16, J=1.0, noise_std=0.0):
        super().__init__(img_size, noise_std)
        self.J = J
        kernel_right = torch.tensor([[0, 0, 0], [0, -1, 1], [0, 0, 0]], dtype=torch.float32).view(1, 1, 3, 3)
        kernel_down = torch.tensor([[0, 0, 0], [0, -1, 0], [0, 1, 0]], dtype=torch.float32).view(1, 1, 3, 3)
        self.register_buffer('kernel_right', kernel_right)
        self.register_buffer('kernel_down', kernel_down)

    def forward(self, x):
        theta = x
        diff_right = F.conv2d(theta, self.kernel_right.to(x.device), padding='same')
        diff_down = F.conv2d(theta, self.kernel_down.to(x.device), padding='same')
        energy = -self.J * (torch.cos(diff_right) + torch.cos(diff_down))
        total_energy = energy.sum(dim=(1,2,3))
        return total_energy

# --- Modèles Estimateurs ---
class Estimator_Potential_2D(nn.Module):
    def __init__(self, in_channels, image_size, activation_fn=nn.GELU):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, 32, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(32, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(16, 1, kernel_size=1))
    def forward(self, x): return self.net(x).squeeze(-1).squeeze(-1)
    def gradient(self, x):
        x_req = x.detach().requires_grad_(True)
        with torch.enable_grad():
            potential_sum = self.forward(x_req).sum()
            grad = torch.autograd.grad(potential_sum, x_req)[0]
        return grad

class Estimator_Direct_2D(nn.Module):
    def __init__(self, in_channels, out_channels, activation_fn=nn.GELU):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, 32, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(32, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, out_channels, kernel_size=3, padding='same'))
    def forward(self, x): return self.net(x)

# --- Dataset et Module PyTorch Lightning ---
class SampleDatasetCNN_2D(IterableDataset):
    def __init__(self, channels, size, steps, batch_size, std=1.0):
        self.channels, self.size, self.steps, self.batch_size, self.std = channels, size, steps, batch_size, std
    def __iter__(self):
        for _ in range(self.steps):
            yield torch.randn(self.batch_size, self.channels, self.size, self.size, device=device) * self.std
            
class GradEstPLModule_2D(pl.LightningModule):
    def __init__(self, target_cnn, estimator_cnn, hparams):
        super().__init__(); self.save_hyperparameters(hparams)
        self.target_cnn, self.estimator_cnn = target_cnn, estimator_cnn
        for p in self.target_cnn.parameters(): p.requires_grad = False

    def training_step(self, batch, batch_idx):
        self.target_cnn.train()
        method = self.hparams.method
        loss = torch.tensor(0.0, device=self.device)

        if method == 'surrogate':
            j_hat = self.estimator_cnn(batch).squeeze()
            with torch.no_grad(): j_true = self.target_cnn(batch).squeeze()
            loss = F.mse_loss(j_hat, j_true)
        elif method == 'wgm':
            loss = self.wgm_loss(batch)
        elif method == 'mixed':
            j_hat = self.estimator_cnn(batch).squeeze()
            with torch.no_grad(): j_true = self.target_cnn(batch).squeeze()
            loss_surrogate = F.mse_loss(j_hat, j_true)
            loss_wgm = self.wgm_loss(batch)
            loss = loss_surrogate + self.hparams.mixed_alpha * loss_wgm
        elif method == 'spsa':
            x = batch[:batch.size(0) // 2]
            if x.shape[0] == 0: return None
            v_theta = self.estimator_cnn(x)
            with torch.no_grad():
                eps = self.hparams.spsa_epsilon
                v = torch.randn_like(x)
                v_norm = torch.norm(v.flatten(1), p=2, dim=1).view(-1, 1, 1, 1) + 1e-8
                v = v / v_norm
                j_plus = self.target_cnn(x + eps * v)
                j_minus = self.target_cnn(x - eps * v)
                target_directional_derivative = (j_plus - j_minus) / (2 * eps)
            predicted_directional_derivative = (v_theta * v).sum(dim=(1,2,3))
            loss = F.mse_loss(predicted_directional_derivative, target_directional_derivative)

        self.log('train_loss', loss, prog_bar=False, logger=True); return loss

    def wgm_loss(self, x):
        M = self.hparams.wgm_n_samples
        x_rep = x.detach().repeat_interleave(M, dim=0).requires_grad_(True)
        if self.hparams.is_conservative:
            potential = self.estimator_cnn(x_rep)
            s_theta = torch.autograd.grad(potential.sum(), x_rep, create_graph=True)[0]
        else:
            s_theta = self.estimator_cnn(x_rep)
        v = torch.randn_like(s_theta)
        s_dot_v_sum = (s_theta * v).sum()
        grad_s_dot_v = torch.autograd.grad(s_dot_v_sum, x_rep, create_graph=True)[0]
        div_terms = (v * grad_s_dot_v).sum(dim=(1, 2, 3))
        grad_log_p_terms = (s_theta * (-x_rep / self.hparams.wgm_noise_std**2)).sum(dim=(1,2,3))
        h_all = div_terms + grad_log_p_terms
        with torch.no_grad():
            j_val = self.target_cnn(x).squeeze()
            j_val_rep = j_val.repeat_interleave(M, dim=0)
            c_opt = j_val_rep.mean() # Centre par batch pour réduire la variance
        term1_all = s_theta.pow(2).sum(dim=(1, 2, 3))
        term2_all = 2 * (j_val_rep - c_opt) * h_all
        loss = (term1_all + term2_all).mean()
        return loss

    def _compute_eval_metrics(self, batch):
        self.target_cnn.eval()
        with torch.no_grad():
            grad_j_true = self.target_cnn.gradient(batch)
        with torch.enable_grad():
            if self.hparams.method == 'spsa' or not self.hparams.is_conservative:
                estimated_grad = self.estimator_cnn(batch)
            else: 
                estimated_grad = self.estimator_cnn.gradient(batch)
        mse = F.mse_loss(estimated_grad.detach(), grad_j_true.detach())
        cos_sim = F.cosine_similarity(estimated_grad.detach().flatten(1), grad_j_true.detach().flatten(1)).mean()
        return {'mse': mse, 'cos_sim': cos_sim}

    def validation_step(self, batch, batch_idx):
        metrics = self._compute_eval_metrics(batch)
        self.log('val_mse', metrics['mse'], prog_bar=True, logger=True)
        self.log('val_cos_sim', metrics['cos_sim'], prog_bar=False, logger=True)
        return metrics

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        metrics = self._compute_eval_metrics(batch)
        self.log('test_mse', metrics['mse'])
        self.log('test_cos_sim', metrics['cos_sim'])
        return metrics

    def configure_optimizers(self): 
        return optim.Adam(self.estimator_cnn.parameters(), lr=self.hparams.lr)

# --- Fonctions d'Exécution et de Logging ---
def run_single_training(target_cnn, hparams):
    torch.manual_seed(hparams['seed'])
    estimator = hparams["estimator_class"](
        IMG_CHANNELS, 
        IMG_SIZE if hparams["estimator_class"] == Estimator_Potential_2D else IMG_CHANNELS, 
        activation_fn=ACTIVATIONS[hparams["activation_name"]]
    ).to(device)
    experiment = GradEstPLModule_2D(target_cnn, estimator, hparams)
    checkpoint_callback = ModelCheckpoint(monitor="val_mse", mode="min", save_top_k=1)
    trainer = pl.Trainer(
        max_epochs=MAX_EPOCHS, accelerator="auto", devices=1, logger=False, 
        callbacks=[checkpoint_callback], enable_model_summary=False, num_sanity_val_steps=0,
        enable_progress_bar=False, inference_mode=False
    )
    train_loader = DataLoader(SampleDatasetCNN_2D(IMG_CHANNELS, IMG_SIZE, 50, BATCH_SIZE), batch_size=None)
    val_loader = DataLoader(SampleDatasetCNN_2D(IMG_CHANNELS, IMG_SIZE, 5, BATCH_SIZE), batch_size=None)
    trainer.fit(model=experiment, train_dataloaders=train_loader, val_dataloaders=val_loader)
    test_results = trainer.test(ckpt_path="best", dataloaders=val_loader, verbose=False)
    if test_results:
         return {'final_mse': test_results[0]['test_mse'], 'final_cos_sim': test_results[0]['test_cos_sim']}
    else:
         return {'final_mse': float('nan'), 'final_cos_sim': float('nan')}

def log_and_save_results(case_name, all_raw_results, all_hparams_info, problem_names):
    output_dir = Path("benchmark_results")
    output_dir.mkdir(exist_ok=True)
    txt_file = output_dir / f"summary_{case_name}.txt"
    raw_csv_file = output_dir / f"all_runs_{case_name}.csv"
    champions_csv_file = output_dir / f"champions_{case_name}.csv"
    df_raw_mse = pd.DataFrame(all_raw_results['mse'], index=problem_names)
    df_raw_cos_sim = pd.DataFrame(all_raw_results['cos_sim'], index=problem_names)
    df_to_save_raw = df_raw_mse.stack().reset_index()
    df_to_save_raw.columns = ['problem_name', 'method_name', 'final_mse']
    df_cos_sim_flat = df_raw_cos_sim.stack().reset_index()
    df_cos_sim_flat.columns = ['problem_name', 'method_name', 'final_cos_sim']
    df_to_save_raw = pd.merge(df_to_save_raw, df_cos_sim_flat, on=['problem_name', 'method_name'])
    df_to_save_raw.to_csv(raw_csv_file, index=False)
    print(f"Rapport CSV avec tous les runs sauvegardé dans '{raw_csv_file}'")
    families = sorted(list(set([name.split(' ')[0].split('(')[0] for name in all_hparams_info.keys()])))
    champion_mse_scores, champion_cos_sim_scores = {fam: [] for fam in families}, {fam: [] for fam in families}
    for problem_name in problem_names:
        problem_mse, problem_cos_sim = df_raw_mse.loc[problem_name].dropna(), df_raw_cos_sim.loc[problem_name].dropna()
        for family in families:
            family_cols_mse = [col for col in problem_mse.index if col.startswith(family)]
            if not family_cols_mse:
                champion_mse_scores[family].append(np.nan)
                champion_cos_sim_scores[family].append(np.nan)
                continue
            best_config_mse = problem_mse[family_cols_mse].idxmin()
            champion_mse_scores[family].append(problem_mse[best_config_mse])
            if best_config_mse in problem_cos_sim.index:
                champion_cos_sim_scores[family].append(problem_cos_sim[best_config_mse])
            else:
                champion_cos_sim_scores[family].append(np.nan)
    df_champions_mse = pd.DataFrame(champion_mse_scores, index=problem_names)
    df_champions_cos_sim = pd.DataFrame(champion_cos_sim_scores, index=problem_names)
    df_champions_to_save = df_champions_mse.stack().reset_index()
    df_champions_to_save.columns = ['problem_name', 'family_name', 'champion_mse']
    df_champions_cos_sim_flat = df_champions_cos_sim.stack().reset_index()
    df_champions_cos_sim_flat.columns = ['problem_name', 'family_name', 'champion_cos_sim']
    df_champions_to_save = pd.merge(df_champions_to_save, df_champions_cos_sim_flat, on=['problem_name', 'family_name'])
    df_champions_to_save.to_csv(champions_csv_file, index=False)
    print(f"Rapport CSV avec les champions de chaque famille sauvegardé dans '{champions_csv_file}'")
    baseline = df_champions_mse['Surrogate'] + 1e-9
    df_champions_norm = df_champions_mse.div(baseline, axis=0)
    champion_stats = {'norm_mse_mean': df_champions_norm.mean(), 'norm_mse_std': df_champions_norm.std(), 'cos_sim_mean': df_champions_cos_sim.mean(), 'cos_sim_std': df_champions_cos_sim.std()}
    df_champion_stats = pd.DataFrame(champion_stats).sort_values('norm_mse_mean')
    header = f"--- Summary for Case: {case_name.upper()} ---\n"
    body = "--- Champion vs. Champion Analysis (Mean over all problems) ---\n"
    body += f"{'Family':<20} | {'Norm MSE Mean':<15} | {'Norm MSE Std':<15} | {'Cos Sim Mean':<15} | {'Cos Sim Std'}\n"
    body += ("-" * 95) + "\n"
    for family, stats in df_champion_stats.iterrows():
        body += f"{family:<20} | {stats['norm_mse_mean']:.4f}           | {stats['norm_mse_std']:.4f}           | {stats['cos_sim_mean']:.4f}          | {stats['cos_sim_std']:.4f}\n"
    summary_content = header + body
    summary_content += "\n\n--- Champion MSE per Problem ---\n" + df_champions_mse.to_string()
    summary_content += "\n\n--- Champion Cosine Similarity per Problem ---\n" + df_champions_cos_sim.to_string()
    with open(txt_file, "w") as f: f.write(summary_content)
    print("\n" + header + body)
    print(f"Résumé détaillé sauvegardé dans '{txt_file}'")

# --- Point d'Entrée Principal ---
if __name__ == "__main__":
    base_methods = { "Surrogate": {"method": "surrogate", "is_conservative": True, "estimator_class": Estimator_Potential_2D} }
    for M in N_HUTCHINSON_SAMPLES_LIST:
        base_methods[f"WGM-NC(M={M})"] = {"method": "wgm", "is_conservative": False, "estimator_class": Estimator_Direct_2D, "wgm_n_samples": M}
        base_methods[f"WGM-C(M={M})"] = {"method": "wgm", "is_conservative": True, "estimator_class": Estimator_Potential_2D, "wgm_n_samples": M}
        for alpha in MIXED_ALPHAS:
            base_methods[f"Mixed(a={alpha},M={M})"] = {"method": "mixed", "is_conservative": True, "estimator_class": Estimator_Potential_2D, "mixed_alpha": alpha, "wgm_n_samples": M}
    for eps in SPSA_EPSILONS:
        base_methods[f"SPSA(e={eps})"] = {"method": "spsa", "is_conservative": False, "estimator_class": Estimator_Direct_2D, "spsa_epsilon": eps}
    all_hparams_info = {}
    for name, config in base_methods.items():
        for lr in LEARNING_RATES:
            config_name = f"{name} lr={lr}"
            effective_lr = lr / (IMG_SIZE * IMG_SIZE) if config['method'] in ['wgm', 'mixed'] else lr
            hparams = config.copy()
            hparams.update({"base_lr": lr, "lr": effective_lr, "activation_name": "GELU", "wgm_noise_std": 1.0})
            all_hparams_info[config_name] = hparams

    TARGET_PROBLEMS = {
        "Ising_Ferromagnetic_H0": {"class": IsingTarget, "params": {"J": 1.0, "h": 0.0}},
        "Ising_Ferromagnetic_H1": {"class": IsingTarget, "params": {"J": 1.0, "h": 1.0}},
        "Ising_Antiferromagnetic": {"class": IsingTarget, "params": {"J": -1.0, "h": 0.0}},
        "XY_Model": {"class": XYTarget, "params": {"J": 1.0}}
    }
    
    for noise_config in J_NOISE_CONFIGS:
        noise_name = noise_config["name"]
        case_name = f"PhysicsModels_{noise_name}_Normalized"
        all_raw_results = {'mse': {name: [] for name in all_hparams_info.keys()}, 'cos_sim': {name: [] for name in all_hparams_info.keys()}}
        problem_names = list(TARGET_PROBLEMS.keys())
        
        pbar_problems = tqdm(enumerate(TARGET_PROBLEMS.items()), total=len(TARGET_PROBLEMS), desc=f"Problèmes ({case_name})")
        for problem_idx, (problem_name, config) in pbar_problems:
            TargetClass = config["class"]
            target_params = config["params"]
            
            base_target_with_noise_param = TargetClass(img_size=IMG_SIZE, **target_params, noise_std=0.0).to(device)
            normalized_target = NormalizedTarget(base_target_with_noise_param).to(device)
            
            current_noise_std = noise_config["noise_std"]
            if current_noise_std == "auto":
                current_noise_std = 0.1
            
            final_target = normalized_target
            # Le bruit sera géré par le wrapper, il faut juste lui passer le paramètre
            final_target.noise_std = current_noise_std
            
            pbar_hparams = tqdm(all_hparams_info.items(), desc=f"HParams for {problem_name}", leave=False)
            for config_name, hparams in pbar_hparams:
                run_hparams = hparams.copy()
                run_hparams['problem_idx'] = problem_idx
                run_hparams['seed'] = SEED * 1000 + problem_idx * 100 + int(hparams['base_lr'] * 1e5) + sum(ord(c) for c in config_name)
                run_hparams['method_name'] = config_name
                
                final_metrics = run_single_training(final_target, run_hparams)
                
                all_raw_results['mse'][config_name].append(final_metrics['final_mse'])
                all_raw_results['cos_sim'][config_name].append(final_metrics['final_cos_sim'])
                
        log_and_save_results(case_name, all_raw_results, all_hparams_info, problem_names)
    print("\nBenchmark complet terminé.")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


Problèmes (PhysicsModels_deterministic_Normalized):   0%|          | 0/4 [00:00<?, ?it/s]

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
Stat Computation:   6%|▋         | 1/16 [00:00<00:12,  1.19it/s][A
                                                                [A

Stats: Mean=-0.0432, Std=4.1586



HParams for Ising_Ferromagnetic_H0:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=10-step=550-v82.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=10-step=550-v82.ckpt

HParams for Ising_Ferromagnetic_H0:   1%|          | 1/150 [00:03<07:46,  3.13s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-0.0300, Std=6.2963



HParams for Ising_Ferromagnetic_H1:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=12-step=650-v120.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=12-step=650-v120.ckpt

HParams for Ising_Ferromagnetic_H1:   1%|          | 1/150 [00:03<07:29,  3.02s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-0.0063, Std=4.1749



HParams for Ising_Antiferromagnetic:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=19-step=1000-v1668.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=19-step=1000-v1668.ckpt

HParams for Ising_Antiferromagnetic:   1%|          | 1/150 [00:03<07:41,  3.10s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VI

Calculating normalization stats for XYTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-51.1251, Std=9.9154



HParams for XY_Model:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=12-step=650-v125.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=12-step=650-v125.ckpt

HParams for XY_Model:   1%|          | 1/150 [00:02<07:23,  2.98s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` s

Rapport CSV avec tous les runs sauvegardé dans 'benchmark_results\all_runs_PhysicsModels_deterministic_Normalized.csv'
Rapport CSV avec les champions de chaque famille sauvegardé dans 'benchmark_results\champions_PhysicsModels_deterministic_Normalized.csv'

--- Summary for Case: PHYSICSMODELS_DETERMINISTIC_NORMALIZED ---
--- Champion vs. Champion Analysis (Mean over all problems) ---
Family               | Norm MSE Mean   | Norm MSE Std    | Cos Sim Mean    | Cos Sim Std
-----------------------------------------------------------------------------------------------
SPSA                 | 0.2163           | 0.0904           | 0.9982          | 0.0004
Mixed                | 0.9351           | 0.1568           | 0.9917          | 0.0015
Surrogate            | 1.0000           | 0.0000           | 0.9909          | 0.0026
WGM-NC               | 2.4640           | 0.9004           | 0.9796          | 0.0017
WGM-C                | 2.6528           | 1.0158           | 0.9778          | 0.002

Problèmes (PhysicsModels_stochastic_Normalized):   0%|          | 0/4 [00:00<?, ?it/s]

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-0.0016, Std=4.1599



HParams for Ising_Ferromagnetic_H0:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=4-step=250-v81.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=4-step=250-v81.ckpt

HParams for Ising_Ferromagnetic_H0:   1%|          | 1/150 [00:03<07:33,  3.04s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: 

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-0.0300, Std=6.2963



HParams for Ising_Ferromagnetic_H1:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=1-step=100-v72.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=1-step=100-v72.ckpt

HParams for Ising_Ferromagnetic_H1:   1%|          | 1/150 [00:03<07:28,  3.01s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVI

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-0.0063, Std=4.1749



HParams for Ising_Antiferromagnetic:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=1-step=100-v75.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=1-step=100-v75.ckpt

HParams for Ising_Antiferromagnetic:   1%|          | 1/150 [00:03<07:29,  3.02s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DE

Calculating normalization stats for XYTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-51.1251, Std=9.9154



HParams for XY_Model:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=19-step=1000-v2031.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=19-step=1000-v2031.ckpt

HParams for XY_Model:   1%|          | 1/150 [00:03<07:43,  3.11s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fi

Rapport CSV avec tous les runs sauvegardé dans 'benchmark_results\all_runs_PhysicsModels_stochastic_Normalized.csv'
Rapport CSV avec les champions de chaque famille sauvegardé dans 'benchmark_results\champions_PhysicsModels_stochastic_Normalized.csv'

--- Summary for Case: PHYSICSMODELS_STOCHASTIC_NORMALIZED ---
--- Champion vs. Champion Analysis (Mean over all problems) ---
Family               | Norm MSE Mean   | Norm MSE Std    | Cos Sim Mean    | Cos Sim Std
-----------------------------------------------------------------------------------------------
Mixed                | 0.8927           | 0.1849           | 0.9899          | 0.0017
Surrogate            | 1.0000           | 0.0000           | 0.9881          | 0.0048
WGM-C                | 2.0364           | 0.8346           | 0.9776          | 0.0048
WGM-NC               | 2.1157           | 0.9159           | 0.9784          | 0.0030
SPSA                 | 11.6161           | 3.2614           | 0.8609          | 0.0275

Résum




In [2]:
# main_physics_benchmark_final_normalized.py

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm
import numpy as np
import random
import pandas as pd
import warnings
import torch.nn.functional as F
from pathlib import Path
from torch.utils.data import DataLoader, IterableDataset, TensorDataset

# Tenter d'importer PyTorch Lightning
try:
    import pytorch_lightning as pl
    from pytorch_lightning.callbacks import ModelCheckpoint
except ImportError:
    print("Pytorch Lightning n'est pas installé. Veuillez l'installer avec : pip install pytorch-lightning")
    pl = None

# --- Configurations Générales ---
if pl is None:
    raise ImportError("Pytorch Lightning est requis pour exécuter ce script.")

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration pour les performances sur les GPU récents
if torch.cuda.is_available():
    torch.set_float32_matmul_precision('medium')

# Ignorer les avertissements non critiques
warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", ".*Checkpoint directory.*exists and is not empty.*")
warnings.filterwarnings("ignore", ".*Using a target size.*")

# --- Hyperparamètres du Benchmark ---
IMG_SIZE, IMG_CHANNELS = 16, 1
BATCH_SIZE, MAX_EPOCHS = 128, 20

# Grille de recherche des hyperparamètres
LEARNING_RATES = [10,5,1,5e-1,1e-1,5e-2, 1e-2, 5e-3, 1e-3, 1e-4]
N_HUTCHINSON_SAMPLES_LIST, MIXED_ALPHAS = [1, 5, 10], [0.1, 0.01]
SPSA_EPSILONS = [1e-1, 1e-3]
ACTIVATIONS = {"GELU": nn.GELU}
J_NOISE_CONFIGS = [{"name": "deterministic", "noise_std": 0.0}, {"name": "stochastic", "noise_std": "auto"}]
N_SAMPLES_FOR_STATS = 4096 # Nombre d'échantillons pour calculer la moyenne et l'std

# --- Fonctions Cibles Basées sur la Physique ---

class PhysicsTarget(nn.Module):
    def __init__(self, img_size=16, noise_std=0.0):
        super().__init__()
        self.img_size = img_size
        self.noise_std = noise_std
        for p in self.parameters(): p.requires_grad = False

    def forward(self, x): raise NotImplementedError

    def gradient(self, x):
        x_req = x.detach().requires_grad_(True)
        with torch.enable_grad():
            original_training_state = self.training
            self.training = False # Toujours calculer le gradient sans bruit
            j_sum = self.forward(x_req).sum()
            grad = torch.autograd.grad(j_sum, x_req)[0]
            self.training = original_training_state
        return grad

class NormalizedTarget(PhysicsTarget):
    """
    Wrapper qui centre et réduit une fonction cible J(x) en utilisant
    des statistiques globales calculées une seule fois.
    """
    def __init__(self, base_target: PhysicsTarget, n_samples_for_stats=N_SAMPLES_FOR_STATS):
        super().__init__(img_size=base_target.img_size, noise_std=base_target.noise_std)
        self.base_target = base_target
        
        print(f"Calculating normalization stats for {base_target.__class__.__name__}...")
        with torch.no_grad():
            original_training_state = self.base_target.training
            self.base_target.train(False)
            
            dummy_dataset = TensorDataset(torch.randn(n_samples_for_stats, IMG_CHANNELS, self.img_size, self.img_size))
            dummy_loader = DataLoader(dummy_dataset, batch_size=BATCH_SIZE*2)
            
            j_values = []
            for batch in tqdm(dummy_loader, desc="Stat Computation", leave=False):
                j_values.append(self.base_target(batch[0].to(device)))
            
            j_values = torch.cat(j_values)
            mean, std = j_values.mean().item(), j_values.std().item()
            self.base_target.train(original_training_state)

        self.register_buffer('mean', torch.tensor(mean, device=device))
        self.register_buffer('std', torch.tensor(std, device=device) + 1e-8)

        print(f"Stats: Mean={self.mean.item():.4f}, Std={self.std.item():.4f}")

    def forward(self, x):
        base_energy_no_noise = self.base_target(x)
        normalized_energy = (base_energy_no_noise - self.mean) / self.std
        
        if self.noise_std > 0 and self.training:
            noise = torch.randn_like(normalized_energy) * self.noise_std 
            return normalized_energy + noise
            
        return normalized_energy

class IsingTarget(PhysicsTarget):
    def __init__(self, img_size=16, J=1.0, h=0.0, noise_std=0.0):
        super().__init__(img_size, noise_std)
        self.J, self.h = J, h
        kernel = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.float32).view(1, 1, 3, 3)
        self.register_buffer('neighbor_kernel', kernel)

    def forward(self, x):
        spins = torch.tanh(x)
        neighbor_sum = F.conv2d(spins, self.neighbor_kernel.to(x.device), padding='same')
        interaction_energy = -self.J * (spins * neighbor_sum)
        field_energy = -self.h * spins
        total_energy = (interaction_energy.sum(dim=(1,2,3)) / 2) + field_energy.sum(dim=(1,2,3))
        # Note: Le bruit est géré par le wrapper NormalizedTarget (si utilisé)
        return total_energy

class XYTarget(PhysicsTarget):
    def __init__(self, img_size=16, J=1.0, noise_std=0.0):
        super().__init__(img_size, noise_std)
        self.J = J
        kernel_right = torch.tensor([[0, 0, 0], [0, -1, 1], [0, 0, 0]], dtype=torch.float32).view(1, 1, 3, 3)
        kernel_down = torch.tensor([[0, 0, 0], [0, -1, 0], [0, 1, 0]], dtype=torch.float32).view(1, 1, 3, 3)
        self.register_buffer('kernel_right', kernel_right)
        self.register_buffer('kernel_down', kernel_down)

    def forward(self, x):
        theta = x
        diff_right = F.conv2d(theta, self.kernel_right.to(x.device), padding='same')
        diff_down = F.conv2d(theta, self.kernel_down.to(x.device), padding='same')
        energy = -self.J * (torch.cos(diff_right) + torch.cos(diff_down))
        total_energy = energy.sum(dim=(1,2,3))
        return total_energy

# --- Modèles Estimateurs ---
class Estimator_Potential_2D(nn.Module):
    def __init__(self, in_channels, image_size, activation_fn=nn.GELU):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, 32, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(32, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(16, 1, kernel_size=1))
    def forward(self, x): return self.net(x).squeeze(-1).squeeze(-1)
    def gradient(self, x):
        x_req = x.detach().requires_grad_(True)
        with torch.enable_grad():
            potential_sum = self.forward(x_req).sum()
            grad = torch.autograd.grad(potential_sum, x_req)[0]
        return grad

class Estimator_Direct_2D(nn.Module):
    def __init__(self, in_channels, out_channels, activation_fn=nn.GELU):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, 32, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(32, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, out_channels, kernel_size=3, padding='same'))
    def forward(self, x): return self.net(x)

# --- Dataset et Module PyTorch Lightning ---
class SampleDatasetCNN_2D(IterableDataset):
    def __init__(self, channels, size, steps, batch_size, std=1.0):
        self.channels, self.size, self.steps, self.batch_size, self.std = channels, size, steps, batch_size, std
    def __iter__(self):
        for _ in range(self.steps):
            yield torch.randn(self.batch_size, self.channels, self.size, self.size, device=device) * self.std
            
class GradEstPLModule_2D(pl.LightningModule):
    def __init__(self, target_cnn, estimator_cnn, hparams):
        super().__init__(); self.save_hyperparameters(hparams)
        self.target_cnn, self.estimator_cnn = target_cnn, estimator_cnn
        for p in self.target_cnn.parameters(): p.requires_grad = False

    def training_step(self, batch, batch_idx):
        self.target_cnn.train()
        method = self.hparams.method
        loss = torch.tensor(0.0, device=self.device)

        if method == 'surrogate':
            j_hat = self.estimator_cnn(batch).squeeze()
            with torch.no_grad(): j_true = self.target_cnn(batch).squeeze()
            loss = F.mse_loss(j_hat, j_true)
        elif method == 'wgm':
            loss = self.wgm_loss(batch)
        elif method == 'mixed':
            j_hat = self.estimator_cnn(batch).squeeze()
            with torch.no_grad(): j_true = self.target_cnn(batch).squeeze()
            loss_surrogate = F.mse_loss(j_hat, j_true)
            loss_wgm = self.wgm_loss(batch)
            loss = loss_surrogate + self.hparams.mixed_alpha * loss_wgm
        elif method == 'spsa':
            x = batch[:batch.size(0) // 2]
            if x.shape[0] == 0: return None
            v_theta = self.estimator_cnn(x)
            with torch.no_grad():
                eps = self.hparams.spsa_epsilon
                v = torch.randn_like(x)
                v_norm = torch.norm(v.flatten(1), p=2, dim=1).view(-1, 1, 1, 1) + 1e-8
                v = v / v_norm
                j_plus = self.target_cnn(x + eps * v)
                j_minus = self.target_cnn(x - eps * v)
                target_directional_derivative = (j_plus - j_minus) / (2 * eps)
            predicted_directional_derivative = (v_theta * v).sum(dim=(1,2,3))
            loss = F.mse_loss(predicted_directional_derivative, target_directional_derivative)

        self.log('train_loss', loss, prog_bar=False, logger=True); return loss

    def wgm_loss(self, x):
        M = self.hparams.wgm_n_samples
        x_rep = x.detach().repeat_interleave(M, dim=0).requires_grad_(True)
        if self.hparams.is_conservative:
            potential = self.estimator_cnn(x_rep)
            s_theta = torch.autograd.grad(potential.sum(), x_rep, create_graph=True)[0]
        else:
            s_theta = self.estimator_cnn(x_rep)
        v = torch.randn_like(s_theta)
        s_dot_v_sum = (s_theta * v).sum()
        grad_s_dot_v = torch.autograd.grad(s_dot_v_sum, x_rep, create_graph=True)[0]
        div_terms = (v * grad_s_dot_v).sum(dim=(1, 2, 3))
        grad_log_p_terms = (s_theta * (-x_rep / self.hparams.wgm_noise_std**2)).sum(dim=(1,2,3))
        h_all = div_terms + grad_log_p_terms
        with torch.no_grad():
            j_val = self.target_cnn(x).squeeze()
            j_val_rep = j_val.repeat_interleave(M, dim=0)
            c_opt = j_val_rep.mean() # Centre par batch pour réduire la variance
        term1_all = s_theta.pow(2).sum(dim=(1, 2, 3))
        term2_all = 2 * (j_val_rep - c_opt) * h_all
        loss = (term1_all + term2_all).mean()
        return loss

    def _compute_eval_metrics(self, batch):
        self.target_cnn.eval()
        with torch.no_grad():
            grad_j_true = self.target_cnn.gradient(batch)
        with torch.enable_grad():
            if self.hparams.method == 'spsa' or not self.hparams.is_conservative:
                estimated_grad = self.estimator_cnn(batch)
            else: 
                estimated_grad = self.estimator_cnn.gradient(batch)
        mse = F.mse_loss(estimated_grad.detach(), grad_j_true.detach())
        cos_sim = F.cosine_similarity(estimated_grad.detach().flatten(1), grad_j_true.detach().flatten(1)).mean()
        return {'mse': mse, 'cos_sim': cos_sim}

    def validation_step(self, batch, batch_idx):
        metrics = self._compute_eval_metrics(batch)
        self.log('val_mse', metrics['mse'], prog_bar=True, logger=True)
        self.log('val_cos_sim', metrics['cos_sim'], prog_bar=False, logger=True)
        return metrics

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        metrics = self._compute_eval_metrics(batch)
        self.log('test_mse', metrics['mse'])
        self.log('test_cos_sim', metrics['cos_sim'])
        return metrics

    def configure_optimizers(self): 
        return optim.Adam(self.estimator_cnn.parameters(), lr=self.hparams.lr)

# --- Fonctions d'Exécution et de Logging ---
def run_single_training(target_cnn, hparams):
    torch.manual_seed(hparams['seed'])
    estimator = hparams["estimator_class"](
        IMG_CHANNELS, 
        IMG_SIZE if hparams["estimator_class"] == Estimator_Potential_2D else IMG_CHANNELS, 
        activation_fn=ACTIVATIONS[hparams["activation_name"]]
    ).to(device)
    experiment = GradEstPLModule_2D(target_cnn, estimator, hparams)
    checkpoint_callback = ModelCheckpoint(monitor="val_mse", mode="min", save_top_k=1)
    trainer = pl.Trainer(
        max_epochs=MAX_EPOCHS, accelerator="auto", devices=1, logger=False, 
        callbacks=[checkpoint_callback], enable_model_summary=False, num_sanity_val_steps=0,
        enable_progress_bar=False, inference_mode=False
    )
    train_loader = DataLoader(SampleDatasetCNN_2D(IMG_CHANNELS, IMG_SIZE, 50, BATCH_SIZE), batch_size=None)
    val_loader = DataLoader(SampleDatasetCNN_2D(IMG_CHANNELS, IMG_SIZE, 5, BATCH_SIZE), batch_size=None)
    trainer.fit(model=experiment, train_dataloaders=train_loader, val_dataloaders=val_loader)
    test_results = trainer.test(ckpt_path="best", dataloaders=val_loader, verbose=False)
    if test_results:
         return {'final_mse': test_results[0]['test_mse'], 'final_cos_sim': test_results[0]['test_cos_sim']}
    else:
         return {'final_mse': float('nan'), 'final_cos_sim': float('nan')}

def log_and_save_results(case_name, all_raw_results, all_hparams_info, problem_names):
    output_dir = Path("benchmark_results")
    output_dir.mkdir(exist_ok=True)
    txt_file = output_dir / f"summary_{case_name}.txt"
    raw_csv_file = output_dir / f"all_runs_{case_name}.csv"
    champions_csv_file = output_dir / f"champions_{case_name}.csv"
    df_raw_mse = pd.DataFrame(all_raw_results['mse'], index=problem_names)
    df_raw_cos_sim = pd.DataFrame(all_raw_results['cos_sim'], index=problem_names)
    df_to_save_raw = df_raw_mse.stack().reset_index()
    df_to_save_raw.columns = ['problem_name', 'method_name', 'final_mse']
    df_cos_sim_flat = df_raw_cos_sim.stack().reset_index()
    df_cos_sim_flat.columns = ['problem_name', 'method_name', 'final_cos_sim']
    df_to_save_raw = pd.merge(df_to_save_raw, df_cos_sim_flat, on=['problem_name', 'method_name'])
    df_to_save_raw.to_csv(raw_csv_file, index=False)
    print(f"Rapport CSV avec tous les runs sauvegardé dans '{raw_csv_file}'")
    families = sorted(list(set([name.split(' ')[0].split('(')[0] for name in all_hparams_info.keys()])))
    champion_mse_scores, champion_cos_sim_scores = {fam: [] for fam in families}, {fam: [] for fam in families}
    for problem_name in problem_names:
        problem_mse, problem_cos_sim = df_raw_mse.loc[problem_name].dropna(), df_raw_cos_sim.loc[problem_name].dropna()
        for family in families:
            family_cols_mse = [col for col in problem_mse.index if col.startswith(family)]
            if not family_cols_mse:
                champion_mse_scores[family].append(np.nan)
                champion_cos_sim_scores[family].append(np.nan)
                continue
            best_config_mse = problem_mse[family_cols_mse].idxmin()
            champion_mse_scores[family].append(problem_mse[best_config_mse])
            if best_config_mse in problem_cos_sim.index:
                champion_cos_sim_scores[family].append(problem_cos_sim[best_config_mse])
            else:
                champion_cos_sim_scores[family].append(np.nan)
    df_champions_mse = pd.DataFrame(champion_mse_scores, index=problem_names)
    df_champions_cos_sim = pd.DataFrame(champion_cos_sim_scores, index=problem_names)
    df_champions_to_save = df_champions_mse.stack().reset_index()
    df_champions_to_save.columns = ['problem_name', 'family_name', 'champion_mse']
    df_champions_cos_sim_flat = df_champions_cos_sim.stack().reset_index()
    df_champions_cos_sim_flat.columns = ['problem_name', 'family_name', 'champion_cos_sim']
    df_champions_to_save = pd.merge(df_champions_to_save, df_champions_cos_sim_flat, on=['problem_name', 'family_name'])
    df_champions_to_save.to_csv(champions_csv_file, index=False)
    print(f"Rapport CSV avec les champions de chaque famille sauvegardé dans '{champions_csv_file}'")
    baseline = df_champions_mse['Surrogate'] + 1e-9
    df_champions_norm = df_champions_mse.div(baseline, axis=0)
    champion_stats = {'norm_mse_mean': df_champions_norm.mean(), 'norm_mse_std': df_champions_norm.std(), 'cos_sim_mean': df_champions_cos_sim.mean(), 'cos_sim_std': df_champions_cos_sim.std()}
    df_champion_stats = pd.DataFrame(champion_stats).sort_values('norm_mse_mean')
    header = f"--- Summary for Case: {case_name.upper()} ---\n"
    body = "--- Champion vs. Champion Analysis (Mean over all problems) ---\n"
    body += f"{'Family':<20} | {'Norm MSE Mean':<15} | {'Norm MSE Std':<15} | {'Cos Sim Mean':<15} | {'Cos Sim Std'}\n"
    body += ("-" * 95) + "\n"
    for family, stats in df_champion_stats.iterrows():
        body += f"{family:<20} | {stats['norm_mse_mean']:.4f}           | {stats['norm_mse_std']:.4f}           | {stats['cos_sim_mean']:.4f}          | {stats['cos_sim_std']:.4f}\n"
    summary_content = header + body
    summary_content += "\n\n--- Champion MSE per Problem ---\n" + df_champions_mse.to_string()
    summary_content += "\n\n--- Champion Cosine Similarity per Problem ---\n" + df_champions_cos_sim.to_string()
    with open(txt_file, "w") as f: f.write(summary_content)
    print("\n" + header + body)
    print(f"Résumé détaillé sauvegardé dans '{txt_file}'")

# --- Point d'Entrée Principal ---
if __name__ == "__main__":
    base_methods = { "Surrogate": {"method": "surrogate", "is_conservative": True, "estimator_class": Estimator_Potential_2D} }
    for M in N_HUTCHINSON_SAMPLES_LIST:
        base_methods[f"WGM-NC(M={M})"] = {"method": "wgm", "is_conservative": False, "estimator_class": Estimator_Direct_2D, "wgm_n_samples": M}
        base_methods[f"WGM-C(M={M})"] = {"method": "wgm", "is_conservative": True, "estimator_class": Estimator_Potential_2D, "wgm_n_samples": M}
        for alpha in MIXED_ALPHAS:
            base_methods[f"Mixed(a={alpha},M={M})"] = {"method": "mixed", "is_conservative": True, "estimator_class": Estimator_Potential_2D, "mixed_alpha": alpha, "wgm_n_samples": M}
    for eps in SPSA_EPSILONS:
        base_methods[f"SPSA(e={eps})"] = {"method": "spsa", "is_conservative": False, "estimator_class": Estimator_Direct_2D, "spsa_epsilon": eps}
    all_hparams_info = {}
    for name, config in base_methods.items():
        for lr in LEARNING_RATES:
            config_name = f"{name} lr={lr}"
            effective_lr = lr / (IMG_SIZE * IMG_SIZE) if config['method'] in ['wgm', 'mixed'] else lr
            hparams = config.copy()
            hparams.update({"base_lr": lr, "lr": effective_lr, "activation_name": "GELU", "wgm_noise_std": 1.0})
            all_hparams_info[config_name] = hparams

    TARGET_PROBLEMS = {
        "Ising_Ferromagnetic_H0": {"class": IsingTarget, "params": {"J": 1.0, "h": 0.0}},
        "Ising_Ferromagnetic_H1": {"class": IsingTarget, "params": {"J": 1.0, "h": 1.0}},
        "Ising_Antiferromagnetic": {"class": IsingTarget, "params": {"J": -1.0, "h": 0.0}},
        "XY_Model": {"class": XYTarget, "params": {"J": 1.0}}
    }
    
    for noise_config in J_NOISE_CONFIGS:
        noise_name = noise_config["name"]
        case_name = f"PhysicsModels_{noise_name}_Normalized"
        all_raw_results = {'mse': {name: [] for name in all_hparams_info.keys()}, 'cos_sim': {name: [] for name in all_hparams_info.keys()}}
        problem_names = list(TARGET_PROBLEMS.keys())
        
        pbar_problems = tqdm(enumerate(TARGET_PROBLEMS.items()), total=len(TARGET_PROBLEMS), desc=f"Problèmes ({case_name})")
        for problem_idx, (problem_name, config) in pbar_problems:
            TargetClass = config["class"]
            target_params = config["params"]
            
            base_target_with_noise_param = TargetClass(img_size=IMG_SIZE, **target_params, noise_std=0.0).to(device)
            normalized_target = NormalizedTarget(base_target_with_noise_param).to(device)
            
            current_noise_std = noise_config["noise_std"]
            if current_noise_std == "auto":
                current_noise_std = 0.1
            
            final_target = normalized_target
            # Le bruit sera géré par le wrapper, il faut juste lui passer le paramètre
            final_target.noise_std = current_noise_std
            
            pbar_hparams = tqdm(all_hparams_info.items(), desc=f"HParams for {problem_name}", leave=False)
            for config_name, hparams in pbar_hparams:
                run_hparams = hparams.copy()
                run_hparams['problem_idx'] = problem_idx
                run_hparams['seed'] = SEED * 1000 + problem_idx * 100 + int(hparams['base_lr'] * 1e5) + sum(ord(c) for c in config_name)
                run_hparams['method_name'] = config_name
                
                final_metrics = run_single_training(final_target, run_hparams)
                
                all_raw_results['mse'][config_name].append(final_metrics['final_mse'])
                all_raw_results['cos_sim'][config_name].append(final_metrics['final_cos_sim'])
                
        log_and_save_results(case_name, all_raw_results, all_hparams_info, problem_names)
    print("\nBenchmark complet terminé.")

Using device: cuda


Problèmes (PhysicsModels_deterministic_Normalized):   0%|          | 0/4 [00:00<?, ?it/s]

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=0.0707, Std=8.6648



HParams for Ising_Ferromagnetic_H0:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=18-step=950-v590.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=18-step=950-v590.ckpt

HParams for Ising_Ferromagnetic_H0:   1%|          | 1/150 [00:03<07:50,  3.16s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU availab

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-0.1310, Std=13.3570



HParams for Ising_Ferromagnetic_H1:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=15-step=800-v146.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=15-step=800-v146.ckpt

HParams for Ising_Ferromagnetic_H1:   1%|          | 1/150 [00:03<07:46,  3.13s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-0.1100, Std=8.6845



HParams for Ising_Antiferromagnetic:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=15-step=800-v149.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=15-step=800-v149.ckpt

HParams for Ising_Antiferromagnetic:   1%|          | 1/150 [00:03<07:50,  3.16s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBL

Calculating normalization stats for XYTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-196.0367, Std=20.3032



HParams for XY_Model:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=18-step=950-v640.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=18-step=950-v640.ckpt

HParams for XY_Model:   1%|          | 1/150 [00:03<07:40,  3.09s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` s

Rapport CSV avec tous les runs sauvegardé dans 'benchmark_results\all_runs_PhysicsModels_deterministic_Normalized.csv'
Rapport CSV avec les champions de chaque famille sauvegardé dans 'benchmark_results\champions_PhysicsModels_deterministic_Normalized.csv'

--- Summary for Case: PHYSICSMODELS_DETERMINISTIC_NORMALIZED ---
--- Champion vs. Champion Analysis (Mean over all problems) ---
Family               | Norm MSE Mean   | Norm MSE Std    | Cos Sim Mean    | Cos Sim Std
-----------------------------------------------------------------------------------------------
SPSA                 | 0.1375           | 0.0502           | 0.9968          | 0.0005
Mixed                | 0.8790           | 0.1175           | 0.9776          | 0.0085
WGM-C                | 0.8937           | 0.2958           | 0.9785          | 0.0064
WGM-NC               | 0.9272           | 0.3006           | 0.9781          | 0.0035
Surrogate            | 1.0000           | 0.0000           | 0.9748          | 0.007

Problèmes (PhysicsModels_stochastic_Normalized):   0%|          | 0/4 [00:00<?, ?it/s]

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
Stat Computation:  81%|████████▏ | 13/16 [00:00<00:00, 66.74it/s][A
                                                                 [A

Stats: Mean=-0.0091, Std=8.6187



HParams for Ising_Ferromagnetic_H0:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=9-step=500-v106.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=9-step=500-v106.ckpt

HParams for Ising_Ferromagnetic_H0:   1%|          | 1/150 [00:03<07:30,  3.03s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-0.1310, Std=13.3570



HParams for Ising_Ferromagnetic_H1:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=1-step=100-v93.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=1-step=100-v93.ckpt

HParams for Ising_Ferromagnetic_H1:   1%|          | 1/150 [00:02<07:26,  2.99s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVI

Calculating normalization stats for IsingTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-0.1100, Std=8.6845



HParams for Ising_Antiferromagnetic:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=19-step=1000-v2617.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=19-step=1000-v2617.ckpt

HParams for Ising_Antiferromagnetic:   1%|          | 1/150 [00:03<08:09,  3.29s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VI

Calculating normalization stats for XYTarget...



Stat Computation:   0%|          | 0/16 [00:00<?, ?it/s][A
                                                        [A

Stats: Mean=-196.0367, Std=20.3032



HParams for XY_Model:   0%|          | 0/150 [00:00<?, ?it/s][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=20` reached.
Restoring states from the checkpoint path at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=15-step=800-v170.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at F:\rebuttal_neurips\resnet\CNN\extract_champions\tanh\gibbs\spsa_correct\benchmark_results\few_data\checkpoints\epoch=15-step=800-v170.ckpt

HParams for XY_Model:   1%|          | 1/150 [00:03<08:11,  3.30s/it][AGPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` s

Rapport CSV avec tous les runs sauvegardé dans 'benchmark_results\all_runs_PhysicsModels_stochastic_Normalized.csv'
Rapport CSV avec les champions de chaque famille sauvegardé dans 'benchmark_results\champions_PhysicsModels_stochastic_Normalized.csv'

--- Summary for Case: PHYSICSMODELS_STOCHASTIC_NORMALIZED ---
--- Champion vs. Champion Analysis (Mean over all problems) ---
Family               | Norm MSE Mean   | Norm MSE Std    | Cos Sim Mean    | Cos Sim Std
-----------------------------------------------------------------------------------------------
WGM-C                | 0.7841           | 0.2139           | 0.9784          | 0.0042
Mixed                | 0.8008           | 0.0996           | 0.9772          | 0.0048
WGM-NC               | 0.8036           | 0.3258           | 0.9782          | 0.0047
Surrogate            | 1.0000           | 0.0000           | 0.9708          | 0.0085
SPSA                 | 6.7603           | 2.0911           | 0.8024          | 0.0093

Résumé




In [None]:
# main_physics_benchmark_final_normalized.py

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm
import numpy as np
import random
import pandas as pd
import warnings
import torch.nn.functional as F
from pathlib import Path
from torch.utils.data import DataLoader, IterableDataset, TensorDataset

# Tenter d'importer PyTorch Lightning
try:
    import pytorch_lightning as pl
    from pytorch_lightning.callbacks import ModelCheckpoint
except ImportError:
    print("Pytorch Lightning n'est pas installé. Veuillez l'installer avec : pip install pytorch-lightning")
    pl = None

# --- Configurations Générales ---
if pl is None:
    raise ImportError("Pytorch Lightning est requis pour exécuter ce script.")

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration pour les performances sur les GPU récents
if torch.cuda.is_available():
    torch.set_float32_matmul_precision('medium')

# Ignorer les avertissements non critiques
warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", ".*Checkpoint directory.*exists and is not empty.*")
warnings.filterwarnings("ignore", ".*Using a target size.*")

# --- Hyperparamètres du Benchmark ---
IMG_SIZE, IMG_CHANNELS = 16, 1
BATCH_SIZE, MAX_EPOCHS = 128, 20

# Grille de recherche des hyperparamètres
LEARNING_RATES = [10,5,1,5e-1,1e-1,5e-2, 1e-2, 5e-3, 1e-3, 1e-4]
N_HUTCHINSON_SAMPLES_LIST, MIXED_ALPHAS = [1, 5, 10], [0.1, 0.01]
SPSA_EPSILONS = [1e-1, 1e-3]
ACTIVATIONS = {"GELU": nn.GELU}
J_NOISE_CONFIGS = [{"name": "deterministic", "noise_std": 0.0}, {"name": "stochastic", "noise_std": "auto"}]
N_SAMPLES_FOR_STATS = 4096 # Nombre d'échantillons pour calculer la moyenne et l'std

# --- Fonctions Cibles Basées sur la Physique ---

class PhysicsTarget(nn.Module):
    def __init__(self, img_size=16, noise_std=0.0):
        super().__init__()
        self.img_size = img_size
        self.noise_std = noise_std
        for p in self.parameters(): p.requires_grad = False

    def forward(self, x): raise NotImplementedError

    def gradient(self, x):
        x_req = x.detach().requires_grad_(True)
        with torch.enable_grad():
            original_training_state = self.training
            self.training = False # Toujours calculer le gradient sans bruit
            j_sum = self.forward(x_req).sum()
            grad = torch.autograd.grad(j_sum, x_req)[0]
            self.training = original_training_state
        return grad

class NormalizedTarget(PhysicsTarget):
    """
    Wrapper qui centre et réduit une fonction cible J(x) en utilisant
    des statistiques globales calculées une seule fois.
    """
    def __init__(self, base_target: PhysicsTarget, n_samples_for_stats=N_SAMPLES_FOR_STATS):
        super().__init__(img_size=base_target.img_size, noise_std=base_target.noise_std)
        self.base_target = base_target
        
        print(f"Calculating normalization stats for {base_target.__class__.__name__}...")
        with torch.no_grad():
            original_training_state = self.base_target.training
            self.base_target.train(False)
            
            dummy_dataset = TensorDataset(torch.randn(n_samples_for_stats, IMG_CHANNELS, self.img_size, self.img_size))
            dummy_loader = DataLoader(dummy_dataset, batch_size=BATCH_SIZE*2)
            
            j_values = []
            for batch in tqdm(dummy_loader, desc="Stat Computation", leave=False):
                j_values.append(self.base_target(batch[0].to(device)))
            
            j_values = torch.cat(j_values)
            mean, std = j_values.mean().item(), j_values.std().item()
            self.base_target.train(original_training_state)

        self.register_buffer('mean', torch.tensor(mean, device=device))
        self.register_buffer('std', torch.tensor(std, device=device) + 1e-8)

        print(f"Stats: Mean={self.mean.item():.4f}, Std={self.std.item():.4f}")

    def forward(self, x):
        base_energy_no_noise = self.base_target(x)
        normalized_energy = (base_energy_no_noise - self.mean) / self.std
        
        if self.noise_std > 0 and self.training:
            noise = torch.randn_like(normalized_energy) * self.noise_std 
            return normalized_energy + noise
            
        return normalized_energy

class IsingTarget(PhysicsTarget):
    def __init__(self, img_size=16, J=1.0, h=0.0, noise_std=0.0):
        super().__init__(img_size, noise_std)
        self.J, self.h = J, h
        kernel = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.float32).view(1, 1, 3, 3)
        self.register_buffer('neighbor_kernel', kernel)

    def forward(self, x):
        spins = torch.tanh(x)
        neighbor_sum = F.conv2d(spins, self.neighbor_kernel.to(x.device), padding='same')
        interaction_energy = -self.J * (spins * neighbor_sum)
        field_energy = -self.h * spins
        total_energy = (interaction_energy.sum(dim=(1,2,3)) / 2) + field_energy.sum(dim=(1,2,3))
        # Note: Le bruit est géré par le wrapper NormalizedTarget (si utilisé)
        return total_energy

class XYTarget(PhysicsTarget):
    def __init__(self, img_size=16, J=1.0, noise_std=0.0):
        super().__init__(img_size, noise_std)
        self.J = J
        kernel_right = torch.tensor([[0, 0, 0], [0, -1, 1], [0, 0, 0]], dtype=torch.float32).view(1, 1, 3, 3)
        kernel_down = torch.tensor([[0, 0, 0], [0, -1, 0], [0, 1, 0]], dtype=torch.float32).view(1, 1, 3, 3)
        self.register_buffer('kernel_right', kernel_right)
        self.register_buffer('kernel_down', kernel_down)

    def forward(self, x):
        theta = x
        diff_right = F.conv2d(theta, self.kernel_right.to(x.device), padding='same')
        diff_down = F.conv2d(theta, self.kernel_down.to(x.device), padding='same')
        energy = -self.J * (torch.cos(diff_right) + torch.cos(diff_down))
        total_energy = energy.sum(dim=(1,2,3))
        return total_energy

# --- Modèles Estimateurs ---
class Estimator_Potential_2D(nn.Module):
    def __init__(self, in_channels, image_size, activation_fn=nn.GELU):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, 32, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(32, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(16, 1, kernel_size=1))
    def forward(self, x): return self.net(x).squeeze(-1).squeeze(-1)
    def gradient(self, x):
        x_req = x.detach().requires_grad_(True)
        with torch.enable_grad():
            potential_sum = self.forward(x_req).sum()
            grad = torch.autograd.grad(potential_sum, x_req)[0]
        return grad

class Estimator_Direct_2D(nn.Module):
    def __init__(self, in_channels, out_channels, activation_fn=nn.GELU):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, 32, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(32, 16, kernel_size=3, padding='same'), activation_fn(),
            nn.Conv2d(16, out_channels, kernel_size=3, padding='same'))
    def forward(self, x): return self.net(x)

# --- Dataset et Module PyTorch Lightning ---
class SampleDatasetCNN_2D(IterableDataset):
    def __init__(self, channels, size, steps, batch_size, std=1.0):
        self.channels, self.size, self.steps, self.batch_size, self.std = channels, size, steps, batch_size, std
    def __iter__(self):
        for _ in range(self.steps):
            yield torch.randn(self.batch_size, self.channels, self.size, self.size, device=device) * self.std
            
class GradEstPLModule_2D(pl.LightningModule):
    def __init__(self, target_cnn, estimator_cnn, hparams):
        super().__init__(); self.save_hyperparameters(hparams)
        self.target_cnn, self.estimator_cnn = target_cnn, estimator_cnn
        for p in self.target_cnn.parameters(): p.requires_grad = False

    def training_step(self, batch, batch_idx):
        self.target_cnn.train()
        method = self.hparams.method
        loss = torch.tensor(0.0, device=self.device)

        if method == 'surrogate':
            j_hat = self.estimator_cnn(batch).squeeze()
            with torch.no_grad(): j_true = self.target_cnn(batch).squeeze()
            loss = F.mse_loss(j_hat, j_true)
        elif method == 'wgm':
            loss = self.wgm_loss(batch)
        elif method == 'mixed':
            j_hat = self.estimator_cnn(batch).squeeze()
            with torch.no_grad(): j_true = self.target_cnn(batch).squeeze()
            loss_surrogate = F.mse_loss(j_hat, j_true)
            loss_wgm = self.wgm_loss(batch)
            loss = loss_surrogate + self.hparams.mixed_alpha * loss_wgm
        elif method == 'spsa':
            x = batch[:batch.size(0) // 2]
            if x.shape[0] == 0: return None
            v_theta = self.estimator_cnn(x)
            with torch.no_grad():
                eps = self.hparams.spsa_epsilon
                v = torch.randn_like(x)
                v_norm = torch.norm(v.flatten(1), p=2, dim=1).view(-1, 1, 1, 1) + 1e-8
                v = v / v_norm
                j_plus = self.target_cnn(x + eps * v)
                j_minus = self.target_cnn(x - eps * v)
                target_directional_derivative = (j_plus - j_minus) / (2 * eps)
            predicted_directional_derivative = (v_theta * v).sum(dim=(1,2,3))
            loss = F.mse_loss(predicted_directional_derivative, target_directional_derivative)

        self.log('train_loss', loss, prog_bar=False, logger=True); return loss

    def wgm_loss(self, x):
        M = self.hparams.wgm_n_samples
        x_rep = x.detach().repeat_interleave(M, dim=0).requires_grad_(True)
        if self.hparams.is_conservative:
            potential = self.estimator_cnn(x_rep)
            s_theta = torch.autograd.grad(potential.sum(), x_rep, create_graph=True)[0]
        else:
            s_theta = self.estimator_cnn(x_rep)
        v = torch.randn_like(s_theta)
        s_dot_v_sum = (s_theta * v).sum()
        grad_s_dot_v = torch.autograd.grad(s_dot_v_sum, x_rep, create_graph=True)[0]
        div_terms = (v * grad_s_dot_v).sum(dim=(1, 2, 3))
        grad_log_p_terms = (s_theta * (-x_rep / self.hparams.wgm_noise_std**2)).sum(dim=(1,2,3))
        h_all = div_terms + grad_log_p_terms
        with torch.no_grad():
            j_val = self.target_cnn(x).squeeze()
            j_val_rep = j_val.repeat_interleave(M, dim=0)
            c_opt = j_val_rep.mean() # Centre par batch pour réduire la variance
        term1_all = s_theta.pow(2).sum(dim=(1, 2, 3))
        term2_all = 2 * (j_val_rep - c_opt) * h_all
        loss = (term1_all + term2_all).mean()
        return loss

    def _compute_eval_metrics(self, batch):
        self.target_cnn.eval()
        with torch.no_grad():
            grad_j_true = self.target_cnn.gradient(batch)
        with torch.enable_grad():
            if self.hparams.method == 'spsa' or not self.hparams.is_conservative:
                estimated_grad = self.estimator_cnn(batch)
            else: 
                estimated_grad = self.estimator_cnn.gradient(batch)
        mse = F.mse_loss(estimated_grad.detach(), grad_j_true.detach())
        cos_sim = F.cosine_similarity(estimated_grad.detach().flatten(1), grad_j_true.detach().flatten(1)).mean()
        return {'mse': mse, 'cos_sim': cos_sim}

    def validation_step(self, batch, batch_idx):
        metrics = self._compute_eval_metrics(batch)
        self.log('val_mse', metrics['mse'], prog_bar=True, logger=True)
        self.log('val_cos_sim', metrics['cos_sim'], prog_bar=False, logger=True)
        return metrics

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        metrics = self._compute_eval_metrics(batch)
        self.log('test_mse', metrics['mse'])
        self.log('test_cos_sim', metrics['cos_sim'])
        return metrics

    def configure_optimizers(self): 
        return optim.Adam(self.estimator_cnn.parameters(), lr=self.hparams.lr)

# --- Fonctions d'Exécution et de Logging ---
def run_single_training(target_cnn, hparams):
    torch.manual_seed(hparams['seed'])
    estimator = hparams["estimator_class"](
        IMG_CHANNELS, 
        IMG_SIZE if hparams["estimator_class"] == Estimator_Potential_2D else IMG_CHANNELS, 
        activation_fn=ACTIVATIONS[hparams["activation_name"]]
    ).to(device)
    experiment = GradEstPLModule_2D(target_cnn, estimator, hparams)
    checkpoint_callback = ModelCheckpoint(monitor="val_mse", mode="min", save_top_k=1)
    trainer = pl.Trainer(
        max_epochs=MAX_EPOCHS, accelerator="auto", devices=1, logger=False, 
        callbacks=[checkpoint_callback], enable_model_summary=False, num_sanity_val_steps=0,
        enable_progress_bar=False, inference_mode=False
    )
    train_loader = DataLoader(SampleDatasetCNN_2D(IMG_CHANNELS, IMG_SIZE, 50, BATCH_SIZE), batch_size=None)
    val_loader = DataLoader(SampleDatasetCNN_2D(IMG_CHANNELS, IMG_SIZE, 5, BATCH_SIZE), batch_size=None)
    trainer.fit(model=experiment, train_dataloaders=train_loader, val_dataloaders=val_loader)
    test_results = trainer.test(ckpt_path="best", dataloaders=val_loader, verbose=False)
    if test_results:
         return {'final_mse': test_results[0]['test_mse'], 'final_cos_sim': test_results[0]['test_cos_sim']}
    else:
         return {'final_mse': float('nan'), 'final_cos_sim': float('nan')}

def log_and_save_results(case_name, all_raw_results, all_hparams_info, problem_names):
    output_dir = Path("benchmark_results")
    output_dir.mkdir(exist_ok=True)
    txt_file = output_dir / f"summary_{case_name}.txt"
    raw_csv_file = output_dir / f"all_runs_{case_name}.csv"
    champions_csv_file = output_dir / f"champions_{case_name}.csv"
    df_raw_mse = pd.DataFrame(all_raw_results['mse'], index=problem_names)
    df_raw_cos_sim = pd.DataFrame(all_raw_results['cos_sim'], index=problem_names)
    df_to_save_raw = df_raw_mse.stack().reset_index()
    df_to_save_raw.columns = ['problem_name', 'method_name', 'final_mse']
    df_cos_sim_flat = df_raw_cos_sim.stack().reset_index()
    df_cos_sim_flat.columns = ['problem_name', 'method_name', 'final_cos_sim']
    df_to_save_raw = pd.merge(df_to_save_raw, df_cos_sim_flat, on=['problem_name', 'method_name'])
    df_to_save_raw.to_csv(raw_csv_file, index=False)
    print(f"Rapport CSV avec tous les runs sauvegardé dans '{raw_csv_file}'")
    families = sorted(list(set([name.split(' ')[0].split('(')[0] for name in all_hparams_info.keys()])))
    champion_mse_scores, champion_cos_sim_scores = {fam: [] for fam in families}, {fam: [] for fam in families}
    for problem_name in problem_names:
        problem_mse, problem_cos_sim = df_raw_mse.loc[problem_name].dropna(), df_raw_cos_sim.loc[problem_name].dropna()
        for family in families:
            family_cols_mse = [col for col in problem_mse.index if col.startswith(family)]
            if not family_cols_mse:
                champion_mse_scores[family].append(np.nan)
                champion_cos_sim_scores[family].append(np.nan)
                continue
            best_config_mse = problem_mse[family_cols_mse].idxmin()
            champion_mse_scores[family].append(problem_mse[best_config_mse])
            if best_config_mse in problem_cos_sim.index:
                champion_cos_sim_scores[family].append(problem_cos_sim[best_config_mse])
            else:
                champion_cos_sim_scores[family].append(np.nan)
    df_champions_mse = pd.DataFrame(champion_mse_scores, index=problem_names)
    df_champions_cos_sim = pd.DataFrame(champion_cos_sim_scores, index=problem_names)
    df_champions_to_save = df_champions_mse.stack().reset_index()
    df_champions_to_save.columns = ['problem_name', 'family_name', 'champion_mse']
    df_champions_cos_sim_flat = df_champions_cos_sim.stack().reset_index()
    df_champions_cos_sim_flat.columns = ['problem_name', 'family_name', 'champion_cos_sim']
    df_champions_to_save = pd.merge(df_champions_to_save, df_champions_cos_sim_flat, on=['problem_name', 'family_name'])
    df_champions_to_save.to_csv(champions_csv_file, index=False)
    print(f"Rapport CSV avec les champions de chaque famille sauvegardé dans '{champions_csv_file}'")
    baseline = df_champions_mse['Surrogate'] + 1e-9
    df_champions_norm = df_champions_mse.div(baseline, axis=0)
    champion_stats = {'norm_mse_mean': df_champions_norm.mean(), 'norm_mse_std': df_champions_norm.std(), 'cos_sim_mean': df_champions_cos_sim.mean(), 'cos_sim_std': df_champions_cos_sim.std()}
    df_champion_stats = pd.DataFrame(champion_stats).sort_values('norm_mse_mean')
    header = f"--- Summary for Case: {case_name.upper()} ---\n"
    body = "--- Champion vs. Champion Analysis (Mean over all problems) ---\n"
    body += f"{'Family':<20} | {'Norm MSE Mean':<15} | {'Norm MSE Std':<15} | {'Cos Sim Mean':<15} | {'Cos Sim Std'}\n"
    body += ("-" * 95) + "\n"
    for family, stats in df_champion_stats.iterrows():
        body += f"{family:<20} | {stats['norm_mse_mean']:.4f}           | {stats['norm_mse_std']:.4f}           | {stats['cos_sim_mean']:.4f}          | {stats['cos_sim_std']:.4f}\n"
    summary_content = header + body
    summary_content += "\n\n--- Champion MSE per Problem ---\n" + df_champions_mse.to_string()
    summary_content += "\n\n--- Champion Cosine Similarity per Problem ---\n" + df_champions_cos_sim.to_string()
    with open(txt_file, "w") as f: f.write(summary_content)
    print("\n" + header + body)
    print(f"Résumé détaillé sauvegardé dans '{txt_file}'")

# --- Point d'Entrée Principal ---
if __name__ == "__main__":
    base_methods = { "Surrogate": {"method": "surrogate", "is_conservative": True, "estimator_class": Estimator_Potential_2D} }
    for M in N_HUTCHINSON_SAMPLES_LIST:
        base_methods[f"WGM-NC(M={M})"] = {"method": "wgm", "is_conservative": False, "estimator_class": Estimator_Direct_2D, "wgm_n_samples": M}
        base_methods[f"WGM-C(M={M})"] = {"method": "wgm", "is_conservative": True, "estimator_class": Estimator_Potential_2D, "wgm_n_samples": M}
        for alpha in MIXED_ALPHAS:
            base_methods[f"Mixed(a={alpha},M={M})"] = {"method": "mixed", "is_conservative": True, "estimator_class": Estimator_Potential_2D, "mixed_alpha": alpha, "wgm_n_samples": M}
    for eps in SPSA_EPSILONS:
        base_methods[f"SPSA(e={eps})"] = {"method": "spsa", "is_conservative": False, "estimator_class": Estimator_Direct_2D, "spsa_epsilon": eps}
    all_hparams_info = {}
    for name, config in base_methods.items():
        for lr in LEARNING_RATES:
            config_name = f"{name} lr={lr}"
            effective_lr = lr / (IMG_SIZE * IMG_SIZE) if config['method'] in ['wgm', 'mixed'] else lr
            hparams = config.copy()
            hparams.update({"base_lr": lr, "lr": effective_lr, "activation_name": "GELU", "wgm_noise_std": 1.0})
            all_hparams_info[config_name] = hparams

    TARGET_PROBLEMS = {
        "Ising_Ferromagnetic_H0": {"class": IsingTarget, "params": {"J": 1.0, "h": 0.0}},
        "Ising_Ferromagnetic_H1": {"class": IsingTarget, "params": {"J": 1.0, "h": 1.0}},
        "Ising_Antiferromagnetic": {"class": IsingTarget, "params": {"J": -1.0, "h": 0.0}},
        "XY_Model": {"class": XYTarget, "params": {"J": 1.0}}
    }
    
    for noise_config in J_NOISE_CONFIGS:
        noise_name = noise_config["name"]
        case_name = f"PhysicsModels_{noise_name}_Normalized"
        all_raw_results = {'mse': {name: [] for name in all_hparams_info.keys()}, 'cos_sim': {name: [] for name in all_hparams_info.keys()}}
        problem_names = list(TARGET_PROBLEMS.keys())
        
        pbar_problems = tqdm(enumerate(TARGET_PROBLEMS.items()), total=len(TARGET_PROBLEMS), desc=f"Problèmes ({case_name})")
        for problem_idx, (problem_name, config) in pbar_problems:
            TargetClass = config["class"]
            target_params = config["params"]
            
            base_target_with_noise_param = TargetClass(img_size=IMG_SIZE, **target_params, noise_std=0.0).to(device)
            normalized_target = NormalizedTarget(base_target_with_noise_param).to(device)
            
            current_noise_std = noise_config["noise_std"]
            if current_noise_std == "auto":
                current_noise_std = 0.1
            
            final_target = normalized_target
            # Le bruit sera géré par le wrapper, il faut juste lui passer le paramètre
            final_target.noise_std = current_noise_std
            
            pbar_hparams = tqdm(all_hparams_info.items(), desc=f"HParams for {problem_name}", leave=False)
            for config_name, hparams in pbar_hparams:
                run_hparams = hparams.copy()
                run_hparams['problem_idx'] = problem_idx
                run_hparams['seed'] = SEED * 1000 + problem_idx * 100 + int(hparams['base_lr'] * 1e5) + sum(ord(c) for c in config_name)
                run_hparams['method_name'] = config_name
                
                final_metrics = run_single_training(final_target, run_hparams)
                
                all_raw_results['mse'][config_name].append(final_metrics['final_mse'])
                all_raw_results['cos_sim'][config_name].append(final_metrics['final_cos_sim'])
                
        log_and_save_results(case_name, all_raw_results, all_hparams_info, problem_names)
    print("\nBenchmark complet terminé.")