In [2]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt

import config as CFG
from models import *
from dataset import *
import scanpy as sc
from torch.utils.data import DataLoader

import os
import numpy as np
import pandas as pd

import scanpy as sc
from itertools import chain

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#print the current scanpy version
print(sc.__version__)

1.9.8


In [5]:
fold=5
data='her2st' #### Change here to test different dataset 'her2st' 'cscc'
model_name = 'BLEEP' ## SGCL2ST, BLEEP

prune='Grid' if data=='her2st' else 'NA'
genes=171 if data=='cscc' else 785

# set CUDA device to use
cuda_device = 4
torch.cuda.set_device(cuda_device)

def pk_load(fold,mode='test',flatten=False,dataset='her2st',r=4,ori=True,adj=True,prune='Grid',neighs=8): #r=4 Hist2ST
    assert dataset in ['her2st','cscc']
    if dataset=='her2st':
        dataset = CLIP_HER2ST(
            train=(mode=='train'),fold=fold,flatten=flatten,
            ori=ori,neighs=neighs,adj=adj,prune=prune,r=r
        )
    elif dataset=='cscc':
        dataset = CLIP_SKIN(
            train=(mode=='train'),fold=fold,flatten=flatten,
            ori=ori,neighs=neighs,adj=adj,prune=prune,r=r
        )
    return dataset

def build_loaders_inference():
    print("Building loaders")
    trainset = pk_load(fold,'train',dataset=data,flatten=False,adj=True,ori=True,prune=prune)
    train_loader = DataLoader(trainset, batch_size=1, num_workers=0, shuffle=True)
    testset = pk_load(fold,'test',dataset=data,flatten=False,adj=True,ori=True,prune=prune)
    test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False)
    print("Finished building loaders")
    return trainset, testset, train_loader, test_loader

#2265x256, 2277x256
def find_matches(spot_embeddings, query_embeddings, top_k=1):
    #find the closest matches 
    spot_embeddings = torch.tensor(spot_embeddings)
    query_embeddings = torch.tensor(query_embeddings)
    query_embeddings = F.normalize(query_embeddings, p=2, dim=-1)
    spot_embeddings = F.normalize(spot_embeddings, p=2, dim=-1)
    dot_similarity = query_embeddings @ spot_embeddings.T   #2277x2265
    print("dot_similarity.shape = spots * reference_spots = ",dot_similarity.shape)
    _, indices = torch.topk(dot_similarity.squeeze(0), k=top_k)
    
    return indices.cpu().numpy()

In [6]:
### Loading data

trainset, testset, train_loader, test_loader = build_loaders_inference()
train_loader = chain(train_loader, test_loader)

print("Finished loading data")

Building loaders
Test set names: ['A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1']
Train set names: ['D6', 'A3', 'H3', 'C4', 'E2', 'B6', 'B4', 'D2', 'B2', 'E3', 'D5', 'B5', 'C3', 'B3', 'G1', 'C2', 'C6', 'F2', 'A5', 'C5', 'D4', 'F3', 'A2', 'A6', 'A4', 'D3', 'G3', 'H2']
Loading imgs...
Loading metadata...
Test set names: ['A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1']
Train set names: ['D6', 'A3', 'H3', 'C4', 'E2', 'B6', 'B4', 'D2', 'B2', 'E3', 'D5', 'B5', 'C3', 'B3', 'G1', 'C2', 'C6', 'F2', 'A5', 'C5', 'D4', 'F3', 'A2', 'A6', 'A4', 'D3', 'G3', 'H2']
Loading imgs...
Loading metadata...


  x = np.around(x).astype(int)
  y = np.around(y).astype(int)
  x = np.around(x).astype(int)
  y = np.around(y).astype(int)


Finished building loaders
Finished loading data


In [7]:

# model_path ="clip/best.pt"
if data =='her2st':
    model_path =f"clip/{model_name}_her2.pt"
    save_path = "clip/embeddings/her2st/"
if data =='cscc':
    model_path =f"clip/{model_name}_cscc.pt"
    save_path = "clip/embeddings/cscc/"

if model_name == 'SGCL2ST':
    model = myModel().cuda()
if model_name == 'BLEEP':
    model = CLIPModel_resnet50().cuda()

state_dict = torch.load(model_path)
new_state_dict = {}
for key in state_dict.keys():
    new_key = key.replace('module.', '')  # remove the prefix 'module.'
    new_key = new_key.replace('well', 'spot') # for compatibility with prior naming
    if "image_encoder.gnn" in new_key: # Special to GNN because GNN use torch_geometric.nn
        new_key = new_key.replace("module_1.","module_1.module.")  
    new_state_dict[new_key] = state_dict[key]

model.load_state_dict(new_state_dict)
model.eval()

print("Finished loading model")

Finished loading model


In [8]:
if not os.path.exists(save_path):
        os.makedirs(save_path)
        
adj_dict = {}
exp_dict = {}
center_dict = {}
with torch.no_grad():
    for batch in tqdm(train_loader):
        ID, patch, center, exp, adj, oris, sfs, centers = batch
        print("Processing image ", ID)
        B,N,C,H,W = patch.shape
        patch = patch.reshape(B*N,C,H,W)  # (N,3,112,112)
        if adj.dim() == 3:
            adj = adj.squeeze(0)
        if exp.dim() == 3:
            exp = exp.squeeze(0)
            centers = centers.squeeze().numpy()
        adj_dict[ID] = adj 
        exp_dict[ID] = exp 
        center_dict[ID] = centers
        
        image_features = model.image_encoder(patch.cuda())
        # spot_features = model.spot_encoder(exp.cuda(), adj.cuda())
        spot_features = exp.cuda()
        
        image_embeddings = model.image_projection(image_features).cpu().numpy()
        spot_embeddings = (model.spot_projection(spot_features.cuda()))
        
        # spot_encoding = model.spot_autoencoder.encode(spot_embeddings, adj.cuda())
        # spot_reconstruction, extra = model.spot_autoencoder.decode(spot_encoding.cuda())
        
        spot_embeddings = spot_embeddings.cpu().numpy()
        # spot_encoding = spot_encoding.cpu().numpy()
        # spot_reconstruction = spot_reconstruction.cpu().numpy()
        
        # print(image_embeddings.shape)
        # print(spot_embeddings.shape)
        np.save(save_path + "img_embeddings_" + str(ID[0]) + ".npy", image_embeddings.T)
        np.save(save_path + "spot_embeddings_" + str(ID[0]) + ".npy", spot_embeddings.T)

0it [00:00, ?it/s]

Processing image  ('D3',)


2it [00:05,  2.15s/it]

Processing image  ('D5',)


3it [00:05,  1.38s/it]

Processing image  ('F3',)


4it [00:06,  1.02s/it]

Processing image  ('H2',)


5it [00:06,  1.37it/s]

Processing image  ('B4',)


6it [00:06,  1.78it/s]

Processing image  ('A2',)


7it [00:06,  1.89it/s]

Processing image  ('E3',)


8it [00:07,  2.16it/s]

Processing image  ('B2',)


9it [00:07,  2.33it/s]

Processing image  ('H3',)


10it [00:07,  2.73it/s]

Processing image  ('B6',)


11it [00:08,  3.07it/s]

Processing image  ('A4',)


12it [00:08,  3.28it/s]

Processing image  ('B5',)
Processing image  ('A5',)


14it [00:08,  4.29it/s]

Processing image  ('C2',)
Processing image  ('D4',)


16it [00:09,  4.47it/s]

Processing image  ('B3',)
Processing image  ('C4',)


18it [00:09,  5.31it/s]

Processing image  ('D6',)


19it [00:09,  4.92it/s]

Processing image  ('D2',)


20it [00:10,  3.49it/s]

Processing image  ('F2',)


21it [00:10,  4.30it/s]

Processing image  ('C5',)


22it [00:10,  3.43it/s]

Processing image  ('E2',)


23it [00:11,  3.21it/s]

Processing image  ('G1',)


24it [00:11,  3.05it/s]

Processing image  ('G3',)
Processing image  ('C6',)


26it [00:11,  4.25it/s]

Processing image  ('A3',)
Processing image  ('C3',)


28it [00:12,  4.65it/s]

Processing image  ('A6',)
Processing image  ('A1',)


31it [00:12,  5.17it/s]

Processing image  ('B1',)
Processing image  ('C1',)


32it [00:12,  5.26it/s]

Processing image  ('D1',)


33it [00:13,  4.19it/s]

Processing image  ('E1',)


34it [00:13,  3.53it/s]

Processing image  ('F1',)


35it [00:13,  3.56it/s]

Processing image  ('G2',)


36it [00:14,  2.53it/s]

Processing image  ('H1',)





In [9]:

all_files = os.listdir(save_path)

# exp_dict = {}
# for batch in tqdm(train_loader):
#     ID, patch, center, exp, adj, oris, sfs, *_ = batch
#     print(ID)
#     print(exp.shape)
#     exp_dict[ID] = exp  # Assuming ID and exp are tensors, we fetch their first elements

image_embeddings_dict = {}
spot_embeddings_dict = {}
ID_list = []

for file in all_files:
    if file.endswith(".npy"):
        # Extract the ID from the filename (e.g., A2, C3, etc.)
        if data=='her2st': 
            if 'rep' not in file:
                ID = file.split("_")[2].split(".")[0]
        elif data=='cscc':
            if 'rep' in file:
                ID = "_".join(file.split("_")[2:-1]) + "_" + file.split("_")[-1].split(".")[0]
        
        if (ID,) in adj_dict:
            adj_dict[ID] = adj_dict.pop((ID,))
        if (ID,) in exp_dict:
            exp_dict[ID] = exp_dict.pop((ID,))
        if (ID,) in center_dict:
            center_dict[ID] = center_dict.pop((ID,))
            
        # Determine the type of file based on its prefix and load the data
        if "img_embeddings" in file:
            image_embeddings_dict[ID] = np.load(os.path.join(save_path, file))
            ID_list.append(ID)
        elif "spot_embeddings" in file:
            spot_embeddings_dict[ID] = np.load(os.path.join(save_path, file))

# Now, image_embeddings_dict and spot_embeddings_dict contain the required data
print(image_embeddings_dict.keys())  # Should list all the image embedding IDs
print(spot_embeddings_dict.keys())  # Should list all the spot embedding IDs
print(exp_dict.keys())  # Should list all the spot embedding IDs
print(ID_list)

if data=='her2st':   
    fold=[0,6,12,18,24,27,31,33]
    test_ID = ['A1','B1','C1','D1','E1','F1','G2','H1']
elif data=='cscc':
    fold=[0,3,6,9]
    test_ID = ['P2_ST_rep1', 'P5_ST_rep1', 'P9_ST_rep1', 'P10_ST_rep1']
    
# test_ID = [ID_list[i] for i in fold]
print("Test set names:", test_ID)
train_ID = list(set(ID_list)-set(test_ID))
print("Train set names:",train_ID)



dict_keys(['E3', 'E2', 'G2', 'F2', 'C1', 'A6', 'C6', 'B1', 'H1', 'G3', 'H3', 'B6', 'B2', 'C5', 'A5', 'D6', 'D4', 'D2', 'A2', 'D1', 'E1', 'A3', 'B4', 'G1', 'D5', 'C2', 'F1', 'H2', 'A1', 'C4', 'A4', 'B3', 'F3', 'D3', 'B5', 'C3'])
dict_keys(['C6', 'C3', 'E1', 'C4', 'D6', 'C2', 'F1', 'F2', 'A6', 'A1', 'D5', 'B6', 'D2', 'C5', 'H1', 'D3', 'C1', 'A5', 'G1', 'G2', 'B3', 'A2', 'G3', 'B5', 'B2', 'E2', 'B1', 'A3', 'B4', 'E3', 'H2', 'D4', 'H3', 'A4', 'D1', 'F3'])
dict_keys(['C6', 'E3', 'C3', 'E1', 'E2', 'C4', 'G2', 'F2', 'D6', 'C2', 'C1', 'F1', 'A6', 'B1', 'A1', 'D5', 'B6', 'H1', 'D2', 'C5', 'G3', 'H3', 'B2', 'D3', 'A5', 'D4', 'G1', 'B3', 'A2', 'B5', 'D1', 'A3', 'B4', 'H2', 'A4', 'F3'])
['E3', 'E2', 'G2', 'F2', 'C1', 'A6', 'C6', 'B1', 'H1', 'G3', 'H3', 'B6', 'B2', 'C5', 'A5', 'D6', 'D4', 'D2', 'A2', 'D1', 'E1', 'A3', 'B4', 'G1', 'D5', 'C2', 'F1', 'H2', 'A1', 'C4', 'A4', 'B3', 'F3', 'D3', 'B5', 'C3']
Test set names: ['A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1']
Train set names: ['D6', 'A3', 'H3'

In [10]:
#query
# test_ID.remove('A1')
print(test_ID)
# image_query = [spot_embeddings_dict[ID] for ID in test_ID]
# expression_gt = [exp_dict[ID].numpy().T for ID in test_ID]

# image_train_data = [image_embeddings_dict[ID] for ID in train_ID]
spot_train_data = [spot_embeddings_dict[ID] for ID in train_ID]
expression_train_data = [exp_dict[ID].numpy().T for ID in train_ID]

spot_key = np.concatenate(spot_train_data, axis=1)
expression_key = np.concatenate(expression_train_data, axis=1)

# print(image_query.shape)
# print(expression_gt.shape)
print(spot_key.shape)
print(expression_key.shape)

if spot_key.shape[1] != 256:
    spot_key = spot_key.T
    print("spot_key shape: ", spot_key.shape)
if expression_key.shape[0] != spot_key.shape[0]:
    expression_key = expression_key.T
    print("expression_key shape: ", expression_key.shape)

['A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1']
(256, 10139)
(785, 10139)
spot_key shape:  (10139, 256)
expression_key shape:  (10139, 785)


In [11]:
import torch
import numpy as np
import scanpy as sc
import anndata as ad
from tqdm import tqdm
from scipy.stats import pearsonr,spearmanr
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score

def test(model,test,device='cuda'):
    model=model.to(device)
    model.eval()
    preds=None
    ct=None
    gt=None
    loss=0
    with torch.no_grad():
        for patch, position, exp, adj, *_, center in tqdm(test):
            patch, position, adj = patch.to(device), position.to(device), adj.to(device).squeeze(0)
            pred = model(patch, position, adj)[0]
            preds = pred.squeeze().cpu().numpy()
            ct = center.squeeze().cpu().numpy()
            gt = exp.squeeze().cpu().numpy()
    adata = ad.AnnData(preds)
    adata.obsm['spatial'] = ct
    adata_gt = ad.AnnData(gt)
    adata_gt.obsm['spatial'] = ct
    return adata,adata_gt

def cluster(adata,label):
    idx = label != 'undetermined'
    tmp=adata[idx]
    l=label[idx]
    print("cluster number:",len(set(l)))
    sc.pp.pca(tmp)
    sc.tl.tsne(tmp)
    kmeans = KMeans(n_clusters=len(set(l)), init="k-means++", random_state=0).fit(tmp.obsm['X_pca'])
    p=kmeans.labels_.astype(str)
    lbl=np.full(len(adata),str(len(set(l))))
    lbl[idx]=p
    adata.obs['kmeans']=lbl
    return p,round(ari_score(p,l),3)

def get_R(data1,data2,dim=1,func=pearsonr):
    adata1=data1.X
    adata2=data2.X
    r1,p1=[],[]
    for g in range(data1.shape[dim]):
        if dim==1:
            r,pv=func(adata1[:,g],adata2[:,g])
        elif dim==0:
            r,pv=func(adata1[g,:],adata2[g,:])
        r1.append(r)
        p1.append(pv)
    r1=np.array(r1)
    p1=np.array(p1)
    return r1,p1

def get_top_values(arr, num_top_values=10, lowest=False):
    return sorted([(i, arr[i]) for i in range(len(arr))], key=lambda x: x[1], reverse=not lowest)[:num_top_values]
top_k = 50
results = {}
top_results = {}
selected_folds = [5]



In [12]:
import warnings
warnings.filterwarnings('ignore')

from sklearn.metrics import mean_squared_error
from math import sqrt

for ID in test_ID:
    print("Begin Processing Image", ID)
    # image_query = spot_embeddings_dict[ID]
    image_query = image_embeddings_dict[ID]
    expression_gt = exp_dict[ID].numpy().T

    method = "weighted_average" # "average" "weighted_average"
    save_path = ""
    if image_query.shape[1] != 256:
        image_query = image_query.T
        print("image query shape: ", image_query.shape)
    if expression_gt.shape[0] != image_query.shape[0]:
        expression_gt = expression_gt.T
        print("expression_gt shape: ", expression_gt.shape)
    if spot_key.shape[1] != 256:
        spot_key = spot_key.T
        print("spot_key shape: ", spot_key.shape)
    if expression_key.shape[0] != spot_key.shape[0]:
        expression_key = expression_key.T
        print("expression_key shape: ", expression_key.shape)

    if method == "simple":
        indices = find_matches(spot_key, image_query, top_k=1)
        matched_spot_embeddings_pred = spot_key[indices[:,0],:]
        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        matched_spot_expression_pred = expression_key[indices[:,0],:]
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "average":
        print("finding matches, using average of top 50 expressions")
        indices = find_matches(spot_key, image_query, top_k=50)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0)
        
        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "weighted_average":
        print("finding matches, using weighted average of top 50 expressions")
        indices = find_matches(spot_key, image_query, top_k=100)
        # print("indices = ", indices)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            a = np.sum((spot_key[indices[i,0],:] - image_query[i,:])**2) #the smallest MSE
            weights = np.exp(-(np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1)-a+1))
            # weights = a/np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1)
            # a = np.sqrt(np.sum((spot_key[indices[i,0],:] - image_query[i,:])**2)) #the smallest RMSE
            # weights = np.exp(-(np.sqrt(np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1))-a+1))
            
            # sorted_indices = np.argsort(weights)[::-1]  # 
            # top_10_weights = weights[sorted_indices[:10]]
            # least_10_weights = weights[sorted_indices[-10:]]
            # print("Top 10 weights: ", top_10_weights)
            # print("least 10 weights: ", least_10_weights)
            
            # if i == 0:
            #     print("weights: ", weights)
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0, weights=weights)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0, weights=weights)
        
        # print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        # print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    true = expression_gt
    pred = matched_spot_expression_pred
    adj = adj_dict[ID]
    
    model.eval()
    
    # Create the directory if it doesn't exist
    output_dir = './figures/show'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # Additional subdirectories
    subdirs = ['gene', 'clus']
    for subdir in subdirs:
        subdir_path = os.path.join(output_dir, subdir)
        if not os.path.exists(subdir_path):
            os.makedirs(subdir_path)
    
    with torch.no_grad():
        # pred_features = model.spot_encoder(torch.tensor(pred, dtype=torch.float32).cuda(), adj.cuda())
        pred_features = torch.tensor(pred, dtype=torch.float32).cuda()
        pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
        # pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
        # pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())

        pred_features = pred_features.cpu().numpy()
        pred_embeddings = pred_embeddings.cpu().numpy()
        # pred_encoding = pred_encoding.cpu().numpy()
        # pred_reconstruction = pred_reconstruction.cpu().numpy()

    print("pred.shape",pred.shape)
    print("true.shape",true.shape)
    print("np.max(pred)",np.max(pred))
    print("np.max(true)",np.max(true))
    print("np.min(pred)",np.min(pred))
    print("np.min(true)",np.min(true))
    
    save_path = "clip/results/"
    filename = f"{model_name}_predict_{ID}.npy"
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    full_path = os.path.join(save_path, filename)
    np.save(full_path, pred)
    print(f"File saved to {full_path}")
    
    ####### Prediction PCC performance
    # mix = (pred + pred_reconstruction)/2
    
    def evaluate_gene_expression(pred, true, ID, top_k, fold, top_results, testset):
        # Genewise correlation across cells
        corr_cells = np.zeros(pred.shape[0])
        for i in range(pred.shape[0]):
            corr_cells[i] = np.corrcoef(pred[i, :], true[i, :])[0, 1]
        # Remove NaN
        corr_cells = corr_cells[~np.isnan(corr_cells)]
        print("Cell Mean R: ", np.mean(corr_cells))
        
        # Calculate RMSE across cells
        mse_cells = mean_squared_error(pred, true)
        rmse_cells = sqrt(mse_cells)
        print("MSE across cells: ", mse_cells)
        print("RMSE across cells: ", rmse_cells)
        
        # Genewise correlation across genes
        corr_genes = np.zeros(pred.shape[1])
        p_values = np.zeros(pred.shape[1])
        for i in range(pred.shape[1]):
            # corr_genes[i] = np.corrcoef(pred[:, i], true[:, i])[0, 1]
            corr_genes[i], p_values[i] = pearsonr(pred[:, i], true[:, i])
        # Remove NaN
        valid_indices = ~np.isnan(corr_genes)
        corr_genes = corr_genes[valid_indices]
        p_values = p_values[valid_indices]
        
        if corr_genes.size == 0:
            print("corr_genes is an empty array")
        elif np.isnan(corr_genes).all():
            print("corr_genes is an array of NaNs")
        else:
            print("Max correlation across genes:", np.nanmax(corr_genes))
        
        print("Genes mean R: ", np.mean(corr_genes))
        print("Gene median R: ", np.median(corr_genes))
        print("number of genes with correlation > 0.3: ", np.sum(corr_genes > 0.3))
        
        mlog_p_values = -np.log10(p_values)
        # Top-k genes
        # top_k_indices = np.argsort(corr_genes)[-top_k:] ## highest R
        top_k_indices = np.argsort(mlog_p_values)[-top_k:] ## highest -log10 p-values
        top_R_values = corr_genes[top_k_indices]
        top_pred_values = pred[:, top_k_indices]
        top_results[ID] = (top_R_values, top_pred_values)
        print(f'Top {top_k} Genes Mean Pearson Correlation:', np.nanmean(top_R_values))
        print(f'Top {top_k} Genes Median Pearson Correlation:', np.nanmedian(top_R_values))
        
        # Get top gene correlations
        top_R_values = get_top_values(corr_genes)
        print('Fold', ID, "Top 10 genes with highest -log10 p-values:")
        for gene_id, r_value in top_R_values:
            gene_name = testset.gene_set[gene_id]
            print(f"Gene ID: {gene_id}, Gene Name: {gene_name}, R: {r_value}, p_values: {p_values[gene_id]}")
        
        return corr_genes

    # Example usage:
    print(f"The Prediction: prediction")
    corr_genes = evaluate_gene_expression(pred, true, ID, top_k, fold, top_results, testset)
    
    save_path = "clip/results/"
    filename = f"{model_name}_corr_{ID}.npy"
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    full_path = os.path.join(save_path, filename)
    np.save(full_path, corr_genes)
    print(f"File saved to {full_path}")
    
    # print(f"\n The Prediction Matrix: pred_reconstruction")
    # evaluate_gene_expression(pred_reconstruction, true, ID, top_k, fold, top_results, testset)
    # print(f"\n The Prediction Matrix: mix")
    # evaluate_gene_expression(mix, true, ID, top_k, fold, top_results, testset)
    

    ####### Clustering 
    ### Change the type of pred to AnnData for the next clustering task
    pred = sc.AnnData(pred)
    pred.obsm['spatial'] = center_dict[ID]
    true = ad.AnnData(true)
    true.obsm['spatial'] = center_dict[ID]
    if data == "her2st":
        pred.var_names = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True))
        true.var_names = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True))
    elif data == "cscc":
        pred.var_names = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True))
        true.var_names = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True))
    pred_features = sc.AnnData(pred_features)
    pred_features.obsm['spatial'] = center_dict[ID]
    pred_embeddings = sc.AnnData(pred_embeddings)
    pred_embeddings.obsm['spatial'] = center_dict[ID]
    # pred_encoding = sc.AnnData(pred_encoding)
    # pred_encoding.obsm['spatial'] = center_dict[ID]
    # pred_reconstruction = sc.AnnData(pred_reconstruction)
    # pred_reconstruction.obsm['spatial'] = center_dict[ID]
    # mix = sc.AnnData(mix)
    # mix.obsm['spatial'] = center_dict[ID]
    
    # # Extract top 2 genes based on -log10 p-values
    # top_R_values_2 = get_top_values(corr_genes, num_top_values=3)
    # # Visualize the top 2 genes for this ID
    # for gene_id, r_value in top_R_values_2:
    #     gene_name = testset.gene_set[gene_id]
    #     title = f"ID {ID} Gene: {gene_name} R = {r_value:.3f}"
    #     file_path = f"/gene/{ID}_{gene_name}_{r_value:.3f}.pdf"
    #     sc.pl.spatial(pred, img=testset.get_img(ID), color=gene_name, spot_size=112, title=title, color_map='magma', save=file_path)
        
        # #if want gt
        # title = f"ID {ID} Gene: {gene_name} gt"
        # file_path = f"/gene/{ID}_{gene_name}_gt.pdf"
        # sc.pl.spatial(true, img=testset.get_img(ID), color=gene_name, spot_size=112, title=title, color_map='magma', save=file_path) 
    
    
    if data=='her2st':
        ####### Generate cluster figure
        label = testset.label[ID]
        # print("label = ",label)
        clus, ARI = cluster(pred, label)
        print('Fold:', fold, 'ARI:', ARI)
        title = f"SGCL2ST {ID} ARI = {ARI:.3f}"  # Format title with ARI value   
        # sc.pl.spatial(pred, img=testset.get_img(ID), color='kmeans', spot_size=112, title=title, save=f"/SGCL2ST_Her2_{ID}_{ARI:.3f}.pdf")    
        
        
        # clus, Top_ARI = cluster(top_pred_values, label)
        # print('Fold:', fold, 'Top 100 ARI:', Top_ARI)
        clus, feature_ARI = cluster(pred_features, label)
        print('Fold:', fold, 'Expression features ARI:', feature_ARI)
        title = f"SGCL2ST {ID} ARI = {feature_ARI:.3f}"  # Format title with ARI value   
        # sc.pl.spatial(pred_features, img=testset.get_img(ID), color='kmeans', spot_size=112, title=title, save=f"/SGCL2ST_Her2_{ID}_features_{feature_ARI:.3f}.pdf")   
        
        clus, Emb_ARI = cluster(pred_embeddings, label)
        print('Fold:', fold, 'Expression Embeddings ARI:', Emb_ARI)
        title = f"SGCL2ST {ID} ARI = {Emb_ARI:.3f}"  # Format title with ARI value   
        # sc.pl.spatial(pred_embeddings, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/SGCL2ST_Her2_{ID}_Emb_{Emb_ARI:.3f}.pdf")

        
        # clus, Re_ARI = cluster(pred_reconstruction, label)
        # print('Fold:', fold, 'Reconstruction ARI:', Re_ARI)
        # title = f"SGCL2ST {ID} ARI = {Re_ARI:.3f}"  # Format title with ARI value   
        # sc.pl.spatial(pred_reconstruction, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/SGCL2ST_Her2_{ID}_Reconstruction_{Re_ARI:.3f}.pdf")    

        # clus, true_ARI = cluster(true, label) # Observed Gene Expression clustering
        # print('Fold:', fold, 'Observed Gene Expression ARI:', true_ARI)
        # title = f"Observed Gene Expression {ID} ARI = {true_ARI:.3f}"  # Format title with ARI value   
        # sc.pl.spatial(true, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/SGCL2ST_Her2_{ID}_true_{true_ARI:.3f}.pdf")    
     
    
    print("Result of ", ID, " ended! ")
    print("\n\n")
    
    
    # if save_path != "":
    #     np.save(save_path + "matched_spot_embeddings_pred.npy", matched_spot_embeddings_pred.T)
    #     np.save(save_path + "matched_spot_expression_pred.npy", matched_spot_expression_pred.T)



Begin Processing Image A1
image query shape:  (346, 256)
expression_gt shape:  (346, 785)
finding matches, using weighted average of top 50 expressions
dot_similarity.shape = spots * reference_spots =  torch.Size([346, 10139])
pred.shape (346, 785)
true.shape (346, 785)
np.max(pred) 3.0342977046966553
np.max(true) 3.5403168
np.min(pred) 0.0
np.min(true) 0.0
File saved to clip/results/BLEEP_predict_A1.npy
The Prediction: prediction
Cell Mean R:  0.4265043024617445
MSE across cells:  0.4164348637399364
RMSE across cells:  0.6453176456133339
Max correlation across genes: 0.368449623793983
Genes mean R:  0.023193442367097045
Gene median R:  0.024555584584137168
number of genes with correlation > 0.3:  1
Top 50 Genes Mean Pearson Correlation: 0.11800751523187246
Top 50 Genes Median Pearson Correlation: 0.1443975118163467
Fold A1 Top 10 genes with highest -log10 p-values:
Gene ID: 366, Gene Name: FASN, R: 0.368449623793983, p_values: 1.4482517751877577e-12
Gene ID: 134, Gene Name: GNAS, R: 0

: 

In [None]:
print(pred.shape)
print(pred_features.shape)
print(pred_embeddings.shape)
print(pred_reconstruction.shape)

(613, 785)
(613, 785)
(613, 256)


NameError: name 'pred_reconstruction' is not defined

: 

### Visualization of the predicted gene expression

In [None]:
# Visualization of pred
pred_dict = {}
true_dict = {}

for ID in train_ID + test_ID:
# for ID in test_ID:
    print("Begin Processing Image", ID)
    image_query = spot_embeddings_dict[ID]
    expression_gt = exp_dict[ID].numpy().T

    method = "weighted_average" # "average" "weighted_average"
    save_path = ""
    if image_query.shape[1] != 256:
        image_query = image_query.T
        print("image query shape: ", image_query.shape)
    if expression_gt.shape[0] != image_query.shape[0]:
        expression_gt = expression_gt.T
        print("expression_gt shape: ", expression_gt.shape)
    if spot_key.shape[1] != 256:
        spot_key = spot_key.T
        print("spot_key shape: ", spot_key.shape)
    if expression_key.shape[0] != spot_key.shape[0]:
        expression_key = expression_key.T
        print("expression_key shape: ", expression_key.shape)

    if method == "simple":
        indices = find_matches(spot_key, image_query, top_k=1)
        matched_spot_embeddings_pred = spot_key[indices[:,0],:]
        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        matched_spot_expression_pred = expression_key[indices[:,0],:]
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "average":
        print("finding matches, using average of top 50 expressions")
        indices = find_matches(spot_key, image_query, top_k=50)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0)
        
        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "weighted_average":
        print("finding matches, using weighted average of top 50 expressions")
        indices = find_matches(spot_key, image_query, top_k=100)
        # print("indices = ", indices)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            a = np.sum((spot_key[indices[i,0],:] - image_query[i,:])**2) #the smallest MSE
            weights = np.exp(-(np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1)-a+1))

            # if i == 0:
            #     print("weights: ", weights)
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0, weights=weights)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0, weights=weights)
        
        # print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        # print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    true = expression_gt
    pred = matched_spot_expression_pred
    adj = adj_dict[ID]
    
    model.eval()
    
    # Create the directory if it doesn't exist
    output_dir = './figures/show'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # Additional subdirectories
    subdirs = ['gene', 'clus']
    for subdir in subdirs:
        subdir_path = os.path.join(output_dir, subdir)
        if not os.path.exists(subdir_path):
            os.makedirs(subdir_path)
    
    with torch.no_grad():
        pred_features = model.spot_encoder(torch.tensor(pred, dtype=torch.float32).cuda(), adj.cuda())
        pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
        pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
        pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())

        pred_features = pred_features.cpu().numpy()
        pred_embeddings = pred_embeddings.cpu().numpy()
        pred_encoding = pred_encoding.cpu().numpy()
        pred_reconstruction = pred_reconstruction.cpu().numpy()

    print("pred.shape",pred.shape)
    print("true.shape",true.shape)
    print("np.max(pred)",np.max(pred))
    print("np.max(true)",np.max(true))
    print("np.min(pred)",np.min(pred))
    print("np.min(true)",np.min(true))
    
    pred_dict[ID] = pred
    true_dict[ID] = true
    ####### Prediction PCC performance
    # mix = (pred + pred_reconstruction)/2
    
import pickle
# save pred_dict to file
output_dir = './clip'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# # set save path pred_dict
# output_file_path = os.path.join(output_dir, 'pred_dict.pkl')
# # use pickle to write pred_dict to file. 
# with open(output_file_path, 'wb') as file:
#     pickle.dump(pred_dict, file)
# print(f"Dictionary saved to {output_file_path}")
    
# # set save path true_dict
# output_file_path = os.path.join(output_dir, 'true_dict.pkl')
# # use pickle to write pred_dict to file. 
# with open(output_file_path, 'wb') as file:
#     pickle.dump(true_dict, file)
# print(f"Dictionary saved to {output_file_path}")


Begin Processing Image C2
image query shape:  (187, 256)
expression_gt shape:  (187, 785)
finding matches, using weighted average of top 50 expressions
dot_similarity.shape = spots * reference_spots =  torch.Size([187, 10139])
pred.shape (187, 785)
true.shape (187, 785)
np.max(pred) 3.232665777206421
np.max(true) 3.483227
np.min(pred) 0.0
np.min(true) 0.0
Begin Processing Image E3
image query shape:  (570, 256)
expression_gt shape:  (570, 785)
finding matches, using weighted average of top 50 expressions
dot_similarity.shape = spots * reference_spots =  torch.Size([570, 10139])


pred.shape (570, 785)
true.shape (570, 785)
np.max(pred) 3.277336835861206
np.max(true) 3.4187984
np.min(pred) 0.0
np.min(true) 0.0
Begin Processing Image B4
image query shape:  (283, 256)
expression_gt shape:  (283, 785)
finding matches, using weighted average of top 50 expressions
dot_similarity.shape = spots * reference_spots =  torch.Size([283, 10139])
pred.shape (283, 785)
true.shape (283, 785)
np.max(pred) 3.1323516368865967
np.max(true) 4.0000434
np.min(pred) 0.0
np.min(true) 0.0
Begin Processing Image E2
image query shape:  (572, 256)
expression_gt shape:  (572, 785)
finding matches, using weighted average of top 50 expressions
dot_similarity.shape = spots * reference_spots =  torch.Size([572, 10139])
pred.shape (572, 785)
true.shape (572, 785)
np.max(pred) 3.273228645324707
np.max(true) 3.5101275
np.min(pred) 0.0
np.min(true) 0.0
Begin Processing Image H2
image query shape:  (603, 256)
expression_gt shape:  (603, 785)
finding matches, using weighted average of top 50 expressio

In [None]:
# pred_dict['A1']

In [None]:
from math import log10
def get_top_genes(pred_dict, true_dict, num_genes=50):
    # Initialize a dictionary to store gene p-values across all images
    gene_p_values = {}
    
    # Loop through each image and its predicted expression in pred_dict
    for ID, pred in pred_dict.items():
        print("Processing Image", ID)
        true = true_dict[ID]
        # Calculate p-values for each gene in this image
        for gene_idx in range(pred.shape[1]):
            _, p_value = pearsonr(pred[:, gene_idx], true[:, gene_idx])
            if not np.isnan(p_value):  # Only consider valid p-values
                if gene_idx not in gene_p_values:
                    gene_p_values[gene_idx] = []
                gene_p_values[gene_idx].append(p_value)
    
    # Calculate the average -log10(p-value) for each gene across all images
    avg_log_p_values = {}
    for gene_idx, p_values in gene_p_values.items():
        avg_p_value = np.mean(p_values)
        avg_log_p_value = -log10(avg_p_value)
        avg_log_p_values[gene_idx] = avg_log_p_value
    
    # Sort genes by the average -log10(p-value) and get the top 50
    top_genes = sorted(avg_log_p_values, key=avg_log_p_values.get, reverse=True)[:num_genes]
    
    # Return the top genes and their average -log10(p-value)
    return [(gene_idx, avg_log_p_values[gene_idx]) for gene_idx in top_genes]

# Usage:
# Assume testset has an attribute 'gene_set' which is a list of gene names
# true_dict is a dictionary with the same keys as pred_dict and contains the true gene expressions

# Get top 50 genes based on average -log10(p-value) across all images
top_genes_info = get_top_genes(pred_dict, true_dict, num_genes=50)

# Print the names and p-values of the top genes
for gene_idx, log_p_value in top_genes_info:
    gene_name = testset.gene_set[gene_idx]
    print(f"Gene Name: {gene_name}, Average -log10(p-value): {log_p_value}")

# top_genes_info = top_genes_info[:10]
# top_genes_list = [testset.gene_set[gene_idx] for gene_idx, _ in top_genes_info]

Processing Image C2
Processing Image E3


Processing Image B4
Processing Image E2
Processing Image H2
Processing Image A4
Processing Image G1
Processing Image C3
Processing Image F2
Processing Image C6
Processing Image A3
Processing Image A5
Processing Image A6
Processing Image B6
Processing Image C5
Processing Image D5
Processing Image A2
Processing Image D4
Processing Image G3
Processing Image H3
Processing Image D3
Processing Image D2
Processing Image B5
Processing Image C4
Processing Image F3
Processing Image B3
Processing Image B2
Processing Image D6
Processing Image A1
Processing Image B1
Processing Image C1
Processing Image D1
Processing Image E1
Processing Image F1
Processing Image G2
Processing Image H1
Gene Name: FN1, Average -log10(p-value): 17.936663529842242
Gene Name: FASN, Average -log10(p-value): 16.887232256434874
Gene Name: HLA-DRA, Average -log10(p-value): 15.998644534392458
Gene Name: CLDN4, Average -log10(p-value): 14.853041963351883
Gene Name: COL3A1, Average -log10(p-value): 13.482417238086274
Gene Name:

In [None]:
top_n_genes_info = top_genes_info[:12]
top_n_genes_list = [testset.gene_set[gene_idx] for gene_idx, _ in top_n_genes_info]
# top_n_genes_list = ['FN1', 'FASN', 'HLA-DRA', 'CLDN4', 'COL3A1', 'C3', 'GNAS', 'LUM', 'CD74', 'HLA-B', 'MYL12B', 'CCT4']
# top_n_genes_list = ['FN1', 'FASN', 'HLA-DRA', 'CLDN4', 'GNAS', 'MYL12B']
# index_to_name = {index: name for index, name in enumerate(testset.gene_set)}
# top_n_genes_info = [(gene_idx, log_p_value) for gene_idx, log_p_value in top_genes_info if index_to_name[gene_idx] in top_n_genes_list]

In [None]:
print("top_n_genes_list:", top_n_genes_list)

# 计算并保存前6个基因的信息
def calculate_and_save_top_genes(pred_dict, true_dict, testset):
    # # 从 get_top_genes 函数获取基因信息
    # top_genes_info = get_top_genes(pred_dict, true_dict, num_genes=num_genes)
    # top_6_genes_info = top_genes_info[:6]  # 取前6个基因的信息

    # 遍历所有图像ID，计算相关系数和p值
    gene_correlations = {gene_idx: [] for gene_idx, _ in top_n_genes_info}
    for ID in test_ID:
        pred = pred_dict[ID]
        true = true_dict[ID]
        for gene_idx, _ in top_n_genes_info:
            if not np.isnan(pred).any() and not np.isnan(true).any():
                R, p_value = pearsonr(pred[:, gene_idx], true[:, gene_idx])
                gene_correlations[gene_idx].append((R, p_value, ID))

    # # 确保保存图像的目录存在
    # output_dir = './figures/gene'
    # if not os.path.exists(output_dir):
    #     os.makedirs(output_dir)

    # 为每个基因选择最佳的图像ID并保存图像
    for gene_idx, _ in top_n_genes_info:
        print(gene_idx)
        # 找到相关系数最高的记录
        best_record = max(gene_correlations[gene_idx], key=lambda x: x[0])
        R, p_value, ID = best_record
        gene_name = testset.gene_set[gene_idx]
        print(f"Gene: {gene_name}, Best R: {R}, p-value: {p_value}, Image ID: {ID}")
        
        pred_gene = sc.AnnData(pred_dict[ID])
        pred_gene.obsm['spatial'] = center_dict[ID]
        true_gene = sc.AnnData(true_dict[ID])
        true_gene.obsm['spatial'] = center_dict[ID]
        if data == "her2st":
            pred_gene.var_names = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True))
            true_gene.var_names = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True))
        elif data == "cscc":
            pred_gene.var_names = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True))
            true_gene.var_names = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True))
        title = f"ID {ID} Gene: {gene_name} R = {R:.3f}"
        file_path = f"/gene/{ID}_{gene_name}_{R:.3f}.pdf"
        # sc.pl.spatial(pred_gene, img=testset.get_img(ID), color=gene_name, spot_size=112, title=title, color_map='magma', save=file_path)
        
        title = f"ID {ID} Gene: {gene_name} Observed Gene Expression"
        file_path = f"/gene/{ID}_{gene_name}_Observed Gene Expression.pdf"
        # sc.pl.spatial(pred_gene, img=testset.get_img(ID), color=gene_name, spot_size=112, title=title, color_map='magma', save=file_path)

calculate_and_save_top_genes(pred_dict, true_dict, testset)
        

top_n_genes_list: ['FN1', 'FASN', 'HLA-DRA', 'CLDN4', 'COL3A1', 'C3', 'GNAS', 'LUM', 'CD74', 'HLA-B', 'MYL12B', 'CCT4']
495
Gene: FN1, Best R: 0.767801452332277, p-value: 9.872979991558361e-61, Image ID: D1
366
Gene: FASN, Best R: 0.6728431114677937, p-value: 1.4947723387434071e-24, Image ID: C1
698
Gene: HLA-DRA, Best R: 0.6953386324018485, p-value: 6.161677165227055e-44, Image ID: B1
431
Gene: CLDN4, Best R: 0.7812694340098176, p-value: 1.8718903353446333e-37, Image ID: C1
561
Gene: COL3A1, Best R: 0.5851436258630178, p-value: 1.4867453175654812e-17, Image ID: C1
275
Gene: C3, Best R: 0.6364151236984426, p-value: 3.7864984913867283e-36, Image ID: D1
134
Gene: GNAS, Best R: 0.7772286416467751, p-value: 7.591387002938116e-37, Image ID: C1
648
Gene: LUM, Best R: 0.5814862499746591, p-value: 2.6256440230116933e-17, Image ID: C1
736
Gene: CD74, Best R: 0.7200550190552399, p-value: 2.036668280473254e-29, Image ID: C1
197
Gene: HLA-B, Best R: 0.6918753917460574, p-value: 2.403815032371665e-