In [1]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt

import config as CFG
from models import *
from dataset import *
import scanpy as sc
from torch.utils.data import DataLoader

import os
import numpy as np
import pandas as pd

import scanpy as sc
from itertools import chain

In [2]:
#print the current scanpy version
print(sc.__version__)

1.9.5


In [3]:
fold=5
data='cscc'
prune='Grid' if data=='her2st' else 'NA'
genes=171 if data=='cscc' else 785

def pk_load(fold,mode='test',flatten=False,dataset='her2st',r=4,ori=True,adj=True,prune='Grid',neighs=8): #r=4 Hist2ST
    assert dataset in ['her2st','cscc']
    if dataset=='her2st':
        dataset = CLIP_HER2ST(
            train=(mode=='train'),fold=fold,flatten=flatten,
            ori=ori,neighs=neighs,adj=adj,prune=prune,r=r
        )
    elif dataset=='cscc':
        dataset = CLIP_SKIN(
            train=(mode=='train'),fold=fold,flatten=flatten,
            ori=ori,neighs=neighs,adj=adj,prune=prune,r=r
        )
    return dataset

def build_loaders_inference():
    print("Building loaders")
    trainset = pk_load(fold,'train',dataset=data,flatten=False,adj=True,ori=True,prune=prune)
    train_loader = DataLoader(trainset, batch_size=1, num_workers=0, shuffle=True)
    testset = pk_load(fold,'test',dataset=data,flatten=False,adj=True,ori=True,prune=prune)
    test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False)
    print("Finished building loaders")
    return trainset, testset, train_loader, test_loader

#2265x256, 2277x256
def find_matches(spot_embeddings, query_embeddings, top_k=1):
    #find the closest matches 
    spot_embeddings = torch.tensor(spot_embeddings)
    query_embeddings = torch.tensor(query_embeddings)
    query_embeddings = F.normalize(query_embeddings, p=2, dim=-1)
    spot_embeddings = F.normalize(spot_embeddings, p=2, dim=-1)
    dot_similarity = query_embeddings @ spot_embeddings.T   #2277x2265
    print(dot_similarity.shape)
    _, indices = torch.topk(dot_similarity.squeeze(0), k=top_k)
    
    return indices.cpu().numpy()

In [4]:
### Loading data

trainset, testset, train_loader, test_loader = build_loaders_inference()
train_loader = chain(train_loader, test_loader)

print("Finished loading data")

Building loaders
All sample names: ['P2_ST_rep1', 'P2_ST_rep2', 'P2_ST_rep3', 'P5_ST_rep1', 'P5_ST_rep2', 'P5_ST_rep3', 'P9_ST_rep1', 'P9_ST_rep2', 'P9_ST_rep3', 'P10_ST_rep1', 'P10_ST_rep2', 'P10_ST_rep3']
Test set names: ['P2_ST_rep1', 'P5_ST_rep1', 'P9_ST_rep1', 'P10_ST_rep1']
Train set names: ['P2_ST_rep2', 'P9_ST_rep3', 'P9_ST_rep2', 'P2_ST_rep3', 'P10_ST_rep2', 'P5_ST_rep2', 'P5_ST_rep3', 'P10_ST_rep3']
['P2_ST_rep1', 'P5_ST_rep1', 'P9_ST_rep1', 'P10_ST_rep1']
Loading imgs...


Loading metadata...
All sample names: ['P2_ST_rep1', 'P2_ST_rep2', 'P2_ST_rep3', 'P5_ST_rep1', 'P5_ST_rep2', 'P5_ST_rep3', 'P9_ST_rep1', 'P9_ST_rep2', 'P9_ST_rep3', 'P10_ST_rep1', 'P10_ST_rep2', 'P10_ST_rep3']
Test set names: ['P2_ST_rep1', 'P5_ST_rep1', 'P9_ST_rep1', 'P10_ST_rep1']
Train set names: ['P2_ST_rep2', 'P9_ST_rep3', 'P9_ST_rep2', 'P2_ST_rep3', 'P10_ST_rep2', 'P5_ST_rep2', 'P5_ST_rep3', 'P10_ST_rep3']
['P2_ST_rep1', 'P5_ST_rep1', 'P9_ST_rep1', 'P10_ST_rep1']
Loading imgs...
Loading metadata...
Finished building loaders
Finished loading data


In [7]:
# datasize = [2378, 2349, 2277, 2265]
# model_path ="clip/best.pt"
# model_path = "clip/mymodel_HER2.pt"
model_path ="clip/ICL2ST_cSCC.pt"
save_path = "clip/embeddings/"
# model = GraphCLIP().cuda()
model = myModel().cuda()

state_dict = torch.load(model_path)
new_state_dict = {}
for key in state_dict.keys():
    new_key = key.replace('module.', '')  # remove the prefix 'module.'
    new_key = new_key.replace('well', 'spot') # for compatibility with prior naming
    if "image_encoder.gnn" in new_key: # Special to GNN because GNN use torch_geometric.nn
        new_key = new_key.replace("module_1.","module_1.module.")  
    new_state_dict[new_key] = state_dict[key]

model.load_state_dict(new_state_dict)
model.eval()

print("Finished loading model")

FileNotFoundError: [Errno 2] No such file or directory: 'clip/mymodel_HER2.pt'

In [6]:
if not os.path.exists(save_path):
        os.makedirs(save_path)
        
adj_dict = {}
exp_dict = {}
center_dict = {}
with torch.no_grad():
    for batch in tqdm(train_loader):
        ID, patch, center, exp, adj, oris, sfs, centers = batch
        print("Processing image ", ID)
        B,N,C,H,W = patch.shape
        patch = patch.reshape(B*N,C,H,W)  # (N,3,224,224)
        if adj.dim() == 3:
            adj = adj.squeeze(0)
        if exp.dim() == 3:
            exp = exp.squeeze(0)
            centers = centers.squeeze().numpy()
        adj_dict[ID] = adj 
        exp_dict[ID] = exp 
        center_dict[ID] = centers
        
        image_features = model.image_encoder(patch.cuda())
        spot_features = model.spot_encoder(exp.cuda(), adj.cuda())
        
        image_embeddings = model.image_projection(image_features).cpu().numpy()
        spot_embeddings = (model.spot_projection(spot_features.cuda()))
        
        spot_encoding = model.spot_autoencoder.encode(spot_embeddings, adj.cuda())
        spot_reconstruction, extra = model.spot_autoencoder.decode(spot_encoding.cuda())
        
        spot_embeddings = spot_embeddings.cpu().numpy()
        spot_encoding = spot_encoding.cpu().numpy()
        spot_reconstruction = spot_reconstruction.cpu().numpy()
        
        # print(image_embeddings.shape)
        # print(spot_embeddings.shape)
        np.save(save_path + "img_embeddings_" + str(ID[0]) + ".npy", image_embeddings.T)
        np.save(save_path + "spot_embeddings_" + str(ID[0]) + ".npy", spot_embeddings.T)

0it [00:00, ?it/s]

Processing image  ('P2_ST_rep2',)


0it [00:00, ?it/s]


RuntimeError: Given normalized_shape=[785], expected input with shape [*, 785], but got input of size[646, 171]

In [None]:
save_path = "clip/embeddings/"
all_files = os.listdir(save_path)

# exp_dict = {}
# for batch in tqdm(train_loader):
#     ID, patch, center, exp, adj, oris, sfs, *_ = batch
#     print(ID)
#     print(exp.shape)
#     exp_dict[ID] = exp  # Assuming ID and exp are tensors, we fetch their first elements

image_embeddings_dict = {}
spot_embeddings_dict = {}
ID_list = []

for file in all_files:
    if file.endswith(".npy"):
        # Extract the ID from the filename (e.g., A2, C3, etc.)
        if data=='her2st': 
            if 'rep' not in file:
                ID = file.split("_")[2].split(".")[0]
        elif data=='cscc':
            if 'rep' in file:
                ID = "_".join(file.split("_")[2:-1]) + "_" + file.split("_")[-1].split(".")[0]
        
        if (ID,) in adj_dict:
            adj_dict[ID] = adj_dict.pop((ID,))
        if (ID,) in exp_dict:
            exp_dict[ID] = exp_dict.pop((ID,))
        if (ID,) in center_dict:
            center_dict[ID] = center_dict.pop((ID,))
            
        # Determine the type of file based on its prefix and load the data
        if "img_embeddings" in file:
            image_embeddings_dict[ID] = np.load(os.path.join(save_path, file))
            ID_list.append(ID)
        elif "spot_embeddings" in file:
            spot_embeddings_dict[ID] = np.load(os.path.join(save_path, file))

# Now, image_embeddings_dict and spot_embeddings_dict contain the required data
print(image_embeddings_dict.keys())  # Should list all the image embedding IDs
print(spot_embeddings_dict.keys())  # Should list all the spot embedding IDs
print(exp_dict.keys())  # Should list all the spot embedding IDs
print(ID_list)

if data=='her2st':   
    fold=[0,6,12,18,24,27,31,33]
    test_ID = ['A1','B1','C1','D1','E1','F1','G2','H1']
elif data=='cscc':
    fold=[0,3,6,9]
    test_ID = ['P2_ST_rep1', 'P5_ST_rep1', 'P9_ST_rep1', 'P10_ST_rep1']
    
# test_ID = [ID_list[i] for i in fold]
print("Test set names:", test_ID)
train_ID = list(set(ID_list)-set(test_ID))
print("Train set names:",train_ID)



dict_keys(['H1', 'A3', 'A4', 'H2', 'E3', 'D6', 'B6', 'B4', 'A6', 'A5', 'G3', 'E2', 'D3', 'C6', 'H3', 'C2', 'D5', 'C3', 'F3', 'C5', 'A2', 'B2', 'F2', 'B5', 'B3', 'D2', 'G1', 'C4', 'D4', 'A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2'])
dict_keys(['H1', 'A3', 'A4', 'H2', 'E3', 'D6', 'B6', 'B4', 'A6', 'A5', 'G3', 'E2', 'D3', 'C6', 'H3', 'C2', 'D5', 'C3', 'F3', 'C5', 'A2', 'B2', 'F2', 'B5', 'B3', 'D2', 'G1', 'C4', 'D4', 'A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2'])
dict_keys(['A3', 'A4', 'H2', 'E3', 'D6', 'B6', 'B4', 'A6', 'A5', 'G3', 'E2', 'D3', 'C6', 'H3', 'C2', 'D5', 'C3', 'F3', 'C5', 'A2', 'B2', 'F2', 'B5', 'B3', 'D2', 'G1', 'C4', 'D4', 'A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1'])
['H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'H1', 'A3', 'A4', 'H2', 'E3', 'D6', 'B6', 'B4', 'A6', 'A5', 'G3', 'E2', 'D3', 'C6', 'H3', 'C2', 'D5', 'C3', 'F3', 'C5', 'A2', 'B2', 'F2', 'B5', 'B3', 'D2', 'G1', 'C4', 'D4', 'A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1']
Test set names: ['A1', 'B1'

In [None]:
#query
# test_ID.remove('A1')
print(test_ID)
# image_query = [spot_embeddings_dict[ID] for ID in test_ID]
# expression_gt = [exp_dict[ID].numpy().T for ID in test_ID]

# image_train_data = [image_embeddings_dict[ID] for ID in train_ID]
spot_train_data = [spot_embeddings_dict[ID] for ID in train_ID]
expression_train_data = [exp_dict[ID].numpy().T for ID in train_ID]

spot_key = np.concatenate(spot_train_data, axis=1)
expression_key = np.concatenate(expression_train_data, axis=1)

# print(image_query.shape)
# print(expression_gt.shape)
print(spot_key.shape)
print(expression_key.shape)

if spot_key.shape[1] != 256:
    spot_key = spot_key.T
    print("spot_key shape: ", spot_key.shape)
if expression_key.shape[0] != spot_key.shape[0]:
    expression_key = expression_key.T
    print("expression_key shape: ", expression_key.shape)

['A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1']
(256, 10139)
(785, 10139)
spot_key shape:  (10139, 256)
expression_key shape:  (10139, 785)


In [None]:
import torch
import numpy as np
import scanpy as sc
import anndata as ad
from tqdm import tqdm
from scipy.stats import pearsonr,spearmanr
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score

def test(model,test,device='cuda'):
    model=model.to(device)
    model.eval()
    preds=None
    ct=None
    gt=None
    loss=0
    with torch.no_grad():
        for patch, position, exp, adj, *_, center in tqdm(test):
            patch, position, adj = patch.to(device), position.to(device), adj.to(device).squeeze(0)
            pred = model(patch, position, adj)[0]
            preds = pred.squeeze().cpu().numpy()
            ct = center.squeeze().cpu().numpy()
            gt = exp.squeeze().cpu().numpy()
    adata = ad.AnnData(preds)
    adata.obsm['spatial'] = ct
    adata_gt = ad.AnnData(gt)
    adata_gt.obsm['spatial'] = ct
    return adata,adata_gt

def cluster(adata,label):
    idx = label != 'undetermined'
    tmp=adata[idx]
    l=label[idx]
    print("cluster number:",len(set(l)))
    sc.pp.pca(tmp)
    sc.tl.tsne(tmp)
    kmeans = KMeans(n_clusters=len(set(l)), init="k-means++", random_state=0).fit(tmp.obsm['X_pca'])
    p=kmeans.labels_.astype(str)
    lbl=np.full(len(adata),str(len(set(l))))
    lbl[idx]=p
    adata.obs['kmeans']=lbl
    return p,round(ari_score(p,l),3)

def get_R(data1,data2,dim=1,func=pearsonr):
    adata1=data1.X
    adata2=data2.X
    r1,p1=[],[]
    for g in range(data1.shape[dim]):
        if dim==1:
            r,pv=func(adata1[:,g],adata2[:,g])
        elif dim==0:
            r,pv=func(adata1[g,:],adata2[g,:])
        r1.append(r)
        p1.append(pv)
    r1=np.array(r1)
    p1=np.array(p1)
    return r1,p1

def get_top_values(arr, num_top_values=10, lowest=False):
    return sorted([(i, arr[i]) for i in range(len(arr))], key=lambda x: x[1], reverse=not lowest)[:num_top_values]
top_k = 50
results = {}
top_results = {}
selected_folds = [5]



In [None]:
from sklearn.metrics import mean_squared_error
from math import sqrt

for ID in test_ID:    
    print("Processing Image", ID)
    image_query = spot_embeddings_dict[ID]
    expression_gt = exp_dict[ID].numpy().T

    method = "weighted_average" # "average" "weighted_average"
    save_path = ""
    if image_query.shape[1] != 256:
        image_query = image_query.T
        print("image query shape: ", image_query.shape)
    if expression_gt.shape[0] != image_query.shape[0]:
        expression_gt = expression_gt.T
        print("expression_gt shape: ", expression_gt.shape)
    if spot_key.shape[1] != 256:
        spot_key = spot_key.T
        print("spot_key shape: ", spot_key.shape)
    if expression_key.shape[0] != spot_key.shape[0]:
        expression_key = expression_key.T
        print("expression_key shape: ", expression_key.shape)

    if method == "simple":
        indices = find_matches(spot_key, image_query, top_k=1)
        matched_spot_embeddings_pred = spot_key[indices[:,0],:]
        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        matched_spot_expression_pred = expression_key[indices[:,0],:]
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "average":
        print("finding matches, using average of top 50 expressions")
        indices = find_matches(spot_key, image_query, top_k=50)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0)
        
        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "weighted_average":
        print("finding matches, using weighted average of top 50 expressions")
        indices = find_matches(spot_key, image_query, top_k=100)
        # print("indices = ", indices)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            a = np.sum((spot_key[indices[i,0],:] - image_query[i,:])**2) #the smallest MSE
            weights = np.exp(-(np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1)-a+1))
            # weights = a/np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1)
            # a = np.sqrt(np.sum((spot_key[indices[i,0],:] - image_query[i,:])**2)) #the smallest RMSE
            # weights = np.exp(-(np.sqrt(np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1))-a+1))
            
            # sorted_indices = np.argsort(weights)[::-1]  # 
            # top_10_weights = weights[sorted_indices[:10]]
            # least_10_weights = weights[sorted_indices[-10:]]
            # print("Top 10 weights: ", top_10_weights)
            # print("least 10 weights: ", least_10_weights)
            
            # if i == 0:
            #     print("weights: ", weights)
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0, weights=weights)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0, weights=weights)
        
        # print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        # print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    true = expression_gt
    pred = matched_spot_expression_pred
    adj = adj_dict[ID]
    
    model.eval()
    
    # Create the directory if it doesn't exist
    output_dir = './figures/show'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    with torch.no_grad():
        pred_features = model.spot_encoder(torch.tensor(pred, dtype=torch.float32).cuda(), adj.cuda())
        pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
        pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
        pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())

        pred_features = pred_features.cpu().numpy()
        pred_embeddings = pred_embeddings.cpu().numpy()
        pred_encoding = pred_encoding.cpu().numpy()
        pred_reconstruction = pred_reconstruction.cpu().numpy()

    print(pred.shape)
    print(true.shape)
    print(np.max(pred))
    print(np.max(true))
    print(np.min(pred))
    print(np.min(true))
    
    ####### Prediction PCC performance
    mix = (pred + pred_reconstruction)/2
    
    def evaluate_gene_expression(pred, true, ID, top_k, fold, top_results, testset):
        # Genewise correlation across cells
        corr_cells = np.zeros(pred.shape[0])
        for i in range(pred.shape[0]):
            corr_cells[i] = np.corrcoef(pred[i, :], true[i, :])[0, 1]
        # Remove NaN
        corr_cells = corr_cells[~np.isnan(corr_cells)]
        print("Mean correlation across cells: ", np.mean(corr_cells))
        
        # Calculate RMSE across cells
        mse_cells = mean_squared_error(pred, true)
        rmse_cells = sqrt(mse_cells)
        print("MSE across cells: ", mse_cells)
        print("RMSE across cells: ", rmse_cells)
        
        # Genewise correlation across genes
        corr_genes = np.zeros(pred.shape[1])
        for i in range(pred.shape[1]):
            corr_genes[i] = np.corrcoef(pred[:, i], true[:, i])[0, 1]
        # Remove NaN
        corr_genes = corr_genes[~np.isnan(corr_genes)]
        
        if corr_genes.size == 0:
            print("corr_genes is an empty array")
        elif np.isnan(corr_genes).all():
            print("corr_genes is an array of NaNs")
        else:
            print("Max correlation across genes:", np.nanmax(corr_genes))
        
        print('Fold', ID, "mean correlation across genes: ", np.mean(corr_genes))
        print('Fold', ID, "number of genes with correlation > 0.3: ", np.sum(corr_genes > 0.3))
        
        # Top-k genes
        top_k_indices = np.argsort(corr_genes)[-top_k:]
        top_R_values = corr_genes[top_k_indices]
        top_pred_values = pred[:, top_k_indices]
        top_results[ID] = (top_R_values, top_pred_values)
        print('Fold', ID, f':Top {top_k} Genes Mean Pearson Correlation:', np.nanmean(top_R_values))
        
        # Get top gene correlations
        top_R_values = get_top_values(corr_genes)
        print('Fold', ID, "Top 10 genes with highest correlation:")
        for gene_id, r_value in top_R_values:
            gene_name = testset.gene_set[gene_id]
            print(f"Gene ID: {gene_id}, Gene Name: {gene_name}, R: {r_value}")

    # Example usage:
    print(f"The Prediction: prediction")
    evaluate_gene_expression(pred, true, ID, top_k, fold, top_results, testset)
    # print(f"\n The Prediction Matrix: pred_reconstruction")
    # evaluate_gene_expression(pred_reconstruction, true, ID, top_k, fold, top_results, testset)
    # print(f"\n The Prediction Matrix: mix")
    # evaluate_gene_expression(mix, true, ID, top_k, fold, top_results, testset)
    
    ####### Clustering 
    ### Change the type of pred to AnnData for the next clustering task
    pred = sc.AnnData(pred)
    pred.obsm['spatial'] = center_dict[ID]
    true = ad.AnnData(true)
    true.obsm['spatial'] = center_dict[ID]
    pred_features = sc.AnnData(pred_features)
    pred_features.obsm['spatial'] = center_dict[ID]
    pred_embeddings = sc.AnnData(pred_embeddings)
    pred_embeddings.obsm['spatial'] = center_dict[ID]
    # pred_encoding = sc.AnnData(pred_encoding)
    # pred_encoding.obsm['spatial'] = center_dict[ID]
    pred_reconstruction = sc.AnnData(pred_reconstruction)
    pred_reconstruction.obsm['spatial'] = center_dict[ID]
    mix = sc.AnnData(mix)
    mix.obsm['spatial'] = center_dict[ID]

    
    ####### Generate cluster figure
    label = testset.label[ID]
    # print("label = ",label)
    clus, ARI = cluster(pred, label)
    print('Fold:', fold, 'ARI:', ARI)
    # title = f"{ID} ARI = {ARI:.3f}"  # Format title with ARI value   
    # sc.pl.spatial(pred, img=testset.get_img(ID), color='kmeans', spot_size=112, title=title, save=f"/mymodel_Her2_{ID}.pdf")    
    
    
    # clus, Top_ARI = cluster(top_pred_values, label)
    # print('Fold:', fold, 'Top 100 ARI:', Top_ARI)
    clus, feature_ARI = cluster(pred_features, label)
    print('Fold:', fold, 'Expression features ARI:', feature_ARI)
    # title = f"{ID} ARI = {ARI:.3f}"  # Format title with ARI value   
    # sc.pl.spatial(pred_features, img=testset.get_img(ID), color='kmeans', spot_size=112, title=title, save=f"/mymodel_Her2_{ID}_features.pdf")   
    
    clus, Emb_ARI = cluster(pred_embeddings, label)
    print('Fold:', fold, 'Expression Embeddings ARI:', Emb_ARI)
    # title = f"{ID} ARI = {Emb_ARI:.3f}"  # Format title with ARI value   
    # sc.pl.spatial(pred_embeddings, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/mymodel_Her2_{ID}_Emb.pdf")

    
    clus, Re_ARI = cluster(pred_reconstruction, label)
    print('Fold:', fold, 'Reconstruction ARI:', Re_ARI)
    # title = f"{ID} ARI = {Re_ARI:.3f}"  # Format title with ARI value   
    # sc.pl.spatial(pred_reconstruction, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/mymodel_Her2_{ID}_Reconstruction.pdf")    
    
    clus, mix_ARI = cluster(mix, label)
    print('Fold:', fold, 'Mixed Reconstruction ARI:', mix_ARI)
    # title = f"{ID} ARI = {mix_ARI:.3f}"  # Format title with ARI value   
    # sc.pl.spatial(mix, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/mymodel_Her2_{ID}_mix.pdf") 

    
    # print("\n")
    # if save_path != "":
    #     np.save(save_path + "matched_spot_embeddings_pred.npy", matched_spot_embeddings_pred.T)
    #     np.save(save_path + "matched_spot_expression_pred.npy", matched_spot_expression_pred.T)



Processing Image A1
image query shape:  (346, 256)
expression_gt shape:  (346, 785)
finding matches, using weighted average of top 50 expressions
torch.Size([346, 10139])
(346, 785)
(346, 785)
3.03171968460083
3.5403168
0.0
0.0
The Prediction: prediction
Mean correlation across cells:  0.5540155412618086
MSE across cells:  0.3321556132846759
RMSE across cells:  0.5763294312150612
Max correlation across genes: 0.6266315942844823
Fold A1 mean correlation across genes:  0.20242404514684853
Fold A1 number of genes with correlation > 0.3:  142
Fold A1 :Top 50 Genes Mean Pearson Correlation: 0.4512985892220058
Fold A1 Top 10 genes with highest correlation:
Gene ID: 154, Gene Name: MUCL1, R: 0.6266315942844823
Gene ID: 227, Gene Name: SCD, R: 0.6091376789158621
Gene ID: 698, Gene Name: HLA-DRA, R: 0.5775756990842519
Gene ID: 60, Gene Name: IGKC, R: 0.5751661266024105
Gene ID: 217, Gene Name: JCHAIN, R: 0.5509313155780583
Gene ID: 447, Gene Name: IGHG3, R: 0.549650460091019
Gene ID: 366, Gene 

  pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
  pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
  pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())
  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] ARI: 0.067
cluster number: 5


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression features ARI: 0.118
cluster number: 5


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression Embeddings ARI: 0.13
cluster number: 5


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Reconstruction ARI: 0.107
cluster number: 5


  super()._check_params_vs_input(X, default_n_init=10)
  pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
  pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
  pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())
  c /= stddev[:, None]
  c /= stddev[None, :]


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Mixed Reconstruction ARI: 0.079
Processing Image B1
image query shape:  (295, 256)
expression_gt shape:  (295, 785)
finding matches, using weighted average of top 50 expressions
torch.Size([295, 10139])
(295, 785)
(295, 785)
3.126047372817993
3.494989
0.0
0.0
The Prediction: prediction
Mean correlation across cells:  0.5675477011055287
MSE across cells:  0.25619484341671256
RMSE across cells:  0.5061569355612077
Max correlation across genes: 0.7303738287000091
Fold B1 mean correlation across genes:  0.35370891753996103
Fold B1 number of genes with correlation > 0.3:  517
Fold B1 :Top 50 Genes Mean Pearson Correlation: 0.6109298545118932
Fold B1 Top 10 genes with highest correlation:
Gene ID: 495, Gene Name: FN1, R: 0.7303738287000091
Gene ID: 704, Gene Name: STMN1, R: 0.6929894146725173
Gene ID: 197, Gene Name: HLA-B, R: 0.68820725855525
Gene ID: 227, Gene Name: SCD, R: 0.6858634636546882
Gene ID: 733, Gene Name: UBA52, R: 0.6669335511212615
Gene ID

  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] ARI: 0.319
cluster number: 4


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression features ARI: 0.252
cluster number: 4


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression Embeddings ARI: 0.312
cluster number: 4


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Reconstruction ARI: 0.288
cluster number: 4


  super()._check_params_vs_input(X, default_n_init=10)
  pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
  pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
  pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())
  c /= stddev[:, None]
  c /= stddev[None, :]


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Mixed Reconstruction ARI: 0.275
Processing Image C1
image query shape:  (176, 256)
expression_gt shape:  (176, 785)
finding matches, using weighted average of top 50 expressions
torch.Size([176, 10139])
(176, 785)
(176, 785)
3.2588746547698975
3.4836807
0.0
0.0
The Prediction: prediction
Mean correlation across cells:  0.6813865948761647
MSE across cells:  0.20749962044008663
RMSE across cells:  0.4555212623358943
Max correlation across genes: 0.7775209154667049
Fold C1 mean correlation across genes:  0.31718705261213637
Fold C1 number of genes with correlation > 0.3:  421
Fold C1 :Top 50 Genes Mean Pearson Correlation: 0.6095614646258611
Fold C1 Top 10 genes with highest correlation:
Gene ID: 328, Gene Name: MGP, R: 0.7775209154667049
Gene ID: 227, Gene Name: SCD, R: 0.7765447250133358
Gene ID: 431, Gene Name: CLDN4, R: 0.7418182180305909
Gene ID: 154, Gene Name: MUCL1, R: 0.7354138848661179
Gene ID: 134, Gene Name: GNAS, R: 0.7088249867499148
Gene

  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] ARI: 0.099
cluster number: 3
Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression features ARI: 0.027
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression Embeddings ARI: 0.098
cluster number: 3
Fold: [0, 6, 12, 18, 24, 27, 31, 33] Reconstruction ARI: 0.123
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)
  pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
  pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
  pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Mixed Reconstruction ARI: 0.105
Processing Image D1
image query shape:  (306, 256)
expression_gt shape:  (306, 785)
finding matches, using weighted average of top 50 expressions
torch.Size([306, 10139])
(306, 785)
(306, 785)
2.7603611946105957
3.1218235
0.0
0.0
The Prediction: prediction
Mean correlation across cells:  0.6774986087137128
MSE across cells:  0.24519159346335645
RMSE across cells:  0.4951682476324148
Max correlation across genes: 0.800537847679012
Fold D1 mean correlation across genes:  0.23912903713979938
Fold D1 number of genes with correlation > 0.3:  196
Fold D1 :Top 50 Genes Mean Pearson Correlation: 0.49319748813828185
Fold D1 Top 10 genes with highest correlation:
Gene ID: 495, Gene Name: FN1, R: 0.800537847679012
Gene ID: 89, Gene Name: ITGB6, R: 0.6409442839929919
Gene ID: 275, Gene Name: C3, R: 0.6314906922794241
Gene ID: 60, Gene Name: IGKC, R: 0.6175426565129241
Gene ID: 736, Gene Name: CD74, R: 0.5761914636153195
Gene ID: 

  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] ARI: 0.273
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression features ARI: 0.16
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)
  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression Embeddings ARI: 0.356
cluster number: 3
Fold: [0, 6, 12, 18, 24, 27, 31, 33] Reconstruction ARI: 0.303
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)
  pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
  pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
  pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Mixed Reconstruction ARI: 0.327
Processing Image E1
image query shape:  (587, 256)
expression_gt shape:  (587, 785)
finding matches, using weighted average of top 50 expressions
torch.Size([587, 10139])
(587, 785)
(587, 785)
3.289362907409668
3.4946344
0.0
0.0
The Prediction: prediction
Mean correlation across cells:  0.7150138701687232
MSE across cells:  0.17073682302039525
RMSE across cells:  0.4132031256178918
Max correlation across genes: 0.5465361953520842
Fold E1 mean correlation across genes:  0.18129913826230495
Fold E1 number of genes with correlation > 0.3:  106
Fold E1 :Top 50 Genes Mean Pearson Correlation: 0.3977087023551698
Fold E1 Top 10 genes with highest correlation:
Gene ID: 134, Gene Name: GNAS, R: 0.5465361953520842
Gene ID: 154, Gene Name: MUCL1, R: 0.5435806711582167
Gene ID: 366, Gene Name: FASN, R: 0.5251120246396335
Gene ID: 311, Gene Name: IGHA1, R: 0.5102078587064454
Gene ID: 431, Gene Name: CLDN4, R: 0.5007848275022152
Ge

  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] ARI: -0.016
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression features ARI: 0.003
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression Embeddings ARI: -0.004
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Reconstruction ARI: -0.003
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)
  pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
  pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
  pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Mixed Reconstruction ARI: -0.004
Processing Image F1
image query shape:  (691, 256)
expression_gt shape:  (691, 785)
finding matches, using weighted average of top 50 expressions
torch.Size([691, 10139])
(691, 785)
(691, 785)
3.2279176712036133
3.5189433
0.0
0.0
The Prediction: prediction
Mean correlation across cells:  0.6431287514543348
MSE across cells:  0.2771359538643484
RMSE across cells:  0.5264370369420719
Max correlation across genes: 0.6324748661397126
Fold F1 mean correlation across genes:  0.18731396350608034
Fold F1 number of genes with correlation > 0.3:  100
Fold F1 :Top 50 Genes Mean Pearson Correlation: 0.363574946682831
Fold F1 Top 10 genes with highest correlation:
Gene ID: 60, Gene Name: IGKC, R: 0.6324748661397126
Gene ID: 507, Gene Name: IGLC2, R: 0.4301629078846443
Gene ID: 366, Gene Name: FASN, R: 0.4298318548860491
Gene ID: 756, Gene Name: CAPG, R: 0.40609201142336004
Gene ID: 636, Gene Name: LUC7L3, R: 0.40132136870325513
G

  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] ARI: 0.003
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression features ARI: -0.004
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression Embeddings ARI: 0.008
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Reconstruction ARI: 0.01
cluster number: 3


  super()._check_params_vs_input(X, default_n_init=10)
  pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
  pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
  pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Mixed Reconstruction ARI: 0.004
Processing Image G2
image query shape:  (467, 256)
expression_gt shape:  (467, 785)
finding matches, using weighted average of top 50 expressions
torch.Size([467, 10139])
(467, 785)
(467, 785)
3.214796304702759
3.564969
0.0
0.0
The Prediction: prediction
Mean correlation across cells:  0.6096632283095837
MSE across cells:  0.27298479329255515
RMSE across cells:  0.5224794668621487
Max correlation across genes: 0.6773509436671871
Fold G2 mean correlation across genes:  0.22580273391238387
Fold G2 number of genes with correlation > 0.3:  189
Fold G2 :Top 50 Genes Mean Pearson Correlation: 0.4818500058063149
Fold G2 Top 10 genes with highest correlation:
Gene ID: 495, Gene Name: FN1, R: 0.6773509436671871
Gene ID: 78, Gene Name: TMEM123, R: 0.6009672275636952
Gene ID: 17, Gene Name: POSTN, R: 0.581464413621956
Gene ID: 89, Gene Name: ITGB6, R: 0.5742668501398198
Gene ID: 134, Gene Name: GNAS, R: 0.5677914254436377
Gene I

  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] ARI: 0.217
cluster number: 6


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression features ARI: 0.165
cluster number: 6


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression Embeddings ARI: 0.159
cluster number: 6


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Reconstruction ARI: 0.139
cluster number: 6


  super()._check_params_vs_input(X, default_n_init=10)
  pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
  pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
  pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Mixed Reconstruction ARI: 0.173
Processing Image H1
image query shape:  (613, 256)
expression_gt shape:  (613, 785)
finding matches, using weighted average of top 50 expressions
torch.Size([613, 10139])
(613, 785)
(613, 785)
3.31240177154541
3.5241778
0.0
0.0
The Prediction: prediction
Mean correlation across cells:  0.6524734131032347
MSE across cells:  0.20294628023504424
RMSE across cells:  0.45049559402400846
Max correlation across genes: 0.6319622303525718
Fold H1 mean correlation across genes:  0.2058572421022836
Fold H1 number of genes with correlation > 0.3:  150
Fold H1 :Top 50 Genes Mean Pearson Correlation: 0.43281122542700784
Fold H1 Top 10 genes with highest correlation:
Gene ID: 311, Gene Name: IGHA1, R: 0.6319622303525718
Gene ID: 60, Gene Name: IGKC, R: 0.5954130858853829
Gene ID: 698, Gene Name: HLA-DRA, R: 0.5392008919429748
Gene ID: 134, Gene Name: GNAS, R: 0.503354220631834
Gene ID: 201, Gene Name: VIM, R: 0.4961680364355265
Gene

  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] ARI: 0.246
cluster number: 6


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression features ARI: 0.217
cluster number: 6


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Expression Embeddings ARI: 0.228
cluster number: 6


  super()._check_params_vs_input(X, default_n_init=10)


Fold: [0, 6, 12, 18, 24, 27, 31, 33] Reconstruction ARI: 0.25
cluster number: 6
Fold: [0, 6, 12, 18, 24, 27, 31, 33] Mixed Reconstruction ARI: 0.221


  super()._check_params_vs_input(X, default_n_init=10)


In [None]:
print(fold)
print(top_results)
print(pred)
print(top_k_indices)

[0, 3, 6, 9]
{}
[[0.90636307 0.01716013 1.27345228 ... 0.96279126 1.40579331 1.40914977]
 [0.78615433 0.16812551 0.97852707 ... 0.71877563 1.29658175 1.17757511]
 [0.57783115 0.2906794  0.61150104 ... 0.6798892  1.19054627 0.96232635]
 ...
 [0.36123031 0.07967214 0.33944419 ... 0.29751673 0.53201169 0.26382375]
 [0.15348776 0.07810118 0.2704705  ... 0.17751326 0.44164845 0.21989258]
 [0.11655529 0.07476368 0.20791259 ... 0.37653571 0.28566301 0.14665827]]


NameError: name 'top_k_indices' is not defined

In [None]:
print(pred.shape)
print(pred_features.shape)
print(pred_embeddings.shape)
print(pred_reconstruction.shape)

(613, 785)
(613, 785)
(613, 256)
(613, 785)
