# Extract Embeddings from Pre-trained Nicheformer Model

This notebook extracts embeddings from a pre-trained Nicheformer model and stores them in an AnnData object.

In [1]:
#make autoreload

%load_ext autoreload
%autoreload 2

In [2]:
import os
from typing import Any, Dict, Optional

import anndata as ad
import numpy as np
import pytorch_lightning as pl
import torch
from rich import print
from torch.utils.data import DataLoader
from tqdm import tqdm

from nicheformer.data import NicheformerDataset
from nicheformer.models import Nicheformer

In [3]:
adata_nicheformer_check = ad.read_h5ad("/home/labs/nyosef/nathanl/nicheformer-benchmark/data/adata_M1_M2_core_6_sections_nicheformer.h5ad")
adata_nicheformer_check

AnnData object with n_obs × n_vars = 250790 × 20310
    obs: 'soma_joinid', 'is_primary_data', 'dataset_id', 'donor_id', 'assay', 'cell_type', 'development_stage', 'disease', 'tissue', 'tissue_general', 'specie', 'technology', 'dataset', 'x', 'y', 'assay_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'tissue_ontology_term_id', 'suspension_type', 'condition_id', 'tissue_type', 'library_key', 'organism', 'sex', 'niche', 'region', 'nicheformer_split', 'author_cell_type', 'batch', 'modality'
    layers: 'counts'

In [8]:
# adata_nicheformer_check.obs["modality"] 
adata_nicheformer_check.obs["specie"]
# adata_nicheformer_check.obs["technology"] 

99119258951401284221813451836785008968-0     6
95610536408155075407048305276227824006-0     6
94627996433995692958551448758698866723-0     6
10618021078266319649091845116122351158-0     6
89038554173344823793949504602728291142-0     6
                                            ..
198804031042921944377635730961789487507-1    6
156934588469937564968781740624428245657-1    6
155321557493369021799199720361468240742-1    6
24226418575755266420023086883481701930-1     6
308372198820486739554187228970484789236-1    6
Name: specie, Length: 250790, dtype: int64

In [11]:
adata_nicheformer_check.X[0].toarray().max()

5.591953754425049

## Configuration

Set up the configuration parameters for the embedding extraction.

In [60]:
config = {
    "data_path": "/home/nathanlevy/sda/Data/adata_M1_M2_core_6_sections.h5ad",  #'path/to/your/data.h5ad',  # Path to your AnnData file
    "technology_mean_path": "/home/nathanlevy/sda/nicheformer/data/model_means/merfish_mean_script.npy",  #'path/to/technology_mean.npy',  # Path to technology mean file
    "checkpoint_path": "/home/nathanlevy/sda/nicheformer/nicheformer.ckpt",  # Path to model checkpoint
    "output_path": "data_with_embeddings.h5ad",  # Where to save the result, it is a new h5ad
    "output_dir": ".",  # Directory for any intermediate outputs
    "batch_size": 32,
    "max_seq_len": 1500,
    "aux_tokens": 30,
    "chunk_size": 1000,  # to prevent OOM
    "num_workers": 4,
    "precision": 32,
    "embedding_layer": -1,  # Which layer to extract embeddings from (-1 for last layer)
    "embedding_name": "embeddings",  # Name suffix for the embedding key in adata.obsm
}

## Load Data and Create Dataset

In [61]:
adata_nicheformer = ad.read_h5ad("/home/nathanlevy/sda/nicheformer/data/model_means/model.h5ad")
adata_nicheformer

AnnData object with n_obs × n_vars = 1 × 20310
    obs: 'soma_joinid', 'is_primary_data', 'dataset_id', 'donor_id', 'assay', 'cell_type', 'development_stage', 'disease', 'tissue', 'tissue_general', 'specie', 'technology', 'dataset', 'x', 'y', 'assay_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'tissue_ontology_term_id', 'suspension_type', 'condition_id', 'tissue_type', 'library_key', 'organism', 'sex', 'niche', 'region', 'nicheformer_split', 'author_cell_type', 'batch'

In [62]:
adata_nicheformer.var_names.tolist()

['ENSG00000000003',
 'ENSG00000000005',
 'ENSG00000000419',
 'ENSG00000000457',
 'ENSG00000000460',
 'ENSG00000000938',
 'ENSG00000000971',
 'ENSG00000001036',
 'ENSG00000001084',
 'ENSG00000001167',
 'ENSG00000001460',
 'ENSG00000001461',
 'ENSG00000001497',
 'ENSG00000001561',
 'ENSG00000001617',
 'ENSG00000001626',
 'ENSG00000001629',
 'ENSG00000001630',
 'ENSG00000001631',
 'ENSG00000002016',
 'ENSG00000002330',
 'ENSG00000002549',
 'ENSG00000002586',
 'ENSG00000002587',
 'ENSG00000002726',
 'ENSG00000002745',
 'ENSG00000002746',
 'ENSG00000002822',
 'ENSG00000002834',
 'ENSG00000002919',
 'ENSG00000002933',
 'ENSG00000003056',
 'ENSG00000003096',
 'ENSG00000003137',
 'ENSG00000003147',
 'ENSG00000003249',
 'ENSG00000003393',
 'ENSG00000003400',
 'ENSG00000003402',
 'ENSG00000003436',
 'ENSG00000003509',
 'ENSG00000003756',
 'ENSG00000003987',
 'ENSG00000003989',
 'ENSG00000004059',
 'ENSG00000004139',
 'ENSG00000004142',
 'ENSG00000004399',
 'ENSG00000004455',
 'ENSG00000004468',


In [63]:
# Set random seed for reproducibility
pl.seed_everything(42)

# Load data
adata_mouse = ad.read_h5ad(config["data_path"])
technology_mean = np.load(config["technology_mean_path"])

Global seed set to 42


In [64]:
adata_mouse

AnnData object with n_obs × n_vars = 250790 × 1122
    obs: 'organism_ontology_term_id', 'donor_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'disease_ontology_term_id', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 'assay_ontology_term_id', 'suspension_type', 'cluster_id_transfer', 'subclass_transfer', 'cluster_confidence_score', 'subclass_confidence_score', 'high_quality_transfer', 'major_brain_region', 'ccf_region_name', 'brain_section_label', 'tissue_type', 'is_primary_data', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'n_counts', 'batch', '_scvi_batch', '_scvi_labels', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'p

In [65]:
# Find shared genes between mouse and model datasets
shared_genes = adata_mouse.var_names.intersection(adata_nicheformer.var_names)

len(shared_genes)

58

In [66]:
adata_mouse.var_names

Index(['ENSMUSG00000024798', 'ENSMUSG00000042385', 'ENSMUSG00000036198',
       'ENSMUSG00000028780', 'ENSMUSG00000015843', 'ENSMUSG00000026768',
       'ENSMUSG00000049928', 'ENSMUSG00000041046', 'ENSMUSG00000032373',
       'ENSMUSG00000004633',
       ...
       'ENSMUSG00000024064', 'ENSMUSG00000035580', 'ENSMUSG00000010136',
       'ENSMUSG00000024376', 'ENSMUSG00000022324', 'ENSMUSG00000015619',
       'ENSMUSG00000070047', 'ENSMUSG00000002266', 'ENSMUSG00000036111',
       'ENSMUSG00000033063'],
      dtype='object', name='gene_id', length=1122)

In [67]:
# Subset both AnnData objects to retain only shared genes
adata_mouse_shared = adata_mouse[:, shared_genes].copy()
adata_model_shared = adata_nicheformer[:, shared_genes].copy()

adata_mouse_not_shared = adata_mouse[:, ~adata_mouse.var_names.isin(shared_genes)].copy()
adata_model_not_shared = adata_nicheformer[:, ~adata_nicheformer.var_names.isin(shared_genes)].copy()

In [68]:
# from pybiomart import Dataset

# # Connect to the BioMart database
# mouse_dataset = Dataset(name="mmusculus_gene_ensembl", host="http://www.ensembl.org")
# human_dataset = Dataset(name="hsapiens_gene_ensembl", host="http://www.ensembl.org")

# # Retrieve orthologous mappings
# mapping = mouse_dataset.query(
#     attributes=[
#         "ensembl_gene_id",  # Mouse gene ID
#         "hsapiens_homolog_ensembl_gene",  # Human ortholog ID
#     ]
# )

# # Convert to dictionary for quick lookup
# mouse_to_human = dict(zip(mapping["Gene stable ID"], mapping["Human gene stable ID"]))

# # Map mouse genes to human genes in adata_mouse
# mapped_genes = adata_mouse.var_names.map(mouse_to_human)

In [69]:
from pybiomart import Dataset

# Initialize BioMart dataset for mouse genes
mouse_dataset = Dataset(name="mmusculus_gene_ensembl", host="http://www.ensembl.org")

# Fetch mouse-to-human gene mappings
orthologs = mouse_dataset.query(
    attributes=[
        "ensembl_gene_id",  # Mouse gene IDs
        "hsapiens_homolog_ensembl_gene",  # Human gene orthologs
    ]
)

# Convert to a dictionary for easy mapping
mouse_to_human = orthologs.set_index("Gene stable ID")["Human gene stable ID"].to_dict()

In [70]:
print(adata_mouse_not_shared.shape)

# Map mouse gene IDs to human gene IDs
adata_mouse_not_shared.var["human_gene"] = adata_mouse_not_shared.var_names.map(mouse_to_human)

# Remove genes without human orthologs
adata_mouse_mapped = adata_mouse_not_shared[:, adata_mouse_not_shared.var["human_gene"].notna()].copy()

print(adata_mouse_not_shared.var["human_gene"].isna()[adata_mouse_not_shared.var["human_gene"].isna()])

# Rename the genes to their human counterparts
adata_mouse_mapped.var_names = adata_mouse_mapped.var["human_gene"]

print(adata_mouse_mapped.shape)

In [84]:
# import requests

# # List of mouse gene IDs
# mouse_genes = ["ENSMUSG00000033581", "ENSMUSG00000046550"]

# # Base URL for MyGene.info API
# url = "http://mygene.info/v3/query"

# # Placeholder for results
# results = []

# # Query each gene
# for gene in mouse_genes:
#     # Define the query parameters
#     params = {
#         "q": gene,  # The gene to query
#         "scopes": "ensembl.gene",  # Input gene ID type
#         "fields": "homologene",  # Get homologene data
#         "species": "mouse",  # Specify the species
#     }
#     # Send the GET request
#     response = requests.get(url, params=params)
#     if response.ok:
#         # Parse and store the result
#         results.append(response.json())
#     else:
#         print(f"Failed to fetch data for {gene}: {response.status_code}")

# # Print the results
# for result in results:
#     print(result)

import requests

# List of mouse gene IDs
mouse_genes = ["ENSMUSG00000033581", "ENSMUSG00000046550"]

# Base URL for MyGene.info API
url = "http://mygene.info/v3/query"

# Placeholder for results
results = []

# Query each mouse gene
for gene in mouse_genes:
    # Define the query parameters
    params = {
        "q": gene,  # The gene to query
        "scopes": "ensembl.gene",  # Input gene ID type
        "fields": "homologene",  # Get homologene data
        "species": "mouse",  # Specify the species
    }
    # Send the GET request
    response = requests.get(url, params=params)
    if response.ok:
        results.append(response.json())
    else:
        print(f"Failed to fetch data for {gene}: {response.status_code}")

# Extract human orthologs and fetch their Ensembl IDs
human_ensembl_ids = []
for result in results:
    if "hits" in result:
        for hit in result["hits"]:
            homologene_data = hit.get("homologene", {})
            orthologs = homologene_data.get("genes", [])
            # Find human orthologs (species ID 9606)
            human_ortholog = next((gene[1] for gene in orthologs if gene[0] == 9606), None)
            if human_ortholog:
                # Query MyGene.info to get the Ensembl ID for the human ortholog
                params = {
                    "q": human_ortholog,  # The human ortholog ID
                    "scopes": "entrezgene",  # Input type is Entrez Gene ID
                    "fields": "ensembl.gene",  # Retrieve Ensembl Gene ID
                    "species": "human",  # Specify the species
                }
                response = requests.get(url, params=params)
                if response.ok:
                    data = response.json()
                    if "hits" in data:
                        for hit in data["hits"]:
                            ensembl_id = hit.get("ensembl", {}).get("gene", None)
                            if ensembl_id:
                                human_ensembl_ids.append((gene, ensembl_id))

# Display the results
for mouse_gene, human_ensembl in human_ensembl_ids:
    print(f"Mouse Gene: {mouse_gene} -> Human Ortholog (Ensembl): {human_ensembl}")


In [87]:
'ENSG00000073792' in adata_nicheformer.var_names

True

In [85]:
missing_genes

Index(['ENSMUSG00000033581', 'ENSMUSG00000046550'], dtype='object', name='gene_id')

In [76]:
import requests

missing_genes = adata_mouse_not_shared.var_names[adata_mouse_not_shared.var["human_gene"].isna()]

for gene in missing_genes[:10]:  # Querying first 10 as an example
    response = requests.get(f"https://www.orthodb.org/v10/lookup/{gene}")
    if response.ok:
        print(response.json())

In [82]:
import mygene

# Initialize MyGene.info client
mg = mygene.MyGeneInfo()

# List of mouse genes to query
mouse_genes = ["ENSMUSG00000033581", "ENSMUSG00000046550"]

# Query MyGene.info for orthologs
results = mg.querymany(
    mouse_genes,
    scopes="ensembl.gene",  # Specify that you're querying Ensembl gene IDs
    fields="homologene",  # Request homology information
    species="mouse",  # Specify mouse as the input species
)

# Process the results
orthologs = {}
for res in results:
    if "homologene" in res and "human" in res["homologene"]:
        orthologs[res["query"]] = res["homologene"]["human"]

# Output the orthologs
print("Orthologs found:")
for mouse_gene, human_gene in orthologs.items():
    print(f"{mouse_gene} -> {human_gene}")

[autoreload of numpy.matrixlib failed: Traceback (most recent call last):
  File "/home/nathanlevy/mambaforge/envs/scvi/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/nathanlevy/mambaforge/envs/scvi/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/home/nathanlevy/mambaforge/envs/scvi/lib/python3.11/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 621, in _exec
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/nathanlevy/mambaforge/envs/scvi/lib/python3.11/site-packages/numpy/matrixlib/__init__.py", line 6, in <module>
    __all__ = defmatrix.__all__
              ^^^^^^^^^
NameError: name 'defmatrix' 

ImportError: cannot import name 'alwayslist' from 'biothings_client' (/home/nathanlevy/mambaforge/envs/scvi/lib/python3.11/site-packages/biothings_client/__init__.py)

In [32]:
del adata_mouse_mapped.varm["PCs"]
del adata_mouse_shared.varm["PCs"]

In [33]:
adata_mouse_mapped.varm

AxisArrays with keys: 

In [34]:
# Combine shared and mapped genes
adata_combined = ad.concat([adata_mouse_shared, adata_mouse_mapped], join="outer", axis=1)
adata_combined.var_names_make_unique()
adata_combined

  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


AnnData object with n_obs × n_vars = 250790 × 1120
    var: 'gene_name', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'mt', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'human_gene'
    layers: 'counts'

In [35]:
False in adata_combined.var_names.notnull()

False

In [36]:
True in adata_combined.var_names.isna()

False

In [37]:
# format data properly with the model
adata_concat = ad.concat([adata_nicheformer, adata_combined], join="outer", axis=0)
# dropping the first observation
adata_concat = adata_concat[1:].copy()

  warn(


In [38]:
# # Update gene names in adata_mouse to match the model
# adata_mouse.var["original_mouse_gene"] = adata_mouse.var_names  # Save original names
# adata_mouse.var_names = mapped_genes

# # Filter out unmapped genes
# adata_mouse = adata_mouse[:, adata_mouse.var_names.notnull()].copy()

# # Verify
# print(f"Number of mapped genes: {adata_mouse.n_vars}")

In [39]:
print("Model obs index unique:", adata_nicheformer.obs.index.is_unique)
print("Model var index unique:", adata_nicheformer.var.index.is_unique)

print("Mouse obs index unique:", adata_mouse.obs.index.is_unique)
print("Mouse var index unique:", adata_mouse.var.index.is_unique)

In [40]:
adata_mouse.var_names_make_unique()
adata_nicheformer.var_names_make_unique()

In [41]:
# # format data properly with the model
# adata_concat = ad.concat([adata_nicheformer, adata_mouse], join="outer", axis=0)
# # dropping the first observation
# adata_concat = adata_concat[1:].copy()

In [42]:
adata_concat

AnnData object with n_obs × n_vars = 250790 × 20310
    obs: 'soma_joinid', 'is_primary_data', 'dataset_id', 'donor_id', 'assay', 'cell_type', 'development_stage', 'disease', 'tissue', 'tissue_general', 'specie', 'technology', 'dataset', 'x', 'y', 'assay_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'tissue_ontology_term_id', 'suspension_type', 'condition_id', 'tissue_type', 'library_key', 'organism', 'sex', 'niche', 'region', 'nicheformer_split', 'author_cell_type', 'batch'
    layers: 'counts'

As a reference, the metadata tokens are 

modality_dict = {
    'dissociated': 3,
    'spatial': 4,}

specie_dict = {
    'human': 5,
    'Homo sapiens': 5,
    'Mus musculus': 6,
    'mouse': 6,}

technology_dict = {
    "merfish": 7,
    "MERFISH": 7,
    "cosmx": 8,
    "NanoString digital spatial profiling": 8,
    "visium": 9,
    "10x 5' v2": 10,
    "10x 3' v3": 11,
    "10x 3' v2": 12,
    "10x 5' v1": 13,
    "10x 3' v1": 14,
    "10x 3' transcription profiling": 15, 
    "10x transcription profiling": 15,
    "10x 5' transcription profiling": 16,
    "CITE-seq": 17, 
    "Smart-seq v4": 18,
}

In [56]:
# Change accordingly

adata_concat.obs["modality"] = 4  # spatial
adata_concat.obs["specie"] = 6  # mouse
adata_concat.obs["technology"] = 7  # mouse
adata_concat

AnnData object with n_obs × n_vars = 250790 × 20310
    obs: 'soma_joinid', 'is_primary_data', 'dataset_id', 'donor_id', 'assay', 'cell_type', 'development_stage', 'disease', 'tissue', 'tissue_general', 'specie', 'technology', 'dataset', 'x', 'y', 'assay_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'tissue_ontology_term_id', 'suspension_type', 'condition_id', 'tissue_type', 'library_key', 'organism', 'sex', 'niche', 'region', 'nicheformer_split', 'author_cell_type', 'batch', 'modality'
    layers: 'counts'

In [57]:
adata_concat.obs["nicheformer_split"] = "train"
adata_concat.obs["nicheformer_split"]

99119258951401284221813451836785008968-0     train
95610536408155075407048305276227824006-0     train
94627996433995692958551448758698866723-0     train
10618021078266319649091845116122351158-0     train
89038554173344823793949504602728291142-0     train
                                             ...  
198804031042921944377635730961789487507-1    train
156934588469937564968781740624428245657-1    train
155321557493369021799199720361468240742-1    train
24226418575755266420023086883481701930-1     train
308372198820486739554187228970484789236-1    train
Name: nicheformer_split, Length: 250790, dtype: object

In [58]:
# Create dataset
dataset = NicheformerDataset(
    adata=adata_concat,
    technology_mean=technology_mean,
    split="train",
    max_seq_len=1500,
    aux_tokens=config.get("aux_tokens", 30),
    chunk_size=config.get("chunk_size", 1000),
)

# Create dataloader
dataloader = DataLoader(
    dataset, batch_size=config["batch_size"], shuffle=False, num_workers=config.get("num_workers", 4), pin_memory=True
)

Processing 250790 cells for train split


100%|██████████| 251/251 [02:25<00:00,  1.72it/s]


## Load Model and Set Up Trainer

In [59]:
# Load pre-trained model
model = Nicheformer.load_from_checkpoint(checkpoint_path=config["checkpoint_path"], strict=False)
model.eval()  # Set to evaluation mode

# Configure trainer
trainer = pl.Trainer(
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    default_root_dir=config["output_dir"],
    precision=config.get("precision", 32),
)

  rank_zero_warn(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Extract Embeddings

In [60]:
print("Extracting embeddings...")
embeddings = []
device = model.embeddings.weight.device

with torch.no_grad():
    for batch in tqdm(dataloader):
        # Move batch to device
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

        # Get embeddings from the model
        emb = model.get_embeddings(
            batch=batch,
            layer=config.get("embedding_layer", -1),  # Default to last layer
        )
        embeddings.append(emb.cpu().numpy())


# Concatenate all embeddings
embeddings = np.concatenate(embeddings, axis=0)

  return torch._transformer_encoder_layer_fwd(
100%|██████████| 7838/7838 [1:45:29<00:00,  1.24it/s]


## Save Results

In [61]:
# Store embeddings in AnnData object
embedding_key = f"X_niche_{config.get('embedding_name', 'embeddings')}"
adata_concat.obsm[embedding_key] = embeddings

# Save updated AnnData
adata_concat.write_h5ad(config["output_path"])

print(f"Embeddings saved to {config['output_path']} in obsm['{embedding_key}']")

In [63]:
embeddings.shape

(250790, 512)

In [78]:
adata_embedding = ad.read_h5ad(config["output_path"])
adata_embedding

AnnData object with n_obs × n_vars = 250790 × 20310
    obs: 'soma_joinid', 'is_primary_data', 'dataset_id', 'donor_id', 'assay', 'cell_type', 'development_stage', 'disease', 'tissue', 'tissue_general', 'specie', 'technology', 'dataset', 'x', 'y', 'assay_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'tissue_ontology_term_id', 'suspension_type', 'condition_id', 'tissue_type', 'library_key', 'organism', 'sex', 'niche', 'region', 'nicheformer_split', 'author_cell_type', 'batch', 'modality'
    obsm: 'X_niche_embeddings'
    layers: 'counts'

In [80]:
adata_mouse.obsm['X_nicheformer'] = adata_embedding.obsm['X_niche_embeddings'].copy()

In [81]:
adata_mouse.write_h5ad(config["data_path"])