In [5]:
import sys
import seaborn as sns
import pandas as pd 
import numpy as np
from itertools import combinations
from scipy.spatial.distance import squareform, pdist
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import torch
import anndata as an
import scanpy as sc
import os
import gc
from importlib import reload

from datasets import Dataset, load_from_disk
from datasets import load_dataset
from geneformer import EmbExtractor
import geneformer as gtu

# classifer tools
import xgboost
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import confusion_matrix

# local imports
sys.path.insert(0, '../../scripts/')
import geneformer_utils as gtu

sns.set_style('white')
torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
### Parameters v
###

# cells to filter on
initial_cell_type = 'Fibroblast'

# list of genes to perturb with at the front of the list

# can set this to be all of the cells in the df
num_initial_cells = 10

model_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/240715_geneformer_cellClassifier_no_induced/ksplit1/"

token_data_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/resources/token_mapping.csv"

data_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/hsc.dataset"

###
### Parameters ^

In [7]:
from itertools import combinations
def ten_choose_five():
# Define the length of sublists
    gene_list = [
        'GATA2', 
        'GFI1B', 
        'FOS', 
        'STAT5A',
        'REL',
        'FOSB',
        'IKZF1',
        'RUNX3',
        'MEF2C',
        'ETV6',
    ]

    len_sublist = 5
    
    # Generate all combinations of the specified length
    sublists = list(combinations(gene_list, len_sublist))
    
    # Create the DataFrame
    df = pd.DataFrame({
        'recipe_iteration': range(1, len(sublists) + 1),
        'recipe_list': [list(sublist) for sublist in sublists]
    })
    return df

# Print the DataFrame
df.head()

Unnamed: 0,recipe_iteration,recipe_list
0,1,"[GATA2, GFI1B, FOS, STAT5A, REL]"
1,2,"[GATA2, GFI1B, FOS, STAT5A, FOSB]"
2,3,"[GATA2, GFI1B, FOS, STAT5A, IKZF1]"
3,4,"[GATA2, GFI1B, FOS, STAT5A, RUNX3]"
4,5,"[GATA2, GFI1B, FOS, STAT5A, MEF2C]"


# Load In Things

In [3]:
### Preliminaries v
###

if torch.cuda.is_available(): 
    print("CUDA is available! Devices: ", torch.cuda.device_count()) 
    print("Current CUDA device: ", torch.cuda.current_device()) 
    print("Device name: ", torch.cuda.get_device_name(torch.cuda.current_device())) 
else: print("CUDA is not available")

model = gtu.load_model(model_path)
print('model loaded!')

token_df = pd.read_csv(token_data_path)
token_df.head()

###
### Preliminaries ^


CUDA is available! Devices:  1
Current CUDA device:  0
Device name:  NVIDIA A100 80GB PCIe MIG 3g.40gb


Some weights of BertForMaskedLM were not initialized from the model checkpoint at /scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/240715_geneformer_cellClassifier_no_induced/ksplit1/ and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model loaded!


Unnamed: 0,gene_id,token_id,gene_name,nonzero_median,gene_version,gene_biotype,Chromosome,Start,End,scenic_tf
0,<pad>,0,,,,,,,,False
1,<mask>,1,,,,,,,,False
2,ENSG00000000003,2,TSPAN6,2.001186,15.0,protein_coding,X,100627107.0,100639991.0,False
3,ENSG00000000005,3,TNMD,3.228213,6.0,protein_coding,X,100584935.0,100599885.0,False
4,ENSG00000000419,4,DPM1,2.218874,14.0,protein_coding,20,50934866.0,50959140.0,False


# Convert to the right format

In [4]:
### Format raw data, print messages to check v
###

# Load from pre-trained data
raw_data = load_from_disk(data_path)

# Convert to DataFrame for filtering
df = raw_data.to_pandas()
print("\nOriginal Dataset:")
print(f"  - Number of samples: {df.shape[0]:,}")
print(f"  - Number of columns: {df.shape[1]:,}")



# Filtering
df = df[df['standardized_cell_type'] == initial_cell_type]

# sampling 


df = df.sample(num_initial_cells)
df = df.reset_index(drop=True)

# add a cell id
df['cell_id'] = [f"cell_{i+1}" for i in range(len(df))]

print("\nFiltered Dataset:")
print(f"  - Number of samples: {df.shape[0]:,}")   # Nicer formatting with commas
print(f"  - Number of columns: {df.shape[1]:,}")

# Value counts with sorting
print("\nCell Type Distribution (Filtered):")
print(df['standardized_cell_type'].value_counts().sort_index())  # Sort for readability

# Convert back to Dataset
data = Dataset.from_pandas(df)
print(f"\nDataset converted back: {data}")

###
### Format raw data, print messages to check ^


Original Dataset:
  - Number of samples: 214,715
  - Number of columns: 8

Filtered Dataset:
  - Number of samples: 10
  - Number of columns: 9

Cell Type Distribution (Filtered):
standardized_cell_type
Fibroblast    10
Name: count, dtype: int64

Dataset converted back: Dataset({
    features: ['input_ids', 'cell_type', 'dataset', 'length', 'ignore', 'standardized_cell_type', 'broad_type', '__index_level_0__', 'cell_id'],
    num_rows: 10
})


# Perturbations

## prelims

In [8]:
# removed since we are no longer looping over combos

# # compute all possible combinations of 5 TFs,
# ##################### this, right now, is also a parameter #####################################
# n_tf = 5
# ###############################################################################################
# inputs = list(combinations(genes['token_id'], n_tf))
# print(f'Number of recipes: {len(inputs)}')

# # used eventually for making the ;-separated recipe lists of actual TFs (instead of just tokens)
# def map_tfs(tokens):
#     return list(map(tf_map.get, tokens))

# print(inputs[0])
# print(map_tfs(inputs[0]))

In [11]:
def add_perturbations_to_cell(cell_tokens, perturbation_tokens):


    original_length = len(cell_tokens)

    # Remove existing perturbation tokens from the cell
    cell_tokens = [token for token in cell_tokens if token not in perturbation_tokens]

    # Add perturbations, then slice or pad to match original length
    final_tokens = (perturbation_tokens + cell_tokens)[:original_length]  # Slice if too long
    final_tokens += [0] * (original_length - len(final_tokens))            # Pad if too short

    return final_tokens


    """
    Modifies a list of cell tokens by adding perturbation tokens and padding.

    Args:
        cell_tokens (list): A list of integers representing gene tokens.
        perturbation_tokens (list): A list of integers representing perturbation tokens.

    Returns:
        list: A new list of tokens with perturbations added, existing perturbations removed,
             and truncated/padded to the original length.
    """


## Running, for all sublists

### Get a df of the initial cells ( i have a pen...)

In [12]:
# Filter out just the fibroblasts (initial cells....)

fb_df = df[df['standardized_cell_type'] == 'Fibroblast'].reset_index(drop=True)
fb_data = Dataset.from_pandas(fb_df)

reload(gtu)
torch.cuda.empty_cache()
fb_embs = gtu.extract_embedding_in_mem(
    model, 
    fb_data, 
    layer_to_quant=-1,
    forward_batch_size=100,
)
print(f"{fb_embs.shape=}")

# translate into an anndata object and plot
fb_adata = gtu.embedding_to_adata(fb_embs)
fb_adata.obs = fb_df.copy()
fb_adata.obs.head()




100%|██████████| 1/1 [00:01<00:00,  1.27s/it]

fb_embs.shape=(10, 512)





Unnamed: 0,input_ids,cell_type,dataset,length,ignore,standardized_cell_type,broad_type,__index_level_0__,cell_id
0,"[9009, 12119, 16916, 3878, 1404, 303, 16166, 1...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,112775,cell_1
1,"[16916, 9009, 15960, 4665, 16876, 15987, 3540,...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,163941,cell_2
2,"[16916, 16876, 1950, 9009, 3992, 11699, 18367,...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,152201,cell_3
3,"[9009, 18367, 16916, 19437, 61, 454, 7301, 110...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,166616,cell_4
4,"[3151, 9190, 4923, 3326, 14898, 10945, 3878, 5...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,124215,cell_5


### Get a df of the target cells (i have an apple...)

In [21]:
hsc_df = df[df['standardized_cell_type'] == 'HSC'].reset_index(drop=True)
hsc_data = Dataset.from_pandas(hsc_df)

reload(gtu)
torch.cuda.empty_cache()
hsc_embs = gtu.extract_embedding_in_mem(
    model, 
    hsc_data, 
    layer_to_quant=-1,
    forward_batch_size=100,
)
print(f"{hsc_embs.shape=}")

# translate into an anndata object and plot
hsc_adata = gtu.embedding_to_adata(hsc_embs)
hsc_adata.obs = hsc_df.copy()

####################################### hrrmmmm, as above, same problem, yoda thinks  ####################
# sc.tl.pca(hsc_adata, n_comps=25)
# sc.pp.neighbors(hsc_adata, n_neighbors=200)
# sc.tl.umap(hsc_adata, 
#            min_dist=0.75,
#           )

# hsc_adata

###################################### comment this whole section out, yoda will  #########################

100%|██████████| 1/1 [00:00<00:00,  4.88it/s]

hsc_embs.shape=(5, 512)





### pen pineapple apple pen

In [None]:
# since we are getting rid of the hsc part of this dataframe entirely, there's no use for raw_cells or reprogramming_df. added 'recipe' and 'type' columns to fb_df to 
# eventually make combining easier

In [14]:

#why did we set this again like chill out
# sample size refers to the number of cells to perturb. this can be cleaned up; we didn't need the first filtering but do need this one. requires that sample_size <= size fb_df, which is at most sample_size as prev set to size of df. SLOPPY 
# sample_size = 5

# raw_cells = fb_df.sample(sample_size).reset_index(drop=True)
# print(f"{raw_cells.shape=}")
# raw_cells['recipe'] = 'raw'
# raw_cells['type'] = 'initial'

# hsc_sample = hsc_df.sample(sample_size).reset_index(drop=True)
# hsc_sample['recipe'] = 'hsc'
# hsc_sample['type'] = 'target'

# reprogramming_df = [
#     raw_cells,
#     hsc_sample,
# ]

# to facil
fb_df['recipe'] = 'raw'  # as opposed to having a speciofic ;-separated recipe list 
fb_df['type'] = 'initial' # this dataframe no longer includes 'target'
# initialize reprogramming_df
reprogramming_df = [
    fb_df
]


In [15]:

# for i, tfs in enumerate(inputs):
    
#     if i % 25 == 0:
#         print(f"Pertubation {i}/{len(inputs)}...")
    
#     # make the dataframe easily useable
#     perturb = raw_cells.copy()
#     recipe = ";".join(map_tfs(tfs))
#     perturb['recipe'] = recipe
#     perturb['type'] = 'reprogrammed'
    
#     # do the actual perturbation
#     perturb['input_ids'] = perturb['input_ids'].apply(lambda x: add_perturbations_to_cell(x, list(tfs)))
    
#     # store the updated data
#     reprogramming_df.append(perturb)
    
# reprogramming_df = pd.concat(reprogramming_df)
# reprogramming_df = reprogramming_df.reset_index(drop=True)
# print(f"{reprogramming_df.shape=}")
# reprogramming_df.sample(10)




# perturb = fb_df.copy()
# recipe = ";".join(gene_list)
# perturb['recipe'] = recipe
# perturb['type'] = 'reprogrammed'
# perturb['input_ids'] = perturb['input_ids'].apply(lambda x: add_perturbations_to_cell(x, gene_list))

# reprogramming_df.append(perturb)
# reprogramming_df = pd.concat(reprogramming_df)
# reprogramming_df = reprogramming_df.reset_index(drop=True)
# print(f"{reprogramming_df.shape=}")
# reprogramming_df.sample(10)



# wrong, reprogramming_df is a list of dataframes
# perturb = 
# recipe = ";".join(gene_list)
# perturb['recipe'] = recipe
# perturb['type'] = 'reprogrammed'
# perturb['input_ids'] = perturb['input_ids'].apply(lambda x: add_perturbations_to_cell(x, gene_list))

# # Concatenate the existing DataFrame with the new perturb DataFrame
# reprogramming_df = pd.concat([reprogramming_df, perturb], ignore_index=True)

# print(f"{reprogramming_df.shape=}")
# reprogramming_df


### Uses token_df to translate from gene_list to tokens_list v
###

# Get a df of the genes we are perturbing with
genes = token_df[token_df['gene_name'].isin(gene_list)]

tf_map = dict(zip(genes['gene_name'].values, genes['token_id'].values))

# Create tokens_list by looking up each gene_name in the tf_map
tokens_list = [tf_map.get(gene_name, gene_name) for gene_name in gene_list]

###
### Uses token_df to translate from gene_list to tokens_list ^



perturb = fb_df.copy()
recipe = ";".join(gene_list)
perturb['recipe'] = recipe
perturb['type'] = 'reprogrammed'
perturb['input_ids'] = perturb['input_ids'].apply(lambda x: add_perturbations_to_cell(x, tokens_list))

reprogramming_df.append(perturb)

reprogramming_df = pd.concat(reprogramming_df, ignore_index=True)

print(f"{reprogramming_df.shape=}")
reprogramming_df.sample(10)



reprogramming_df.shape=(20, 11)


Unnamed: 0,input_ids,cell_type,dataset,length,ignore,standardized_cell_type,broad_type,__index_level_0__,cell_id,recipe,type
14,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,124215,cell_5,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
5,"[3438, 19437, 9190, 5941, 17126, 15987, 3540, ...",Myofibroblasts,TS_Fat,2048,Myofibroblasts,Fibroblast,fibroblast,166129,cell_6,raw,initial
3,"[9009, 18367, 16916, 19437, 61, 454, 7301, 110...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,166616,cell_4,raw,initial
17,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,164914,cell_8,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
2,"[16916, 16876, 1950, 9009, 3992, 11699, 18367,...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,152201,cell_3,raw,initial
12,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,152201,cell_3,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
11,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,163941,cell_2,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
16,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,163321,cell_7,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
13,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,166616,cell_4,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
1,"[16916, 9009, 15960, 4665, 16876, 15987, 3540,...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,163941,cell_2,raw,initial


## For just the specified list

In [None]:
#### In order to get this to run, right now, just set n_tf = number of elements of gene_list.

In [None]:
# reload(gtu)
# torch.cuda.empty_cache()

# reprogramming_data = Dataset.from_pandas(reprogramming_df)

# reprogramming_embs = gtu.extract_embedding_in_mem(
#     model, 
#     reprogramming_data, 
#     layer_to_quant=-1,
#     forward_batch_size=100,
# )
# print(f"{reprogramming_embs.shape=}")

# # translate into an anndata object and plot
# reprogramming_adata = gtu.embedding_to_adata(reprogramming_embs)
# reprogramming_adata.obs = reprogramming_df.copy()

# # sc.tl.pca(reprogramming_adata, n_comps=25)
# # sc.pp.neighbors(reprogramming_adata, n_neighbors=200)
# # sc.tl.umap(reprogramming_adata, 
# #            min_dist=0.75,
# #           )

# reprogramming_adata.obs.head()

# Get Distance info

In [None]:
initial = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'initial']
target = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'target']
repro = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'reprogrammed']

# precompute all distances
metric = 'cosine'
D = squareform(pdist(reprogramming_adata.X, metric=metric))
print(f"{D.shape=}")

inital_to_target = D[initial.index, target.index].mean()
print(f"{inital_to_target=:.4f}")

result = []

for i, (recipe, group) in enumerate(repro.groupby('recipe')):

    # might want to 
    if i % 25 == 0:
        print(f"Recipe {i}/{len(inputs)}...")
    
    # compute group to intial
    recipe_to_initial = D[group.index, initial.index].mean() # average over all cells
    
    # compute group to target
    recipe_to_target = D[group.index, target.index].mean() # average over all cells
    
    row = {
        'recipe' : recipe,
        'recipe_to_initial' : recipe_to_initial,
        'recipe_to_target' : recipe_to_target,
        'recipe_diff' : inital_to_target - recipe_to_target,
    }
    result.append(row)
    
result = pd.DataFrame(result)
result.head(10)

In [11]:
## Testing string naming
from datetime import datetime

    # Get the current date and time
now = datetime.now()
    
    # Format the date and time
date_time_str = now.strftime("%Y-%m-%d_%H-%M-%S")


jobNumber = 10
filename = f"{date_time_str}_job_number_{jobNumber}"
filepath = "/home/oliven/test_trash" + filename + ".h5ad"
reprogramming_adata.write(filepath)

print(filepath)

NameError: name 'reprogramming_adata' is not defined