Extract cell type embeddings from layer prior to classification heads.

Prompting with all genes and no metadata.

In [1]:
import os
import math
import torch
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
from braceexpand import braceexpand
from tqdm import tqdm

# for flex attention
import torch._dynamo
torch._dynamo.config.suppress_errors = True

# DEVICE = torch.device('cuda:1')
DEVICE = torch.device('cuda:0')
sc.set_figure_params(figsize=(4, 4))

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext, \
    GeneNetworkAnalysisBase

2025-03-16 20:25:25.441184: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
ROOT_PATH = "/work/hdd/bbjr/mallina1/data/mb-ml-dev-vm"
TRAIN_ROOT_PATH = "/work/hdd/bbjr/mallina1/data/human_cellariumgpt_v2/extract_files"
CHECKPOINT_PATH = "/work/hdd/bbjr/mallina1/cellarium/models/compute_optimal_checkpoints/epoch=6-step=63560.ckpt"

REF_ADATA_PATH = os.path.join(ROOT_PATH, "data", "extract_0.h5ad")
GENE_INFO_PATH = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=CHECKPOINT_PATH,
    ref_adata_path=REF_ADATA_PATH,
    gene_info_tsv_path=GENE_INFO_PATH,
    device=DEVICE,
    attention_backend="mem_efficient"
)

Get the names of the 100 most expressed cell types in the first train chunk and filter down on these. To be updated later with a custom list.

In [3]:
train_fnames = list(braceexpand('extract_{0..10}.h5ad'))

tr_adata = sc.read_h5ad(os.path.join(TRAIN_ROOT_PATH, train_fnames[0]))
cell_type_filter_list = list(tr_adata.obs['cell_type'].value_counts().to_dict().keys())[:11]
cell_type_filter_list.remove('unknown')

print(cell_type_filter_list)

cell_type_filter_list = set(cell_type_filter_list)

['neuron', 'L2/3-6 intratelencephalic projecting glutamatergic neuron', 'glutamatergic neuron', 'oligodendrocyte', 'CD4-positive, alpha-beta T cell', 'CD8-positive, alpha-beta T cell', 'B cell', 'classical monocyte', 'natural killer cell', 'fibroblast']


In [4]:
# print(tr_adata.obs.iloc[0].keys())
# print(tr_adata.obs.iloc[0])
print(tr_adata.X.shape)
print(tr_adata.X[0].todense())

expr_samp = tr_adata.X[0].todense()
sort_idx = np.argsort(expr_samp)

(10000, 36601)
[[0. 0. 0. ... 0. 0. 0.]]


In [5]:
model_var_names = np.asarray(tr_adata.var_names)
model_var_names_set = set(model_var_names)
var_name_to_index_map = {var_name: i for i, var_name in enumerate(model_var_names)}

print(len(model_var_names))
print(var_name_to_index_map)

sample_idx = np.random.choice(np.arange(len(tr_adata.var_names)), size=4096, replace=False)

print(tr_adata.var_names)
print(tr_adata.X.shape)

from copy import copy

new_adata = copy(tr_adata)
new_adata = new_adata[:, sample_idx]

# tr_adata.var_names = tr_adata.var_names[sample_idx]
# tr_adata.X = tr_adata.X[:, sample_idx]

# print(tr_adata.var_names)
# print(tr_adata.X.shape)

36601
{'ENSG00000187642': 0, 'ENSG00000078808': 1, 'ENSG00000272106': 2, 'ENSG00000162585': 3, 'ENSG00000272088': 4, 'ENSG00000204624': 5, 'ENSG00000162490': 6, 'ENSG00000177000': 7, 'ENSG00000011021': 8, 'ENSG00000120949': 9, 'ENSG00000116721': 10, 'ENSG00000237700': 11, 'ENSG00000169991': 12, 'ENSG00000158748': 13, 'ENSG00000162543': 14, 'ENSG00000142798': 15, 'ENSG00000125945': 16, 'ENSG00000158055': 17, 'ENSG00000286061': 18, 'ENSG00000121769': 19, 'ENSG00000233775': 20, 'ENSG00000121905': 21, 'ENSG00000286899': 22, 'ENSG00000183317': 23, 'ENSG00000287987': 24, 'ENSG00000238287': 25, 'ENSG00000287400': 26, 'ENSG00000117385': 27, 'ENSG00000198520': 28, 'ENSG00000187048': 29, 'ENSG00000123473': 30, 'ENSG00000143001': 31, 'ENSG00000132854': 32, 'ENSG00000227485': 33, 'ENSG00000118473': 34, 'ENSG00000142864': 35, 'ENSG00000235200': 36, 'ENSG00000285778': 37, 'ENSG00000142875': 38, 'ENSG00000016490': 39, 'ENSG00000137947': 40, 'ENSG00000272672': 41, 'ENSG00000233593': 42, 'ENSG000001430

In [6]:
print(new_adata.shape)
print(new_adata.var_names)
print(new_adata.X.shape)

(10000, 4096)
Index(['ENSG00000188986', 'ENSG00000278981', 'ENSG00000171723',
       'ENSG00000163497', 'ENSG00000225163', 'ENSG00000270177',
       'ENSG00000267764', 'ENSG00000274211', 'ENSG00000187821',
       'ENSG00000286575',
       ...
       'ENSG00000122859', 'ENSG00000128274', 'ENSG00000226702',
       'ENSG00000253710', 'ENSG00000168631', 'ENSG00000260454',
       'ENSG00000233308', 'ENSG00000221983', 'ENSG00000256588',
       'ENSG00000241180'],
      dtype='object', length=4096)
(10000, 4096)


In [7]:
metadata_prompt_dict = {
    "cell_type": False,
    "tissue": False,
    "disease": False,
    "sex": False,
    "development_stage": False
}

train_fnames = train_fnames[:1]

cell_type_embeddings = []
disease_embeddings = []

susp_type_labels = []
assay_labels = []
cell_type_labels = []
disease_labels = []

batch_size = 160
for train_fname in train_fnames:
    tr_adata = sc.read_h5ad(os.path.join(TRAIN_ROOT_PATH, train_fname))

    # only preserve rows where the cell type is in the list we want
    tr_adata_filtered = tr_adata[tr_adata.obs.cell_type.isin(cell_type_filter_list)]

    sample_idx = np.random.choice(np.arange(len(tr_adata.var_names)), size=4096, replace=False)
    tr_adata_filtered = tr_adata_filtered[:, sample_idx]

    for obs_idx in tqdm(range(0, tr_adata_filtered.obs.shape[0], batch_size)):
        obs_indices = [obs_idx + x for x in range(batch_size)]
        try:
            tokens_dict, context_indices = ctx.generate_tokens_from_adata(tr_adata_filtered, obs_index=obs_indices, query_var_names=[],
                                                                        metadata_prompt_masks_dict=metadata_prompt_dict)
        except:
            # incomplete batch, stop iterating
            break

        query_cell_type_idx = context_indices['query_cell_type']
        query_disease_idx = context_indices['query_disease']

        with torch.inference_mode():
            hidden_states = ctx.get_activations_from_tokens(tokens_dict, to_cpu=True)
            hidden_states_ncd = hidden_states[-1]

        cell_type_embeddings.append(hidden_states_ncd[:, query_cell_type_idx, :])
        disease_embeddings.append(hidden_states_ncd[:, query_disease_idx, :])

        cell_type_labels.append(tr_adata_filtered.obs.iloc[obs_idx].cell_type)
        disease_labels.append(tr_adata_filtered.obs.iloc[obs_idx].disease)
        susp_type_labels.append(tr_adata_filtered.obs.iloc[obs_idx].suspension_type)
        assay_labels.append(tr_adata_filtered.obs.iloc[obs_idx].assay)

        # if len(cell_type_embeddings) >= math.ceil(1000/batch_size):
        #     break

cell_type_embeddings = torch.cat(cell_type_embeddings, dim=0).squeeze()
disease_embeddings = torch.cat(disease_embeddings, dim=0).squeeze()


  5%|█████▌                                                                                                        | 1/20 [00:34<11:00, 34.78s/it]


KeyboardInterrupt: 

In [None]:
cell_type_embeddings.shape
disease_embeddings.shape

NameError: name 'cell_type_embeddings' is not defined