# Stage A Diagnostic Notebook

Ce notebook reproduit un entraînement *Stage A* minimal du modèle `PhysicallyInformedAE` afin de vérifier que les pertes diminuent correctement et que le modèle apprend bien.

Les étapes clés sont :
1. Préparation d'un jeu de données synthétique réduit (bruit inclus).
2. Instanciation du modèle avec la configuration `TrainingConfig`.
3. Lancement d'un Stage A de quelques époques avec les callbacks d'actualisation d'époque.
4. Vérification rapide des pertes sur un mini-lot de validation.


In [None]:
import os
import sys
from pathlib import Path

import torch
import pytorch_lightning as pl

PROJECT_ROOT = Path.cwd()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

pl.seed_everything(42)
torch.set_float32_matmul_precision("high")


In [None]:
from project.config.training_config import TrainingConfig
from project.data.dataset import SpectraDataset
from project.data.loaders import create_dataloader
from project.models.autoencoder import PhysicallyInformedAE
from project.training.callbacks.epoch_sync import (
    UpdateEpochInDataset,
    UpdateEpochInValDataset,
)
from project.training.stages import train_stage_A

# Configuration compacte pour un diagnostic rapide
cfg = TrainingConfig(
    n_points=512,
    n_train=512,
    n_val=128,
    batch_size=32,
    learning_rates=(3e-4, 1e-4),
)

# Pas de lignes spectrales explicites pour ce test rapide
transitions_dict = {}
poly_freq = None

train_dataset = SpectraDataset(
    n_samples=cfg.n_train,
    num_points=cfg.n_points,
    poly_freq_CH4=poly_freq,
    transitions_dict=transitions_dict,
    sample_ranges=cfg.resolved_train_ranges(),
    with_noise=True,
)
train_dataset.freeze_parameter_draws(True)

val_dataset = SpectraDataset(
    n_samples=cfg.n_val,
    num_points=cfg.n_points,
    poly_freq_CH4=poly_freq,
    transitions_dict=transitions_dict,
    sample_ranges=cfg.resolved_val_ranges(),
    with_noise=True,
    freeze_noise=True,
)
val_dataset.freeze_parameter_draws(True)

train_loader = create_dataloader(
    train_dataset,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=0,
)
val_loader = create_dataloader(
    val_dataset,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=0,
)


In [None]:
batch = next(iter(train_loader))
print("Noisy spectra shape:", batch["noisy_spectra"].shape)
print("Clean spectra shape:", batch["clean_spectra"].shape)
print("Params shape:", batch["params"].shape)
print("Noisy mean |value|:", batch["noisy_spectra"].abs().mean().item())


In [None]:
# Préparation du modèle et des callbacks Stage A
model_kwargs = cfg.model_kwargs()
model_kwargs.update(cfg.stage_overrides("A"))
model = PhysicallyInformedAE(
    **model_kwargs,
    transitions_dict=transitions_dict,
    poly_freq_CH4=poly_freq,
)

callbacks = [
    UpdateEpochInDataset(),
    UpdateEpochInValDataset(),
]

trained_model = train_stage_A(
    model,
    train_loader,
    val_loader,
    epochs=5,
    base_lr=3e-4,
    enable_progress_bar=True,
    callbacks=callbacks,
    accelerator="cpu",
    devices=1,
)


In [None]:
import torch.nn.functional as F

trained_model.eval()
val_batch = next(iter(val_loader))
with torch.no_grad():
    loss_val = trained_model._common_step(val_batch, "val")

print(f"Validation loss (snapshot après Stage A): {loss_val.item():.6f}")

# Comparaison rapide des paramètres prédits vs cibles normalisées
with torch.no_grad():
    latent, _ = trained_model.backbone(val_batch["noisy_spectra"].unsqueeze(1))
    feats = trained_model._pool_features(latent)
    params_pred_norm = trained_model._predict_params_from_features(feats)

print("Mean predicted param (normed):", params_pred_norm.mean().item())
print("Mean target param (normed):", val_batch["params"].mean().item())
