# Stage A Training with YAML Configurations

This notebook shows how to launch the Stage A training of **PhysicallyInformedAE** by relying exclusively on the YAML configuration files shipped with the project. The workflow mirrors `physae.py` while keeping everything declarative: parameter ranges, noise settings, and spectroscopic transitions are all loaded from the YAML bundles in `project/config/data/`.

> ℹ️ **Tip:** You can adapt the same pattern for Stage B or the denoiser simply by swapping the training helper that is imported at the end of the notebook.

## 1. Locate the repository and configuration files

When the notebook is opened from `project/examples/notebooks`, the helper below automatically climbs up to the repository root. Adjust `CONFIG_DIR` if you keep your YAML files somewhere else.

In [None]:
from pathlib import Path
import sys

NOTEBOOK_DIR = Path.cwd().resolve()
PROJECT_ROOT = NOTEBOOK_DIR
while not (PROJECT_ROOT / 'project' / '__init__.py').exists():
    if PROJECT_ROOT.parent == PROJECT_ROOT:
        raise RuntimeError('Could not locate the repository root — run this notebook from within the physae repo.')
    PROJECT_ROOT = PROJECT_ROOT.parent

CONFIG_DIR = PROJECT_ROOT / 'project' / 'config' / 'data'
print(f'Repository root: {PROJECT_ROOT}')
print(f'Configuration directory: {CONFIG_DIR}')
sys.path.insert(0, str(PROJECT_ROOT / 'project'))
print(f"Python path primed with: {PROJECT_ROOT / 'project'}")

## 2. Load YAML-driven parameter ranges, noise profile, and transitions

The helper functions from `config.data_config` parse the YAML files and update the global normalisation tables (`config.params.NORM_PARAMS`, `LOG_SCALE_PARAMS`) so that datasets and models stay consistent.

In [None]:
from config.data_config import load_parameter_ranges, load_noise_profile, load_transitions
from config.params import NORM_PARAMS, LOG_SCALE_PARAMS

param_ranges = load_parameter_ranges(CONFIG_DIR / 'parameters_default.yaml')
noise_profile = load_noise_profile(CONFIG_DIR / 'noise_default.yaml')
transitions, poly_freq = load_transitions(CONFIG_DIR / 'transitions_sample.yaml', include_poly_freq=True)

# The sample catalogue includes H2O lines like in physae.py.
# They require a mole-fraction range; we mirror physae.py's defaults here.
if 'H2O' in transitions and 'mf_H2O' not in NORM_PARAMS:
    h2o_range = (1.0e-7, 5.0e-4)
    NORM_PARAMS['mf_H2O'] = h2o_range
    param_ranges['mf_H2O'] = h2o_range
    LOG_SCALE_PARAMS.add('mf_H2O')

print('Loaded parameter ranges:')
for name, (lo, hi) in param_ranges.items():
    print(f'  - {name}: {lo:.3e} → {hi:.3e}')

print('
Transitions summary:')
for mol, entries in transitions.items():
    print(f'  - {mol}: {len(entries)} lines, poly coeffs = {poly_freq.get(mol)}')

## 3. Build synthetic datasets from the YAML configuration

`TrainingConfig` consumes the YAML-driven ranges/noise dictionaries and exposes convenience helpers to create consistent training and validation datasets.

In [None]:
from config.training_config import TrainingConfig
from data.dataset import SpectraDataset
from torch.utils.data import DataLoader

config = TrainingConfig(
    n_points=800,
    n_train=4096,
    n_val=512,
    batch_size=32,
    train_ranges=param_ranges,
    val_ranges=param_ranges,
    noise_train=noise_profile,
    noise_val=noise_profile,
    learning_rates=(1e-4, 1e-5),
)

noise_train, noise_val = config.resolved_noise_profiles()
train_dataset = SpectraDataset(
    n_samples=config.n_train,
    num_points=config.n_points,
    poly_freq_CH4=poly_freq.get('CH4'),
    transitions_dict=transitions,
    sample_ranges=config.resolved_train_ranges(),
    with_noise=True,
    noise_profile=noise_train,
)
val_dataset = SpectraDataset(
    n_samples=config.n_val,
    num_points=config.n_points,
    poly_freq_CH4=poly_freq.get('CH4'),
    transitions_dict=transitions,
    sample_ranges=config.resolved_val_ranges(),
    with_noise=False,
    noise_profile=noise_val,
)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)

batch = next(iter(train_loader))
print('Mini-batch keys:', list(batch))
print('Noisy spectra shape:', batch['noisy_spectra'].shape)

## 4. Instantiate the Lightning module

The autoencoder shares the same defaults as in `physae.py`. We pass the YAML-derived transitions and optional polynomial frequency coefficients directly into the constructor.

In [None]:
from models.autoencoder import PhysicallyInformedAE

model = PhysicallyInformedAE(
    poly_freq_CH4=poly_freq.get('CH4'),
    transitions_dict=transitions,
    **config.model_kwargs(),
)

print('Predicting parameters:', model.predict_params)
print('Base learning rate:', model.base_lr)

## 5. Launch a short Stage A run

We reuse the high-level helper from `training.stages` so the freezing/unfreezing logic matches the original training script. For interactive exploration, a single epoch with a reduced number of batches keeps runtime under a minute on CPU.

In [None]:
from training.stages import train_stage_A
from training.callbacks.loss_curves import LossCurvePlotCallback

callbacks = [LossCurvePlotCallback()]

train_stage_A(
    model,
    train_loader,
    val_loader,
    epochs=1,
    enable_progress_bar=True,
    callbacks=callbacks,
    limit_train_batches=2,
    limit_val_batches=2,
    accelerator='cpu',
    devices=1,
)

## 6. Next steps

- Increase `epochs`, remove the `limit_*_batches` caps, and point `Trainer` to your GPU(s) for a full training run.
- Swap `train_stage_A` for `train_stage_B1`, `train_stage_B2`, or `train_refiner_idx` to continue the staged fine-tuning.
- Provide your own YAML files and re-run the notebook; the same code path will load them without changes.