# Stage A Training: Backbone Encoder

This notebook walks through Stage A of the Physically Informed Autoencoder (PhysAE) training pipeline.
Stage A optimises the EfficientNet backbone and parameter prediction heads while the cascade refiners remain frozen.

## Prerequisites

* Python 3.10 or newer
* PyTorch, PyTorch Lightning, and the optional Lion optimizer (`pip install torch pytorch-lightning lion-pytorch`)
* (Optional) HITRAN TIPS_2021 partition files placed in a `QTpy/` directory at the repository root

Execute the first code cell below to install runtime dependencies when using a fresh environment.

In [None]:
# !pip install -q torch pytorch-lightning lion-pytorch numpy pandas matplotlib scipy
# Uncomment the line above when running inside a clean environment.

In [None]:
from __future__ import annotations

from pathlib import Path
from types import SimpleNamespace
import sys

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from torch.utils.data import DataLoader
import torch

# Resolve repository root so we can import the project package
REPO_ROOT = Path.cwd().resolve()
while not (REPO_ROOT / 'project').exists():
    if REPO_ROOT.parent == REPO_ROOT:
        raise RuntimeError('Run this notebook from inside the physae repository.')
    REPO_ROOT = REPO_ROOT.parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from config.data_config import load_noise_profile, load_parameter_ranges, load_transitions
from config.params import PARAMS, NORM_PARAMS
from data.dataset import SpectraDataset
from models.autoencoder import PhysicallyInformedAE
from physics.tips import Tips2021QTpy, find_qtpy_dir
from training.callbacks.epoch_sync import UpdateEpochInDataset


In [None]:
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
stage_a = SimpleNamespace(
    seed=42,
    train_samples=4096,
    val_samples=512,
    num_points=1024,
    batch_size=64,
    num_workers=4,
    epochs=60,
    lr=1e-4,
    weight_decay=1e-2,
    mlp_dropout=0.10,
    refiner_dropout=0.05,
    backbone_variant='s',
    refiner_variant='s',
    refine_steps=1,
    refine_delta_scale=0.10,
    refine_warmup_epochs=30,
    freeze_base_epochs=20,
    stage3_lr_shrink=0.33,
    gpus=1,
    precision='32',
    log_every_n_steps=50,
    checkpoint_dir=REPO_ROOT / 'checkpoints' / 'stage_a',
    log_dir=REPO_ROOT / 'logs' / 'stage_a',
    qtpy_dir=REPO_ROOT / 'QTpy',
)

stage_a.checkpoint_dir.mkdir(parents=True, exist_ok=True)
stage_a.log_dir.mkdir(parents=True, exist_ok=True)

pl.seed_everything(stage_a.seed)


In [None]:
def make_linear_frequency_grid(num_points: int, start: float = 5995.0, end: float = 6005.0) -> list[float]:
    step = (end - start) / num_points
    coeffs = [start, step, 0.0]
    print(f'Frequency grid: {num_points} points from {start} to {end} cm^-1')
    print(f'Polynomial coefficients: {coeffs}')
    return coeffs

parameters_path = REPO_ROOT / 'project' / 'config' / 'data' / 'parameters_default.yaml'
noise_path = REPO_ROOT / 'project' / 'config' / 'data' / 'noise_default.yaml'
transitions_path = REPO_ROOT / 'project' / 'config' / 'data' / 'transitions_sample.yaml'

parameter_ranges = load_parameter_ranges(parameters_path)
noise_profile = load_noise_profile(noise_path)
transitions = load_transitions(transitions_path)
# For the quick-start example we keep only CH4 lines to match the default PARAMS list
transitions = {'CH4': transitions.get('CH4', [])}

poly_freq = make_linear_frequency_grid(stage_a.num_points)

try:
    qtpy_dir = find_qtpy_dir(stage_a.qtpy_dir)
    tipspy = Tips2021QTpy(qtpy_dir, device='cpu')
    print(f'TIPS data loaded from: {qtpy_dir}')
except FileNotFoundError:
    tipspy = None
    print('QTpy directory not found; continuing without partition functions.')

train_dataset = SpectraDataset(
    n_samples=stage_a.train_samples,
    num_points=stage_a.num_points,
    poly_freq_CH4=poly_freq,
    transitions_dict=transitions,
    sample_ranges=NORM_PARAMS,
    with_noise=True,
    noise_profile=noise_profile,
    freeze_noise=False,
    tipspy=tipspy,
)
val_dataset = SpectraDataset(
    n_samples=stage_a.val_samples,
    num_points=stage_a.num_points,
    poly_freq_CH4=poly_freq,
    transitions_dict=transitions,
    sample_ranges=NORM_PARAMS,
    with_noise=True,
    noise_profile=noise_profile,
    freeze_noise=True,
    tipspy=tipspy,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=stage_a.batch_size,
    shuffle=True,
    num_workers=stage_a.num_workers,
    pin_memory=True,
    persistent_workers=stage_a.num_workers > 0,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=stage_a.batch_size,
    shuffle=False,
    num_workers=stage_a.num_workers,
    pin_memory=True,
    persistent_workers=stage_a.num_workers > 0,
)

print(f'Training batches: {len(train_loader)}')
print(f'Validation batches: {len(val_loader)}')


In [None]:
model_kwargs = dict(
    n_points=stage_a.num_points,
    param_names=PARAMS,
    poly_freq_CH4=poly_freq,
    transitions_dict=transitions,
    tipspy=tipspy,
    lr=stage_a.lr,
    mlp_dropout=stage_a.mlp_dropout,
    refiner_mlp_dropout=stage_a.refiner_dropout,
    backbone_variant=stage_a.backbone_variant,
    refiner_variant=stage_a.refiner_variant,
    refine_steps=stage_a.refine_steps,
    refine_delta_scale=stage_a.refine_delta_scale,
    refine_warmup_epochs=stage_a.refine_warmup_epochs,
    freeze_base_epochs=stage_a.freeze_base_epochs,
    stage3_lr_shrink=stage_a.stage3_lr_shrink,
)
model = PhysicallyInformedAE(**model_kwargs)
model.weight_decay = stage_a.weight_decay
model.base_lr = stage_a.lr
model.refiner_lr = stage_a.lr
model.set_stage_mode('A', refine_steps=stage_a.refine_steps, delta_scale=stage_a.refine_delta_scale)
print(model)


In [None]:
checkpoint_cb = ModelCheckpoint(
    dirpath=stage_a.checkpoint_dir,
    filename='physae-stage-a-{epoch:03d}-{val_loss:.4f}',
    monitor='val_loss',
    mode='min',
    save_last=True,
    save_top_k=3,
)
early_stop_cb = EarlyStopping(monitor='val_loss', mode='min', patience=15, verbose=True)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
epoch_sync = UpdateEpochInDataset()

accelerator = 'gpu' if stage_a.gpus > 0 and torch.cuda.is_available() else 'cpu'
devices = stage_a.gpus if accelerator == 'gpu' else 1

trainer = pl.Trainer(
    max_epochs=stage_a.epochs,
    accelerator=accelerator,
    devices=devices,
    precision=stage_a.precision,
    callbacks=[checkpoint_cb, early_stop_cb, lr_monitor, epoch_sync],
    log_every_n_steps=stage_a.log_every_n_steps,
    gradient_clip_val=1.0,
    default_root_dir=stage_a.log_dir,
    enable_model_summary=True,
)


In [None]:
trainer.fit(model, train_loader, val_loader)


## Next steps

* Explore the saved checkpoints under `checkpoints/stage_a/`.
* Launch Stage B1 training (refiners with the backbone frozen) once validation loss plateaus.
* Monitor TensorBoard logs from `logs/stage_a/` to inspect optimisation curves.