In [5]:
import sys
import seaborn as sns
import pandas as pd 
import numpy as np
from itertools import combinations
from scipy.spatial.distance import squareform, pdist
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import torch
import anndata as an
import scanpy as sc
import os
import gc
from importlib import reload

from datasets import Dataset, load_from_disk
from datasets import load_dataset
from geneformer import EmbExtractor
import geneformer as gtu

# classifer tools
import xgboost
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import confusion_matrix

# local imports
sys.path.insert(0, '../../scripts/')
import geneformer_utils as gtu

sns.set_style('white')
torch.cuda.empty_cache()

In [None]:
# Parameters we will need: model path, token mapping path (if we want it to be arbitrary and not just use the local one), sample_size, gene_list 

# Load In Things

In [6]:
if torch.cuda.is_available(): 
    print("CUDA is available! Devices: ", torch.cuda.device_count()) 
    print("Current CUDA device: ", torch.cuda.current_device()) 
    print("Device name: ", torch.cuda.get_device_name(torch.cuda.current_device())) 
else: print("CUDA is not available")

CUDA is available! Devices:  1
Current CUDA device:  0
Device name:  NVIDIA A100 80GB PCIe MIG 3g.40gb


In [7]:
"""Load the model"""
# model_path = "/nfs/turbo/umms-indikar/shared/projects/geneformer/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/"
# model_path = "/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer-12L-30M/"

##################################################  this is a parameter  ##################################################
model_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/240715_geneformer_cellClassifier_no_induced/ksplit1/"
###########################################################################################################################
model = gtu.load_model(model_path)
print('loaded!')

Some weights of BertForMaskedLM were not initialized from the model checkpoint at /scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/240715_geneformer_cellClassifier_no_induced/ksplit1/ and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


loaded!


In [8]:
token_data_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/resources/token_mapping.csv"
token_df = pd.read_csv(token_data_path)
token_df.head()

Unnamed: 0,gene_id,token_id,gene_name,nonzero_median,gene_version,gene_biotype,Chromosome,Start,End,scenic_tf
0,<pad>,0,,,,,,,,False
1,<mask>,1,,,,,,,,False
2,ENSG00000000003,2,TSPAN6,2.001186,15.0,protein_coding,X,100627107.0,100639991.0,False
3,ENSG00000000005,3,TNMD,3.228213,6.0,protein_coding,X,100584935.0,100599885.0,False
4,ENSG00000000419,4,DPM1,2.218874,14.0,protein_coding,20,50934866.0,50959140.0,False


# Convert to the right format

In [15]:
data_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/hsc.dataset"

# Load from pre-trained data
raw_data = load_from_disk(data_path)

# Convert to DataFrame for filtering
df = raw_data.to_pandas()
print("\nOriginal Dataset:")
print(f"  - Number of samples: {df.shape[0]:,}")
print(f"  - Number of columns: {df.shape[1]:,}")

# Cell types to filter on
cell_types = ['HSC', 'Fibroblast']

# Filtering
df = df[df['standardized_cell_type'].isin(cell_types)]

# sampling 

###################   this is a parameter ################################
sample_size = 10
##########################################################################
df = df.sample(sample_size)
df = df.reset_index(drop=True)

# add a cell id
df['cell_id'] = [f"cell_{i+1}" for i in range(len(df))]

print("\nFiltered Dataset:")
print(f"  - Number of samples: {df.shape[0]:,}")   # Nicer formatting with commas
print(f"  - Number of columns: {df.shape[1]:,}")

# Value counts with sorting
print("\nCell Type Distribution (Filtered):")
print(df['standardized_cell_type'].value_counts().sort_index())  # Sort for readability

# Convert back to Dataset
data = Dataset.from_pandas(df)
print(f"\nDataset converted back: {data}")


Original Dataset:
  - Number of samples: 214,715
  - Number of columns: 8

Filtered Dataset:
  - Number of samples: 10
  - Number of columns: 9

Cell Type Distribution (Filtered):
standardized_cell_type
Fibroblast    5
HSC           5
Name: count, dtype: int64

Dataset converted back: Dataset({
    features: ['input_ids', 'cell_type', 'dataset', 'length', 'ignore', 'standardized_cell_type', 'broad_type', '__index_level_0__', 'cell_id'],
    num_rows: 10
})


# Perturbations

## prelims

In [11]:
###############################################    this is a parameter #############################
gene_list = [
    'GATA2', 
    'GFI1B', 
    'FOS', 
    'STAT5A',
    'REL',
    'FOSB',
    'IKZF1',
    'RUNX3',
    'MEF2C',
    'ETV6',
]
####################################################################################################
genes = token_df[token_df['gene_name'].isin(gene_list)]
tf_map = dict(zip(genes['token_id'].values, genes['gene_name'].values))

genes

Unnamed: 0,gene_id,token_id,gene_name,nonzero_median,gene_version,gene_biotype,Chromosome,Start,End,scenic_tf
404,ENSG00000020633,404,RUNX3,3.195369,19.0,protein_coding,1,24899510.0,24965121.0,True
1532,ENSG00000081189,1532,MEF2C,7.818396,16.0,protein_coding,5,88717116.0,88904257.0,True
5675,ENSG00000125740,5675,FOSB,5.344128,14.0,protein_coding,19,45467994.0,45475179.0,True
5806,ENSG00000126561,5806,STAT5A,2.177263,18.0,protein_coding,17,42287546.0,42311943.0,True
7725,ENSG00000139083,7725,ETV6,3.312123,11.0,protein_coding,12,11649673.0,11895377.0,True
10804,ENSG00000162924,10804,REL,3.891583,16.0,protein_coding,2,60881490.0,60931612.0,True
11599,ENSG00000165702,11599,GFI1B,2.131079,15.0,protein_coding,9,132943999.0,132991687.0,True
12698,ENSG00000170345,12698,FOS,16.001316,10.0,protein_coding,14,75278825.0,75282230.0,True
14409,ENSG00000179348,14409,GATA2,2.523616,13.0,protein_coding,3,128479426.0,128493201.0,True
15641,ENSG00000185811,15641,IKZF1,3.329544,21.0,protein_coding,7,50304067.0,50405101.0,True


In [12]:
# compute all possible combinations of 5 TFs,
##################### this, right now, is also a parameter #####################################
n_tf = 5
###############################################################################################
inputs = list(combinationsgenes['token_id'], n_tf))
print(f'Number of recipes: {len(inputs)}')

def map_tfs(tokens):
    return list(map(tf_map.get, tokens))

print(inputs[0])
print(map_tfs(inputs[0]))

Number of recipes: 252
(404, 1532, 5675, 5806, 7725)
['RUNX3', 'MEF2C', 'FOSB', 'STAT5A', 'ETV6']


In [13]:
def add_perturbations_to_cell(cell_tokens, perturbation_tokens):
    """
    Modifies a list of cell tokens by adding perturbation tokens and padding.

    Args:
        cell_tokens (list): A list of integers representing gene tokens.
        perturbation_tokens (list): A list of integers representing perturbation tokens.

    Returns:
        list: A new list of tokens with perturbations added, existing perturbations removed,
             and truncated/padded to the original length.
    """

    original_length = len(cell_tokens)

    # Remove existing perturbation tokens from the cell
    cell_tokens = [token for token in cell_tokens if token not in perturbation_tokens]

    # Add perturbations, then slice or pad to match original length
    final_tokens = (perturbation_tokens + cell_tokens)[:original_length]  # Slice if too long
    final_tokens += [0] * (original_length - len(final_tokens))            # Pad if too short

    return final_tokens



[16916 19437  3992   811  9009 16876  5357 12908  5950  2124]
[404, 1532, 5675, 5806, 7725, 16916, 19437, 3992, 811, 9009]


## Running, for all sublists

### Get a df of the initial cells ( i have a pen...)

In [19]:
# Filter out just the fibroblasts (initial cells....)

fb_df = df[df['standardized_cell_type'] == 'Fibroblast'].reset_index(drop=True)
fb_data = Dataset.from_pandas(fb_df)

reload(gtu)
torch.cuda.empty_cache()
fb_embs = gtu.extract_embedding_in_mem(
    model, 
    fb_data, 
    layer_to_quant=-1,
    forward_batch_size=100,
)
print(f"{fb_embs.shape=}")

# translate into an anndata object and plot
fb_adata = gtu.embedding_to_adata(fb_embs)
fb_adata.obs = fb_df.copy()
fb_adata.obs.head()

#this stuff doesn't work but eeeehehhhhhh who needs plots right

##############################    these only worn rn for specific values of n_comps and n_neighbors, which depend on the size of the fb_adata df. fix #######################
# sc.tl.pca(fb_adata, n_comps=25)
# sc.pp.neighbors(fb_adata, n_neighbors=200)
# sc.tl.umap(fb_adata, 
#            min_dist=0.75,
#           )


100%|██████████| 1/1 [00:00<00:00,  5.18it/s]


fb_embs.shape=(5, 512)


### Get a df of the target cells (i have an apple...)

In [21]:
hsc_df = df[df['standardized_cell_type'] == 'HSC'].reset_index(drop=True)
hsc_data = Dataset.from_pandas(hsc_df)

reload(gtu)
torch.cuda.empty_cache()
hsc_embs = gtu.extract_embedding_in_mem(
    model, 
    hsc_data, 
    layer_to_quant=-1,
    forward_batch_size=100,
)
print(f"{hsc_embs.shape=}")

# translate into an anndata object and plot
hsc_adata = gtu.embedding_to_adata(hsc_embs)
hsc_adata.obs = hsc_df.copy()

####################################### hrrmmmm, as above, same problem, yoda thinks  ####################
# sc.tl.pca(hsc_adata, n_comps=25)
# sc.pp.neighbors(hsc_adata, n_neighbors=200)
# sc.tl.umap(hsc_adata, 
#            min_dist=0.75,
#           )

# hsc_adata

###################################### comment this whole section out, yoda will  #########################

100%|██████████| 1/1 [00:00<00:00,  4.88it/s]

hsc_embs.shape=(5, 512)





### pen pineapple apple pen

In [23]:

#why did we set this again like chill out
# sample size refers to the number of cells to perturb. this can be cleaned up; we didn't need the first filtering but do need this one. requires that sample_size <= size fb_df, which is at most sample_size as prev set to size of df. SLOPPY 
sample_size = 5

raw_cells = fb_df.sample(sample_size).reset_index(drop=True)
print(f"{raw_cells.shape=}")
raw_cells['recipe'] = 'raw'
raw_cells['type'] = 'initial'

hsc_sample = hsc_df.sample(sample_size).reset_index(drop=True)
hsc_sample['recipe'] = 'hsc'
hsc_sample['type'] = 'target'

reprogramming_df = [
    raw_cells,
    hsc_sample,
]


raw_cells.shape=(5, 9)


In [25]:

for i, tfs in enumerate(inputs):
    
    if i % 25 == 0:
        print(f"Pertubation {i}/{len(inputs)}...")
    
    # make the dataframe easily useable
    perturb = raw_cells.copy()
    recipe = ";".join(map_tfs(tfs))
    perturb['recipe'] = recipe
    perturb['type'] = 'reprogrammed'
    
    # do the actual perturbation
    perturb['input_ids'] = perturb['input_ids'].apply(lambda x: add_perturbations_to_cell(x, list(tfs)))
    
    # store the updated data
    reprogramming_df.append(perturb)
    
reprogramming_df = pd.concat(reprogramming_df)
reprogramming_df = reprogramming_df.reset_index(drop=True)
print(f"{reprogramming_df.shape=}")
reprogramming_df.sample(10)

Pertubation 0/252...
Pertubation 25/252...
Pertubation 50/252...
Pertubation 75/252...
Pertubation 100/252...
Pertubation 125/252...
Pertubation 150/252...
Pertubation 175/252...
Pertubation 200/252...
Pertubation 225/252...
Pertubation 250/252...
reprogramming_df.shape=(2395, 11)


Unnamed: 0,input_ids,cell_type,dataset,length,ignore,standardized_cell_type,broad_type,__index_level_0__,cell_id,recipe,type
1000,"[5675, 5806, 7725, 10804, 14409, 1950, 16916, ...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,119476,cell_5,FOSB;STAT5A;ETV6;REL;GATA2,reprogrammed
233,"[404, 1532, 7725, 12698, 15641, 16916, 3878, 9...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,112156,cell_7,RUNX3;MEF2C;ETV6;FOS;IKZF1,reprogrammed
1818,"[1532, 5675, 5806, 11599, 14409, 16916, 3878, ...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,112156,cell_7,MEF2C;FOSB;STAT5A;GFI1B;GATA2,reprogrammed
1926,"[1532, 5675, 11599, 12698, 15641, 9009, 12119,...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,115663,cell_2,MEF2C;FOSB;GFI1B;FOS;IKZF1,reprogrammed
1563,"[404, 5675, 10804, 12698, 15641, 16916, 3878, ...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,112156,cell_7,RUNX3;FOSB;REL;FOS;IKZF1,reprogrammed
1028,"[5675, 5806, 7725, 12698, 14409, 16916, 3878, ...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,112156,cell_7,FOSB;STAT5A;ETV6;FOS;GATA2,reprogrammed
1208,"[404, 1532, 5675, 10804, 15641, 16916, 3878, 9...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,112156,cell_7,RUNX3;MEF2C;FOSB;REL;IKZF1,reprogrammed
1314,"[404, 1532, 5806, 14409, 15641, 16916, 454, 90...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,170147,cell_10,RUNX3;MEF2C;STAT5A;GATA2;IKZF1,reprogrammed
1292,"[404, 1532, 5806, 11599, 14409, 16916, 9009, 6...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,164975,cell_6,RUNX3;MEF2C;STAT5A;GFI1B;GATA2,reprogrammed
520,"[404, 5806, 10804, 11599, 14409, 1950, 16916, ...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,119476,cell_5,RUNX3;STAT5A;REL;GFI1B;GATA2,reprogrammed


## For just the specified list

In [None]:
#### In order to get this to run, right now, just set n_tf = number of elements of gene_list.

In [None]:
reload(gtu)
torch.cuda.empty_cache()

reprogramming_data = Dataset.from_pandas(reprogramming_df)

reprogramming_embs = gtu.extract_embedding_in_mem(
    model, 
    reprogramming_data, 
    layer_to_quant=-1,
    forward_batch_size=100,
)
print(f"{reprogramming_embs.shape=}")

# translate into an anndata object and plot
reprogramming_adata = gtu.embedding_to_adata(reprogramming_embs)
reprogramming_adata.obs = reprogramming_df.copy()

# sc.tl.pca(reprogramming_adata, n_comps=25)
# sc.pp.neighbors(reprogramming_adata, n_neighbors=200)
# sc.tl.umap(reprogramming_adata, 
#            min_dist=0.75,
#           )

reprogramming_adata.obs.head()

# Get Distance info

In [None]:
initial = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'initial']
target = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'target']
repro = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'reprogrammed']

# precompute all distances
metric = 'cosine'
D = squareform(pdist(reprogramming_adata.X, metric=metric))
print(f"{D.shape=}")

inital_to_target = D[initial.index, target.index].mean()
print(f"{inital_to_target=:.4f}")

result = []

for i, (recipe, group) in enumerate(repro.groupby('recipe')):

    # might want to 
    if i % 25 == 0:
        print(f"Recipe {i}/{len(inputs)}...")
    
    # compute group to intial
    recipe_to_initial = D[group.index, initial.index].mean() # average over all cells
    
    # compute group to target
    recipe_to_target = D[group.index, target.index].mean() # average over all cells
    
    row = {
        'recipe' : recipe,
        'recipe_to_initial' : recipe_to_initial,
        'recipe_to_target' : recipe_to_target,
        'recipe_diff' : inital_to_target - recipe_to_target,
    }
    result.append(row)
    
result = pd.DataFrame(result)
result.head(10)