In [1]:
import torch
import numpy as np
import pytorch_lightning as pl
import torchvision.transforms as tf
from tqdm import tqdm
from predict import *
from HIST2ST import *
from dataset import ViT_HER2ST, ViT_SKIN
from scipy.stats import pearsonr,spearmanr
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
from copy import deepcopy as dcp
from collections import defaultdict as dfd
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # For debugging CUDA errors

2025-07-22 12:06:11.515162: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX512F, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[easydl] tensorflow not available!


# Data Loading

In [2]:
# Generate a list of names for different groups (A-G) with specified ranges
name = [*[f'A{i}' for i in range(2,7)], *[f'B{i}' for i in range(1,7)],
    *[f'C{i}' for i in range(1,7)], *[f'D{i}' for i in range(1,7)],
    *[f'E{i}' for i in range(1,4)], *[f'F{i}' for i in range(1,4)], *[f'G{i}' for i in range(1,4)]]

# List of patient identifiers
patients = ['P2', 'P5', 'P9', 'P10']

# List of replicate identifiers
reps = ['rep1', 'rep2', 'rep3']

# Generate skin sample names by combining patient and replicate info
skinname = []
for i in patients:
    for j in reps:
        skinname.append(i + '_ST_' + j)

# Set device for computation
device = 'cuda'

# Model configuration tag (used for hyperparameters)
tag = '5-7-2-8-4-16-32'

# Unpack model hyperparameters from tag string
k, p, d1, d2, d3, h, c = map(lambda x: int(x), tag.split('-'))

# Set dropout rate for model
dropout = 0.2

# Set random seeds for reproducibility
random.seed(12000)
np.random.seed(12000)
torch.manual_seed(12000)
torch.cuda.manual_seed(12000)
torch.cuda.manual_seed_all(12000)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
# Create dictionaries to convert between Ensembl gene IDs and HGNC symbols
conversion = pd.read_csv('../../../../tahsin/HEST/assets/gene_ids/hgnc_complete_set.txt', sep="\t").loc[:, ['ensembl_gene_id', 'symbol']]
convert_to_ens = conversion.set_index('symbol')['ensembl_gene_id'].to_dict()
convert_to_sym = conversion.set_index('ensembl_gene_id')['symbol'].to_dict()

In [4]:
def emb_to_sym(label):
    """
    Convert the label to a list of gene symbols.
    """
    if isinstance(label, str):
        for gene in label.split(' '):
            if gene not in convert_to_sym and gene in convert_to_ens:
                # print(f"Already a symbol.")
                return label
            elif gene not in convert_to_sym:
                print(f"Error: {gene} not in conversion dictionary.")
            else:
                conversion = [convert_to_sym[gene] for gene in label.split(' ')]
                return ' '.join(conversion)
    elif isinstance(label, list):
        conversion = [convert_to_sym[gene] for gene in label]
        return ' '.join(conversion)
    else:
        return []


# Hist2ST Prediction

### To run the trained model, please select the trained model and replace the value of the variable fold with the number in the name of the selected trained model.

In [5]:
fold=5
data='her2st'
# data = 'hest1k'
if data == 'her2st':
    prune = 'Grid'
    genes = 785
elif data == 'cscc':
    prune = 'NA'
    genes = 171
elif data == 'hest1k':
    prune = 'NA'# IDK??
    genes = 785
prune='Grid' if data=='her2st' else 'NA'
genes=171 if data=='cscc' else 785

In [6]:
import os

hest_path = "/work/bose_lab/tahsin/data/HEST"

meta_df = pd.read_csv(os.path.join(hest_path, "HEST_v1_1_0.csv"))
meta_df = meta_df[meta_df['species'] == 'Homo sapiens']


# Filter the dataframe if you need
# meta_df = meta_df[(meta_df['organ'] == 'Heart') & (meta_df['species'] == 'Homo sapiens')]
# print(meta_df.shape)

# Separate dataframe into training and validation sets
# meta_df = meta_df.sample(frac=1, random_state=42)  # Shuffle the dataframe
# train_frac = 0.8
# train_size = int(len(meta_df) * train_frac)
# train_df_80 = meta_df.iloc[:train_size]
# val_df_20 = meta_df.iloc[train_size:]

ids = meta_df['id']

In [7]:
gene_list = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True))
print(np.array(gene_list))

['HPS6' 'TNC' 'NR1H2' 'NUP93' 'HNRNPUL2' 'MARS' 'SUGP1' 'DEAF1' 'WDR13'
 'FLNA' 'ELN' 'MX2' 'COQ4' 'LARP7' 'TXNDC17' 'MIIP' 'AQP1' 'POSTN' 'PSCA'
 'FAM117A' 'AKT1S1' 'DVL1' 'SPIN1' 'LUC7L' 'TTC31' 'CIR1' 'GAR1'
 'RALGAPA2' 'TIMP1' 'GNAI2' 'WDR73' 'UXS1' 'CNN2' 'C1QB' 'CCDC106'
 'ARHGAP10' 'ATL2' 'UBL7' 'NDUFA12' 'PTRF' 'PPP1R1A' 'KAT6B' 'CORO1A'
 'EZH2' 'CSF1' 'CD79A' 'SDHB' 'GATC' 'TMEM160' 'UXT' 'ZNF787' 'ATF6B'
 'SNRPD3' 'ZNF217' 'SLC35A2' 'MMP9' 'ECHDC2' 'PEX11G' 'SF3B4' 'SYDE1'
 'IGKC' 'RASGRP2' 'RND3' 'PNISR' 'HSPB1' 'UBR2' 'TESK1' 'STAMBP' 'DPYSL3'
 'CPSF1' 'ZMYND19' 'COMP' 'TBCA' 'PPHLN1' 'MYL6B' 'AHCY' 'TRAP1' 'R3HDM2'
 'TMEM123' 'TOM1L2' 'VPS13D' 'CD3D' 'MID1IP1' 'DPP7' 'LTBP3' 'GANAB'
 'HLA-DOA' 'FAM193B' 'BRPF1' 'ITGB6' 'GBP5' 'CADM4' 'ARID1B' 'CEP250'
 'PTRH1' 'GPSM1' 'DLG1' 'MLLT4' 'APBB2' 'C19orf24' 'CCS' 'DHX16' 'CRACR2B'
 'CRABP2' 'MEN1' 'ARHGEF40' 'NDUFV3' 'HMGB2' 'ZFPL1' 'POLR2E' 'DAPK3'
 'FAM213B' 'RNF135' 'FDPS' 'GPS1' 'DDAH1' 'CSF1R' 'CHCHD6' 'ARID3A'
 'RABGAP1L' 

In [8]:
# gene_lists = []
# list_o_lists = []
# for id in ids:
#     adata_path = os.path.join(hest_path, "st", id + ".h5ad")
#     if os.path.exists(adata_path):
#         adata = sc.read_h5ad(adata_path)
#         print(f"Loaded {id} with shape {adata.shape}")
#         gene_list = adata.var_names.tolist()
#         gene_list = [emb_to_sym(gene) for gene in gene_list if gene.startswith("ENSG")]
#         str = ' '.join(gene_list)
#         gene_lists.append(str)
#         list_o_lists.append(gene_list)

In [9]:
# import importlib
# import predict
# importlib.reload(predict)
# from predict import *

In [10]:
# import importlib
# import dataset
# importlib.reload(dataset)
# from dataset import ViT_HEST1K



In [11]:
testset = pk_load(fold,'test',dataset=data,flatten=False,adj=True,ori=True,prune=prune)
test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False)
# label=testset.label[testset.names[0]]
label = meta_df[meta_df['id'] == testset.names[0]]['disease_state']
genes=785
model=Hist2ST(
    depth1=d1, depth2=d2,depth3=d3,n_genes=genes, 
    kernel_size=k, patch_size=p,
    heads=h, channel=c, dropout=0.2,
    zinb=0.25, nb=False,
    bake=5, lamb=0.5, 
)
model.load_state_dict(torch.load(f'./downloads/{fold}-Hist2ST.ckpt'))
print(f"X embedding dimension: {model.x_embed.num_embeddings}")
print(f"Y embedding dimension: {model.y_embed.num_embeddings}")
sample_data = testset[0] 
# patches, positions, expression, adj, centers = sample_data
# print(f"Position coordinate range: {positions.min()} to {positions.max()}")


['B1']
Loading imgs...
Loading metadata...
X embedding dimension: 64
Y embedding dimension: 64


In [12]:
# Your current parameters from tag '5-7-2-8-4-16-32'
print(f"Current params: k={k}, p={p}, d1={d1}, d2={d2}, d3={d3}, h={h}, c={c}")
print(f"patch_size={p}, channel={c}")

# Check what the model actually expects vs. receives
sample_data = testset[0]
# patches, positions, expression, adj, centers = sample_data
# print(f"Input patch shape: {patches.shape}")
# print(f"Model expects channel dimension: {c}")

Current params: k=5, p=7, d1=2, d2=8, d3=4, h=16, c=32
patch_size=7, channel=32


In [13]:
# Load and inspect the checkpoint
checkpoint = torch.load(f'./downloads/{fold}-Hist2ST.ckpt')

print("Checkpoint keys (first 20):")
for i, key in enumerate(checkpoint.keys()):
    print(f"{i+1:2d}: {key}")
    if i >= 19:
        break
print(f"... total keys: {len(checkpoint.keys())}")

# If the checkpoint contains metadata about architecture
if 'hyper_parameters' in checkpoint:
    print("\nSaved hyperparameters:")
    for k, v in checkpoint['hyper_parameters'].items():
        print(f"  {k}: {v}")

Checkpoint keys (first 20):
 1: patch_embedding.weight
 2: patch_embedding.bias
 3: x_embed.weight
 4: y_embed.weight
 5: vit.transformer.layer1.0.dw.0.weight
 6: vit.transformer.layer1.0.dw.0.bias
 7: vit.transformer.layer1.0.dw.1.weight
 8: vit.transformer.layer1.0.dw.1.bias
 9: vit.transformer.layer1.0.dw.1.running_mean
10: vit.transformer.layer1.0.dw.1.running_var
11: vit.transformer.layer1.0.dw.1.num_batches_tracked
12: vit.transformer.layer1.0.dw.3.weight
13: vit.transformer.layer1.0.dw.3.bias
14: vit.transformer.layer1.0.dw.4.weight
15: vit.transformer.layer1.0.dw.4.bias
16: vit.transformer.layer1.0.dw.4.running_mean
17: vit.transformer.layer1.0.dw.4.running_var
18: vit.transformer.layer1.0.dw.4.num_batches_tracked
19: vit.transformer.layer1.0.pw.0.weight
20: vit.transformer.layer1.0.pw.0.bias
... total keys: 162


In [14]:
test_set = pk_load(2, 'test', dataset=data, flatten=False, adj=True, ori=True, prune=prune)
test2_loader = DataLoader(test_set, batch_size=1, num_workers=0, shuffle=False)

['A4']
Loading imgs...
Loading metadata...


In [15]:
for d in iter(test2_loader):
    for k in d:
        print(f"Processing sample with label: {k.shape}")
    print("=====")

Processing sample with label: torch.Size([1, 343, 3, 112, 112])
Processing sample with label: torch.Size([1, 343, 2])
Processing sample with label: torch.Size([1, 343, 785])
Processing sample with label: torch.Size([1, 343, 343])
Processing sample with label: torch.Size([1, 343, 785])
Processing sample with label: torch.Size([1, 343])
Processing sample with label: torch.Size([1, 343, 2])
=====


In [16]:
for d in iter(test_loader):
    for k in d:
        print(f"Processing sample with label: {k.shape}")
    print("=====")

Processing sample with label: torch.Size([1, 295, 3, 112, 112])
Processing sample with label: torch.Size([1, 295, 2])
Processing sample with label: torch.Size([1, 295, 785])
Processing sample with label: torch.Size([1, 295, 295])
Processing sample with label: torch.Size([1, 295, 785])
Processing sample with label: torch.Size([1, 295])
Processing sample with label: torch.Size([1, 295, 2])
=====


In [17]:
all_ids = meta_df['id'].tolist()
np.random.seed(42)
np.random.shuffle(all_ids)

split_idx = int(len(all_ids) * 0.8)
sample_ids = all_ids[split_idx:]

In [18]:
# for sample_id in sample_ids:
#     hestset = pk_load(fold, 'test', sample_ids=[sample_id], dataset='hest1k', flatten=False, adj=True, ori=True, prune=prune)
#     hest_loader = DataLoader(hestset, batch_size=1, num_workers=0, shuffle=False)
#     for d in iter(hest_loader):
#         for k in d:
#             print(f"Processing HEST sample with label: {k.shape}")
#         print("=====")

In [19]:
meta_df

Unnamed: 0,dataset_title,id,image_filename,organ,disease_state,oncotree_code,species,patient,st_technology,data_publication_date,...,treatment_comment,pixel_size_um_embedded,pixel_size_um_estimated,magnification,fullres_px_width,fullres_px_height,tissue,disease_comment,subseries,hest_version_added
1,FFPE Human Skin Primary Dermal Melanoma with 5...,TENX158,TENX158.tif,Skin,Cancer,SKCM,Homo sapiens,,Xenium,7/31/24,...,,0.273777,0.273754,40x,18669,35787,Skin,Primary Dermal Melanoma,,v1_1_0
2,FFPE Human Prostate Adenocarcinoma with 5K Hum...,TENX157,TENX157.tif,Prostate,Cancer,PRAD,Homo sapiens,,Xenium,7/31/24,...,,0.273772,0.273741,40x,25002,49976,Prostate,,,v1_1_0
3,Characterization of immune cell populations in...,TENX156,TENX156.tif,Bowel,Cancer,COAD,Homo sapiens,Patient 1,Visium HD,7/11/24,...,,0.264583,0.273802,40x,71106,58791,Colon,Stage II-A,"Visium HD, Sample P1 CRC",v1_1_0
4,Characterization of immune cell populations in...,TENX155,TENX155.tif,Bowel,Cancer,COAD,Homo sapiens,Patient 1,Visium HD,7/11/24,...,,0.264583,0.273874,40x,75250,48740,Colon,,"Visium HD, Sample P2 CRC",v1_1_0
5,Characterization of immune cell populations in...,TENX154,TENX154.tif,Bowel,Cancer,COAD,Homo sapiens,Patient 1,Visium HD,7/11/24,...,,0.264583,0.273771,40x,72897,64370,Colon,,"Visium HD, Sample P5 CRC",v1_1_0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1224,spatialLIBD,MISC5,MISC5.tif,Brain,Healthy,,Homo sapiens,,Visium,,...,,,0.727113,20x,13332,13332,dorsolateral prefrontal cortex,,151672,v1_0_0
1225,spatialLIBD,MISC4,MISC4.tif,Brain,Healthy,,Homo sapiens,,Visium,,...,,,0.726109,20x,13332,13332,dorsolateral prefrontal cortex,,151673,v1_0_0
1226,spatialLIBD,MISC3,MISC3.tif,Brain,Healthy,,Homo sapiens,,Visium,,...,,,0.725124,20x,13332,13332,dorsolateral prefrontal cortex,,151674,v1_0_0
1227,spatialLIBD,MISC2,MISC2.tif,Brain,Healthy,,Homo sapiens,,Visium,,...,,,0.726109,20x,13332,13332,dorsolateral prefrontal cortex,,151675,v1_0_0


In [20]:
def get_R_robust(data1, data2):
    """Robust correlation calculation handling edge cases"""
    x1 = data1.X
    x2 = data2.X
    
    print(f"Data shapes: pred {x1.shape}, gt {x2.shape}")
    
    if x1.shape != x2.shape:
        print("ERROR: Shape mismatch between predicted and ground truth data")
        return np.array([]), np.array([])
    
    r_values = []
    
    for i in range(x1.shape[1]):  # For each gene
        col1 = x1[:, i]
        col2 = x2[:, i]
        
        # Check for constant columns (zero variance)
        if np.var(col1) == 0 or np.var(col2) == 0:
            r_values.append(0.0)
            continue
            
        # Check for valid data points
        mask = np.isfinite(col1) & np.isfinite(col2)
        if mask.sum() < 2:
            r_values.append(0.0)
            continue
            
        try:
            r, _ = pearsonr(col1[mask], col2[mask])
            r_values.append(r if np.isfinite(r) else 0.0)
        except:
            r_values.append(0.0)
    
    return np.array(r_values), np.array([])



In [21]:
correlations = []
for sample_id in sample_ids:
    hestset = pk_load(fold, 'test', sample_ids=[sample_id], dataset='hest1k', flatten=False, adj=True, ori=True, prune=prune)
    hest_loader = DataLoader(hestset, batch_size=1, num_workers=0, shuffle=False)
    pred, gt = test(model, hest_loader,'cuda')
    # Use robust correlation
    R, _ = get_R_robust(pred, gt)
    print(f'Pearson Correlation: {np.nanmean(R):.4f}')
    print(f'Valid correlations: {np.sum(~np.isnan(R))}/{len(R)}')
    correlations.append([sample_id, np.nanmean(R)])

Error: MARS not in conversion dictionary.
Error: CIR1 not in conversion dictionary.
Error: PTRF not in conversion dictionary.
Error: MLLT4 not in conversion dictionary.
Error: C19orf24 not in conversion dictionary.
Error: FAM213B not in conversion dictionary.
Error: METTL12 not in conversion dictionary.
Error: SLC9A3R2 not in conversion dictionary.
Error: C4orf48 not in conversion dictionary.
Error: SLC22A18 not in conversion dictionary.
Error: C5orf38 not in conversion dictionary.
Error: ATP5O not in conversion dictionary.
Error: C19orf52 not in conversion dictionary.
Error: KIAA1211L not in conversion dictionary.
Error: C19orf60 not in conversion dictionary.
Error: WDR34 not in conversion dictionary.
Error: C12orf45 not in conversion dictionary.
Error: KIAA1715 not in conversion dictionary.
Error: SSSCA1 not in conversion dictionary.
Error: MFSD7 not in conversion dictionary.
Error: FAM173A not in conversion dictionary.
Error: GLTSCR2 not in conversion dictionary.
Error: C10orf54 not

  0%|                                                                           | 0/1 [00:00<?, ?it/s]

Sample NCBI474 has 758 common genes with the dataset
Batch 0: patch shape = torch.Size([1, 598, 3, 112, 112])


100%|███████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.18s/it]


Final shapes - preds: (598, 785), ct: (598, 2), gt: (598, 785)
Data shapes: pred (598, 785), gt (598, 785)
Pearson Correlation: 0.0000
Valid correlations: 785/785
Error: MARS not in conversion dictionary.
Error: CIR1 not in conversion dictionary.
Error: PTRF not in conversion dictionary.
Error: MLLT4 not in conversion dictionary.
Error: C19orf24 not in conversion dictionary.
Error: FAM213B not in conversion dictionary.
Error: METTL12 not in conversion dictionary.
Error: SLC9A3R2 not in conversion dictionary.
Error: C4orf48 not in conversion dictionary.
Error: SLC22A18 not in conversion dictionary.
Error: C5orf38 not in conversion dictionary.
Error: ATP5O not in conversion dictionary.
Error: C19orf52 not in conversion dictionary.
Error: KIAA1211L not in conversion dictionary.
Error: C19orf60 not in conversion dictionary.
Error: WDR34 not in conversion dictionary.
Error: C12orf45 not in conversion dictionary.
Error: KIAA1715 not in conversion dictionary.
Error: SSSCA1 not in conversion d

  0%|                                                                           | 0/1 [00:00<?, ?it/s]

Sample NCBI539 has 771 common genes with the dataset


100%|███████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.24s/it]

Batch 0: patch shape = torch.Size([1, 475, 3, 112, 112])
Final shapes - preds: (475, 785), ct: (475, 2), gt: (475, 785)
Data shapes: pred (475, 785), gt (475, 785)
Pearson Correlation: 0.0000
Valid correlations: 785/785





Error: MARS not in conversion dictionary.
Error: CIR1 not in conversion dictionary.
Error: PTRF not in conversion dictionary.
Error: MLLT4 not in conversion dictionary.
Error: C19orf24 not in conversion dictionary.
Error: FAM213B not in conversion dictionary.
Error: METTL12 not in conversion dictionary.
Error: SLC9A3R2 not in conversion dictionary.
Error: C4orf48 not in conversion dictionary.
Error: SLC22A18 not in conversion dictionary.
Error: C5orf38 not in conversion dictionary.
Error: ATP5O not in conversion dictionary.
Error: C19orf52 not in conversion dictionary.
Error: KIAA1211L not in conversion dictionary.
Error: C19orf60 not in conversion dictionary.
Error: WDR34 not in conversion dictionary.
Error: C12orf45 not in conversion dictionary.
Error: KIAA1715 not in conversion dictionary.
Error: SSSCA1 not in conversion dictionary.
Error: MFSD7 not in conversion dictionary.
Error: FAM173A not in conversion dictionary.
Error: GLTSCR2 not in conversion dictionary.
Error: C10orf54 not

  0%|                                                                           | 0/1 [00:00<?, ?it/s]

Sample MISC61 has 755 common genes with the dataset


100%|███████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.95s/it]

Batch 0: patch shape = torch.Size([1, 2047, 3, 112, 112])
Final shapes - preds: (2047, 785), ct: (2047, 2), gt: (2047, 785)





Data shapes: pred (2047, 785), gt (2047, 785)
Pearson Correlation: 0.0405
Valid correlations: 785/785
Error: MARS not in conversion dictionary.
Error: CIR1 not in conversion dictionary.
Error: PTRF not in conversion dictionary.
Error: MLLT4 not in conversion dictionary.
Error: C19orf24 not in conversion dictionary.
Error: FAM213B not in conversion dictionary.
Error: METTL12 not in conversion dictionary.
Error: SLC9A3R2 not in conversion dictionary.
Error: C4orf48 not in conversion dictionary.
Error: SLC22A18 not in conversion dictionary.
Error: C5orf38 not in conversion dictionary.
Error: ATP5O not in conversion dictionary.
Error: C19orf52 not in conversion dictionary.
Error: KIAA1211L not in conversion dictionary.
Error: C19orf60 not in conversion dictionary.
Error: WDR34 not in conversion dictionary.
Error: C12orf45 not in conversion dictionary.
Error: KIAA1715 not in conversion dictionary.
Error: SSSCA1 not in conversion dictionary.
Error: MFSD7 not in conversion dictionary.
Error: 

  0%|                                                                           | 0/1 [00:02<?, ?it/s]

Sample SPA102 has 739 common genes with the dataset





KeyError: np.str_('HPS6')

In [None]:
meta_df.columns

Index(['dataset_title', 'id', 'image_filename', 'organ', 'disease_state',
       'oncotree_code', 'species', 'patient', 'st_technology',
       'data_publication_date', 'license', 'study_link', 'download_page_link1',
       'inter_spot_dist', 'spot_diameter', 'spots_under_tissue',
       'preservation_method', 'nb_genes', 'treatment_comment',
       'pixel_size_um_embedded', 'pixel_size_um_estimated', 'magnification',
       'fullres_px_width', 'fullres_px_height', 'tissue', 'disease_comment',
       'subseries', 'hest_version_added'],
      dtype='object')

In [None]:
meta_df[meta_df['id']=='MISC139']

Unnamed: 0,dataset_title,id,image_filename,organ,disease_state,oncotree_code,species,patient,st_technology,data_publication_date,...,treatment_comment,pixel_size_um_embedded,pixel_size_um_estimated,magnification,fullres_px_width,fullres_px_height,tissue,disease_comment,subseries,hest_version_added
42,Spatially resolved multiomics of human cardiac...,MISC139,MISC139.tif,Heart,Healthy,,Homo sapiens,,Visium,,...,,31.750063,0.455581,40x,17634,17674,heart right ventricle,,HCAHeartST10550730,v1_1_0


In [None]:
import anndata
ad = anndata.read_h5ad(os.path.join(hest_path, "st", "MISC139.h5ad"))
print(type(ad.var_names[0]))

<class 'str'>


In [None]:
# Clean the data by replacing NaN/Inf values
def clean_data(adata):
    """Remove NaN and Inf values from AnnData object"""
    X = adata.X.copy()
    # Replace NaN with 0 or mean
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
    # Or use median imputation
    # mask = np.isnan(X) | np.isinf(X)
    # X[mask] = np.nanmedian(X)
    adata_clean = ad.AnnData(X)
    adata_clean.obsm = adata.obsm.copy()
    return adata_clean

# Clean both datasets before correlation
pred_clean = clean_data(pred)

R = get_R(pred_clean, gt)[0]
print('Pearson Correlation:', np.nanmean(R))

Pearson Correlation: nan


In [None]:
# Before calling get_R, inspect your data
print("Checking pred data:")
print(f"  Shape: {pred.X.shape}")
print(f"  Has NaN: {np.isnan(pred.X).any()}")
print(f"  Has Inf: {np.isinf(pred.X).any()}")
print(f"  Min value: {np.nanmin(pred.X)}")
print(f"  Max value: {np.nanmax(pred.X)}")

print("Checking gt data:")
print(f"  Shape: {gt.X.shape}")
print(f"  Has NaN: {np.isnan(gt.X).any()}")
print(f"  Has Inf: {np.isinf(gt.X).any()}")
print(f"  Min value: {np.nanmin(gt.X)}")
print(f"  Max value: {np.nanmax(gt.X)}")

Checking pred data:
  Shape: (598, 785)
  Has NaN: True
  Has Inf: False
  Min value: nan
  Max value: nan
Checking gt data:
  Shape: (598, 785)
  Has NaN: False
  Has Inf: False
  Min value: 0.0
  Max value: 4.0000433921813965


In [None]:
R=get_R(pred,gt)[0]
print('Pearson Correlation:',np.nanmean(R))


# clus,ARI=cluster(pred,label)
# print('ARI:',ARI)


Pearson Correlation: 0.2887870599966082
ARI: 0.431
