# Tutorial 2: Model pre-training

In [None]:
import datetime
import os
from argparse import Namespace

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts
from dateutil import tz
from einops import rearrange
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import (EarlyStopping, LearningRateMonitor,
                                         ModelCheckpoint, Callback)
from pytorch_lightning.loggers import WandbLogger
from model.datasets.data_module import DataModule
from model.datasets.pretrain_dataset import (SpatialRadiusDataset, 
                                             my_collate_fn)
from model.emb_gen.backbones.encoder import BertEncoder
from step1_pretrain import Omics, EpochCallback
torch.autograd.set_detect_anomaly(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
from model.constants import *

is_debug = False
is_save_ckpt = True
if is_debug:
    os.environ['WANDB_MODE'] = 'disabled'
else:
    pass

## Initialize the args and fix seeds

In [None]:
# Specify the GPU to use
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Directly specify the training parameters from the original command
args = Namespace(
    gpus=1,
    batch_size=500,
    epochs=3,
    max_epochs=3,
    mask_ratio=0.4,
    experiment_name="cosmx_1",
    learning_rate=5e-6,
    config=os.path.join(BASE_DIR, '../../configs/bert_config.json'),
    dataset_name="cosmx_lung5_rep1",
    fold="fold_1",
    max_points=20,
    radius=20,
    mask_function="dynamic",
    num_workers=10,
    seed=42,
    data_pct=1.0,
    deterministic=True,
    freeze_bert=False,
    emb_dim=192,
    lambda_1=1.0,
    momentum=0.9,
    weight_decay=0.05,
    output_dim=768,
    hidden_dim=768,
    ct_obs="cell_class",
    ckpt_path=None
)

# fix the seed to ensure reproducibility of the experiment
seed_everything(args.seed)

## Set fold split

In [None]:
if args.dataset_name == 'seqfish': # dataset: seqFISH+ 3T3
    train_split = SEQFISH_FOLDS[args.fold]['train']
    val_split = SEQFISH_FOLDS[args.fold]['val']
elif args.dataset_name == 'merfish': # MERFISH U2-OS
    train_split = MERFISH_FOLDS[args.fold]['train']
    val_split = MERFISH_FOLDS[args.fold]['val']
elif args.dataset_name == 'mop1' or args.dataset_name == 'mop1_filtered': # dataset: MOp
    train_split = MOP_FOLDS1[args.fold]['train']
    val_split = MOP_FOLDS1[args.fold]['val']
elif args.dataset_name == 'cosmx_lung5_rep1': # dataset: CosMx lung
    train_split = COSMX_FOLDS51[args.fold]['train']
    val_split = COSMX_FOLDS51[args.fold]['val']
elif args.dataset_name == 'AD_64g_m9721' or args.dataset_name == 'AD_64g_m9781' \
    or args.dataset_name == 'AD_64g_m9919' or args.dataset_name == 'AD_64g_m9930': # dataset: STARmap PLUS
    train_split = AD_FOLDS[args.fold]['train']
    val_split = AD_FOLDS[args.fold]['val']
elif args.dataset_name == 'xenium_hbc1' or args.dataset_name == 'xenium_hbc1_rep2': # dataset: Xenium HBC
    train_split = XENIUM_HBC_FOLDS1[args.fold]['train']
    val_split = XENIUM_HBC_FOLDS1[args.fold]['val']
elif args.dataset_name == 'AD_2766g_m9498' or args.dataset_name == 'AD_2766g_m9707' \
    or args.dataset_name == 'AD_2766g_m9735' or args.dataset_name == 'AD_2766g_m9494' \
        or args.dataset_name == 'AD_2766g_m11346' or args.dataset_name == 'AD_2766g_m9723' \
            or args.dataset_name == 'AD_2766g_m11351': # dataset: STARmap PLUS
    train_split = AD_FOLDS[args.fold]['train']
    val_split = AD_FOLDS[args.fold]['val']
else:
    raise NotImplementedError

## Initialize dataloaders and SpotFormer

In [None]:
epoch_callback = EpochCallback()

datamodule = DataModule(SpatialRadiusDataset, my_collate_fn,
                        args.data_pct, args.batch_size, 
                        args.num_workers, radius=args.radius, 
                        mask_ratio=args.mask_ratio, mask_function=args.mask_function,
                        dataset_name=args.dataset_name, max_points=args.max_points,
                        train_split=train_split, val_split=val_split,
                        label_type='pretrain', callback=epoch_callback)
# Add load from checkpoint
if args.ckpt_path:
    model = Omics.load_from_checkpoint(args.ckpt_path)
else:
    model = Omics(**args.__dict__)

## Initialize the trainer

In [None]:
# get current time
now = datetime.datetime.now(tz.tzlocal())
extension = now.strftime("%Y_%m_%d_%H_%M_%S")
ckpt_dir = os.path.join(
    BASE_DIR, f"../../../data/ckpts/Omics/{args.experiment_name}_{args.dataset_name}_{extension}")
os.makedirs(ckpt_dir, exist_ok=True)
callbacks = [
    LearningRateMonitor(logging_interval="step"),
    ModelCheckpoint(monitor="val_loss", dirpath=ckpt_dir,
                    save_last=False, mode="min", save_top_k=1),
    EarlyStopping(monitor="val_loss", min_delta=0.,
                    patience=5, verbose=False, mode="min"),
    epoch_callback
]
logger_dir = os.path.join(
    BASE_DIR, f"../../../data")
os.makedirs(logger_dir, exist_ok=True)
wandb_logger = WandbLogger(
    project="SpotFormer_pretrain", save_dir=logger_dir, 
    name=args.experiment_name+"_"+args.dataset_name+"_"+extension)
trainer = Trainer.from_argparse_args(
    args=args,
    callbacks=callbacks,
    logger=wandb_logger,
    precision=16,
    gradient_clip_val=0.5)

model.training_steps = model.num_training_steps(trainer, datamodule)
print('number of model training steps: ', model.training_steps)

## Model training

In [None]:
trainer.fit(model, datamodule=datamodule)

if is_save_ckpt:
    best_ckpt_path = os.path.join(ckpt_dir, "best_ckpts.yaml")
    callbacks[1].to_yaml(filepath=best_ckpt_path)