In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torch.utils.data import DataLoader, IterableDataset, TensorDataset
import numpy as np
import random
import time
import matplotlib.pyplot as plt
import torch.nn.functional as F

# --- Configuration ---
N_PROBLEMS = 10
INPUT_DIM = 250
HIDDEN_DIM = 256
N_LAYERS = 4
TRAIN_STEPS_PER_EPOCH = 100
VAL_STEPS_PER_EPOCH = 1
N_TEST_SAMPLES = 2048
BATCH_SIZE = 4096
LEARNING_RATE = 1e-4
MAX_EPOCHS = 500

WGM_NOISE_STD = 1
RFD_EPSILON = 0.0001
WGM_LAMBDA = 1.0/np.sqrt(1.0*(INPUT_DIM))

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


import math
import torch
import torch.nn as nn

from pytorch_lightning.callbacks import Callback
# ...

# --- Callback pour Historique ---
class HistoryCallback(Callback):
    """Callback pour enregistrer l'historique des métriques de validation."""
    def __init__(self, monitor='val_mse'):
        super().__init__()
        self.monitor = monitor
        self.history = [] # Stockera les valeurs de la métrique pour un run

    def on_validation_epoch_end(self, trainer, pl_module):
        # Vérifier si la métrique est présente dans les logs du trainer
        # Les logs sont typiquement disponibles après le calcul des métriques de validation
        logs = trainer.callback_metrics
        if self.monitor in logs:
            # Ajouter la valeur actuelle à l'historique de ce run
            self.history.append(logs[self.monitor].item())
            # print(f"Epoch {trainer.current_epoch}: Logged {self.monitor}={logs[self.monitor]:.6f}") # Debug
        # else:
            # print(f"Epoch {trainer.current_epoch}: Metric {self.monitor} not found in logs") # Debug

    def get_history(self):
        """Retourne l'historique collecté."""
        return self.history

    def reset(self):
        """Réinitialise l'historique pour un nouveau run."""
        self.history = []

class Sine(nn.Module):
    """
    Activation sinus : y = sin(w0 * x).
    w0=30 pour la 1ère couche si on veut une forte oscillation.
    """
    def __init__(self, w0=30.0):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)


def siren_init(linear, fan_in, w0=30.0, is_first_layer=False):
    """
    Initialisation SIREN:
      - si is_first_layer=True => uniform(-1/fan_in, 1/fan_in)
      - sinon => uniform(-sqrt(6/fan_in)/w0, sqrt(6/fan_in)/w0)
    """
    with torch.no_grad():
        if is_first_layer:
            bound = 0.3 #/ fan_in
            linear.weight.normal_(mean=0.0, std=bound)
        else:
            bound = 0.3#/ w0 #math.sqrt(6.0 / fan_in) / w0
            linear.weight.normal_(mean=0.0, std=bound)


class GradientMLP(nn.Module):
    """
    MLP type SIREN pour générer une fonction J(x) oscillante,
    tout en forçant la création du graphe quand on appelle gradient(x).
    
    On garde la même signature que dans votre code original :
      - input_dim, n_layers, hidden_dim, envelope=False
    n_layers => nb de couches cachées (hors couche finale).
    """

    def __init__(self, input_dim, n_layers, hidden_dim, envelope=False,
                 w0=0.5, w0_hidden=0.5):
        super().__init__()
        self.envelope = envelope
        self.input_dim = input_dim
        layers = []
        # (1) Première couche
        first_linear = nn.Linear(input_dim, hidden_dim)
        siren_init(first_linear, fan_in=input_dim, w0=w0, is_first_layer=True)
        layers.append(first_linear)
        layers.append(Sine(w0=w0))

        # (2) Couches cachées
        for _ in range(n_layers - 1):
            hidden_linear = nn.Linear(hidden_dim, hidden_dim)
            siren_init(hidden_linear, fan_in=hidden_dim, w0=w0_hidden, is_first_layer=False)
            layers.append(hidden_linear)
            layers.append(Sine(w0=w0_hidden))

        # (3) Couche finale (pas de sinus)
        final_linear = nn.Linear(hidden_dim, 1)
        bound = 0.3
        final_linear.weight.data.normal_(mean=0.0, std=bound)
        layers.append(final_linear)
        layers.append(Sine(w0=w0_hidden))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        # Sortie scalaire
        scalar_output = self.net(x)
        # Enveloppe optionnelle
        if self.envelope:
            #norm = x.pow(4).mean(dim=1, keepdim=True)
            norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()  # ||x||
        # Centrer la cloche autour de la norme typique
            r = 1*WGM_NOISE_STD * np.sqrt(self.input_dim) 

            # Ajuster ce facteur pour contrôler la décroissance
            #r=1#/np.sqrt(self.input_dim)
            #r = np.sqrt(self.input_dim)
            alpha = torch.clamp(norm-r,min=0)/np.sqrt(self.input_dim) 
            envelope = torch.exp(- alpha)
            scalar_output = scalar_output * envelope
        return scalar_output +  0*torch.randn_like(scalar_output)*scalar_output.std()*1e-1

    def gradient(self, x):
        """
        Calcule le gradient wrt x, en forçant l'activation de l'autograd
        même si Lightning est en no_grad() pendant la validation ou le test.
        """
        # On override le contexte no_grad() éventuel :
        with torch.enable_grad():
            x = x.clone().detach().requires_grad_(True)
            y_sum = self.forward(x).sum()
            # create_graph=True si on veut autoriser des dérivées plus haut ordre
            grad = torch.autograd.grad(y_sum, x, create_graph=True)[0]
            
        return grad 

class GradientMLP2(nn.Module):
    """
    MLP type SIREN pour générer une fonction J(x) oscillante,
    tout en forçant la création du graphe quand on appelle gradient(x).
    
    On garde la même signature que dans votre code original :
      - input_dim, n_layers, hidden_dim, envelope=False
    n_layers => nb de couches cachées (hors couche finale).
    """

    def __init__(self, input_dim, n_layers, hidden_dim, envelope=False,
                 w0=1.0, w0_hidden=0.5):
        super().__init__()
        self.envelope = envelope

        layers = []
        # (1) Première couche
        first_linear = nn.Linear(input_dim, hidden_dim)
        layers.append(first_linear)
        layers.append(nn.Tanh())

        # (2) Couches cachées
        for _ in range(n_layers - 1):
            hidden_linear = nn.Linear(hidden_dim, hidden_dim)
            layers.append(hidden_linear)
            layers.append(nn.Tanh())

        # (3) Couche finale (pas de sinus)
        final_linear = nn.Linear(hidden_dim, 1)

        layers.append(final_linear)

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        # Sortie scalaire
        scalar_output = self.net(x)
        # Enveloppe optionnelle
        if self.envelope:
            norm = x.pow(2).mean(dim=1, keepdim=True)
            # Ajuster ce facteur pour contrôler la décroissance
            envelope = 1 #torch.exp(-0.5 * norm)
            scalar_output = scalar_output * envelope
        return scalar_output

    def gradient(self, x):
        """
        Calcule le gradient wrt x, en forçant l'activation de l'autograd
        même si Lightning est en no_grad() pendant la validation ou le test.
        """
        # On override le contexte no_grad() éventuel :
        with torch.enable_grad():
            x = x.clone().detach().requires_grad_(True)
            y_sum = self.forward(x).sum()
            # create_graph=True si on veut autoriser des dérivées plus haut ordre
            grad = torch.autograd.grad(y_sum, x, create_graph=True)[0]
        return grad

# --- Dataset ---
class SampleDataset(IterableDataset):
    def __init__(self, input_dim, steps_per_epoch, batch_size, distribution_std=WGM_NOISE_STD):
        self.input_dim = input_dim
        self.steps_per_epoch = steps_per_epoch
        self.batch_size = batch_size
        self.distribution_std = distribution_std
        self.step_count = 0

    def __iter__(self):
        self.step_count = 0
        return self

    def __next__(self):
        if self.step_count < self.steps_per_epoch:
            self.step_count += 1
            x = torch.randn(self.batch_size, self.input_dim, device=device) * self.distribution_std
            return torch.cat([x,x,x,x,x])
        else:
            raise StopIteration

# --- Lightning Module ---
class GradientEstimationExperiment(pl.LightningModule):
    def __init__(self, target_mlp, estimator_mlp, method, lr, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        self.target_mlp = target_mlp
        self.estimator_mlp = estimator_mlp
        self.method = method
        self.p_std_sq = WGM_NOISE_STD ** 2
        self.rfd_epsilon = hparams.get("rfd_epsilon", 1e-4)
        self.wgm_n_samples = 1
        self.val_historic = []
        self.l2_historic = []
        self.grad_l2_historic = []
        
        for p in self.target_mlp.parameters():
            p.requires_grad = False
        self.target_mlp.eval()

    def forward(self, x):
        return self.estimator_mlp.gradient(x)

    def training_step(self, batch, batch_idx):
        x = batch
        loss = None

        if self.method == 'direct_mse':
            grad_j_true = self.target_mlp.gradient(x.detach())
            s_theta = self.estimator_mlp.gradient(x)
            loss = F.mse_loss(s_theta, grad_j_true)

        elif self.method == 'wgm':
            x = x.clone().detach().requires_grad_(True)
            with torch.no_grad():
                j_val = self.target_mlp(x).squeeze()
        
            s_output = self.estimator_mlp(x)  # scalaire
            s_theta = torch.autograd.grad(s_output.sum(), x, create_graph=True)[0]
        
            div_s_approx = 0.0
            for _ in range(self.wgm_n_samples):
                #v = torch.randint_like(x, low=0, high=2).float() * 2.0 - 1.0
                #v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-8)
                v = torch.randn_like(x)
                #v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-8) 
                
                s_dot_v = (s_theta * v).sum(dim=1)
                grad_s_dot_v = torch.autograd.grad(s_dot_v.sum(), x, create_graph=True, allow_unused=True)[0]
                div_s_approx += (grad_s_dot_v * v).sum(dim=1)
        
            div_s_approx /= (1.0*self.wgm_n_samples)
            grad_log_p = -x / self.p_std_sq
        
            term1 = (s_theta ** 2).sum(dim=1)
            term2 = 2 * j_val * (div_s_approx + (s_theta * grad_log_p).sum(dim=1))
            loss = (term1 + 1 * term2).mean()
    

        elif self.method == 'rfd_mse':
            with torch.no_grad():
                v = torch.randn_like(x)
                v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-8)
                j_plus = self.target_mlp(x + self.rfd_epsilon * v).squeeze()
                j_minus = self.target_mlp(x - self.rfd_epsilon * v).squeeze()
                fd_dir_deriv = (j_plus - j_minus) / (2 * self.rfd_epsilon)

            s_theta_dot_v = (self.estimator_mlp.gradient(x) * v).sum(dim=1)
            loss = F.mse_loss(s_theta_dot_v, fd_dir_deriv)


        
        elif self.method == 'surrogate':

            j = self.target_mlp(x).squeeze()
            j_s_theta = self.estimator_mlp(x).sum(dim=1)
        
            loss = F.mse_loss(j, j_s_theta)

        elif self.method == "mixed":
            
            j = self.target_mlp(x).squeeze()
            j_s_theta = self.estimator_mlp(x).sum(dim=1)
        
            loss = F.mse_loss(j, j_s_theta)

            x = x.clone().detach().requires_grad_(True)
            with torch.no_grad():
                j_val = self.target_mlp(x).squeeze()
        
            s_output = self.estimator_mlp(x)  # scalaire
            s_theta = torch.autograd.grad(s_output.sum(), x, create_graph=True)[0]
        
            div_s_approx = 0.0
            for _ in range(self.wgm_n_samples):
                v = torch.randint_like(x, low=0, high=2).float() * 2.0 - 1.0
                #v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-8)
                #v = torch.randn_like(x)
                #v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-8) 
                
                s_dot_v = (s_theta * v).sum(dim=1)
                grad_s_dot_v = torch.autograd.grad(s_dot_v.sum(), x, create_graph=True, allow_unused=True)[0]
                div_s_approx += (grad_s_dot_v * v).sum(dim=1)
        
            div_s_approx /= (1.0*self.wgm_n_samples)
            grad_log_p = -x / self.p_std_sq
        
            term1 = (s_theta ** 2).sum(dim=1)
            term2 = 2 * j_val * (div_s_approx + (s_theta * grad_log_p).sum(dim=1))
            loss += (term1 + 1 * term2).mean()
            loss *= 0.5

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=False)
        return loss
        
    def validation_step(self, batch, batch_idx):
        x = batch
        grad_j_true = self.target_mlp.gradient(x.detach())
        j_true = self.target_mlp(x.detach()).sum(dim=1).detach()
        
        self.grad_l2_historic.append((grad_j_true**2).mean(dim=1).mean().detach())
        self.l2_historic.append((j_true**2).mean().detach())

        s_theta = self.estimator_mlp.gradient(x)
        val_mse = F.mse_loss(s_theta, grad_j_true)
        self.log('val_mse', val_mse, on_step=False, on_epoch=True, prog_bar=False)
        self.val_historic.append(val_mse.cpu().detach())
        return val_mse

    def test_step(self, batch, batch_idx):
        x = batch[0] if isinstance(batch, (list, tuple)) else batch
        grad_j_true = self.target_mlp.gradient(x.detach())
        s_theta = self.estimator_mlp.gradient(x)
        test_mse = F.mse_loss(s_theta, grad_j_true)
        self.log('test_mse', test_mse, on_step=False, on_epoch=True)
        return test_mse

    def configure_optimizers(self):
        """Configure l'optimiseur Adam et un scheduler ExponentialLR simple."""

        # Créer l'optimiseur Adam
        # Utilise le learning rate défini dans les hyperparamètres (accessible via self.hparams)
        optimizer = optim.Adam(self.estimator_mlp.parameters(), lr=self.hparams['lr'])

        # Créer un scheduler de learning rate (exemple: ExponentialLR)
        # Réduit le LR de 2% à chaque époque (gamma=0.98)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998)

        # Retourner l'optimiseur et le scheduler configuré pour Pytorch Lightning
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",  # Mettre à jour le LR à chaque époque
                "frequency": 1,       # Mettre à jour à chaque intervalle (chaque époque ici)
                # 'monitor': 'val_loss' # Optionnel, utile seulement pour certains schedulers comme ReduceLROnPlateau
            },
        }

# --- Training Script ---
if __name__ == "__main__":
    all_results = { 'wgm': [],'surrogate': []}
    trained_estimators = { 'wgm': [],'surrogate': []}
    val_loss_per_epoch = { 'wgm': [],'surrogate': []}
    grad_l2_historic_epoch = { 'wgm': [],'surrogate': []}
    l2_historic_epoch = { 'wgm': [],'surrogate': []}

    '''
    all_results = { 'mixed': [],'surrogate': []}
    trained_estimators = { 'mixed': [],'surrogate': []}
    val_loss_per_epoch = { 'mixed': [],'surrogate': []}
    grad_l2_historic_epoch = { 'mixed': [],'surrogate': []}
    l2_historic_epoch = { 'mixed': [],'surrogate': []}'''
    

    
    for i in range(N_PROBLEMS):
        print(f"\n--- Training on Problem {i+1}/{N_PROBLEMS} ---")

        target_mlp = GradientMLP(INPUT_DIM, 2, 128, envelope=True).to(device)

        train_dataset = SampleDataset(INPUT_DIM, TRAIN_STEPS_PER_EPOCH, BATCH_SIZE)
        val_dataset   = SampleDataset(INPUT_DIM, VAL_STEPS_PER_EPOCH, BATCH_SIZE)
        X_test = torch.randn(N_TEST_SAMPLES, INPUT_DIM, device=device) * WGM_NOISE_STD
        test_dataset = TensorDataset(X_test)

        train_loader = DataLoader(train_dataset, batch_size=None, num_workers=0)
        val_loader   = DataLoader(val_dataset, batch_size=None, num_workers=0)
        test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=0)

        for method in ['surrogate', 'wgm']:
            print(f"  Training method: {method}")

            estimator_mlp = GradientMLP2(INPUT_DIM, N_LAYERS, HIDDEN_DIM, envelope=False).to(device)

            hparams_dict = {
                "input_dim": INPUT_DIM, "hidden_dim": HIDDEN_DIM, "n_layers": N_LAYERS,
                "lr": LEARNING_RATE, "method": method, "wgm_lambda": WGM_LAMBDA,
                "rfd_epsilon": RFD_EPSILON, "problem_idx": i
            }

            experiment = GradientEstimationExperiment(target_mlp, estimator_mlp, method, LEARNING_RATE, hparams_dict)
            logger = pl.loggers.TensorBoardLogger("lightning_logs/", name=f"problem_{i}", version=f"{method}_{int(time.time())}")

            trainer = pl.Trainer(
                max_epochs=MAX_EPOCHS,
                accelerator="auto",
                devices=1,
                logger=logger,
                callbacks=[
                    pl.callbacks.ModelCheckpoint(
                        monitor="val_mse", mode="min",
                        filename=f"{method}-{{epoch:02d}}-{{val_mse:.6f}}"
                    )
                ],
                enable_progress_bar=False,
                enable_model_summary=True,
                num_sanity_val_steps=0,
                limit_train_batches=TRAIN_STEPS_PER_EPOCH,
                limit_val_batches=VAL_STEPS_PER_EPOCH,
                inference_mode=False,
            )

            print(f"    Starting training for {method}...")
            trainer.fit(experiment, train_loader, val_loader)
            print(f"    Finished training for {method}.")

            print(f"    Starting testing for {method}...")
            test_results = trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=False)
            if test_results:
                final_mse = test_results[0].get('test_mse')
                if final_mse is not None:
                    all_results[method].append(final_mse)
                    print(f"    Method: {method}, Final Test MSE: {final_mse:.6f}")
                else:
                    print(f"    Method: {method}, 'test_mse' not found.")
                    all_results[method].append(float('nan'))
            else:
                print(f"    Method: {method}, Testing failed.")
                all_results[method].append(float('nan'))

            # ✅ Stocker le modèle estimé pour visualisation future
            trained_estimators[method].append(experiment.estimator_mlp.cpu().eval())
            val_loss_per_epoch[method].append(experiment.val_historic)
            grad_l2_historic_epoch[method].append(experiment.grad_l2_historic)
            l2_historic_epoch[method].append(experiment.l2_historic)

    print("\n--- Average Results ---")
    for method in [ 'mixed',  'surrogate']:
        results = [res for res in all_results[method] if not np.isnan(res)]
        if results:
            avg_mse = np.mean(results)
            std_mse = np.std(results)
            print(f"Method: {method:<12} | Avg Test MSE: {avg_mse:.6f} +/- {std_mse:.6f}  ({len(results)} successful runs)")
        else:
            print(f"Method: {method:<12} | No successful runs found.")





Using device: cuda

--- Training on Problem 1/10 ---
  Training method: surrogate


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.


    Starting training for surrogate...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)
/home/sbenaich/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
/home/sbenaich/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to impr

    Finished training for surrogate.
    Starting testing for surrogate...
    Method: surrogate, Final Test MSE: 0.029274
  Training method: wgm


HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Starting training for wgm...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_0/wgm_1746001296/checkpoints/wgm-epoch=456-val_mse=0.027965.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_0/wgm_1746001296/checkpoints/wgm-epoch=456-val_mse=0.027965.ckpt
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


    Finished training for wgm.
    Starting testing for wgm...
    Method: wgm, Final Test MSE: 0.029005

--- Training on Problem 2/10 ---
  Training method: surrogate
    Starting training for surrogate...



  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_1/surrogate_1746003252/checkpoints/surrogate-epoch=487-val_mse=0.020299.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_1/surrogate_1746003252/checkpoints/surrogate-epoch=487-val_mse=0.020299.ckpt


    Finished training for surrogate.
    Starting testing for surrogate...


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Method: surrogate, Final Test MSE: 0.020562
  Training method: wgm
    Starting training for wgm...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_1/wgm_1746003661/checkpoints/wgm-epoch=490-val_mse=0.019808.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_1/wgm_1746003661/checkpoints/wgm-epoch=490-val_mse=0.019808.ckpt


    Finished training for wgm.
    Starting testing for wgm...
    Method: wgm, Final Test MSE: 0.020158

--- Training on Problem 3/10 ---
  Training method: surrogate


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Starting training for surrogate...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_2/surrogate_1746005618/checkpoints/surrogate-epoch=478-val_mse=0.024289.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_2/surrogate_1746005618/checkpoints/surrogate-epoch=478-val_mse=0.024289.ckpt
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Finished training for surrogate.
    Starting testing for surrogate...
    Method: surrogate, Final Test MSE: 0.024770
  Training method: wgm
    Starting training for wgm...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_2/wgm_1746006027/checkpoints/wgm-epoch=443-val_mse=0.024100.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_2/wgm_1746006027/checkpoints/wgm-epoch=443-val_mse=0.024100.ckpt
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.


    Finished training for wgm.
    Starting testing for wgm...
    Method: wgm, Final Test MSE: 0.024501

--- Training on Problem 4/10 ---
  Training method: surrogate
    Starting training for surrogate...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_3/surrogate_1746007984/checkpoints/surrogate-epoch=476-val_mse=0.019549.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_3/surrogate_1746007984/checkpoints/surrogate-epoch=476-val_mse=0.019549.ckpt
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LO

    Finished training for surrogate.
    Starting testing for surrogate...
    Method: surrogate, Final Test MSE: 0.019740
  Training method: wgm
    Starting training for wgm...



  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_3/wgm_1746008393/checkpoints/wgm-epoch=489-val_mse=0.019046.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_3/wgm_1746008393/checkpoints/wgm-epoch=489-val_mse=0.019046.ckpt


    Finished training for wgm.
    Starting testing for wgm...
    Method: wgm, Final Test MSE: 0.019271

--- Training on Problem 5/10 ---
  Training method: surrogate


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Starting training for surrogate...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_4/surrogate_1746010350/checkpoints/surrogate-epoch=482-val_mse=0.023146.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_4/surrogate_1746010350/checkpoints/surrogate-epoch=482-val_mse=0.023146.ckpt


    Finished training for surrogate.
    Starting testing for surrogate...
    Method: surrogate, Final Test MSE: 0.024906
  Training method: wgm


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Starting training for wgm...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_4/wgm_1746010760/checkpoints/wgm-epoch=495-val_mse=0.022858.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_4/wgm_1746010760/checkpoints/wgm-epoch=495-val_mse=0.022858.ckpt


    Finished training for wgm.
    Starting testing for wgm...
    Method: wgm, Final Test MSE: 0.024444

--- Training on Problem 6/10 ---
  Training method: surrogate


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Starting training for surrogate...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_5/surrogate_1746012716/checkpoints/surrogate-epoch=464-val_mse=0.026231.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_5/surrogate_1746012716/checkpoints/surrogate-epoch=464-val_mse=0.026231.ckpt
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.


    Finished training for surrogate.
    Starting testing for surrogate...
    Method: surrogate, Final Test MSE: 0.026997
  Training method: wgm
    Starting training for wgm...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_5/wgm_1746013125/checkpoints/wgm-epoch=445-val_mse=0.026633.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_5/wgm_1746013125/checkpoints/wgm-epoch=445-val_mse=0.026633.ckpt
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIB

    Finished training for wgm.
    Starting testing for wgm...
    Method: wgm, Final Test MSE: 0.027396

--- Training on Problem 7/10 ---
  Training method: surrogate
    Starting training for surrogate...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_6/surrogate_1746015082/checkpoints/surrogate-epoch=491-val_mse=0.017740.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_6/surrogate_1746015082/checkpoints/surrogate-epoch=491-val_mse=0.017740.ckpt
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Finished training for surrogate.
    Starting testing for surrogate...
    Method: surrogate, Final Test MSE: 0.018175
  Training method: wgm
    Starting training for wgm...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_6/wgm_1746015491/checkpoints/wgm-epoch=468-val_mse=0.017504.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_6/wgm_1746015491/checkpoints/wgm-epoch=468-val_mse=0.017504.ckpt
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Finished training for wgm.
    Starting testing for wgm...
    Method: wgm, Final Test MSE: 0.017955

--- Training on Problem 8/10 ---
  Training method: surrogate
    Starting training for surrogate...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_7/surrogate_1746017448/checkpoints/surrogate-epoch=494-val_mse=0.020531.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_7/surrogate_1746017448/checkpoints/surrogate-epoch=494-val_mse=0.020531.ckpt


    Finished training for surrogate.
    Starting testing for surrogate...
    Method: surrogate, Final Test MSE: 0.020798
  Training method: wgm


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Starting training for wgm...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_7/wgm_1746017857/checkpoints/wgm-epoch=433-val_mse=0.020858.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_7/wgm_1746017857/checkpoints/wgm-epoch=433-val_mse=0.020858.ckpt


    Finished training for wgm.
    Starting testing for wgm...
    Method: wgm, Final Test MSE: 0.020954

--- Training on Problem 9/10 ---
  Training method: surrogate


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Starting training for surrogate...


`Trainer.fit` stopped: `max_epochs=500` reached.
Restoring states from the checkpoint path at lightning_logs/problem_8/surrogate_1746019812/checkpoints/surrogate-epoch=485-val_mse=0.021170.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at lightning_logs/problem_8/surrogate_1746019812/checkpoints/surrogate-epoch=485-val_mse=0.021170.ckpt
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type         | Params
-----------------------------------------------
0 | target_mlp    | GradientMLP  | 48.8 K
1 | estimator_mlp | GradientMLP2 | 261 K 
-----------------------------------------------
261 K     Trainable params
48.8 K    Non-trainable params
310 K     Total params
1.243     Total estimated model params size (MB)


    Finished training for surrogate.
    Starting testing for surrogate...
    Method: surrogate, Final Test MSE: 0.021557
  Training method: wgm
    Starting training for wgm...


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# --- la fonction d’aide que tu avais déjà ---
def extract_ts(method, results):
    runs = results[method]                # list[list[float]]
    arr = torch.zeros((len(runs), len(runs[0])))
    for i, run in enumerate(runs):
        arr[i] = torch.tensor(run)
    return arr.cpu().numpy()              # shape = (n_runs, n_epochs)

# --- on récupère les séries ---
#ts_wgm       = extract_ts("wgm",       val_loss_per_epoch)
ts_wgm       = extract_ts("wgm",       val_loss_per_epoch)

ts_surrogate = extract_ts("surrogate", val_loss_per_epoch)

l2_surrogate = extract_ts("wgm",       l2_historic_epoch)
l2_wgm = extract_ts("surrogate", l2_historic_epoch)

grad_l2_surrogate = extract_ts("wgm",       grad_l2_historic_epoch)
grad_l2_wgm = extract_ts("surrogate", grad_l2_historic_epoch)

import os   
os.makedirs("metrics", exist_ok=True)

# --- variante NumPy (.npy) : -----
np.save("metrics/ts_wgm.npy",             ts_wgm)
np.save("metrics/ts_surrogate.npy",       ts_surrogate)
np.save("metrics/l2_surrogate.npy",       l2_surrogate)
np.save("metrics/l2_wgm.npy",             l2_wgm)
np.save("metrics/grad_l2_surrogate.npy",  grad_l2_surrogate)
np.save("metrics/grad_l2_wgm.npy",        grad_l2_wgm)

# --- moyenne & écart-type sur les runs (axe 0) ---
mean_wgm, std_wgm             = ts_wgm.mean(0), ts_wgm.std(0)
mean_surrogate, std_surrogate = ts_surrogate.mean(0), ts_surrogate.std(0)

epochs = np.arange(len(mean_wgm))  # X-axis

# --- figure ---
plt.figure(figsize=(7,4))
plt.plot(epochs, mean_wgm,       label="WGM")
plt.fill_between(epochs,
                 mean_wgm - std_wgm,
                 mean_wgm + std_wgm,
                 alpha=0.2)

plt.plot(epochs, mean_surrogate, label="Surrogate")
plt.fill_between(epochs,
                 mean_surrogate - std_surrogate,
                 mean_surrogate + std_surrogate,
                 alpha=0.2)

plt.xlabel("Epoch")
plt.ylabel("Validation loss")
#plt.yscale("log")              # retire si tu préfères l’échelle linéaire
plt.legend()
plt.title("Évolution moyenne de la loss (±1 σ)")
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np

def plot_scalar_surfaces(target_mlp, scalar_estimators, methods, fixed_vars=None, resolution=200, range_min=-6, range_max=6):
    """
    Affiche les graphes de niveau (contours) de la fonction scalaire cible et des estimateurs pour une analyse qualitative.

    Args:
        target_mlp (nn.Module): Le modèle cible (retourne un scalaire).
        scalar_estimators (dict): Dictionnaire de modèles entraînés {method_name: estimator_model}.
        methods (list): Liste des noms de méthodes à afficher.
        fixed_vars (list or tensor): Valeurs des dimensions fixes (INPUT_DIM - 2).
        resolution (int): Résolution de la grille (pixels).
        range_min (float): Min des axes x/y.
        range_max (float): Max des axes x/y.
    """
    INPUT_DIM = target_mlp.net[0].in_features  # suppose GradientMLP
    assert INPUT_DIM >= 2, "Input dimension must be at least 2 for plotting."

    device = next(target_mlp.parameters()).device

    if fixed_vars is None:
        fixed_vars = torch.zeros(INPUT_DIM - 2, device=device)+1
    else:
        fixed_vars = torch.tensor(fixed_vars, dtype=torch.float32, device=device)

    x = np.linspace(range_min, range_max, resolution)
    y = np.linspace(range_min, range_max, resolution)
    xx, yy = np.meshgrid(x, y)
    grid_points = np.stack([xx.ravel(), yy.ravel()], axis=-1)

    fixed_exp = fixed_vars.unsqueeze(0).expand(resolution ** 2, -1)
    grid_tensor = torch.tensor(grid_points, dtype=torch.float32, device=device)
    full_inputs = torch.cat([grid_tensor, fixed_exp], dim=1)

    # Ground Truth
    with torch.no_grad():
        gt_values = target_mlp(full_inputs).view(resolution, resolution).cpu().numpy()

    plt.figure(figsize=(5 * (len(methods) + 1), 4))
    plt.subplot(1, len(methods)+1, 1)
    plt.contourf(xx, yy, gt_values, levels=50, cmap='viridis')
    plt.colorbar()
    plt.title('Ground Truth')

    for idx, method in enumerate(methods):
        model = scalar_estimators[method]
        model.eval()
        with torch.no_grad():
            est_values = model(full_inputs).view(resolution, resolution).cpu().numpy()

        plt.subplot(1, len(methods)+1, idx+2)
        plt.contourf(xx, yy, est_values, levels=50, cmap='viridis')
        plt.colorbar()
        plt.title(f'{method.upper()} Estimate')

    plt.tight_layout()
    plt.show()


In [None]:

scalar_estimators_dict = {
    #method: trained_estimators[method][-1] for method in ['direct_mse', 'wgm', 'rfd_mse','surrogate','mixed']
        method: trained_estimators[method][-1] for method in ['wgm', 'surrogate']

}

fixed_vars = [-0.20] * (INPUT_DIM - 2)
plot_scalar_surfaces(target_mlp.cpu().eval(), scalar_estimators_dict, list(scalar_estimators_dict.keys()), fixed_vars=fixed_vars)


In [None]:
for method in ['direct_mse', 'wgm', 'rfd_mse', 'surrogate', 'mixed']:
    results = [res for res in all_results[method] if not np.isnan(res)]
    if results:
        avg_mse = np.mean(results)
        std_mse = np.std(results)
        print(f"Method: {method:<12} | Avg Test MSE: {avg_mse:.6f} +/- {std_mse:.6f}  ({len(results)} successful runs)")
    else:
        print(f"Method: {method:<12} | No successful runs found.")
