In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import argparse

from palettable.cartocolors.diverging import *
from palettable.scientific.diverging import *

In [None]:
parser = argparse.ArgumentParser(description='oh visualization')
parser.add_argument('--img_size', default=256, type=int)
parser.add_argument('--workers', default=4, type=int, help="number of data loading workers (default: 4)")
parser.add_argument('--train_batch', default=32, type=int)
parser.add_argument('--test_batch', default=32, type=int)

args, _ = parser.parse_known_args()

# Load dataset

In [None]:
n_top_genes = 3000

rna_path = '/home/wzk/ST_data/2024_nmethods_SpatialGlue_Human_lymph_node_3slides/slice2/s2_adata_rna.h5ad'
protein_path = '/home/wzk/ST_data/2024_nmethods_SpatialGlue_Human_lymph_node_3slides/slice2/s2_adata_adt.h5ad'

##
adata_rna_testing = sc.read_h5ad(rna_path)
adata_rna_testing.obs['array_row'] = -adata_rna_testing.obsm['spatial'][:, 0]
adata_rna_testing.obs['array_col'] = -adata_rna_testing.obsm['spatial'][:, 1]

##
adata_msi_testing = sc.read_h5ad(protein_path)

adata_msi_testing.obs['array_row'] = adata_msi_testing.obsm['spatial'][:, 0]
adata_msi_testing.obs['array_col'] = adata_msi_testing.obsm['spatial'][:, 1]

In [None]:
from datasets.human_lymph_node_data_manager import *
from utils.utils_dataloader import *

dataset = Lymph_node()
_, testloader = human_node_dataloader(args, dataset)

# Load model

In [None]:
from model.nicheTrans import *
source_dimension, target_dimension = dataset.rna_length, dataset.msi_length

model = NicheTrans(source_length=source_dimension, target_length=target_dimension, noise_rate=0.2, dropout_rate=0.1)
model = nn.DataParallel(model).cuda()

model.load_state_dict(torch.load('./last.pth'))
model.eval() 

# Inference 

In [None]:
pd_dictionary, gt_dictionary = defaultdict(), defaultdict()
pd_value, gt_value = [], [] 

with torch.no_grad():
    for _, (rna, protein, rna_neighbors, _) in enumerate(testloader):

        rna, protein, rna_neighbors = rna.cuda(), protein.cuda(), rna_neighbors.cuda()
        source, target, source_neightbors = rna, protein, rna_neighbors

        outputs = model(source, source_neightbors)

        pd_value.append(outputs)
        gt_value.append(target)

# Model evaluation

In [None]:
pd_value = torch.cat(pd_value, dim=0).cpu().numpy()
gt_value = torch.cat(gt_value, dim=0).cpu().numpy()

In [None]:
from scipy.stats import spearmanr

pcc, spcc, rmse = [], [], []
for i in range(len(dataset.target_panel)):
    pcc.append( np.corrcoef(pd_value[:, i], gt_value[:, i])[0, 1] )
    spcc.append( spearmanr(pd_value[:, i], gt_value[:, i])[0] )
    rmse.append( np.sqrt(np.mean((pd_value[:, i] - gt_value[:, i]) ** 2)) )

dict = {
    "pearson": pcc,
    "spearman": spcc,
    "rmse": rmse
}

df = pd.DataFrame(dict)
df.index = dataset.target_panel

df.to_csv('pcc_spcc_rmse.csv')

In [None]:
pd_adata = adata_msi_testing.copy()
pd_adata.X = pd_value

pd_adata.write('./results/pd_msi.h5ad')
adata_msi_testing.write('./results/gt_msi.h5ad')

In [None]:
pd_value = np.exp((pd_value * dataset.std) + dataset.mean)
gt_value = np.exp((gt_value * dataset.std) + dataset.mean)

In [None]:
proteins = adata_msi_testing.var['gene_ids'].values

for index, protein in enumerate(proteins):
    adata_msi_testing.obs['pd_' + protein ] = pd_value[:, index]
    adata_msi_testing.obs['gt_' + protein ] = gt_value[:, index]
    

In [None]:
# protein = 'PAX5'
protein = 'HLA-DRA'
# protein = 'VIM'

fig, ax = plt.subplots(1, figsize=(4, 3), dpi=100)
sc.pl.embedding(adata_msi_testing, basis='spatial', color='pd_' + protein, title=f'prediction {protein}', ax=ax, show=False, cmap=Tropic_7.mpl_colormap, size=30) 

fig.savefig('./results/pd_{}.eps'.format(protein), format='eps', dpi=300, bbox_inches='tight')  
fig.savefig('./results/pd_{}.png'.format(protein), format='png', dpi=300, bbox_inches='tight')

fig, ax = plt.subplots(1, figsize=(4, 3), dpi=100)
sc.pl.embedding(adata_msi_testing, basis='spatial', color='gt_' + protein, title=f'Ground Truth {protein}', ax=ax, show=False, cmap=Tropic_7.mpl_colormap, size=30, vmax=90000) 

fig.savefig('./results/gt_{}.eps'.format(protein), format='eps', dpi=300, bbox_inches='tight')  
fig.savefig('./results/gt_{}.png'.format(protein), format='png', dpi=300, bbox_inches='tight')

In [None]:
protein = 'CD3E'

fig, ax = plt.subplots(1, figsize=(4, 3), dpi=100)
sc.pl.embedding(adata_msi_testing, basis='spatial', color='pd_' + protein, title=f'prediction {protein}', ax=ax, show=False, cmap=Tropic_7.mpl_colormap, size=30) 

fig.savefig('./results/pd_{}.eps'.format(protein), format='eps', dpi=300, bbox_inches='tight')  
fig.savefig('./results/pd_{}.png'.format(protein), format='png', dpi=300, bbox_inches='tight')

fig, ax = plt.subplots(1, figsize=(4, 3), dpi=100)
sc.pl.embedding(adata_msi_testing, basis='spatial', color='gt_' + protein, title=f'Ground Truth {protein}', ax=ax, show=False, cmap=Tropic_7.mpl_colormap, size=30) 

fig.savefig('./results/gt_{}.eps'.format(protein), format='eps', dpi=300, bbox_inches='tight')  
fig.savefig('./results/gt_{}.png'.format(protein), format='png', dpi=300, bbox_inches='tight')