# PyTorch Lightning fine-tuning template

by Andr√©s Mu√±oz-Jaramillo

This notebook is meant to act as a template to train and use a surya model to implement DS application.

It focuses on the concept of defining a modified Surya model, loading its weigths, and using a PyTorch lightning training loop to train it

This notebook assumes familiarity with the concepts of datasets and dataloaders contained in the **_0_dataset_dataloader_template.ipynb_**

It doesn't require having seen the baselines template, but they are meant to complement each other.  **_In fact they are on purpose almost identical!!!_**

## Set your cuda visible device

**IMPORTANT:** Since we are sharing resources, please make sure that the cuda visible device you put here is the one assigned to your team and your machine.   

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5" # using 6 or 7 now

Here we initalize variables related to Weights and Biases, our online logging system to ensure they are user specific

In [2]:
# Make sure wandb logs are stored in a user-specific directory
# Set writable directories
os.environ["WANDB_DIR"] = "./wandb/wandb_logs"
os.environ["WANDB_CACHE_DIR"] = "./wandb/wandb_cache"
os.environ["WANDB_CONFIG_DIR"] = "./wandb/wandb_config"
# Optional:
os.environ["TMPDIR"] = "./wandb/wandb_tmp"

# Ensure directories exist (optional, wandb usually creates them)
os.makedirs(os.environ["WANDB_DIR"], exist_ok=True)
os.makedirs(os.environ["WANDB_CACHE_DIR"], exist_ok=True)
os.makedirs(os.environ["WANDB_CONFIG_DIR"], exist_ok=True)
os.makedirs(os.environ["TMPDIR"], exist_ok=True)

In [3]:

import sys
from torch.utils.data import DataLoader

import torch
import yaml

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, WandbLogger

# Append base path.  May need to be modified if the folder structure changes.
# It gives the notebook access to the wokshop_infrastructure folder.
sys.path.append("../../")
 
# Append Surya path. May need to be modified if the folder structure changes.
# It gives the notebook access to surya's release code.
sys.path.append("../../Surya")

from surya.utils.data import build_scalers  # Data scaling utilities for Surya stacks
from workshop_infrastructure.utils import apply_peft_lora
torch.set_float32_matmul_precision('medium')



## Download scalers and Weights
Surya input data needs to be scaled properly for the model to work and this cell downloads the scaling information.  In this notebook we also download the model weights for finetuning


- If the cell below fails, try running the provided shell script directly in the terminal.
- Sometimes the download may fail due to network or server issues‚Äîif that happens, simply re-run the script a few times until it completes successfully.

In [4]:
!sh download_scalers_and_weights.sh

==> Checking assets directory at: /home/haodijiang/surya_workshop/downstream_apps/haodi/assets
==> Downloading scalers and model weights into: /home/haodijiang/surya_workshop/downstream_apps/haodi/assets
/home/haodijiang/surya_workshop/downstream_apps/haodi/assets
/home/haodijiang/surya_workshop/downstream_apps/haodi/assets
‚úì Done. Files are in: /home/haodijiang/surya_workshop/downstream_apps/haodi/assets


## Load configuration

Surya was designed to read a configuration file that defines many aspects of the model
including the data it uses we use this config file to set default values that do not
need to be modified, but also to define values specific to our downstream application

In [5]:
# Configuration paths - modify these if your files are in different locations
config_path = "./configs/config.yaml"

# Load configuration
print("üìã Loading configuration...")
try:
    config = yaml.safe_load(open(config_path, "r"))
    config["data"]["scalers"] = yaml.safe_load(open(config["data"]["scalers_path"], "r"))
    print("‚úÖ Configuration loaded successfully!")
except FileNotFoundError as e:
    print(f"‚ùå Error: {e}")
    print("Make sure config.yaml exists in your current directory")
    raise

scalers = build_scalers(info=config["data"]["scalers"])

üìã Loading configuration...
‚úÖ Configuration loaded successfully!


## Define Downstream (DS) datasets

This child class takes as input all expected HelioFM parameters, plus additonal parameters relevant to the downstream application.  Here we focus in particular to the DS index and parameters necessary to combine it with the HelioFM index.

Another important component of creating a dataset class for your DS is normalization.  Here we use a log normalization on xray flux that will act as the output target.  Making log10(xray_flux) strictly positive and having 66% of its values between 0 and 1

In this case we will define both a training and a validation dataset using the indices pointed at in the config

**_Important:  In this notebook we sets max_number_of_samples=6 to potentially avoid going through the whole dataset as we explore it.  Keep in mind this for the future in case the database seems smaller than you expect_**


In [6]:
from downstream_apps.haodi.datasets.template_dataset_haodi import FlareDSDataset

In [7]:
train_dataset = FlareDSDataset(
    #### All these lines are required by the parent HelioNetCDFDataset class
    index_path=config["data"]["train_data_path"],
    time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
    time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
    n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
    rollout_steps=config["rollout_steps"],
    channels=config["data"]["channels"],
    drop_hmi_probability=config["drop_hmi_probability"],
    use_latitude_in_learned_flow=config["use_latitude_in_learned_flow"],
    scalers=scalers,
    phase="train",
    s3_use_simplecache = False,
    s3_cache_dir= "/tmp/helio_s3_cache",    
    #### Put your donwnstream (DS) specific parameters below this line
    return_surya_stack=True,
    max_number_of_samples=20, # change from 10 to 50
    # ds_flare_index_path="./data/hek_flare_catalog.csv",
    # ds_time_column="start_time",
    # ds_time_tolerance = "4d",
    # ds_match_direction = "forward"    
    ds_flare_index_path="./data/caiik_2011_2013_EVE_13.5_sample_75_aws.csv",
    ds_time_column="timestep",
    ds_time_tolerance = "6h",
    ds_match_direction = "nearest"
)

# The Validation dataset changes the index we read
val_dataset = FlareDSDataset(
    #### All these lines are required by the parent HelioNetCDFDataset class
    index_path=config["data"]["valid_data_path"],  #<---------------- different index path
    time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
    time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
    n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
    rollout_steps=config["rollout_steps"],
    channels=config["data"]["channels"],
    drop_hmi_probability=config["drop_hmi_probability"],
    use_latitude_in_learned_flow=config["use_latitude_in_learned_flow"],
    scalers=scalers,
    s3_use_simplecache = False,
    s3_cache_dir= "/tmp/helio_s3_cache",    
    #### Put your donwnstream (DS) specific parameters below this line
    return_surya_stack=True,
    max_number_of_samples=10,
    # ds_flare_index_path="./data/hek_flare_catalog.csv",
    # ds_time_column="start_time",
    # ds_time_tolerance = "4d",
    # ds_match_direction = "forward"    
    ds_flare_index_path="./data/caiik_2011_2013_EVE_13.5_sample_75_aws.csv",
    ds_time_column="timestep",
    ds_time_tolerance = "6h",
    ds_match_direction = "nearest"
)

We also intialize separate training and validation dataloaders.   Since we are working in a shared environment.  Using multiprocessing_context="spawn" helps avoid lockups.

In [8]:
batch_size = 2

train_data_loader = DataLoader(
                dataset=train_dataset,
                batch_size=batch_size,
                shuffle=True,
                num_workers=8,
                multiprocessing_context="spawn",
                persistent_workers=True,
                pin_memory=True,
            )

val_data_loader = DataLoader(
                dataset=val_dataset,
                batch_size=batch_size,
                num_workers=8,
                multiprocessing_context="spawn",
                persistent_workers=True,
                pin_memory=True,
            )

## Initialize the HelioSpectformer model

This is the main difference beteween the notebook that trains the simple model and the one that fine-tunes Surya.  

In the case of the finetuning exercise one of the main differences between DS applications is the dimensionality of the output.  In this notebook we use a modified HelioSpectformer that projects into a 1D space. 

**_IMPORTANT: If your DS application is 2D you need to use the HelioSpectformer2D_**

In [9]:
from workshop_infrastructure.models.finetune_models import HelioSpectformer1D
# from surya.models.helio_spectformer import HelioSpectFormer



Now the config file really comes into bear. The Spectformer has a metric ton of hyperparameters

In [10]:
model = HelioSpectformer1D(
    img_size=config["model"]["img_size"],
    patch_size=config["model"]["patch_size"],
    in_chans=config["model"]["in_channels"],
    embed_dim=config["model"]["embed_dim"],
    time_embedding=config["model"]["time_embedding"],
    depth=config["model"]["depth"],
    num_heads=config["model"]["num_heads"],
    mlp_ratio=config["model"]["mlp_ratio"],
    drop_rate=config["model"]["drop_rate"],
    dtype=config["dtype"],
    window_size=config["model"]["window_size"],
    dp_rank=config["model"]["dp_rank"],
    learned_flow=config["model"]["learned_flow"],
    use_latitude_in_learned_flow=config["use_latitude_in_learned_flow"],
    init_weights=config["model"]["init_weights"],
    checkpoint_layers=config["model"]["checkpoint_layers"],
    n_spectral_blocks=config["model"]["spectral_blocks"],
    rpe=config["model"]["rpe"],
    ensemble=config["model"]["ensemble"],
    finetune=config["model"]["finetune"],
    nglo=config["model"]["nglo"],
    # Put finetuning additions below this line
    dropout=config["model"]["dropout"],
    num_penultimate_transformer_layers=0,
    num_penultimate_heads=0,
    num_outputs=1,
    config=config,
)

## Load model weights

Here we load the pre-trained checkpoint and load the weights.  The exercise of loading follows the idea of us as many of the weights as possible.  This is accomplished through the filtered_checkpoint_state.   It checks to see if the pretrained model's layers match those of your finetuning architecture.   It also checks that all your dimensions across layers check out.   If something does not work those paramameters are left in their random initialization. 

In [11]:
model_state = model.state_dict()
checkpoint_state = torch.load(config["pretrained_path"], weights_only=True, map_location="cpu")
filtered_checkpoint_state = {
    k: v
    for k, v in checkpoint_state.items()
    if k in model_state and v.shape == model_state[k].shape
}

# 2. Load the filtered weights
model_state.update(filtered_checkpoint_state)
model.load_state_dict(model_state, strict=True)

<All keys matched successfully>

## To LoRA or not to Lora

This cell gives you two options.  On the one hand we have the classic freezing of the backbone (the initial layers of the model).   On the other hand we have the use of a LoRA.

LoRas have been a remarkable addition to our arsenal of models.   They have the advantage of keeping pretty much the entire model intact and only add broad modifications to weights as needed.

In [12]:
use_LoRa = True

if use_LoRa:
    model = apply_peft_lora(model, config)
else:
    for name, param in model.named_parameters():
        if "embedding" in name or "backbone" in name:
            param.requires_grad = False
    parameters_with_grads = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            parameters_with_grads.append(name)
    print(
        f"{len(parameters_with_grads)} parameters require gradients: {', '.join(parameters_with_grads)}."
    )

Applying PEFT LoRA with configuration: {'r': 8, 'lora_alpha': 8, 'target_modules': ['q_proj', 'v_proj', 'k_proj', 'out_proj', 'fc1', 'fc2'], 'lora_dropout': 0.1, 'bias': 'none'}
trainable params: 1,024,000 || all params: 360,333,313 || trainable%: 0.28%


We can now test that this model manipulates a batch as expected and returns an estimate of flare intensity as we did for the simple baseline.

We pass the input stack 'ts' to the model to transform it into our regression output.   Note that since this model was trained for a different task, it's likely it won't perform very well.  As with the simple baseline, this only acts as a test that our model forward doesn't have dimension problems.

Dimension problemns are the dominant source of error in this kind of work.

Note that our output has now the size of our batch.

In [13]:
# batch = next(iter(train_data_loader))
# output = model.forward(batch)  # Get rid of singleton dimension
# output

## Define your metrics

Metrics are a very important part of training AI models.   They provide your models with the quantitification of error, which in turn shifts the weights towards better pefrorming models.  They also provide a way for you to monitor performance, identify overfitting, and quantify value added. 

We now initialize the metrics class which allows you to control what metrics do you want to use as "loss" (i.e. the metrics that backpropagate through your model) and which ones for monitoring performance.  As with other components, this takes the form of a loaded module that can be later use in a training script

In [14]:
from downstream_apps.haodi.metrics.template_metrics import FlareMetrics

In [15]:
train_loss_metrics = FlareMetrics("train_loss")
train_evaluation_metrics = FlareMetrics("train_metrics")
validation_evaluation_metrics = FlareMetrics("val_metrics")

Now they can be evaluated in our model's output and our ground truth.   First the loss that actually will backpropagate, in this case Mean Squared Errror

In [16]:
#train_loss_metrics(output, batch["forecast"])

Then a training evaluation that will not backpropagate and inform our model, but that we can keep an eye on. Note that reporting lots of metrics during training will slow the training process.  I'm including it her as an example, but oftentimes is better to put the diagnostics only in the validation evaluation metrics.

Here we are caclulating the Root Relative Squared Error https://lightning.ai/docs/torchmetrics/stable/regression/rse.html 

A value below one means the prediction is better than predicting the average.  It is unlikely that this metric will be lower than one with a randomly initialized model

In [17]:
# train_evaluation_metrics(output, batch["forecast"])

In the validation evaluation metrics we report both MSE and RRSE

In [18]:
# validation_evaluation_metrics(output, batch["forecast"])

## Define your PyTorch ligthning module

In this workshop we will use PyTorch lightning to train our models.  PyTorch lighting reduces the amount of code required to implement a training loop in comparison to PyTorch (at the expense of control and versatility).  

Opening the FlareLightningModule shows a simple Lightning model implementation.  It consists of:

- An initialization of the class (metrics, model, and learning rate).
- The forward code that runs evaluation of the model.
- Training and validation steps.
- Configuration of optimizers.

**_Note that it is the same Lightning module we used for the baseline!!_**

In [19]:
from downstream_apps.haodi.lightning_modules.pl_simple_baseline import FlareLightningModule
# modify FlareLightningModule to handle caiik inputs

## Set your global seeds

Since training AI models generally uses stochastic gradient descent, it is a good idea to fix your random seeds so that your training exercise is reproducible.    

In [20]:
L.seed_everything(42, workers=True)

Seed set to 42


42

## Intialize Lightning module

Now we properly initalize the Lightning module to enable training, including passing the dictionary of metrics

In [21]:
metrics = {'train_loss': train_loss_metrics,
           'train_metrics': train_evaluation_metrics,
           'val_metrics': validation_evaluation_metrics}

learning_rate = 1e-4 # change from 1e-3 to 1e-4
# lit_model = FlareLightningModule(model, metrics, lr=learning_rate, batch_size=batch_size)

# add after 1/15/2026
lit_model = FlareLightningModule(
    model=model,
    metrics=metrics,
    lr=learning_rate,
    batch_size=batch_size,
    eve_log10_min=train_dataset.eve_log10_min,
    eve_log10_scale=train_dataset.eve_log10_scale,
)
print(
    "EVE inverse params:",
    lit_model.eve_log10_min,
    lit_model.eve_log10_scale
)

EVE inverse params: -5.078961027729049 0.12238584732591484


## Logging

In order to properly compare experiments against each other, it is very useful to log evaluation metrics in a place where they can be compared against other training runs.  In this workshop we will use Weights and Biases (WandB). 

The first time you run WandB in a machine it will ask you to login to WandB.  You should have received an invitation to our project.  In order to login you must:

- Select option 2 (existing account).   In VScode the dialog opens a box at the top of your screen.
- Click on get API Key (this will open a browser).
- Generate API Key.
- Paste it in the dialog box at the top of your VSCode

In [22]:
# project_name = "template_flare_regression"
project_name = "template_EVE_13.5_regression"
# run_name = "baseline_experiment_1"
run_name = "HJ_surya_finetune_experiment_t20_v10_e10_lr-4_eve13p5"  # training 50, val - 20, epoch 20, lr: 1e-4, lr:1e-5 is not good

wandb_logger = WandbLogger(
    entity="surya_handson",
    project=project_name,
    name=run_name,
    log_model=False,
    save_dir="./wandb/wandb_tmp",
)

# csv_logger = CSVLogger("runs", name="simple_flare")
csv_logger = CSVLogger("runs", name="simple_EVE_13.5_3")

## Initialize trainer

With the loggers done, now the trainer needs to be defined.  The trainer defines several properties of your training run. Here we define:

- The max number of epochs (one epoch represents your model seeing your entire training dataset).
- Define where the training run will take place (auto uses the GPU if possible, if not, CPU).
- The loggers.
- The callbacks (here we save the model with the lowest validation loss).
- Logging frequency (because we are working with a small dataset it needs to be small).


**Note that in this notebook we also set a mixed precision to reduce the model's footprint in memory.**

In [23]:
max_epochs = 10 # change from 2 to 50

# -------------------------------------------------------------------------
# Trainer
# -------------------------------------------------------------------------
trainer = L.Trainer(
    max_epochs=max_epochs,
    accelerator="auto",
    devices="auto",
    precision="bf16-mixed", 
    logger=[wandb_logger, csv_logger],
    callbacks=[
        ModelCheckpoint(
            monitor="val_loss",
            mode="min",
            save_top_k=1,
        )
    ],
    log_every_n_steps=2,
)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


## Fit the model

Finally we fit the model.  We pass the Lighting module, and our dataloaders.

In [24]:
trainer.fit(lit_model, train_data_loader, val_data_loader) #hj78@njit.edu 

[34m[1mwandb[0m: Currently logged in as: [33mhaodijiang[0m ([33msurya_handson[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [5]
/opt/anaconda3/envs/surya_ws/lib/python3.12/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:242: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name                 | Type      | Params | Mode  | FLOPs
-------------------------------------------------------------------
0 | model                | PeftModel | 360 M  | train | 0    
1 | caiik_to_surya_layer | Conv2d    | 26     | train | 0    
-------------------------------------------------------------------
1.0 M     Trainable params
359 M     Non-trainable params
360 M     Total params
1,441.333 Total estimated model params size (MB)
370       Modules in train mode
0         Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


## Conclusion

With this we have now integrated our dataset, dataloaders, metrics, and DS into an end-2-end training loop and we are ready to experiment!