In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
try:
  import pytorch_lightning as pl
except:
  !pip install pytorch-lightning
  import pytorch_lightning as pl

from torch.utils.data import DataLoader, IterableDataset, TensorDataset
import numpy as np
import random
import time
# NOTE: No Matplotlib imports here yet
import torch.nn.functional as F
import math

# --- Configuration ---
N_PROBLEMS = 3             # Set to 1 for simplicity in saving/loading models later
INPUT_DIM = 2
HIDDEN_DIM = 128
N_LAYERS = 4
TRAIN_STEPS_PER_EPOCH = 300
VAL_STEPS_PER_EPOCH = 1  # Increased slightly for more stable validation
N_TEST_SAMPLES = 1024
BATCH_SIZE = 256
LEARNING_RATE = 1e-5
MAX_EPOCHS = 400 # Reduced for faster execution example, increase for better results

WGM_NOISE_STD = 2.0
RFD_EPSILON = 0.01
WGM_LAMBDA = 1.0  # Often 1/sigma^2
WGM_N_SAMPLES = 1

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Model Components ---

class Sine(nn.Module):
    def __init__(self, w0=30.0):
        super().__init__()
        self.w0 = w0
    def forward(self, x):
        return torch.sin(self.w0 * x)

# siren_init function is defined but not used by GradientMLP2 currently
def siren_init(linear, fan_in, w0=30.0, is_first_layer=False):
    with torch.no_grad():
        if is_first_layer:
            bound = 1.0 / fan_in
            linear.weight.uniform_(-bound, bound)
        else:
            bound = math.sqrt(6.0 / fan_in) / w0
            linear.weight.uniform_(-bound, bound)
        if linear.bias is not None:
             linear.bias.uniform_(-bound, bound)


class AnalyticalTarget(nn.Module):
    r"""
    Fonctions analytiques usuelles pour benchmarks de régression / optimisation.

    mode ∈ {
        "rosenbrock",          # d ≥ 2
        "ackley",              # d ≥ 2
        "styblinski",          # d ≥ 1
        "booth",               # d = 2
        "mccormick"            # d = 2
    }
    Toutes les sorties sont multipliées par `scale` (par défaut 0.1) afin de
    rester dans une amplitude raisonnable pour l'apprentissage.
    """
    def __init__(self, mode="rosenbrock", dim=2,
                 a=1.0, b=100.0, scale=1, radius_d = 20,alpha=0.1):
        super().__init__()
        self.mode  = mode.lower()
        self.dim   = dim
        self.a, self.b = a, b
        self.scale = scale
        self.radius_d = radius_d
        self.alpha = alpha
        valid = {"rosenbrock", "ackley",
                 "styblinski", "booth", "mccormick"}
        if self.mode not in valid:
            raise ValueError(f"Unknown mode {self.mode}")
        if self.mode in {"booth", "mccormick"} and dim != 2:
            raise ValueError(f"{self.mode} defined only for dim=2, got dim={dim}")
        if dim < 1:
            raise ValueError("dim must be ≥ 1")

    # -------------------------------------------------- #
    def forward(self, x):                        # x : (B, dim)
        x = x.to(x.device)                       # safety
        if self.mode == "rosenbrock":
            y = 0.0
            for i in range(self.dim - 1):
                xi, xip1 = x[:, i], x[:, i + 1]
                y += (self.a - xi) ** 2 + self.b * (xip1 - xi ** 2) ** 2
            y = 1e-3*y

        elif self.mode == "ackley":
            part1 = -20.0 * torch.exp(
                -0.2 * torch.sqrt(0.5 * (x ** 2).sum(dim=1))
            )
            part2 = -torch.exp(
                torch.cos(2 * math.pi * x).mean(dim=1)
            )
            y = part1 + part2
            y= 1e-1*y

        elif self.mode == "styblinski":
            # f(x) = 0.5 Σ_i (x_i^4 - 16 x_i^2 + 5 x_i)
            y = 0.5 * (x ** 4 - 16 * x ** 2 + 5 * x).sum(dim=1)
            y = 1e-2*y

        elif self.mode == "booth":
            x1, x2 = x[:, 0], x[:, 1]
            y = (x1 + 2 * x2 - 7) ** 2 + (2 * x1 + x2 - 5) ** 2

        elif self.mode == "mccormick":
            x1, x2 = x[:, 0], x[:, 1]
            y = torch.sin(x1 + x2) + (x1 - x2) ** 2 - 1.5 * x1 + 2.5 * x2 + 1.0
        y = (self.scale * y).unsqueeze(1)  # (B, 1)

        # Appliquer la fonction test radiale
        radial_mask = torch.norm(x, dim=1)**2  # (B,)
        weight = torch.ones_like(radial_mask)
        inside = radial_mask <= self.radius_d
        weight[~inside] = torch.exp(-self.alpha * (radial_mask[~inside] - self.radius_d))
        y = y * weight.unsqueeze(1)  # Appliquer pondération

        return y

    # -------------------------------------------------- #
    def gradient(self, x):
        x = x.to(x.device)
        with torch.enable_grad():
            x_req = x.detach().requires_grad_(True)
            y_sum = self.forward(x_req).sum()
            grad  = torch.autograd.grad(y_sum, x_req,
                                        create_graph=False)[0]
        return grad

class GradientMLP2(nn.Module):
    """ MLP using Tanh activation. """
    def __init__(self, input_dim, n_layers, hidden_dim, envelope=False,
                 w0=1.0, w0_hidden=0.5):
        super().__init__()
        self.envelope = envelope
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.GELU())
        for _ in range(n_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.GELU())
        layers.append(nn.Linear(hidden_dim, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = x.to(self.get_device())
        scalar_output = self.net(x)
        if self.envelope:
            norm = x.pow(2).mean(dim=1, keepdim=True)
            envelope_val = torch.exp(-0.5 * norm)
            scalar_output = scalar_output * envelope_val
        return scalar_output

    def gradient(self, x):
        x = x.to(self.get_device())
        with torch.enable_grad():
            x_detached = x.detach().requires_grad_(True)
            y_sum = self.forward(x_detached).sum()
            grad = torch.autograd.grad(y_sum, x_detached, create_graph=True)[0]
        return grad

    def get_device(self):
        try:
            return next(self.parameters()).device
        except StopIteration:
            return device


# --- Dataset ---
class SampleDataset(IterableDataset):
    def __init__(self, input_dim, steps_per_epoch, batch_size, distribution_std=WGM_NOISE_STD):
        self.input_dim = input_dim
        self.steps_per_epoch = steps_per_epoch
        self.batch_size = batch_size
        self.distribution_std = distribution_std
        self.step_count = 0

    def __iter__(self):
        self.step_count = 0
        return self

    def __next__(self):
        if self.step_count < self.steps_per_epoch:
            self.step_count += 1
            x = torch.randn(self.batch_size, self.input_dim, device=device) * self.distribution_std
            return torch.cat([x])
        else:
            raise StopIteration
class TestDataset(TensorDataset):
    def __init__(self, n_samples, input_dim=INPUT_DIM, noise_std=WGM_NOISE_STD):
        X_test = torch.randn(n_samples, input_dim) * noise_std
        super().__init__(X_test.to(device))

# --- Lightning Module ---
class GradientEstimationExperiment(pl.LightningModule):
    def __init__(self, target_mlp, estimator_mlp, method, lr, hparams):
        super().__init__()
        self.save_hyperparameters(hparams, ignore=['target_mlp', 'estimator_mlp'])
        self.target_mlp = target_mlp
        self.estimator_mlp = estimator_mlp
        self.method = method
        self.lr = lr
        self.p_std_sq = self.hparams.get("wgm_noise_std", WGM_NOISE_STD) ** 2
        self.rfd_epsilon = self.hparams.get("rfd_epsilon", RFD_EPSILON)
        self.wgm_n_samples = self.hparams.get("wgm_n_samples", WGM_N_SAMPLES)
        self.wgm_lambda = self.hparams.get("wgm_lambda", WGM_LAMBDA)

        for p in self.target_mlp.parameters():
            p.requires_grad = False
        self.target_mlp.eval()

        # Internal history (optional, but useful)
        self.val_mse_history = []


    def forward(self, x):
        return self.estimator_mlp.gradient(x)

    def training_step(self, batch, batch_idx):
        x = batch[0] if isinstance(batch, (list, tuple)) else batch
        loss = torch.tensor(0.0, device=self.device)

        if self.method == 'direct_mse':
            grad_j_true = self.target_mlp.gradient(x.detach())
            s_theta = self.estimator_mlp.gradient(x)
            loss = F.mse_loss(s_theta, grad_j_true)

        elif self.method == 'wgm':
            x = x.clone().detach().requires_grad_(True)
            with torch.no_grad():
                j_val = self.target_mlp(x).squeeze()

            s_output = self.estimator_mlp(x)  # scalaire
            s_theta = torch.autograd.grad(s_output.sum(), x, create_graph=True)[0]

            div_s_approx = 0.0
            for _ in range(2):
                #v = torch.randint_like(x, low=0, high=2).float() * 2.0 - 1.0
                #v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-8)
                v = torch.randn_like(x)
                #v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-8)

                '''s_dot_v = (s_theta * v).sum(dim=1)
                grad_s_dot_v = torch.autograd.grad(s_dot_v.sum(), x, create_graph=True, allow_unused=True)[0]
                div_s_approx += (grad_s_dot_v * v).sum(dim=1)'''

                div_s_approx += torch.autograd.grad(s_theta[:,_].sum(), x, create_graph=True, allow_unused=True)[0][:,_]

            grad_log_p = -x / self.p_std_sq

            term1 = (s_theta ** 2).sum(dim=1)
            term2 = 2 * j_val * (div_s_approx + (s_theta * grad_log_p).sum(dim=1))
            loss = 1*(term1 + term2).mean()

        elif self.method == 'rfd_mse':
             with torch.no_grad():
                v = torch.randn_like(x)
                v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-8)
                x_detached = x.detach()
                j_plus = self.target_mlp(x_detached + self.rfd_epsilon * v).squeeze()
                j_minus = self.target_mlp(x_detached - self.rfd_epsilon * v).squeeze()
                fd_dir_deriv = (j_plus - j_minus) / (2 * self.rfd_epsilon)
             s_theta = self.estimator_mlp.gradient(x)
             s_theta_dot_v = (s_theta * v).sum(dim=1)
             loss = F.mse_loss(s_theta_dot_v, fd_dir_deriv)

        elif self.method == 'surrogate':
            j_true = self.target_mlp(x.detach()).squeeze()
            j_s_theta = self.estimator_mlp(x).squeeze()
            loss = F.mse_loss(j_true, j_s_theta)
        # Added basic mixed method from original code for completeness if needed
        elif self.method == "mixed":
            # Surrogate part
            j_true_surrogate = self.target_mlp(x.detach()).squeeze()
            j_s_theta_surrogate = self.estimator_mlp(x).squeeze()
            loss_surrogate = F.mse_loss(j_true_surrogate, j_s_theta_surrogate)

            x = x.clone().detach().requires_grad_(True)
            with torch.no_grad():
                j_val = self.target_mlp(x).squeeze()

            s_output = self.estimator_mlp(x)  # scalaire
            s_theta = torch.autograd.grad(s_output.sum(), x, create_graph=True)[0]

            div_s_approx = 0.0
            for _ in range(2):
                #v = torch.randint_like(x, low=0, high=2).float() * 2.0 - 1.0
                #v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-8)
                v = torch.randn_like(x)
                #v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-8)

                '''s_dot_v = (s_theta * v).sum(dim=1)
                grad_s_dot_v = torch.autograd.grad(s_dot_v.sum(), x, create_graph=True, allow_unused=True)[0]
                div_s_approx += (grad_s_dot_v * v).sum(dim=1)'''

                div_s_approx += torch.autograd.grad(s_theta[:,_].sum(), x, create_graph=True, allow_unused=True)[0][:,_]

            grad_log_p = -x / self.p_std_sq

            term1 = (s_theta ** 2).sum(dim=1)
            term2 = 2 * j_val * (div_s_approx + (s_theta * grad_log_p).sum(dim=1))
            loss_wgm = (term1 + 1 * term2).mean()

            # Combine losses (e.g., average)
            loss = 0.75 * loss_surrogate + 0.25 * loss_wgm


        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch[0] if isinstance(batch, (list, tuple)) else batch
        grad_j_true = self.target_mlp.gradient(x.detach())
        s_theta = self.forward(x) # Calls estimator_mlp.gradient(x)
        val_mse = F.mse_loss(s_theta, grad_j_true)
        self.log('val_mse', val_mse, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.val_mse_history.append(val_mse.item()) # Store history
        return val_mse

    def test_step(self, batch, batch_idx):
        x = batch[0] if isinstance(batch, (list, tuple)) else batch
        grad_j_true = self.target_mlp.gradient(x.detach())
        s_theta = self.forward(x)
        test_mse = F.mse_loss(s_theta, grad_j_true)
        self.log('test_mse', test_mse, on_step=False, on_epoch=True, logger=True)
        return test_mse

    def configure_optimizers(self):
        optimizer = optim.Adam(self.estimator_mlp.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)
        return {"optimizer": optimizer,"lr_scheduler": {"scheduler": scheduler,"interval": "epoch", "frequency": 1}}

# ================= MAIN TRAINING SCRIPT ======================
if __name__ == "__main__":
    # Define methods to run - use the ones from your original code
    methods_to_run = ['wgm', 'surrogate',"mixed"] # Or ['mixed', 'surrogate'] as in your comment

    all_results = {method: [] for method in methods_to_run}
    # THIS IS THE IMPORTANT DICTIONARY TO STORE MODELS FOR PLOTTING LATER
    trained_estimators = {method: [] for method in methods_to_run}

    # We only run for N_PROBLEMS=1 here, so target is defined once
    target = AnalyticalTarget(mode="styblinski").to(device) # Define target once
    target = AnalyticalTarget(mode="rosenbrock").to(device) # Define target once


    for i in range(N_PROBLEMS): # Loop kept for structure, but N_PROBLEMS=1
        print(f"\n{'='*10} Training Problem {i+1}/{N_PROBLEMS} {'='*10}")

        train_x = torch.randn(8 * BATCH_SIZE, INPUT_DIM) * WGM_NOISE_STD
        train_ds = TensorDataset(train_x.to(device))
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)


        val_dataset   = SampleDataset(INPUT_DIM, VAL_STEPS_PER_EPOCH, BATCH_SIZE)
        X_test = torch.randn(N_TEST_SAMPLES, INPUT_DIM, device=device) * WGM_NOISE_STD
        test_dataset = TensorDataset(X_test)

        #train_loader = DataLoader(train_dataset, batch_size=None, num_workers=0)
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

        val_loader   = DataLoader(val_dataset, batch_size=None, num_workers=0)
        test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=0)
        for method in methods_to_run:
            print(f"\n--- Training method: {method} ---")
            estimator_mlp = GradientMLP2(INPUT_DIM, N_LAYERS, HIDDEN_DIM).to(device)
            hparams_dict = {
                "input_dim": INPUT_DIM, "hidden_dim": HIDDEN_DIM, "n_layers": N_LAYERS,
                "lr": LEARNING_RATE, "method": method, "wgm_lambda": WGM_LAMBDA,
                "wgm_noise_std": WGM_NOISE_STD, "wgm_n_samples": WGM_N_SAMPLES,
                "rfd_epsilon": RFD_EPSILON, "target_function": target.mode,
                "problem_idx": i, "epochs": MAX_EPOCHS, "batch_size": BATCH_SIZE,
            }
            experiment = GradientEstimationExperiment(target, estimator_mlp, method, LEARNING_RATE, hparams_dict)
            logger = pl.loggers.TensorBoardLogger("lightning_logs/", name=f"problem_{i}", version=f"{method}")
            trainer = pl.Trainer(
                max_epochs=MAX_EPOCHS, accelerator="auto", devices=1, logger=logger,
                enable_checkpointing=False, enable_progress_bar=True,
                enable_model_summary=False, # Quieter output
                num_sanity_val_steps=0,
                inference_mode=False, # Important for grad calculations in val/test if needed by method
                log_every_n_steps=TRAIN_STEPS_PER_EPOCH // 5,
            )

            print(f"    Starting training for {method}...")
            trainer.fit(experiment, train_loader, val_loader)
            print(f"    Finished training for {method}.")

            print(f"    Starting testing for {method}...")
            test_results = trainer.test(experiment, dataloaders=test_loader, verbose=False)

            if test_results:
                final_mse = test_results[0].get('test_mse')
                if final_mse is not None:
                    all_results[method].append(final_mse)
                    print(f"    Method: {method}, Final Test MSE (Gradient): {final_mse:.6f}")
                else:
                    all_results[method].append(float('nan'))
            else:
                all_results[method].append(float('nan'))

            # --- STORE THE TRAINED MODEL ---
            # Move to CPU and set to eval mode before storing
            trained_estimators[method].append(experiment.estimator_mlp.cpu().eval())
            # --------------------------------

    # --- Final Summary ---
    print(f"\n{'='*10} Final Results Summary {'='*10}")
    for method in methods_to_run:
        results = [res for res in all_results[method] if not np.isnan(res)]
        if results:
            avg_mse = np.mean(results) # Average over N_PROBLEMS (here, just 1)
            std_mse = np.std(results)
            print(f"Method: {method:<12} | Avg Test MSE (Gradient): {avg_mse:.6f} +/- {std_mse:.6f}")
        else:
            print(f"Method: {method:<12} | No successful runs.")

    # --- IMPORTANT ---
    # The 'trained_estimators' dictionary and the 'target' object
    # now exist in your environment and can be used by the next block.
    print("\nTraining complete. Models stored in 'trained_estimators' dictionary.")

In [None]:
# ================= VISUALIZATION SCRIPT ======================
# Make sure this part is run AFTER the main training script above,
# so that 'target' and 'trained_estimators' are defined.

import matplotlib.pyplot as plt
# La ligne suivante n'est plus strictement nécessaire avec les versions récentes
# de matplotlib, mais elle était classiquement utilisée pour activer les outils 3D.
# from mpl_toolkits.mplot3d import Axes3D

print("\n{'='*10} Generating Visualizations {'='*10}")

# --- 1. Définir la grille de visualisation ---
# Ajustez les limites si nécessaire en fonction de la fonction cible et de WGM_NOISE_STD
# Par exemple, si WGM_NOISE_STD est petit, vous voudrez peut-être des limites plus serrées.
# Pour 'styblinski', les minima intéressants sont autour de +/-2.9.
# Pour 'ackley', le minimum est à (0,0) et la fonction est complexe.
# Un range de [-5, 5] ou [-7, 7] est souvent un bon début pour ces fonctions.
# Le code original utilisait [-10,10], ce qui est bien pour voir le comportement global.

lim_plot = 3 # Ou 10, ou 5, dépendant de la fonction
x_vis = np.linspace(-lim_plot, lim_plot, 100)
y_vis = np.linspace(-lim_plot, lim_plot, 100)
X_vis, Y_vis = np.meshgrid(x_vis, y_vis)

# Préparer les entrées pour les modèles PyTorch
# Les modèles s'attendent à une entrée de forme (batch_size, input_dim)
xy_flat = np.stack([X_vis.ravel(), Y_vis.ravel()], axis=-1)
xy_tensor = torch.tensor(xy_flat, dtype=torch.float32).cpu() # Mettre sur CPU

# --- Fonction utilitaire pour tracer ---
def plot_3d_surface_custom(X, Y, Z, title_str, fig_ax_pair, cmap_choice=plt.cm.viridis, z_offset_contours=-1.0):
    """
    Fonction pour tracer une surface 3D avec maillage et contours.
    fig_ax_pair: tuple (fig, ax)
    """
    fig, ax = fig_ax_pair

    # --- 3. Dessiner la surface ---
    surf = ax.plot_surface(X, Y, Z, cmap=cmap_choice,
                           linewidth=0.5, edgecolor='k', # Maillage noir fin
                           antialiased=True)

    # --- 4. Ajouter les lignes de niveau (contours) ---
    num_niveaux = 15
    min_Z_val = np.min(Z)
    max_Z_val = np.max(Z)
    # S'assurer que l'offset est raisonnable
    contour_offset = min_Z_val + z_offset_contours if (max_Z_val - min_Z_val) > 1e-6 else min_Z_val - 0.5

    ax.contour(X, Y, Z, zdir='z', offset=contour_offset,
               levels=num_niveaux, cmap=cmap_choice,
               linewidths=1)

    # --- 5. Personnaliser le graphique ---
    ax.set_xlabel('Axe X')
    ax.set_ylabel('Axe Y')
    ax.set_zlabel('Axe Z')
    ax.set_title(title_str)

    # Ajuster les limites de l'axe Z
    z_lim_min = contour_offset - abs(0.5 * z_offset_contours) # Un peu d'espace sous les contours
    z_lim_max = max_Z_val * 1.1 if max_Z_val >= 0 else max_Z_val * 0.9
    if abs(z_lim_max - z_lim_min) < 1e-6: # Si la surface est plate
        z_lim_min -= 0.5
        z_lim_max += 0.5
    ax.set_zlim(z_lim_min, z_lim_max)


    # Ajouter une barre de couleur
    fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10, label='Valeur Z')

    # Ajuster l'angle de vue
    ax.view_init(elev=30., azim=45)#-60)
    return surf


# --- 2. Calculer et afficher la fonction CIBLE ---
# Assurez-vous que 'target' est sur CPU et en mode eval
target_cpu = target.cpu().eval()
with torch.no_grad():
    Z_target_flat = target_cpu(xy_tensor).numpy()
Z_target = Z_target_flat.reshape(X_vis.shape)

fig_target = plt.figure(figsize=(10, 8))
ax_target = fig_target.add_subplot(111, projection='3d')
plot_3d_surface_custom(X_vis, Y_vis, Z_target,
                         f"Fonction Cible: {target_cpu.mode.capitalize()}",
                         (fig_target, ax_target),
                         cmap_choice=plt.cm.viridis) # Viridis pour la cible
plt.tight_layout()
plt.savefig(f'surface_3d_target_{target_cpu.mode}.png', dpi=300)
plt.show()


# --- 3. Calculer et afficher les ESTIMATEURS ---
# (Rappel : N_PROBLEMS est à 1 dans votre config, donc on prend l'indice 0)
problem_idx_to_plot = 0 # Puisque N_PROBLEMS = 1

for method in methods_to_run[:1]:
    if trained_estimators[method] and len(trained_estimators[method]) > problem_idx_to_plot:
        estimator_model = trained_estimators[method][problem_idx_to_plot]
        # Le modèle devrait déjà être sur CPU et en mode eval grâce à la sauvegarde
        estimator_model.cpu().eval()

        with torch.no_grad():
            Z_estimator_flat = estimator_model(xy_tensor).numpy()
        Z_estimator = Z_estimator_flat.reshape(X_vis.shape)

        fig_est = plt.figure(figsize=(10, 8))
        ax_est = fig_est.add_subplot(111, projection='3d')
        plot_3d_surface_custom(X_vis, Y_vis, Z_estimator,
                                 f"Estimateur ({method.upper()}) pour {target_cpu.mode.capitalize()}",
                                 (fig_est, ax_est),
                                 cmap_choice=plt.cm.plasma) # Plasma ou autre pour différencier
        plt.tight_layout()
        plt.savefig(f'surface_3d_estimator_{method}_{target_cpu.mode}.png', dpi=50)
        plt.show()
    else:
        print(f"Pas de modèle entraîné trouvé pour la méthode '{method}' pour le problème {problem_idx_to_plot + 1} à visualiser.")

print(f"{'='*10} Visualizations Complete {'='*10}")

# Optionnel: si vous voulez afficher Cible et Estimateurs sur la même figure avec des subplots
# (Ceci est plus complexe si les échelles Z sont très différentes)

# num_plots = 1 + len(methods_to_run)
# cols = 2
# rows = (num_plots + cols - 1) // cols

# fig_all = plt.figure(figsize=(7 * cols, 6 * rows))
# plot_idx = 1

# # Plot Target
# ax_all_target = fig_all.add_subplot(rows, cols, plot_idx, projection='3d')
# plot_3d_surface_custom(X_vis, Y_vis, Z_target,
#                          f"Cible: {target_cpu.mode.capitalize()}",
#                          (fig_all, ax_all_target),
#                          cmap_choice=plt.cm.viridis)
# plot_idx += 1

# # Plot Estimators
# for method in methods_to_run:
#     if trained_estimators[method] and len(trained_estimators[method]) > problem_idx_to_plot:
#         estimator_model = trained_estimators[method][problem_idx_to_plot].cpu().eval()
#         with torch.no_grad():
#             Z_estimator_flat = estimator_model(xy_tensor).numpy()
#         Z_estimator = Z_estimator_flat.reshape(X_vis.shape)

#         ax_all_est = fig_all.add_subplot(rows, cols, plot_idx, projection='3d')
#         plot_3d_surface_custom(X_vis, Y_vis, Z_estimator,
#                                  f"Estim. ({method.upper()})", # Titre plus court pour subplot
#                                  (fig_all, ax_all_est),
#                                  cmap_choice=plt.cm.plasma)
#         plot_idx += 1

# plt.tight_layout(pad=3.0) # Ajouter du padding pour éviter chevauchement des titres/labels
# plt.savefig(f'surface_3d_comparison_{target_cpu.mode}.png', dpi=300)
# plt.show()