In [1]:
import os
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
import anndata as ad
from typing import Optional, Dict, Any
from tqdm import tqdm

from nicheformer.models import Nicheformer
from nicheformer.data import NicheformerDataset
import anndata as ad

## load nicheformer model 

In [2]:
model_path=os.path.join("weights","nicheformer.ckpt")

In [3]:
config = {
    'data_path': '/home/ec2-user/SageMaker/test_data/spatial/preprocessed/Xenium_Preview_Human_Non_diseased_Lung_With_Add_on_FFPE_outs.h5ad', #'path/to/your/data.h5ad',  # Path to your AnnData file
    'technology_mean_path': '/home/ec2-user/SageMaker/nicheformer/data/model_means/merfish_mean_script.npy', #'path/to/technology_mean.npy',  # Path to technology mean file
    'checkpoint_path': '/weights/nicheformer.ckpt',  # Path to model checkpoint
    'output_path': 'test_data_with_embeddings.h5ad',  # Where to save the result, it is a new h5ad
    'output_dir': '.',  # Directory for any intermediate outputs
    'batch_size': 32,
    'max_seq_len': 1500, 
    'aux_tokens': 30, 
    'chunk_size': 100, # to prevent OOM
    'num_workers': 4,
    'precision': 32,
    'embedding_layer': -1,  # Which layer to extract embeddings from (-1 for last layer)
    'embedding_name': 'embeddings'  # Name suffix for the embedding key in adata.obsm
}

## Process data

In [4]:
model = ad.read_h5ad('nicheformer/data/model_means/model.h5ad')


In [5]:
# Set random seed for reproducibility
pl.seed_everything(42)

# Load data
adata = ad.read_h5ad(config['data_path'])
technology_mean = np.load(config['technology_mean_path'])

# format data properly with the model
adata = ad.concat([model, adata], join='outer', axis=0)
# dropping the first observation 
adata = adata[1:].copy()

Seed set to 42


In [6]:
Niche_model = Nicheformer.load_from_checkpoint(model_path, strict=False)
Niche_model.eval()  # Set to evaluation mode

# Configure trainer
trainer = pl.Trainer(
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=4,
    default_root_dir=config['output_dir'],
    precision=config.get('precision', 32),
)

/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:191: Found keys that are in the model state dict but not in the checkpoint: ['pos']
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [7]:
adata = ad.read_h5ad(config['data_path'])
technology_mean=np.load(config['technology_mean_path'])

adata = ad.concat([model, adata], join='outer', axis=0)
# # dropping the first observation 
adata = adata[1:].copy()

In [8]:
np.size(technology_mean)

20310

In [12]:
technology_mean=np.load(config['technology_mean_path'])
technology_mean=technology_mean.tolist()
np.size(technology_mean)
arr2=np.array([1.0])
result = np.append(technology_mean, arr2)
technology_mean=result

In [13]:
technology_mean

array([        nan,         nan,         nan, ...,         nan,
       14.23553126,  1.        ])

In [15]:
# Create dataset
dataset = NicheformerDataset(
    adata=adata,
    technology_mean=technology_mean,
    split='train',
    max_seq_len=1500,
    aux_tokens=config.get('aux_tokens', 30),
    chunk_size=config.get('chunk_size', 100)
)

# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=10,
    shuffle=False,
    num_workers=config.get('num_workers', 4),
    pin_memory=True
)

100%|██████████| 2959/2959 [16:24<00:00,  3.00it/s]


In [16]:
trainer = pl.Trainer(
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    default_root_dir=config['output_dir'],
    precision=config.get('precision', 32),
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [17]:
len(tqdm(dataloader))

  0%|          | 0/29589 [00:00<?, ?it/s]


29589

In [18]:
print("Extracting embeddings...")
embeddings = []
device = Niche_model.embeddings.weight.device

with torch.no_grad():
    for batch in tqdm(dataloader):
        # Move batch to device
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                for k, v in batch.items()}

        # Get embeddings from the model
        emb = Niche_model.get_embeddings(
            batch=batch,
            layer=config.get('embedding_layer', -1)  # Default to last layer
        )
        embeddings.append(emb.cpu().numpy())


# Concatenate all embeddings
embeddings = np.concatenate(embeddings, axis=0)

Extracting embeddings...


  4%|▍         | 1153/29589 [05:58<2:27:22,  3.22it/s]


KeyboardInterrupt: 