# Tutorial 9: Train NicheTrans on embryonic mouse brain data

In [None]:
import os, time, datetime, warnings

import torch
import torch.nn as nn
from torch.optim import lr_scheduler

from model.nicheTrans_hd import *
from datasets.data_manager_MISAR_seq import ATAC_RNA_Seq

from utils.utils import *
from utils.utils_training_embryonic_mouse_brain import *
from utils.utils_dataloader import *

warnings.filterwarnings("ignore")

### Initialize the args and fix seeds

In [None]:
%run ./args/args_MISAR_seq.py
args = args

set_seed(args.seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices

print("==========\nArgs:{}\n==========".format(args))

### Initialize dataloaders and NicheTrans

In [None]:
# create the dataloaders
dataset = ATAC_RNA_Seq(peak_threshold=args.peak_threshold, hvg_gene=args.hvg_gene, adata_path=args.adata_path, RNA2ATAC=True, knn_smoothing=args.knn_smooth)
trainloader, testloader = embryonic_mouse_brain(args, dataset)

# create the model
source_dimension, target_dimension = len(dataset.source_panel), len(dataset.target_panel)
model = NicheTrans(source_length=source_dimension, target_length=target_dimension, noise_rate=args.noise_rate, dropout_rate=args.dropout_rate)
model = nn.DataParallel(model).cuda()

### Initialize loss function (criterion) and optimizer

In [None]:
criterion = nn.BCELoss()

if args.optimizer == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == 'SGD':
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
else:
    print('unexpected optimizer')

if args.stepsize > 0:
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)

### Model training and testing

In [None]:
start_time = time.time()

for epoch in range(args.max_epoch):
    last_epoch = epoch + 1 == args.max_epoch

    print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
    
    ################
    train_binary(model, criterion, optimizer, trainloader)
    if args.stepsize > 0: scheduler.step()
    ################

# test_binary(args, model, testloader)
test_regression(model, testloader, if_sigmoid=True)
torch.save(model.state_dict(), 'NicheTrans_embryonic_mouse_brain_rna2atac.pth')

elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))