This is a tutorial of how to trian Reg2ST on your own computer.
I beleive you have already create a  Reg2ST environment and downloaded all datasets needed.

### Step one: Import moudules needed.

In [None]:
from herst import ViT_HER2ST, ViT_SKIN
from model import Reg2ST
from pytorch_lightning.loggers import CSVLogger
import torch
import numpy as np
import os
import argparse

from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from performance import get_R
from utils import *

### Step Two: Set some parameters.

In [None]:
def parser_option():
    # args = dict2namespace(config=update_config_from_file(args.cfg_file))
    
    parser = argparse.ArgumentParser(
        'gene prediction', add_help=False)
    parser.add_argument('--name', type=str, default='Reg2ST')
    # preprocess
    parser.add_argument('--dataset', type=str, default='her2st')

    # model
    parser.add_argument('--dim_in', type=int, default=1024)
    parser.add_argument('--dim_hidden', type=int, default=256)
    parser.add_argument('--dim_out', type=int, default=785)
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--wikg_top', type=int, default=6)
    parser.add_argument('--decoder_layer', type=int, default=6)
    parser.add_argument('--decoder_head', type=int, default=8)

    parser.add_argument('--mask_rate', type=float, default=0.75)
    parser.add_argument('--w_con', type=float, default=0.5)    
    parser.add_argument('--w_zinb', type=float, default=0.25)

    # trains
    parser.add_argument('--epochs', type=int, default=400)
    # parser.add_argument("--fold", type=int, default=0, help="fold number")
    parser.add_argument('--device_id', type=int, default=0)
    args_cmd, _ = parser.parse_known_args()
    # print(type(vars(args_cmd)))
    return args_cmd

args = parser_option()

### Step Three: Train the model.
All code is implemetend through Pytorch-Lightning. Here are some detailed explaination about classes used in code.

The `Trainer` is the core class in PyTorch Lightning that simplifies the training process. It encapsulates all logic related to training, validation, testing, and prediction, allowing us to focus solely on model implementation. You can control the training flow flexibly by setting parameters such as max_epochs, gpus, and callbacks. 

`CSVLogger` is a logging utility that records training and validation metrics (such as loss and accuracy) for each epoch into a .csv file. This file can be used for further analysis, such as plotting performance curves or comparing different experiments. 

`ModelCheckpoint` is a callback function used to automatically save the model during training. You can configure it to save models based on specific metrics (e.g., validation loss or accuracy) and choose whether to save only the best model or one for each epoch. It is very useful for model recovery, hyperparameter tuning, and deployment.

In [None]:
def train(args):
    torch.set_float32_matmul_precision('high')
    i = args.fold

    save_dir = f"{args.dataset}_model/"
    
    # trained models are saved in save_dir and named as fold{fold}_model.ckpt.
    val_checkpoint_callback = ModelCheckpoint(
    dirpath=save_dir,
    filename=f"fold{i}_model",
    save_last=True,
    save_top_k=0,
    save_on_train_epoch_end=True
    )
    
    # Detailed training and validation metrics are recorded automatically in {datasets}_fold{fold}.
    logger_name = f'{args.dataset}_fold{i}/'
    csv_logger = CSVLogger(logger_name, name=f"fold{i}")
    
    # Different output dimensions are setted according to gene numbers of HER2+ and cSCC.
    if args.dataset == "her2st":
        train_data = ViT_HER2ST(train=True, flatten=False,ori=True, adj=False, fold=i)
        val_data = ViT_HER2ST(train=False, flatten=False,ori=True, adj=False, fold=i)
    else:
        args.dim_out = 171
        train_data = ViT_SKIN(train=True, flatten=False, ori=True, adj=False, fold=i)
        val_data = ViT_SKIN(train=False, flatten=False, ori=True, adj=False, fold=i)
    
    train_loader = DataLoader(train_data, batch_size=1, shuffle=False, pin_memory=True)
    val_loader = DataLoader(val_data, batch_size=1, shuffle=False, pin_memory=True)

    model = Reg2ST(args=args)
    trainer = pl.Trainer(logger=csv_logger, precision=32, max_epochs=args.epochs, 
                         accelerator='gpu', devices=[args.device_id], 
                         callbacks=[val_checkpoint_callback], 
                         log_every_n_steps=5)
    trainer.fit(model, train_loader, val_loader)
    
train(args)

### Step Four: Test

In [None]:
def predict(args):
    torch.set_float32_matmul_precision('high')
    if args.dataset == "her2st":
        val_data = ViT_HER2ST(train=False, flatten=False,ori=True, adj=False, fold=args.fold)
        args.dimout = 785
    else:
        val_data = ViT_SKIN(train=False, flatten=False, ori=True, adj=False, fold=args.fold)
        args.dim_out = 171
    val_loader = DataLoader(val_data, batch_size=1, shuffle=False, pin_memory=True)
    
    # Download checkpoints and put them into current directory.
    ckpt_dir = f"{args.dataset}_model/fold{args.fold}_model.ckpt"
    print(f"Loading checkpoints from {ckpt_dir}")
    model = Reg2ST(args)
    trainer = pl.Trainer(precision=32, max_epochs=args.epochs, 
                         accelerator='gpu', devices=[args.device_id])
    trainer.test(model, val_loader, ckpt_path=ckpt_dir)
    
predict()