# PhysAE data generation walkthrough

This notebook documents a compact scenario for building synthetic spectra with `SpectraDataset` while keeping track of the clean and noisy normalisation steps.


## Notebook overview

This workflow covers:
- Loading the baseline data configuration and narrowing the training domain.
- Cloning and trimming validation ranges so they remain strict subsets of the training space.
- Registering the normalisation parameters used by `SpectraDataset`.
- Instantiating train and validation datasets with consistent noise settings.
- Inspecting individual samples with denormalised parameters.
- Visualising the clean and noisy spectra before and after normalisation.


In [None]:
from __future__ import annotations

from copy import deepcopy
from typing import Dict, Iterable, Mapping, Tuple

try:
    import torch
except ImportError as exc:
    raise RuntimeError("PyTorch is required to run this notebook end-to-end.") from exc

try:
    import matplotlib.pyplot as plt
except ImportError as exc:
    raise RuntimeError("Matplotlib is required for the visualisations.") from exc

from physae.config_loader import load_data_config
from physae import config as physae_config
from physae.config import PARAMS, assert_subset
from physae.dataset import SpectraDataset
from physae.normalization import unnorm_param_tensor
from physae.physics import batch_physics_forward_multimol_vgrid, parse_csv_transitions

torch.set_default_dtype(torch.float32)

def to_interval_dict(mapping: Mapping[str, Iterable[float]]) -> Dict[str, Tuple[float, float]]:
    return {name: (float(bounds[0]), float(bounds[1])) for name, bounds in mapping.items()}

def display_ranges(title: str, ranges: Mapping[str, Tuple[float, float]]) -> None:
    print(title)
    for name, (lo, hi) in ranges.items():
        print(f"  {name:>9}: [{lo:.6g}, {hi:.6g}]")

def normalise_noise_cfg(cfg: Mapping[str, object]) -> Dict[str, object]:
    result: Dict[str, object] = {}
    for key, value in cfg.items():
        if isinstance(value, list):
            result[key] = tuple(float(v) for v in value)
        else:
            result[key] = value
    return result

def denorm_params(params_tensor: torch.Tensor) -> Dict[str, torch.Tensor]:
    values: Dict[str, torch.Tensor] = {}
    for idx, name in enumerate(PARAMS):
        values[name] = unnorm_param_tensor(name, params_tensor[..., idx])
    return values

def regenerate_clean_spectrum(
    params_phys: Mapping[str, float],
    *,
    num_points: int,
    poly_freq_CH4,
    transitions_dict,
) -> torch.Tensor:
    device = torch.device('cpu')
    dtype = torch.float32
    sig0 = torch.tensor([params_phys['sig0']], dtype=dtype, device=device)
    dsig = torch.tensor([params_phys['dsig']], dtype=dtype, device=device)
    mf_CH4 = torch.tensor([params_phys['mf_CH4']], dtype=dtype, device=device)
    baseline_coeffs = torch.tensor([[
        params_phys['baseline0'],
        params_phys['baseline1'],
        params_phys['baseline2'],
    ]], dtype=dtype, device=device)
    pressure = torch.tensor([params_phys['P']], dtype=dtype, device=device)
    temperature = torch.tensor([params_phys['T']], dtype=dtype, device=device)
    v_grid_idx = torch.arange(num_points, dtype=dtype, device=device)
    spectra_clean, _ = batch_physics_forward_multimol_vgrid(
        sig0,
        dsig,
        poly_freq_CH4,
        v_grid_idx,
        baseline_coeffs,
        transitions_dict,
        pressure,
        temperature,
        {'CH4': mf_CH4},
        device=device,
    )
    return spectra_clean[0].detach().cpu()


## 1. Load the default data configuration

We start from the packaged `default` scenario and create a copy that we can modify without mutating the YAML-backed dictionary.


In [None]:
base_cfg = load_data_config(name='default')
custom_cfg = deepcopy(base_cfg)

base_train = to_interval_dict(custom_cfg['train_ranges_base'])
base_val = to_interval_dict(custom_cfg['val_ranges'])
display_ranges('Baseline train ranges', base_train)
display_ranges('Baseline validation ranges', base_val)


## 2. Define a narrower training domain

The experiment below focuses on a tighter region of the parameter space. We override the training ranges accordingly while keeping the structure consistent with the YAML configuration.


In [None]:
narrow_train_ranges = {
    'sig0': (3085.435, 3085.447),
    'dsig': (0.0015235, 0.0015335),
    'mf_CH4': (5.0e-06, 1.4e-05),
    'baseline0': (0.995, 1.005),
    'baseline1': (-0.00038, -0.00031),
    'baseline2': (-3.8e-08, -3.2e-08),
    'P': (450.0, 550.0),
    'T': (305.0, 309.0),
}
custom_cfg['train_ranges'] = {name: [float(lo), float(hi)] for name, (lo, hi) in narrow_train_ranges.items()}
display_ranges('Custom train ranges', narrow_train_ranges)


## 3. Clamp validation ranges inside the new training domain

After shrinking the training region we clone the validation ranges and clamp every interval (including `mf_CH4`) so that it remains a strict subset of the training space.


In [None]:
original_val = deepcopy(custom_cfg.get('val_ranges', {}))
adjusted_val = {}
for name, train_interval in custom_cfg['train_ranges'].items():
    train_min, train_max = map(float, train_interval)
    val_min, val_max = map(float, original_val.get(name, train_interval))
    adj_min = max(val_min, train_min)
    adj_max = min(val_max, train_max)
    if adj_min > adj_max:
        centre = 0.5 * (train_min + train_max)
        adj_min = adj_max = centre
    adjusted_val[name] = [adj_min, adj_max]
custom_cfg['val_ranges'] = adjusted_val
val_ranges = to_interval_dict(custom_cfg['val_ranges'])
assert_subset(val_ranges, narrow_train_ranges, 'validation', 'train')
display_ranges('Adjusted validation ranges', val_ranges)


## 4. Register normalisation parameters

`SpectraDataset` pulls its parameter normalisation bounds from `physae.config.NORM_PARAMS`. We therefore register the updated training ranges before instantiating any dataset objects.


In [None]:
physae_config.set_norm_params({name: (float(lo), float(hi)) for name, (lo, hi) in narrow_train_ranges.items()})
physae_config.get_norm_params()


## 5. Instantiate train and validation datasets

With the ranges and normalisation in place we build both datasets. The validation split reuses the same sampling logic but freezes noise for reproducibility.


In [None]:
torch.manual_seed(int(custom_cfg.get('seed', 42)))

poly_freq_CH4 = [-2.3614803e-07, 1.2103413e-10, -3.1617856e-14]
transitions_ch4_str = """
6;1;3085.861015;1.013E-19;0.06;0.078;219.9411;0.73;-0.00712;0.0;0.0221;0.96;0.584;1.12
6;1;3085.832038;1.693E-19;0.0597;0.078;219.9451;0.73;-0.00712;0.0;0.0222;0.91;0.173;1.11
6;1;3085.893769;1.011E-19;0.0602;0.078;219.9366;0.73;-0.00711;0.0;0.0184;1.14;-0.516;1.37
6;1;3086.030985;1.659E-19;0.0595;0.078;219.9197;0.73;-0.00711;0.0;0.0193;1.17;-0.204;0.97
6;1;3086.071879;1.000E-19;0.0585;0.078;219.9149;0.73;-0.00703;0.0;0.0232;1.09;-0.0689;0.82
6;1;3086.085994;6.671E-20;0.055;0.078;219.9133;0.70;-0.00610;0.0;0.0300;0.54;0.00;0.0
"""
transitions_dict = {'CH4': parse_csv_transitions(transitions_ch4_str)}

n_points = int(custom_cfg.get('n_points', 800))
n_train = int(custom_cfg.get('n_train', 50000))
n_val = int(custom_cfg.get('n_val', 5000))
train_ranges = narrow_train_ranges
val_ranges = to_interval_dict(custom_cfg['val_ranges'])
noise_train = normalise_noise_cfg(custom_cfg['noise']['train'])
noise_val = normalise_noise_cfg(custom_cfg['noise']['val'])

train_dataset = SpectraDataset(
    n_samples=n_train,
    num_points=n_points,
    poly_freq_CH4=poly_freq_CH4,
    transitions_dict=transitions_dict,
    sample_ranges=train_ranges,
    strict_check=True,
    with_noise=True,
    noise_profile=noise_train,
    freeze_noise=False,
)
val_dataset = SpectraDataset(
    n_samples=n_val,
    num_points=n_points,
    poly_freq_CH4=poly_freq_CH4,
    transitions_dict=transitions_dict,
    sample_ranges=val_ranges,
    strict_check=True,
    with_noise=True,
    noise_profile=noise_val,
    freeze_noise=True,
)
len(train_dataset), len(val_dataset)


## 6. Inspect a sample and denormalise its parameters

We extract one training example, recover the physical parameters, and prepare the spectra in both raw and normalised forms.


In [None]:
sample = train_dataset[0]
params_norm = sample['params'].unsqueeze(0)
params_phys_tensors = denorm_params(params_norm)
params_phys = {name: tensor[0].item() for name, tensor in params_phys_tensors.items()}
params_phys


## 7. Reconstruct clean spectra and compare normalisations

The helper below rebuilds the clean spectrum from the denormalised parameters. We compare raw and normalised clean/noisy curves to make each normalisation step explicit.


In [None]:
clean_raw = regenerate_clean_spectrum(
    params_phys,
    num_points=n_points,
    poly_freq_CH4=poly_freq_CH4,
    transitions_dict=transitions_dict,
)
noisy_raw = sample['noisy_spectra'] * sample['scale']
clean_norm = sample['clean_spectra']
noisy_norm = sample['noisy_spectra']

clean_max = clean_raw.max().item()
noisy_max = noisy_raw.max().item()
print(f'Clean raw max: {clean_max:.4f}')
print(f'Clean normalised max: {clean_norm.max().item():.4f}')
print(f'Noisy raw max: {noisy_max:.4f}')
print(f'Noisy normalised max: {noisy_norm.max().item():.4f}')


In [None]:
fig, axes = plt.subplots(2, 1, figsize=(10, 8), sharex=True)
axes[0].plot(clean_raw.numpy(), label='clean (raw)', linewidth=2)
axes[0].plot(noisy_raw.numpy(), label='noisy (raw)', linewidth=1, alpha=0.8)
axes[0].set_title('Raw spectra')
axes[0].legend()
axes[0].set_ylabel('Intensity')

axes[1].plot(clean_norm.numpy(), label='clean (normalised)', linewidth=2)
axes[1].plot(noisy_norm.numpy(), label='noisy (normalised)', linewidth=1, alpha=0.8)
axes[1].set_title('Normalised spectra')
axes[1].legend()
axes[1].set_xlabel('Spectral point index')
axes[1].set_ylabel('Scaled intensity')
plt.tight_layout()
plt.show()


## 8. Summary

- Validation intervals are cloned after overriding the training domain and clamped to guarantee they remain nested.
- The registered normalisation parameters align with the narrowed training space so every downstream component sees consistent scaling.
- Visualisations confirm that the notebook exposes both the raw and normalised spectra for clean and noisy signals, clarifying the full data-generation pipeline.
