In [3]:
import os
import time
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
import anndata as ad
import pandas as pd
from tqdm import tqdm

from nicheformer.models import Nicheformer
from nicheformer.data import NicheformerDataset

# --------------------
# Config
# --------------------
config = {
    'data_dir': "/workspace/Projects/FM/Final data/H5AD/Visium",   # input directory with .h5ad files
    'output_dir': "/workspace/Projects/FM/Final data/Visium_embeddings_nicheformer",  # save results here
    'checkpoint_path': "/workspace/Projects/FM/Codes/nicheformer-main/nicheformer-main/nicheformer.ckpt",
    'batch_size': 8,
    'max_seq_len': 1500,
    'aux_tokens': 30,
    'chunk_size': 1000,
    'num_workers': 4,
    'precision': 32,
    'embedding_layer': -1,
    'embedding_name': 'embeddings'
}

os.makedirs(config['output_dir'], exist_ok=True)

# --------------------
# Set reproducibility
# --------------------
pl.seed_everything(42)

# --------------------
# Load pretrained model once
# --------------------
print("Loading pretrained Nicheformer...")
model = Nicheformer.load_from_checkpoint(checkpoint_path=config['checkpoint_path'], strict=False)
model.eval()

trainer = pl.Trainer(
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    default_root_dir=config['output_dir'],
    precision=config.get('precision', 32),
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# --------------------
# Runtime log
# --------------------
runtime_records = []

# --------------------
# Loop over samples
# --------------------
for fname in os.listdir(config['data_dir']):
    if not fname.endswith(".h5ad"):
        continue

    in_path = os.path.join(config['data_dir'], fname)
    out_path = os.path.join(config['output_dir'], fname.replace(".h5ad", "_with_embeddings.h5ad"))

    print(f"\nüîπ Processing {fname}")
    start_time = time.time()

    # Load AnnData
    adata = ad.read_h5ad(in_path)

    # --- Ensure required fields ---
    adata.obs['modality'] = 4   # spatial
    adata.obs['specie'] = 6     # mouse
    adata.obs['assay'] = 9      # adjust depending on Visium/CosMx etc.

    if 'nicheformer_split' not in adata.obs.columns:
        adata.obs['nicheformer_split'] = 'train'
    else:
        adata.obs['nicheformer_split'] = 'train'

    # Ensure ensembl_id_collapsed
    adata.var["ensembl_id_collapsed"] = adata.var_names

    # Compute technology mean for this dataset
    technology_mean = np.asarray(adata.X.mean(axis=0)).ravel()

    # Create dataset + dataloader
    dataset = NicheformerDataset(
        adata=adata,
        technology_mean=technology_mean,
        split="train",
        max_seq_len=config['max_seq_len'],
        aux_tokens=config['aux_tokens'],
        chunk_size=config['chunk_size'],
        metadata_fields={'obs': ['modality', 'specie', 'assay']}
    )

    dataloader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True
    )

    # --- Extract embeddings ---
    embeddings = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Extracting {fname}"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                     for k, v in batch.items()}
            emb = model.get_embeddings(
                batch=batch,
                layer=config['embedding_layer']
            )
            embeddings.append(emb.cpu().numpy())

    embeddings = np.concatenate(embeddings, axis=0)

    # --- Store embeddings ---
    embedding_key = f"X_niche_{config['embedding_name']}"
    adata.obsm[embedding_key] = embeddings
    adata.write_h5ad(out_path)

    # Runtime logging
    total_time = time.time() - start_time
    runtime_records.append({
        "sample": fname,
        "cells": adata.n_obs,
        "genes": adata.n_vars,
        "runtime_minutes": round(total_time / 60, 2)
    })

    print(f"‚úÖ Finished {fname} in {total_time/60:.2f} min ‚Üí saved {out_path}")

# --------------------
# Save runtime log as Excel
# --------------------
runtime_df = pd.DataFrame(runtime_records)
excel_path = os.path.join(config['output_dir'], "embedding_runtimes.xlsx")
runtime_df.to_excel(excel_path, index=False)

print(f"\n‚è±Ô∏è Runtime summary saved to {excel_path}")
display(runtime_df)


Seed set to 42


Loading pretrained Nicheformer...


üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



üîπ Processing 53430.h5ad


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:04<00:00,  1.07it/s]
Extracting 53430.h5ad: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 534/534 [01:31<00:00,  5.86it/s]


‚úÖ Finished 53430.h5ad in 1.62 min ‚Üí saved /workspace/Projects/FM/Final data/Visium_embeddings_nicheformer/53430_with_embeddings.h5ad

üîπ Processing 26933.h5ad


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.13it/s]
Extracting 26933.h5ad: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 254/254 [00:43<00:00,  5.86it/s]


‚úÖ Finished 26933.h5ad in 0.82 min ‚Üí saved /workspace/Projects/FM/Final data/Visium_embeddings_nicheformer/26933_with_embeddings.h5ad

üîπ Processing 26934.h5ad


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:04<00:00,  1.45s/it]
Extracting 26934.h5ad: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 314/314 [00:53<00:00,  5.85it/s]


‚úÖ Finished 26934.h5ad in 1.05 min ‚Üí saved /workspace/Projects/FM/Final data/Visium_embeddings_nicheformer/26934_with_embeddings.h5ad

üîπ Processing 53433.h5ad


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:04<00:00,  1.01s/it]
Extracting 53433.h5ad: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 439/439 [01:14<00:00,  5.87it/s]


‚úÖ Finished 53433.h5ad in 1.41 min ‚Üí saved /workspace/Projects/FM/Final data/Visium_embeddings_nicheformer/53433_with_embeddings.h5ad

üîπ Processing 26935.h5ad


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.01it/s]
Extracting 26935.h5ad: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 270/270 [00:45<00:00,  5.88it/s]


‚úÖ Finished 26935.h5ad in 0.89 min ‚Üí saved /workspace/Projects/FM/Final data/Visium_embeddings_nicheformer/26935_with_embeddings.h5ad

üîπ Processing 26932.h5ad


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.09it/s]
Extracting 26932.h5ad: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 273/273 [00:46<00:00,  5.82it/s]


‚úÖ Finished 26932.h5ad in 0.90 min ‚Üí saved /workspace/Projects/FM/Final data/Visium_embeddings_nicheformer/26932_with_embeddings.h5ad

üîπ Processing 53435.h5ad


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:02<00:00,  1.10s/it]
Extracting 53435.h5ad: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 239/239 [00:40<00:00,  5.87it/s]


‚úÖ Finished 53435.h5ad in 0.78 min ‚Üí saved /workspace/Projects/FM/Final data/Visium_embeddings_nicheformer/53435_with_embeddings.h5ad

üîπ Processing 53431.h5ad


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:03<00:00,  1.04s/it]
Extracting 53431.h5ad: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 340/340 [00:58<00:00,  5.85it/s]


‚úÖ Finished 53431.h5ad in 1.10 min ‚Üí saved /workspace/Projects/FM/Final data/Visium_embeddings_nicheformer/53431_with_embeddings.h5ad

üîπ Processing 53434.h5ad


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:03<00:00,  1.23it/s]
Extracting 53434.h5ad: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 397/397 [01:07<00:00,  5.87it/s]


‚úÖ Finished 53434.h5ad in 1.26 min ‚Üí saved /workspace/Projects/FM/Final data/Visium_embeddings_nicheformer/53434_with_embeddings.h5ad

üîπ Processing 53432.h5ad


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:03<00:00,  1.28it/s]
Extracting 53432.h5ad: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 422/422 [01:11<00:00,  5.87it/s]


‚úÖ Finished 53432.h5ad in 1.32 min ‚Üí saved /workspace/Projects/FM/Final data/Visium_embeddings_nicheformer/53432_with_embeddings.h5ad

‚è±Ô∏è Runtime summary saved to /workspace/Projects/FM/Final data/Visium_embeddings_nicheformer/embedding_runtimes.xlsx


Unnamed: 0,sample,cells,genes,runtime_minutes
0,53430.h5ad,4271,14868,1.62
1,26933.h5ad,2031,14868,0.82
2,26934.h5ad,2505,14868,1.05
3,53433.h5ad,3511,14868,1.41
4,26935.h5ad,2154,14868,0.89
5,26932.h5ad,2182,14868,0.9
6,53435.h5ad,1910,14868,0.78
7,53431.h5ad,2718,14868,1.1
8,53434.h5ad,3176,14868,1.26
9,53432.h5ad,3376,14868,1.32
