In [1]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt

import config as CFG
from models import *
from dataset import *
import scanpy as sc
from torch.utils.data import DataLoader

import os
import numpy as np
import pandas as pd

import scanpy as sc
from itertools import chain

In [2]:
#print the current scanpy version
print(sc.__version__)

1.9.5


In [3]:
fold=5
data='her2st' #### Change here to test different dataset 'her2st' 'cscc'

prune='Grid' if data=='her2st' else 'NA'
genes=171 if data=='cscc' else 785

def pk_load(fold,mode='test',flatten=False,dataset='her2st',r=4,ori=True,adj=True,prune='Grid',neighs=8): #r=4 Hist2ST
    assert dataset in ['her2st','cscc']
    if dataset=='her2st':
        dataset = CLIP_HER2ST(
            train=(mode=='train'),fold=fold,flatten=flatten,
            ori=ori,neighs=neighs,adj=adj,prune=prune,r=r
        )
    elif dataset=='cscc':
        dataset = CLIP_SKIN(
            train=(mode=='train'),fold=fold,flatten=flatten,
            ori=ori,neighs=neighs,adj=adj,prune=prune,r=r
        )
    return dataset

def build_loaders_inference():
    print("Building loaders")
    trainset = pk_load(fold,'train',dataset=data,flatten=False,adj=True,ori=True,prune=prune)
    train_loader = DataLoader(trainset, batch_size=1, num_workers=0, shuffle=True)
    testset = pk_load(fold,'test',dataset=data,flatten=False,adj=True,ori=True,prune=prune)
    test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False)
    print("Finished building loaders")
    return trainset, testset, train_loader, test_loader

#2265x256, 2277x256
def find_matches(spot_embeddings, query_embeddings, top_k=1):
    #find the closest matches 
    spot_embeddings = torch.tensor(spot_embeddings)
    query_embeddings = torch.tensor(query_embeddings)
    query_embeddings = F.normalize(query_embeddings, p=2, dim=-1)
    spot_embeddings = F.normalize(spot_embeddings, p=2, dim=-1)
    dot_similarity = query_embeddings @ spot_embeddings.T   #2277x2265
    print("dot_similarity.shape = spots * reference_spots = ",dot_similarity.shape)
    _, indices = torch.topk(dot_similarity.squeeze(0), k=top_k)
    
    return indices.cpu().numpy()

In [4]:
### Loading data

trainset, testset, train_loader, test_loader = build_loaders_inference()
train_loader = chain(train_loader, test_loader)

print("Finished loading data")

Building loaders
Test set names: ['A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1']
Train set names: ['F2', 'C6', 'D2', 'D6', 'G1', 'A2', 'B4', 'B6', 'C5', 'E2', 'B2', 'B5', 'D4', 'A3', 'C2', 'H3', 'D3', 'G3', 'C4', 'D5', 'H2', 'A5', 'C3', 'A4', 'A6', 'E3', 'F3', 'B3']
Loading imgs...
Loading metadata...
Test set names: ['A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1']
Train set names: ['F2', 'C6', 'D2', 'D6', 'G1', 'A2', 'B4', 'B6', 'C5', 'E2', 'B2', 'B5', 'D4', 'A3', 'C2', 'H3', 'D3', 'G3', 'C4', 'D5', 'H2', 'A5', 'C3', 'A4', 'A6', 'E3', 'F3', 'B3']
Loading imgs...
Loading metadata...


  x = np.around(x).astype(int)
  y = np.around(y).astype(int)
  x = np.around(x).astype(int)
  y = np.around(y).astype(int)


Finished building loaders
Finished loading data


In [5]:

# model_path ="clip/best.pt"
model_path ="clip/ICL2ST_HER2.pt"
if data =='cscc':
    model_path ="clip/ICL2ST_cSCC.pt"
save_path = "clip/embeddings/"
model = myModel().cuda()

state_dict = torch.load(model_path)
new_state_dict = {}
for key in state_dict.keys():
    new_key = key.replace('module.', '')  # remove the prefix 'module.'
    new_key = new_key.replace('well', 'spot') # for compatibility with prior naming
    if "image_encoder.gnn" in new_key: # Special to GNN because GNN use torch_geometric.nn
        new_key = new_key.replace("module_1.","module_1.module.")  
    new_state_dict[new_key] = state_dict[key]

model.load_state_dict(new_state_dict)
model.eval()

print("Finished loading model")

Finished loading model


In [6]:
if not os.path.exists(save_path):
        os.makedirs(save_path)
        
adj_dict = {}
exp_dict = {}
center_dict = {}
with torch.no_grad():
    for batch in tqdm(train_loader):
        ID, patch, center, exp, adj, oris, sfs, centers = batch
        print("Processing image ", ID)
        B,N,C,H,W = patch.shape
        patch = patch.reshape(B*N,C,H,W)  # (N,3,112,112)
        if adj.dim() == 3:
            adj = adj.squeeze(0)
        if exp.dim() == 3:
            exp = exp.squeeze(0)
            centers = centers.squeeze().numpy()
        adj_dict[ID] = adj 
        exp_dict[ID] = exp 
        center_dict[ID] = centers
        
        image_features = model.image_encoder(patch.cuda())
        spot_features = model.spot_encoder(exp.cuda(), adj.cuda())
        
        image_embeddings = model.image_projection(image_features).cpu().numpy()
        spot_embeddings = (model.spot_projection(spot_features.cuda()))
        
        spot_encoding = model.spot_autoencoder.encode(spot_embeddings, adj.cuda())
        spot_reconstruction, extra = model.spot_autoencoder.decode(spot_encoding.cuda())
        
        spot_embeddings = spot_embeddings.cpu().numpy()
        spot_encoding = spot_encoding.cpu().numpy()
        spot_reconstruction = spot_reconstruction.cpu().numpy()
        
        # print(image_embeddings.shape)
        # print(spot_embeddings.shape)
        np.save(save_path + "img_embeddings_" + str(ID[0]) + ".npy", image_embeddings.T)
        np.save(save_path + "spot_embeddings_" + str(ID[0]) + ".npy", spot_embeddings.T)

0it [00:00, ?it/s]

Processing image  ('G1',)


2it [00:03,  1.34s/it]

Processing image  ('D2',)


3it [00:03,  1.17it/s]

Processing image  ('C3',)


4it [00:04,  1.20it/s]

Processing image  ('F3',)


5it [00:04,  1.32it/s]

Processing image  ('B5',)


6it [00:05,  1.43it/s]

Processing image  ('C4',)


7it [00:05,  1.50it/s]

Processing image  ('B4',)


8it [00:06,  1.63it/s]

Processing image  ('C2',)


9it [00:07,  1.19it/s]

Processing image  ('H3',)


10it [00:08,  1.32it/s]

Processing image  ('D5',)


11it [00:08,  1.63it/s]

Processing image  ('C6',)
Processing image  ('C5',)


13it [00:09,  2.10it/s]

Processing image  ('A3',)


14it [00:10,  1.43it/s]

Processing image  ('E2',)


15it [00:11,  1.52it/s]

Processing image  ('A4',)


16it [00:11,  1.41it/s]

Processing image  ('E3',)


17it [00:12,  1.32it/s]

Processing image  ('A2',)


18it [00:13,  1.26it/s]

Processing image  ('D6',)


19it [00:14,  1.14it/s]

Processing image  ('H2',)


20it [00:15,  1.18it/s]

Processing image  ('G3',)


21it [00:16,  1.31it/s]

Processing image  ('D3',)


22it [00:16,  1.30it/s]

Processing image  ('B3',)


23it [00:17,  1.33it/s]

Processing image  ('B6',)


24it [00:17,  1.55it/s]

Processing image  ('B2',)


25it [00:18,  1.56it/s]

Processing image  ('A6',)


26it [00:19,  1.30it/s]

Processing image  ('F2',)


27it [00:20,  1.33it/s]

Processing image  ('D4',)


28it [00:21,  1.36it/s]

Processing image  ('A5',)


29it [00:21,  1.45it/s]

Processing image  ('A1',)


30it [00:22,  1.63it/s]

Processing image  ('B1',)


31it [00:22,  1.76it/s]

Processing image  ('C1',)


32it [00:23,  1.56it/s]

Processing image  ('D1',)
Processing image  ('E1',)


34it [00:25,  1.05it/s]

Processing image  ('F1',)


35it [00:26,  1.12it/s]

Processing image  ('G2',)


36it [00:27,  1.30it/s]

Processing image  ('H1',)





In [7]:
save_path = "clip/embeddings/"
all_files = os.listdir(save_path)

# exp_dict = {}
# for batch in tqdm(train_loader):
#     ID, patch, center, exp, adj, oris, sfs, *_ = batch
#     print(ID)
#     print(exp.shape)
#     exp_dict[ID] = exp  # Assuming ID and exp are tensors, we fetch their first elements

image_embeddings_dict = {}
spot_embeddings_dict = {}
ID_list = []

for file in all_files:
    if file.endswith(".npy"):
        # Extract the ID from the filename (e.g., A2, C3, etc.)
        if data=='her2st': 
            if 'rep' not in file:
                ID = file.split("_")[2].split(".")[0]
        elif data=='cscc':
            if 'rep' in file:
                ID = "_".join(file.split("_")[2:-1]) + "_" + file.split("_")[-1].split(".")[0]
        
        if (ID,) in adj_dict:
            adj_dict[ID] = adj_dict.pop((ID,))
        if (ID,) in exp_dict:
            exp_dict[ID] = exp_dict.pop((ID,))
        if (ID,) in center_dict:
            center_dict[ID] = center_dict.pop((ID,))
            
        # Determine the type of file based on its prefix and load the data
        if "img_embeddings" in file:
            image_embeddings_dict[ID] = np.load(os.path.join(save_path, file))
            ID_list.append(ID)
        elif "spot_embeddings" in file:
            spot_embeddings_dict[ID] = np.load(os.path.join(save_path, file))

# Now, image_embeddings_dict and spot_embeddings_dict contain the required data
print(image_embeddings_dict.keys())  # Should list all the image embedding IDs
print(spot_embeddings_dict.keys())  # Should list all the spot embedding IDs
print(exp_dict.keys())  # Should list all the spot embedding IDs
print(ID_list)

if data=='her2st':   
    fold=[0,6,12,18,24,27,31,33]
    test_ID = ['A1','B1','C1','D1','E1','F1','G2','H1']
elif data=='cscc':
    fold=[0,3,6,9]
    test_ID = ['P2_ST_rep1', 'P5_ST_rep1', 'P9_ST_rep1', 'P10_ST_rep1']
    
# test_ID = [ID_list[i] for i in fold]
print("Test set names:", test_ID)
train_ID = list(set(ID_list)-set(test_ID))
print("Train set names:",train_ID)



dict_keys(['D2', 'C5', 'A2', 'C6', 'A6', 'D6', 'A4', 'B4', 'C4', 'C2', 'A3', 'B6', 'D3', 'B5', 'E2', 'F3', 'B2', 'H2', 'G1', 'A5', 'F2', 'E3', 'C3', 'D5', 'G3', 'H3', 'D4', 'B3', 'A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1'])
dict_keys(['D2', 'C5', 'A2', 'C6', 'A6', 'D6', 'A4', 'B4', 'C4', 'C2', 'A3', 'B6', 'D3', 'B5', 'E2', 'F3', 'B2', 'H2', 'G1', 'A5', 'F2', 'E3', 'C3', 'D5', 'G3', 'H3', 'D4', 'B3', 'A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1'])
dict_keys(['D2', 'C5', 'A2', 'C6', 'A6', 'D6', 'A4', 'B4', 'C4', 'C2', 'A3', 'B6', 'D3', 'B5', 'E2', 'F3', 'B2', 'H2', 'G1', 'A5', 'F2', 'E3', 'C3', 'D5', 'G3', 'H3', 'D4', 'B3', 'A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1'])
['D2', 'C5', 'A2', 'C6', 'A6', 'D6', 'A4', 'B4', 'C4', 'C2', 'A3', 'B6', 'D3', 'B5', 'E2', 'F3', 'B2', 'H2', 'G1', 'A5', 'F2', 'E3', 'C3', 'D5', 'G3', 'H3', 'D4', 'B3', 'A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1']
Test set names: ['A1', 'B1'

In [8]:
#query
# test_ID.remove('A1')
print(test_ID)
# image_query = [spot_embeddings_dict[ID] for ID in test_ID]
# expression_gt = [exp_dict[ID].numpy().T for ID in test_ID]

# image_train_data = [image_embeddings_dict[ID] for ID in train_ID]
spot_train_data = [spot_embeddings_dict[ID] for ID in train_ID]
expression_train_data = [exp_dict[ID].numpy().T for ID in train_ID]

spot_key = np.concatenate(spot_train_data, axis=1)
expression_key = np.concatenate(expression_train_data, axis=1)

# print(image_query.shape)
# print(expression_gt.shape)
print(spot_key.shape)
print(expression_key.shape)

if spot_key.shape[1] != 256:
    spot_key = spot_key.T
    print("spot_key shape: ", spot_key.shape)
if expression_key.shape[0] != spot_key.shape[0]:
    expression_key = expression_key.T
    print("expression_key shape: ", expression_key.shape)

['A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1']
(256, 10139)
(785, 10139)
spot_key shape:  (10139, 256)
expression_key shape:  (10139, 785)


In [9]:
import torch
import numpy as np
import scanpy as sc
import anndata as ad
from tqdm import tqdm
from scipy.stats import pearsonr,spearmanr
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score

def test(model,test,device='cuda'):
    model=model.to(device)
    model.eval()
    preds=None
    ct=None
    gt=None
    loss=0
    with torch.no_grad():
        for patch, position, exp, adj, *_, center in tqdm(test):
            patch, position, adj = patch.to(device), position.to(device), adj.to(device).squeeze(0)
            pred = model(patch, position, adj)[0]
            preds = pred.squeeze().cpu().numpy()
            ct = center.squeeze().cpu().numpy()
            gt = exp.squeeze().cpu().numpy()
    adata = ad.AnnData(preds)
    adata.obsm['spatial'] = ct
    adata_gt = ad.AnnData(gt)
    adata_gt.obsm['spatial'] = ct
    return adata,adata_gt

def cluster(adata,label):
    idx = label != 'undetermined'
    tmp=adata[idx]
    l=label[idx]
    print("cluster number:",len(set(l)))
    sc.pp.pca(tmp)
    sc.tl.tsne(tmp)
    kmeans = KMeans(n_clusters=len(set(l)), init="k-means++", random_state=0).fit(tmp.obsm['X_pca'])
    p=kmeans.labels_.astype(str)
    lbl=np.full(len(adata),str(len(set(l))))
    lbl[idx]=p
    adata.obs['kmeans']=lbl
    return p,round(ari_score(p,l),3)

def get_R(data1,data2,dim=1,func=pearsonr):
    adata1=data1.X
    adata2=data2.X
    r1,p1=[],[]
    for g in range(data1.shape[dim]):
        if dim==1:
            r,pv=func(adata1[:,g],adata2[:,g])
        elif dim==0:
            r,pv=func(adata1[g,:],adata2[g,:])
        r1.append(r)
        p1.append(pv)
    r1=np.array(r1)
    p1=np.array(p1)
    return r1,p1

def get_top_values(arr, num_top_values=10, lowest=False):
    return sorted([(i, arr[i]) for i in range(len(arr))], key=lambda x: x[1], reverse=not lowest)[:num_top_values]
top_k = 50
results = {}
top_results = {}
selected_folds = [5]



In [10]:
import warnings
warnings.filterwarnings('ignore')

from sklearn.metrics import mean_squared_error
from math import sqrt

for ID in test_ID:
    print("Begin Processing Image", ID)
    image_query = spot_embeddings_dict[ID]
    expression_gt = exp_dict[ID].numpy().T

    method = "weighted_average" # "average" "weighted_average"
    save_path = ""
    if image_query.shape[1] != 256:
        image_query = image_query.T
        print("image query shape: ", image_query.shape)
    if expression_gt.shape[0] != image_query.shape[0]:
        expression_gt = expression_gt.T
        print("expression_gt shape: ", expression_gt.shape)
    if spot_key.shape[1] != 256:
        spot_key = spot_key.T
        print("spot_key shape: ", spot_key.shape)
    if expression_key.shape[0] != spot_key.shape[0]:
        expression_key = expression_key.T
        print("expression_key shape: ", expression_key.shape)

    if method == "simple":
        indices = find_matches(spot_key, image_query, top_k=1)
        matched_spot_embeddings_pred = spot_key[indices[:,0],:]
        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        matched_spot_expression_pred = expression_key[indices[:,0],:]
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "average":
        print("finding matches, using average of top 50 expressions")
        indices = find_matches(spot_key, image_query, top_k=50)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0)
        
        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "weighted_average":
        print("finding matches, using weighted average of top 50 expressions")
        indices = find_matches(spot_key, image_query, top_k=100)
        # print("indices = ", indices)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            a = np.sum((spot_key[indices[i,0],:] - image_query[i,:])**2) #the smallest MSE
            weights = np.exp(-(np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1)-a+1))
            # weights = a/np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1)
            # a = np.sqrt(np.sum((spot_key[indices[i,0],:] - image_query[i,:])**2)) #the smallest RMSE
            # weights = np.exp(-(np.sqrt(np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1))-a+1))
            
            # sorted_indices = np.argsort(weights)[::-1]  # 
            # top_10_weights = weights[sorted_indices[:10]]
            # least_10_weights = weights[sorted_indices[-10:]]
            # print("Top 10 weights: ", top_10_weights)
            # print("least 10 weights: ", least_10_weights)
            
            # if i == 0:
            #     print("weights: ", weights)
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0, weights=weights)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0, weights=weights)
        
        # print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        # print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    true = expression_gt
    pred = matched_spot_expression_pred
    adj = adj_dict[ID]
    
    model.eval()
    
    # Create the directory if it doesn't exist
    output_dir = './figures/show'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    with torch.no_grad():
        pred_features = model.spot_encoder(torch.tensor(pred, dtype=torch.float32).cuda(), adj.cuda())
        pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
        pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
        pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())

        pred_features = pred_features.cpu().numpy()
        pred_embeddings = pred_embeddings.cpu().numpy()
        pred_encoding = pred_encoding.cpu().numpy()
        pred_reconstruction = pred_reconstruction.cpu().numpy()

    print("pred.shape",pred.shape)
    print("true.shape",true.shape)
    print("np.max(pred)",np.max(pred))
    print("np.max(true)",np.max(true))
    print("np.min(pred)",np.min(pred))
    print("np.min(true)",np.min(true))
    
    ####### Prediction PCC performance
    mix = (pred + pred_reconstruction)/2
    
    def evaluate_gene_expression(pred, true, ID, top_k, fold, top_results, testset):
        # Genewise correlation across cells
        corr_cells = np.zeros(pred.shape[0])
        for i in range(pred.shape[0]):
            corr_cells[i] = np.corrcoef(pred[i, :], true[i, :])[0, 1]
        # Remove NaN
        corr_cells = corr_cells[~np.isnan(corr_cells)]
        print("Cell Mean R: ", np.mean(corr_cells))
        
        # Calculate RMSE across cells
        mse_cells = mean_squared_error(pred, true)
        rmse_cells = sqrt(mse_cells)
        print("MSE across cells: ", mse_cells)
        print("RMSE across cells: ", rmse_cells)
        
        # Genewise correlation across genes
        corr_genes = np.zeros(pred.shape[1])
        p_values = np.zeros(pred.shape[1])
        for i in range(pred.shape[1]):
            # corr_genes[i] = np.corrcoef(pred[:, i], true[:, i])[0, 1]
            corr_genes[i], p_values[i] = pearsonr(pred[:, i], true[:, i])
        # Remove NaN
        valid_indices = ~np.isnan(corr_genes)
        corr_genes = corr_genes[valid_indices]
        p_values = p_values[valid_indices]
        
        if corr_genes.size == 0:
            print("corr_genes is an empty array")
        elif np.isnan(corr_genes).all():
            print("corr_genes is an array of NaNs")
        else:
            print("Max correlation across genes:", np.nanmax(corr_genes))
        
        print("Genes mean R: ", np.mean(corr_genes))
        print("Gene median R: ", np.median(corr_genes))
        print("number of genes with correlation > 0.3: ", np.sum(corr_genes > 0.3))
        
        mlog_p_values = -np.log10(p_values)
        # Top-k genes
        # top_k_indices = np.argsort(corr_genes)[-top_k:] ## highest R
        top_k_indices = np.argsort(mlog_p_values)[-top_k:] ## highest -log10 p-values
        top_R_values = corr_genes[top_k_indices]
        top_pred_values = pred[:, top_k_indices]
        top_results[ID] = (top_R_values, top_pred_values)
        print(f'Top {top_k} Genes Mean Pearson Correlation:', np.nanmean(top_R_values))
        print(f'Top {top_k} Genes Mean Pearson Correlation:', np.nanmedian(top_R_values))
        
        # Get top gene correlations
        top_R_values = get_top_values(corr_genes)
        print('Fold', ID, "Top 10 genes with highest -log10 p-values:")
        for gene_id, r_value in top_R_values:
            gene_name = testset.gene_set[gene_id]
            print(f"Gene ID: {gene_id}, Gene Name: {gene_name}, R: {r_value}, p_values: {p_values[gene_id]}")

    # Example usage:
    print(f"The Prediction: prediction")
    evaluate_gene_expression(pred, true, ID, top_k, fold, top_results, testset)
    # print(f"\n The Prediction Matrix: pred_reconstruction")
    # evaluate_gene_expression(pred_reconstruction, true, ID, top_k, fold, top_results, testset)
    # print(f"\n The Prediction Matrix: mix")
    # evaluate_gene_expression(mix, true, ID, top_k, fold, top_results, testset)
    
    ####### Clustering 
    ### Change the type of pred to AnnData for the next clustering task
    pred = sc.AnnData(pred)
    pred.obsm['spatial'] = center_dict[ID]
    true = ad.AnnData(true)
    true.obsm['spatial'] = center_dict[ID]
    pred_features = sc.AnnData(pred_features)
    pred_features.obsm['spatial'] = center_dict[ID]
    pred_embeddings = sc.AnnData(pred_embeddings)
    pred_embeddings.obsm['spatial'] = center_dict[ID]
    # pred_encoding = sc.AnnData(pred_encoding)
    # pred_encoding.obsm['spatial'] = center_dict[ID]
    pred_reconstruction = sc.AnnData(pred_reconstruction)
    pred_reconstruction.obsm['spatial'] = center_dict[ID]
    mix = sc.AnnData(mix)
    mix.obsm['spatial'] = center_dict[ID]

    if data=='her2st':
        ####### Generate cluster figure
        label = testset.label[ID]
        # print("label = ",label)
        # clus, ARI = cluster(pred, label)
        # print('Fold:', fold, 'ARI:', ARI)
        # title = f"{ID} ARI = {ARI:.3f}"  # Format title with ARI value   
        # sc.pl.spatial(pred, img=testset.get_img(ID), color='kmeans', spot_size=112, title=title, save=f"/mymodel_Her2_{ID}.pdf")    
        
        
        # clus, Top_ARI = cluster(top_pred_values, label)
        # print('Fold:', fold, 'Top 100 ARI:', Top_ARI)
        # clus, feature_ARI = cluster(pred_features, label)
        # print('Fold:', fold, 'Expression features ARI:', feature_ARI)
        # title = f"{ID} ARI = {ARI:.3f}"  # Format title with ARI value   
        # sc.pl.spatial(pred_features, img=testset.get_img(ID), color='kmeans', spot_size=112, title=title, save=f"/mymodel_Her2_{ID}_features.pdf")   
        
        # clus, Emb_ARI = cluster(pred_embeddings, label)
        # print('Fold:', fold, 'Expression Embeddings ARI:', Emb_ARI)
        # title = f"{ID} ARI = {Emb_ARI:.3f}"  # Format title with ARI value   
        # sc.pl.spatial(pred_embeddings, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/mymodel_Her2_{ID}_Emb.pdf")

        
        # clus, Re_ARI = cluster(pred_reconstruction, label)
        # print('Fold:', fold, 'Reconstruction ARI:', Re_ARI)
        # title = f"{ID} ARI = {Re_ARI:.3f}"  # Format title with ARI value   
        # sc.pl.spatial(pred_reconstruction, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/mymodel_Her2_{ID}_Reconstruction.pdf")    
        
        # clus, mix_ARI = cluster(mix, label)
        # print('Fold:', fold, 'Mixed Reconstruction ARI:', mix_ARI)
        # title = f"{ID} ARI = {mix_ARI:.3f}"  # Format title with ARI value   
        # sc.pl.spatial(mix, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/mymodel_Her2_{ID}_mix.pdf") 

    print("Result of ", ID, " ended! ")
    print("\n\n")
    
    
    # if save_path != "":
    #     np.save(save_path + "matched_spot_embeddings_pred.npy", matched_spot_embeddings_pred.T)
    #     np.save(save_path + "matched_spot_expression_pred.npy", matched_spot_expression_pred.T)



Begin Processing Image A1
image query shape:  (346, 256)
expression_gt shape:  (346, 785)
finding matches, using weighted average of top 50 expressions
dot_similarity.shape = spots * reference_spots =  torch.Size([346, 10139])
pred.shape (346, 785)
true.shape (346, 785)
np.max(pred) 3.063084840774536
np.max(true) 3.5403168
np.min(pred) 0.0
np.min(true) 0.0
The Prediction: prediction
Cell Mean R:  0.5604347543336636
MSE across cells:  0.3341227162293125
RMSE across cells:  0.5780334905775897
Max correlation across genes: 0.6865587995609885
Genes mean R:  0.21658230593425473
Gene median R:  0.20637844165557437
number of genes with correlation > 0.3:  180
Top 50 Genes Mean Pearson Correlation: 0.46693745814690374
Top 50 Genes Mean Pearson Correlation: 0.43699430610935225
Fold A1 Top 10 genes with highest -log10 p-values:
Gene ID: 154, Gene Name: MUCL1, R: 0.6865587995609885, p_values: 1.5079007983123628e-49
Gene ID: 227, Gene Name: SCD, R: 0.6333224276797179, p_values: 3.428820600102361e-

/opt/conda/conda-bld/pytorch_1682343967769/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [9,0,0], thread: [54,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1682343967769/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [9,0,0], thread: [62,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1682343967769/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [8,0,0], thread: [94,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1682343967769/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [8,0,0], thread: [118,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1682343967769/work/aten/src/ATen/native/cuda/ScatterGa

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
print(pred.shape)
print(pred_features.shape)
print(pred_embeddings.shape)
print(pred_reconstruction.shape)

(613, 785)
(613, 785)
(613, 256)
(613, 785)
