This notebook provides a short sample on how to train a s2s benchmark model. It uses `pytorch lightning` module, and uses MLP as an example.

The complete training script can be found in the root directory of the repository `train.py`

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import yaml
import torch
import lightning.pytorch as pl
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks import ModelCheckpoint
pl.seed_everything(42)

import sys
sys.path.append('..')

from chaosbench import dataset, config, utils, criterion
from chaosbench.models import model, mlp, cnn, ae, fno, vit


[rank: 0] Global seed set to 42
  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Load config filepath which consists of all the definition needed to fit/eval a model

model_config_filepath = '../chaosbench/configs/segformer_s2s.yaml'
with open(model_config_filepath, 'r') as config_filepath:
    hyperparams = yaml.load(config_filepath, Loader=yaml.FullLoader)

model_args = hyperparams['model_args']
data_args = hyperparams['data_args']


In [None]:
# This is how the hyperparameters are structured: 
# `model_args` for model specification
# `data_args` for data definition

hyperparams


In [None]:
# Initialize model
# By passing the necessary hyperparameters (model + dataset)

baseline = model.S2SBenchmarkModel(model_args=model_args, data_args=data_args)
baseline.setup()


In [None]:
# Setup trainer
# Including tensorboard logger and checkpoint callback (eg. saving top-1 based on lowest validation error)

tb_logger = pl_loggers.TensorBoardLogger(save_dir=f'logs/{model_args["model_name"]}')
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min')

trainer = pl.Trainer(
    devices=-1,
    accelerator='gpu',
    strategy='auto',
    max_epochs=model_args['epochs'],
    logger=tb_logger,
    callbacks=[checkpoint_callback]
 )


In [None]:
# Fit the model
# Checkpoint can be found under `logs/<MODEL_NAME>`

trainer.fit(baseline)


## Training the UNet and FNO models with the specified variables

In [None]:
# Load config filepath which consists of all the definition needed to fit/eval a model

unet_model_config_filepath = '../chaosbench/configs/unet_s2s.yaml'
fno_model_config_filepath = '../chaosbench/configs/fno_s2s.yaml'

with open(unet_model_config_filepath, 'r') as config_filepath:
    unet_hyperparams = yaml.load(config_filepath, Loader=yaml.FullLoader)

with open(fno_model_config_filepath, 'r') as config_filepath:
    fno_hyperparams = yaml.load(config_filepath, Loader=yaml.FullLoader)

unet_model_args = unet_hyperparams['model_args']
unet_data_args = unet_hyperparams['data_args']

fno_model_args = fno_hyperparams['model_args']
fno_data_args = fno_hyperparams['data_args']


In [None]:
# This is how the hyperparameters are structured: 
# `model_args` for model specification
# `data_args` for data definition

unet_hyperparams
fno_hyperparams


In [None]:
# Initialize models
# By passing the necessary hyperparameters (model + dataset)

unet_baseline = model.S2SBenchmarkModel(model_args=unet_model_args, data_args=unet_data_args)
unet_baseline.setup()

fno_baseline = model.S2SBenchmarkModel(model_args=fno_model_args, data_args=fno_data_args)
fno_baseline.setup()


In [None]:
# Setup trainers
# Including tensorboard logger and checkpoint callback (eg. saving top-1 based on lowest validation error)

unet_tb_logger = pl_loggers.TensorBoardLogger(save_dir=f'logs/{unet_model_args["model_name"]}')
unet_checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min')

unet_trainer = pl.Trainer(
    devices=-1,
    accelerator='gpu',
    strategy='auto',
    max_epochs=unet_model_args['epochs'],
    logger=unet_tb_logger,
    callbacks=[unet_checkpoint_callback]
 )

fno_tb_logger = pl_loggers.TensorBoardLogger(save_dir=f'logs/{fno_model_args["model_name"]}')
fno_checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min')

fno_trainer = pl.Trainer(
    devices=-1,
    accelerator='gpu',
    strategy='auto',
    max_epochs=fno_model_args['epochs'],
    logger=fno_tb_logger,
    callbacks=[fno_checkpoint_callback]
 )


In [None]:
# Fit the models
# Checkpoints can be found under `logs/<MODEL_NAME>`

unet_trainer.fit(unet_baseline)
fno_trainer.fit(fno_baseline)
