In [1]:
import sys
import os
import gc
import cv2
import tqdm
import torch
import random
import pytorch_lightning as pl
import pandas as pd
import numpy as np
import anndata as ad
import argparse

# Add relative path
sys.path.append("../../")

from torch.utils.data import DataLoader
from pytorch_lightning import seed_everything
from Models.DeepPT.DeepPT_GNN import *
# from Dataloader.Dataset import *
from Dataloader.Dataset_wiener import *


# parser = argparse.ArgumentParser()
# parser.add_argument('--fold', type=int, default=5, help='dataset fold.')
# parser.add_argument('--seed', type=int, default=42, help='random seed.')
# parser.add_argument('--colornorm', type=str, default="reinhard", help='Color normalization methods.')
# parser.add_argument('--dataset_name', type=str, default="SCC_Chenhao", help='Dataset choice.')
# parser.add_argument('--model_name', type=str, default="DeepPT_GNN", help='Model choice.')
# parser.add_argument('--gene_list', type=str, default="func", help='Gene list choice.')
# parser.add_argument('--hpc', type=str, default="wiener", help='Clusters choice')
# args = parser.parse_args()

# fold = args.fold
# seed = args.seed
# colornorm = args.colornorm
# dataset_name = args.dataset_name
# model_name = args.model_name
# gene_list = args.gene_list
# hpc = args.hpc

"""
Hyperparameters settings
"""
hpc = "wiener"
fold = 1
seed = 42
dataset_name = "BC_visium"
colornorm = "raw"    # "reinhard", "raw"
model_name = "DeepPT_GNN"  
gene_list = "func"
exp_norm = "lognorm"
PAG = True
SLG = False
HSG = False

print("Start training!")
print("Hyperparameters are as follows:")
print("Fold:", fold)
print("Color normalization method:", colornorm)
# print("Dataset_name:", dataset_name)
print("Model_name:", model_name)
print("gene_list:", gene_list)
print("exp_norm:", exp_norm)
print("cluster:", hpc)

if hpc == "wiener":
    abs_path = "/afm03/Q2/Q2051/DeepHis2Exp/Models/Benchmarking_main"
    model_weight_path = "/afm03/Q2/Q2051/DeepHis2Exp/Model_Weights"
   #  model_weight_path = "/scratch/imb/uqyjia11/Yuanhao/DeepHis2Exp/Model_Weights"
    res_path = "/afm03/Q2/Q2051/DeepHis2Exp/Results"
    data_path = "/afm03/Q2/Q2051/DeepHis2Exp/Dataset"
elif hpc == "vmgpu":
    abs_path = "/afm01/UQ/Q2051/DeepHis2Exp/Implementation"
    model_weight_path = "/afm01/UQ/Q2051/DeepHis2Exp/Model_Weights"
    res_path = "/afm01/UQ/Q2051/DeepHis2Exp/Results"
    data_path = "/afm01/UQ/Q2051/DeepHis2Exp/Dataset"
elif hpc == "bunya":
    abs_path = "/QRISdata/Q2051/DeepHis2Exp/Implementation"
    model_weight_path = "/QRISdata/Q2051/DeepHis2Exp/Model_Weights"
    res_path = "/QRISdata/Q2051/DeepHis2Exp/Results"
    data_path = "/QRISdata/Q2051/DeepHis2Exp/Dataset"

# For reproducing the results
seed_everything(seed)

# Load train and test dataset and wrap dataloader

# Functional genes for visium dataset
target_gene_list = list(np.load(f'{data_path}/Gene_list/Gene_list_{gene_list}_{dataset_name}.npy', allow_pickle=True))

# Load sample names
full_train_dataset = WeightedGraph_Anndata(fold=fold, gene_list=target_gene_list, num_subsets=50,
                    train=True, r=112, exp_norm='lognorm', SLG=SLG, HSG=HSG, PAG=PAG,
                    neighs=8, color_norm=colornorm, target=target, distance_mode="distance",)
tr_loader = DataLoader(full_train_dataset, batch_size=1, shuffle=True)
gc.collect()
test_dataset = WeightedGraph_Anndata(fold=fold, gene_list=target_gene_list, num_subsets=50,
                    train=False, r=112, exp_norm='lognorm', SLG=SLG, HSG=HSG, PAG=PAG,
                    neighs=8, color_norm=colornorm, target=target, distance_mode="distance",)
te_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
gc.collect()


[rank: 0] Global seed set to 42


Dataset path: /afm03/Q2/Q2051
Start training!
Hyperparameters are as follows:
Fold: 1
Color normalization method: raw
Model_name: DeepPT_GNN
gene_list: func
exp_norm: lognorm
cluster: wiener
Datasize:6
Loading whole slide imgs...


100%|██████████| 5/5 [00:11<00:00,  2.24s/it]
100%|██████████| 5/5 [00:01<00:00,  2.86it/s]


subset_size: 81
subset_size: 48
subset_size: 24
subset_size: 23
subset_size: 20
Loading imgs...
Loading spatial coordinates...
Loading gene expression
Loading imgs...


100%|██████████| 250/250 [00:03<00:00, 72.36it/s] 


Calculating adjacency matrices or distance matrices...
Loading pathology annotations...
Loading pathology annotation graph...
Datasize:6
Loading whole slide imgs...


100%|██████████| 1/1 [00:08<00:00,  8.20s/it]
100%|██████████| 1/1 [00:00<00:00,  2.30it/s]


subset_size: 96
Loading imgs...
Loading spatial coordinates...
Loading gene expression
Loading imgs...


100%|██████████| 50/50 [00:02<00:00, 23.21it/s]


Calculating adjacency matrices or distance matrices...
Loading pathology annotations...
Loading pathology annotation graph...


125

In [2]:
# Define model and train
model = CNN_GNN_AE(n_genes=len(target_gene_list), hidden_dim=512, learning_rate=1e-4)
    
# Empty cache of GPU
torch.cuda.empty_cache()

# Create folder to save model weights
if not os.path.isdir(f"{model_weight_path}/{dataset_name}/"):
    os.mkdir(f"{model_weight_path}/{dataset_name}/")
early_stop = pl.callbacks.EarlyStopping(monitor='train_loss', mode='min', patience=10)
checkpoint_callback = pl.callbacks.ModelCheckpoint(save_top_k=1, dirpath=f"{model_weight_path}", 
                                                   filename=f"{model_name}_{dataset_name}_{colornorm}_{test_dataset.te_names}", 
                                                   monitor="train_loss", mode="min")

trainer = pl.Trainer(accelerator='auto', 
                    callbacks=[early_stop, checkpoint_callback], 
                    max_epochs=100, logger=False)

# Start training and save best model
trainer.fit(model, tr_loader)

# debug
# trainer.fit(model, te_loader)

print(checkpoint_callback.best_model_path)   # prints path to the best model's checkpoint
print(checkpoint_callback.best_model_score) # and prints it score
best_model = model.load_state_dict(torch.load(checkpoint_callback.best_model_path)["state_dict"])
torch.save(torch.load(checkpoint_callback.best_model_path)["state_dict"], f"{model_weight_path}/{dataset_name}/{model_name}_PAG{str(PAG)}_HSG{str(HSG)}_{test_dataset.te_names}_{gene_list}.ckpt")
os.remove(checkpoint_callback.best_model_path)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type             | Params
-------------------------------------------------------
0 | feature_extractor | FeatureExtractor | 23.5 M
1 | AE                | Autoencoder      | 1.1 M 
2 | GNN               | GNN              | 23.1 M
3 | pred_head         | Linear           | 836 K 
-------------------------------------------------------
48.5 M    Trainable params
0         Non-trainable params
48.5 M    Total params
193.930   Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 19:  68%|██████▊   | 171/250 [00:33<00:15,  5.10it/s, train_loss_step=0.167, mse_step=0.166, recon_loss_step=0.000106, train_loss_epoch=0.148, mse_epoch=0.147, recon_loss_epoch=5.32e-5]  

In [2]:
import torch

torch.zeros(9,9)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
# Inference

gc.collect()
target_gene_list = test_dataset.gene_set
out = trainer.predict(model, te_loader)
pred = ad.AnnData(np.concatenate([out[i][0] for i in range(len(out))]))
gt_exp = np.concatenate([out[i][1] for i in range(len(out))])
gt = sc.concat([test_dataset.meta_dict[sub_slide] for sub_slide in list(test_dataset.meta_dict.keys())])[:,target_gene_list]
gt.X = gt_exp
        
# Add the gene list to AnnData    
pred.var_names = target_gene_list
gt.var_names = target_gene_list

# Save AnnData to H5AD file
if not os.path.isdir(f"{res_path}/{dataset_name}/"):
    os.mkdir(f"{res_path}/{dataset_name}/")
pred.write(f"{res_path}/{dataset_name}/pred_{model_name}_{dataset_name}_{colornorm}_{test_dataset.te_names}_{gene_list}.h5ad")
gt.write(f"{res_path}/{dataset_name}/gt_{model_name}_{dataset_name}_{colornorm}_{test_dataset.te_names}_{gene_list}.h5ad")
gc.collect()

# Save spatial location to numpy array
spatial_loc = np.concatenate([test_dataset.meta_dict[key].obsm["spatial"] for key in list(test_dataset.meta_dict.keys())])
np.save(f'{res_path}/{dataset_name}/spatial_loc_{model_name}_{dataset_name}_{colornorm}_{test_dataset.te_names}_{gene_list}.npy', spatial_loc)
gc.collect()

print("Finish training!")

In [None]:
from scipy.stats import pearsonr

pcc = [pearsonr(gt.X.toarray()[:,g], pred.X[:,g])[0] for g in range(len(target_gene_list))]
pcc = np.array(pcc)
sns.boxplot(pcc)

In [None]:
h = torch.rand(81, 512)
edge_index = test_dataset[0][-2]
edge_weights = test_dataset[0][-1]


In [None]:
jk = [gcn(h, edge_index, edge_weights).unsqueeze(0) for gcn in model.GNN.conv]
x = torch.cat(jk,0).to(torch.float32)
x, _ = nn.LSTM(512, 512, 2)(x)
x = x.mean(0)
x.shape