# Tutorial 7: Train NicheTrans on 10x Xenium data

In [2]:
import os, time, datetime, warnings

import torch
import torch.nn as nn
from torch.optim import lr_scheduler

from model.nicheTrans_img import *
from datasets.data_manager_breast_cancer import Breast_cancer

from utils.utils import *
from utils.utils_training_breast_cancer import train, test
from utils.utils_dataloader import *

warnings.filterwarnings("ignore")

## Initialize the args and fix seeds

In [3]:
%run ./args/args_breast_cancer.py
args = args

set_seed(args.seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices

print("==========\nArgs:{}\n==========".format(args))

Args:Namespace(adata_path='/home/wzk/ST_data/2024_NicheTrans_upload/2023_nc_10x_breast_cancer/HBC_rep1_cell_nucleus_3channel_strength_mean.h5ad', coordinate_path='/home/wzk/ST_data/2023_nc_Xenium_breast/In_situ_sample_1_replicate_1/outs/cells.csv.gz', ct_path='/home/wzk/ST_data/2023_nc_Xenium_breast/Cell_Barcode_Type_Matrices.xlsx', dropout_rate=0.2, eval_step=1, gamma=0.1, gpu_devices='0', lr=0.0003, max_epoch=40, noise_rate=0.2, optimizer='adam', seed=1, stepsize=20, test_batch=32, train_batch=32, weight_decay=0.0005, workers=4)


## Initialize dataloaders and NicheTrans

In [4]:
# create the dataloaders
dataset = Breast_cancer(adata_path=args.adata_path, coordinate_path=args.coordinate_path, ct_path=args.ct_path)
trainloader, testloader = breast_cancer_dataloader(args, dataset)

# create the model
source_dimension, target_dimension = dataset.rna_length, dataset.protein_length
model = NicheTrans(source_length=source_dimension, target_length=target_dimension, noise_rate=args.noise_rate, dropout_rate=args.dropout_rate)
model = nn.DataParallel(model).cuda()

------Calculating spatial graph...
The graph contains 1185564 edges, 98797 cells.
12.0000 neighbors per cell on average.
------Calculating spatial graph...
The graph contains 827796 edges, 68983 cells.
12.0000 neighbors per cell on average.
=> AD Mouse loaded
Dataset statistics:
  ------------------------------
  subset   | # num | 
  ------------------------------
  train    |  98797 spots, 98659 positive CD20, 84043 positive HER2 
  test     |  68983 spots, 67600 positive CD20, 36904 positive HER2 
  ------------------------------


## Initialize loss function (criterion) and optimizer

In [5]:
criterion = nn.MSELoss()

if args.optimizer == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == 'SGD':
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
else:
    print('unexpected optimizer')

if args.stepsize > 0:
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)

## Model training and testing

In [None]:
start_time = time.time()

for epoch in range(args.max_epoch):
    last_epoch = epoch + 1 == args.max_epoch

    print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
    
    ################
    train(args, model, criterion, optimizer, trainloader, dataset.target_panel)
    if args.stepsize > 0: scheduler.step()
    
    if (epoch+1) % args.eval_step == 0:
        pearson = test(args, model, testloader, dataset.target_panel, last_epoch)

    if last_epoch==True:
        torch.save(model.state_dict(), 'NicheTrans_breast_cancer_last.pth')
    ################

elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))