In [None]:
import os
import gc
import h5py
import pickle
import torch
import mygene
import numpy as np
import pandas as pd
import datasets
import traceback
import warnings
import scanpy as sc
from tqdm import tqdm
from pathlib import Path
from scipy.sparse import csr_matrix
from huggingface_hub import login, hf_hub_download
from tqdm import tqdm

In [2]:
BASE_DIR = "."
HEST_DATA_DIR = os.path.join(BASE_DIR, "hest_data")
ST_DIR = os.path.join(HEST_DATA_DIR, "st")
PATCHES_DIR = os.path.join(HEST_DATA_DIR, "patches")
VOCAB_PATH = os.path.join(BASE_DIR,"whole_human_vocab.csv")
COMMON_GENES_PATH = os.path.join(BASE_DIR, "common_overlap_genes.txt")
COMMON_SAMPLES_PATH = os.path.join(BASE_DIR, "common_overlap_samples.txt")
FILE_NAMES_PATH = os.path.join(BASE_DIR, "file_names.txt")

gene_vocab = pd.read_csv(VOCAB_PATH)
genes_in_vocab = gene_vocab["SYMBOL"]
symbol_to_id = dict(zip(gene_vocab["SYMBOL"], gene_vocab["ID"]))

In [3]:
h5_files = [f for f in os.listdir(PATCHES_DIR) if f.endswith(".h5")]
h5ad_files = [f for f in os.listdir(ST_DIR) if f.endswith(".h5ad")]

In [5]:
print(f"Current number of files - h5 files : {len(h5_files)}, h5ad files : {len(h5ad_files)}")

Current number of files - h5 files : 649, h5ad files : 649


In [None]:
files_with_GRCh = []
files_with_ENSG = []

for file in tqdm(h5ad_files):
    file_path = os.path.join(ST_DIR, file)
    
    try:
        adata = sc.read_h5ad(file_path)  
        genes = set(adata.var_names) 
        
        if any(gene.startswith("GRCh") for gene in genes):
            files_with_GRCh.append(file)

        elif any(gene.startswith("ENSG") for gene in genes):
            files_with_ENSG.append(file)
    
    except Exception as e:
        print(f"Error loading {file}: {e}")

In [None]:
mg = mygene.MyGeneInfo()

for file in tqdm(files_with_ENSG):
    try:
        file_path = os.path.join(ST_DIR, file)
        adata = sc.read_h5ad(file_path)
        
        adata.var_names = adata.var_names.astype(str)
        
        # Extract Ensembl IDs
        ensembl_ids = list(set(
            gene.split(".")[0] 
            for gene in adata.var_names 
            if gene.startswith("ENSG")
        ))
        
        if not ensembl_ids:
            print(f"{file}: No ENSG genes found")
            continue
        
        # Query API
        gene_info = mg.querymany(
            ensembl_ids,
            scopes="ensembl.gene",
            fields="symbol",
            species="human",
            verbose=False
        )
        
        ensembl_to_symbol = {
            item['query']: item.get('symbol', item['query'])
            for item in gene_info
            if 'query' in item
        }
        
        # Convert gene names
        adata.var_names = [
            ensembl_to_symbol.get(gene.split(".")[0], gene)
            for gene in adata.var_names
        ]
        
        adata.write_h5ad(file_path)
        print(f"{file}: Converted successfully")
        
    except Exception as e:
        print(f"{file}: Error - {e}")
        traceback.print_exc()

In [None]:
for file in files_with_GRCh:
    try:
        file_path = os.path.join(ST_DIR, file)
        adata = sc.read_h5ad(file_path)

        adata.var_names = adata.var_names.str.replace(r"^GRCh38__+", "", regex=True)

        adata.write_h5ad(file_path)
        print(f"{file} Converted. Saved at: {file_path}")

    except Exception as e:
        print(f"{file} Error occured: {e}")
        traceback.print_exc() 

## Check overlapping genes across samples

In [None]:
adata_paths = [os.path.join(ST_DIR, file) for file in os.listdir(ST_DIR)]

overlap_counts = {}
overlap_genes = {}

for file in tqdm(adata_paths):
    try:
        adata = sc.read_h5ad(file)
        
        gene_series = pd.Series(adata.var_names)
        
        unique_genes = gene_series[gene_series.map(gene_series.value_counts()) == 1].unique()
        
        overlapping_genes = pd.Index(unique_genes)[pd.Index(unique_genes).isin(genes_in_vocab)]
        overlap_count = len(overlapping_genes)
        
        overlap_counts[file] = overlap_count
        overlap_genes[file] = list(overlapping_genes)

    except Exception as e:
        print(f"{file} Error occured: {e}")

In [None]:
df_oc = pd.DataFrame(
    overlap_counts.items(), 
    columns=["Files", "overlap_counts"]
)
print(df_oc)

file_paths = sorted(list(df_oc[df_oc["overlap_counts"] >= 10000]["Files"]))
file_names = [os.path.basename(file).split(".")[0] for file in file_paths]

df_oc_sorted = df_oc.sort_values(by = 'overlap_counts', ascending = False)

overlap_genes_filter = {}

for file in tqdm(file_paths):
    try:
        adata = sc.read_h5ad(file)
        gene_series = pd.Series(adata.var_names)

        unique_genes = set(gene_series[gene_series.map(gene_series.value_counts()) == 1])
        overlapping_genes = sorted(unique_genes.intersection(set(genes_in_vocab)))

        overlap_genes_filter[file] = overlapping_genes

    except Exception as e:
        print(f"{file} Error occured: {e}")

if overlap_genes_filter:
    common_overlap_genes = set.intersection(*(set(genes) for genes in overlap_genes_filter.values()))
else:
    common_overlap_genes = set()

with open(COMMON_GENES_PATH, "w", encoding = "utf-8") as file:
    for gene in sorted(common_overlap_genes):
        file.write(gene + "\n")


In [4]:
with open(COMMON_SAMPLES_PATH, "r", encoding="utf-8") as file:
    file_names = [line.strip() for line in file if line.strip()]

with open(COMMON_GENES_PATH, "r", encoding="utf-8") as file:
    common_overlap_genes = [line.strip() for line in file if line.strip()]

In [None]:
import os, numpy as np, pandas as pd, scanpy as sc, h5py
from scipy.sparse import csr_matrix
from tqdm import tqdm

h5ad_paths = [os.path.join(ST_DIR, sample + ".h5ad") for sample in file_names]
OUT_CSV   = "./st_sentences/top100_sentences.csv"

sample_stats = []        
total_nonzero = 0          
total_cells = 0           
header_written = False    

for file in tqdm(h5ad_paths):
    sample = os.path.basename(file).split(".")[0]
    h5_file = os.path.join(PATCHES_DIR, sample + ".h5")

    with h5py.File(h5_file, "r") as h5_data:  
        adata = sc.read_h5ad(file)
        adata.var_names_make_unique()
        adata = adata[:, adata.var.index.intersection(sorted(common_overlap_genes))]

        h5_barcodes = h5_data["barcode"][:].astype(str).reshape(-1)
        adata_barcodes = np.array(adata.obs_names)

        valid_indices = np.where(np.isin(adata_barcodes, h5_barcodes))[0]
        adata = adata[valid_indices].copy()

        adata = adata[np.array([np.where(adata_barcodes[valid_indices] == barcode)[0][0] for barcode in h5_barcodes])]

        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.log1p(adata)

        X = csr_matrix(adata.X)
        genes = adata.var.index.to_numpy()
        barcodes = adata.obs.index.to_numpy()

        nnz_per_cell = np.diff(X.indptr)           
        if nnz_per_cell.size > 0:                  
            sample_stats.append({                   
                "sample": sample,
                "n_cells": int(nnz_per_cell.size),
                "mean_nonzero_genes": float(nnz_per_cell.mean()),
                "median_nonzero_genes": float(np.median(nnz_per_cell)),
                "std_nonzero_genes": float(nnz_per_cell.std()),
            })
            total_cells += int(nnz_per_cell.size)   
            total_nonzero += int(nnz_per_cell.sum())

        rows = []
        for i in range(X.shape[0]):
            row = X.getrow(i)
            if row.nnz == 0:
                continue
            else:
                nz_idx = row.indices
                nz_val = row.data
                order = np.argsort(-nz_val)
                top_idx = nz_idx[order][:100]
                top_genes = genes[top_idx]
                sentence = " ".join(top_genes)

            rows.append({"id": f"{sample}_{barcodes[i]}", "sentence": sentence})

        if rows:  
            df_out = pd.DataFrame(rows, columns=["id", "sentence"])
            df_out.to_csv(OUT_CSV, mode="a", index=False, header=(not header_written))
            header_written = True


In [None]:
sentences_path = "./st_sentences/top100_sentences.csv"
images_dir = "./st_images"

spot_sentences = pd.read_csv(sentences_path)
spot_image_files = os.listdir(images_dir)

In [None]:
valid_ids = set(spot_sentences['id'].astype(str))

for img_file in spot_image_files:
    if img_file.endswith(".png"):
        sample_id = img_file.split(".png")[0]  
        if sample_id not in valid_ids:
            img_path = os.path.join(images_dir, img_file)
            os.remove(img_path)
            print(f"Deleted: {img_file}")

In [12]:
spot_image_files_ar = os.listdir(images_dir)
print(len(spot_sentences))
print(len(spot_image_files_ar))

1012725
1012725
