# Stage B2 Training: Full-Model Fine-Tuning

Stage B2 re-enables the EfficientNet backbone and jointly tunes all modules with a smaller learning rate.
Start from the best Stage B1 checkpoint for best results.

## Prerequisites

* Completed Stage A and Stage B1 training
* PyTorch + PyTorch Lightning environment with matching dependencies
* Path to the Stage B1 checkpoint

In [None]:
# !pip install -q torch pytorch-lightning lion-pytorch numpy pandas matplotlib scipy
# Uncomment if you need to install the runtime dependencies.

In [None]:
from __future__ import annotations

from pathlib import Path
from types import SimpleNamespace
import sys

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from torch.utils.data import DataLoader
import torch

REPO_ROOT = Path.cwd().resolve()
while not (REPO_ROOT / 'project').exists():
    if REPO_ROOT.parent == REPO_ROOT:
        raise RuntimeError('Run this notebook from inside the physae repository.')
    REPO_ROOT = REPO_ROOT.parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from config.data_config import load_noise_profile, load_parameter_ranges, load_transitions
from config.params import PARAMS, NORM_PARAMS
from data.dataset import SpectraDataset
from models.autoencoder import PhysicallyInformedAE
from physics.tips import Tips2021QTpy, find_qtpy_dir
from training.callbacks.epoch_sync import UpdateEpochInDataset


In [None]:
stage_b2 = SimpleNamespace(
    seed=42,
    train_samples=4096,
    val_samples=512,
    num_points=1024,
    batch_size=48,
    num_workers=4,
    epochs=30,
    base_lr=3e-6,
    refiner_lr=3e-6,
    weight_decay=1e-2,
    mlp_dropout=0.10,
    refiner_dropout=0.05,
    backbone_variant='s',
    refiner_variant='s',
    refine_steps=2,
    refine_delta_scale=0.05,
    stage3_lr_shrink=0.5,
    qtpy_dir=REPO_ROOT / 'QTpy',
    checkpoint_dir=REPO_ROOT / 'checkpoints' / 'stage_b2',
    log_dir=REPO_ROOT / 'logs' / 'stage_b2',
    stage_b1_checkpoint=REPO_ROOT / 'checkpoints' / 'stage_b1' / 'last.ckpt',
    gpus=1,
    precision='32',
    log_every_n_steps=50,
)

stage_b2.checkpoint_dir.mkdir(parents=True, exist_ok=True)
stage_b2.log_dir.mkdir(parents=True, exist_ok=True)

pl.seed_everything(stage_b2.seed)
assert stage_b2.stage_b1_checkpoint.exists(), f'Stage B1 checkpoint not found: {stage_b2.stage_b1_checkpoint}'


In [None]:
def make_linear_frequency_grid(num_points: int, start: float = 5995.0, end: float = 6005.0) -> list[float]:
    step = (end - start) / num_points
    return [start, step, 0.0]

parameters_path = REPO_ROOT / 'project' / 'config' / 'data' / 'parameters_default.yaml'
noise_path = REPO_ROOT / 'project' / 'config' / 'data' / 'noise_default.yaml'
transitions_path = REPO_ROOT / 'project' / 'config' / 'data' / 'transitions_sample.yaml'

parameter_ranges = load_parameter_ranges(parameters_path)
noise_profile = load_noise_profile(noise_path)
transitions = load_transitions(transitions_path)
transitions = {'CH4': transitions.get('CH4', [])}
poly_freq = make_linear_frequency_grid(stage_b2.num_points)

try:
    qtpy_dir = find_qtpy_dir(stage_b2.qtpy_dir)
    tipspy = Tips2021QTpy(qtpy_dir, device='cpu')
except FileNotFoundError:
    tipspy = None
    print('QTpy directory not found; continuing without partition functions.')

train_dataset = SpectraDataset(
    n_samples=stage_b2.train_samples,
    num_points=stage_b2.num_points,
    poly_freq_CH4=poly_freq,
    transitions_dict=transitions,
    sample_ranges=NORM_PARAMS,
    with_noise=True,
    noise_profile=noise_profile,
    freeze_noise=False,
    tipspy=tipspy,
)
val_dataset = SpectraDataset(
    n_samples=stage_b2.val_samples,
    num_points=stage_b2.num_points,
    poly_freq_CH4=poly_freq,
    transitions_dict=transitions,
    sample_ranges=NORM_PARAMS,
    with_noise=True,
    noise_profile=noise_profile,
    freeze_noise=True,
    tipspy=tipspy,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=stage_b2.batch_size,
    shuffle=True,
    num_workers=stage_b2.num_workers,
    pin_memory=True,
    persistent_workers=stage_b2.num_workers > 0,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=stage_b2.batch_size,
    shuffle=False,
    num_workers=stage_b2.num_workers,
    pin_memory=True,
    persistent_workers=stage_b2.num_workers > 0,
)


In [None]:
model_kwargs = dict(
    n_points=stage_b2.num_points,
    param_names=PARAMS,
    poly_freq_CH4=poly_freq,
    transitions_dict=transitions,
    tipspy=tipspy,
    lr=stage_b2.base_lr,
    mlp_dropout=stage_b2.mlp_dropout,
    refiner_mlp_dropout=stage_b2.refiner_dropout,
    backbone_variant=stage_b2.backbone_variant,
    refiner_variant=stage_b2.refiner_variant,
    refine_steps=stage_b2.refine_steps,
    refine_delta_scale=stage_b2.refine_delta_scale,
    stage3_lr_shrink=stage_b2.stage3_lr_shrink,
)
model = PhysicallyInformedAE.load_from_checkpoint(
    stage_b2.stage_b1_checkpoint,
    **model_kwargs,
)
model.weight_decay = stage_b2.weight_decay
model.base_lr = stage_b2.base_lr
model.refiner_lr = stage_b2.refiner_lr
model.set_stage_mode('B2', refine_steps=stage_b2.refine_steps, delta_scale=stage_b2.refine_delta_scale)
model.stage3_lr_shrink = stage_b2.stage3_lr_shrink
print('Loaded checkpoint:', stage_b2.stage_b1_checkpoint)


In [None]:
checkpoint_cb = ModelCheckpoint(
    dirpath=stage_b2.checkpoint_dir,
    filename='physae-stage-b2-{epoch:03d}-{val_loss:.4f}',
    monitor='val_loss',
    mode='min',
    save_last=True,
    save_top_k=3,
)
early_stop_cb = EarlyStopping(monitor='val_loss', mode='min', patience=10, verbose=True)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
epoch_sync = UpdateEpochInDataset()

accelerator = 'gpu' if stage_b2.gpus > 0 and torch.cuda.is_available() else 'cpu'
devices = stage_b2.gpus if accelerator == 'gpu' else 1

trainer = pl.Trainer(
    max_epochs=stage_b2.epochs,
    accelerator=accelerator,
    devices=devices,
    precision=stage_b2.precision,
    callbacks=[checkpoint_cb, early_stop_cb, lr_monitor, epoch_sync],
    log_every_n_steps=stage_b2.log_every_n_steps,
    gradient_clip_val=1.0,
    default_root_dir=stage_b2.log_dir,
    enable_model_summary=True,
)


In [None]:
trainer.fit(model, train_loader, val_loader)


## Recommendations

* Use the best Stage B1 checkpoint (lowest validation loss) for initialisation.
* Reduce `refine_steps` back to 1 if overfitting appears early.
* After B2 converges you can optionally enable the denoiser stage (DEN) using the same pattern.