In [1]:
import os
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
import anndata as ad
from typing import Optional, Dict, Any
from tqdm import tqdm

from nicheformer.models import Nicheformer
from nicheformer.data import NicheformerDataset
import anndata as ad

In [2]:
import anndata as ad

## load nicheformer model 

In [3]:
os.getcwd()

'/root/capsule/LC_NE_analayis'

In [15]:
model_path=os.path.join("/root/capsule/data/nicheformer_weights/","nicheformer.ckpt")

In [13]:
model_path

'/root/capsule/data/nicheformer_weights/nicheformer_weights/nicheformer.ckpt'

In [7]:
config = {
    'data_path': '/root/capsule/data/LC_NE/adata_sc_pre.h5ad', #'path/to/your/data.h5ad',  # Path to your AnnData file
    'technology_mean_path': '/root/capsule/data/model_means/merfish_mean_script.npy', #'path/to/technology_mean.npy',  # Path to technology mean file
    'checkpoint_path': '/root/capsule/data/nicheformer_weights/nicheformer.ckpt',  # Path to model checkpoint
    'output_path': 'test_data_with_embeddings.h5ad',  # Where to save the result, it is a new h5ad
    'output_dir': '/root/capsule/LC_NE_output/',  # Directory for any intermediate outputs
    'batch_size': 10,
    'max_seq_len': 1500, 
    'aux_tokens': 30, 
    'chunk_size': 1000, # to prevent OOM
    'num_workers': 4,
    'precision': 32,
    'embedding_layer': -1,  # Which layer to extract embeddings from (-1 for last layer)
    'embedding_name': 'embeddings'  # Name suffix for the embedding key in adata.obsm
}

In [16]:
model = Nicheformer.load_from_checkpoint(model_path, strict=False)
model.eval()  # Set to evaluation mode

# Configure trainer
trainer = pl.Trainer(
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    default_root_dir=config['output_dir'],
    precision=config.get('precision', 32),
)

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:191: Found keys that are in the model state dict but not in the checkpoint: ['pos']
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [17]:
config['data_path']

'/root/capsule/data/LC_NE/adata_sc_pre.h5ad'

In [18]:
adata_0 = ad.read_h5ad(config['data_path'])
technology_mean=np.load(config['technology_mean_path'])
adata_orig = adata_0.copy()

In [19]:
adata_orig.var.index.shape,np.unique(list(adata_orig.var.index)).shape

((24343,), (24343,))

In [20]:
from nicheformer.data.constants import ObsConstants

In [21]:
adata_0.obs[ObsConstants.DATASET] = adata_0.uns['title']
adata_0.obs[ObsConstants.SPLIT] = 'train'
adata_0.obs[ObsConstants.NICHE] = 'nan'
adata_0.obs[ObsConstants.REGION] = 'nan'


# chnage orthologus genes

In [22]:
import pandas as pd

In [23]:
path_to_orthologus_csv = '/root/capsule/data/LC_NE/mouse_to_human.csv'
orthologus_csv = pd.read_csv(path_to_orthologus_csv)

In [24]:
orthologus_csv = orthologus_csv[orthologus_csv.hsapiens_homolog_ensembl_gene.notna()]
orthologus_csv

Unnamed: 0.1,Unnamed: 0,ensembl_gene_id,hsapiens_homolog_ensembl_gene
5,5,ENSMUSG00000064341,ENSG00000198888
9,9,ENSMUSG00000064345,ENSG00000198763
15,15,ENSMUSG00000064351,ENSG00000198804
18,18,ENSMUSG00000064354,ENSG00000198712
20,20,ENSMUSG00000064356,ENSG00000228253
...,...,...,...
83509,83509,ENSMUSG00000108908,ENSG00000186092
83510,83510,ENSMUSG00000108908,ENSG00000176695
83517,83517,ENSMUSG00000026679,ENSG00000151023
83518,83518,ENSMUSG00002076042,ENSG00000252396


In [25]:
foo = dict(zip(orthologus_csv.ensembl_gene_id, orthologus_csv.hsapiens_homolog_ensembl_gene))

In [26]:
adata_0.var=adata_0.var.rename(index=foo)

In [27]:
technology_mean=np.load(config['technology_mean_path'])

In [28]:
technology_mean.shape

(20310,)

In [29]:
model = ad.read_h5ad('/root/capsule/data/model_means/model.h5ad')

In [30]:
adata_0.var.index.shape,np.unique(list(adata_0.var.index)).shape

((24343,), (23911,))

In [32]:
dup_indx = ~adata_0.var.index.duplicated(keep='first')
#adata_0.var = adata_0.var[dup_indx]
adata_0 = adata_0[:,dup_indx]

In [33]:
adata_0.var.shape, np.unique(list(adata_0.var.index)).shape

((23911, 13), (23911,))

In [34]:
keep_genes = [m in set(model.var.index) for m in adata_0.var.index]

In [35]:
adata_0 = adata_0[:, keep_genes]

In [36]:
adata_0.shape

(5040, 15661)

In [37]:
adata = ad.concat([model, adata_0], join='outer', axis=0)
# dropping the first observation 
adata = adata[1:].copy()

In [38]:
adata.shape

(5040, 20310)

In [39]:
# Create dataset
dataset = NicheformerDataset(
    adata=adata,
    technology_mean=technology_mean,
    split='train',
    max_seq_len=1500,
    aux_tokens=config.get('aux_tokens', 30),
    chunk_size=config.get('chunk_size', 1000)
)

100%|██████████| 6/6 [00:10<00:00,  1.83s/it]


In [40]:


# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config.get('num_workers', 4),
    pin_memory=True
)

In [41]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7fb4914faad0>

In [42]:
trainer = pl.Trainer(
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    default_root_dir=config['output_dir'],
    precision=config.get('precision', 32),
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [44]:
Niche_model = Nicheformer.load_from_checkpoint(model_path, strict=False)
Niche_model.eval()  # Set to evaluation mode

# Configure trainer
trainer = pl.Trainer(
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    default_root_dir=config['output_dir'],
    precision=config.get('precision', 32),
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [45]:
print("Extracting embeddings...")
embeddings = []
device = Niche_model.embeddings.weight.device

with torch.no_grad():
    for batch in tqdm(dataloader):
        # Move batch to device
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                for k, v in batch.items()}

        # Get embeddings from the model
        emb = Niche_model.get_embeddings(
            batch=batch,
            layer=config.get('embedding_layer', -1)  # Default to last layer
        )
        embeddings.append(emb.cpu().numpy())


# Concatenate all embeddings
embeddings = np.concatenate(embeddings, axis=0)

Extracting embeddings...


 59%|█████▉    | 297/504 [04:58<03:28,  1.01s/it]


KeyboardInterrupt: 

In [339]:
adata.obsm['nicheformer'] = embeddings

adata.write_h5ad('/home/ec2-user/SageMaker/output_dir/adata_sc_embedded.h5ad')