In [1]:
import sys
import seaborn as sns
import pandas as pd 
import numpy as np
from itertools import combinations
from scipy.spatial.distance import squareform, pdist
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import torch
import anndata as an
import scanpy as sc
import os
import gc
from importlib import reload

from datasets import Dataset, load_from_disk
from datasets import load_dataset
from geneformer import EmbExtractor
import geneformer as gtu

# classifer tools
import xgboost
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import confusion_matrix

# local imports
sys.path.insert(0, '../../scripts/')
import geneformer_utils as gtu

sns.set_style('white')
torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
### Parameters v
###

# cells to filter on
initial_cell_type = 'Fibroblast'

# list of genes to perturb with at the front of the list
gene_list = [
    'GATA2', 
    'GFI1B', 
    'FOS', 
    'STAT5A',
    'REL',
    'FOSB',
    'IKZF1',
    'RUNX3',
    'MEF2C',
    'ETV6',
]

############### important! ###############
# removed sampling, now all cells
# num_initial_cells = 10
################  (1/2)  #################
model_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/240715_geneformer_cellClassifier_no_induced/ksplit1/"

token_data_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/resources/token_mapping.csv"

data_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/hsc.dataset"

###
### Parameters ^

# Load In Things

In [3]:


# Uses token_df to translate from gene_list to tokens_list v
def get_tokens_list(gene_list):
    # Get a df of the genes we are perturbing with
    genes = token_df[token_df['gene_name'].isin(gene_list)]
    
    tf_map = dict(zip(genes['gene_name'].values, genes['token_id'].values))
    
    # Create tokens_list by looking up each gene_name in the tf_map
    tokens_list = [tf_map.get(gene_name, gene_name) for gene_name in gene_list]

    return tokens_list

def add_perturbations_to_cell(cell_tokens, perturbation_tokens):


    original_length = len(cell_tokens)

    # Remove existing perturbation tokens from the cell
    cell_tokens = [token for token in cell_tokens if token not in perturbation_tokens]

    # Add perturbations, then slice or pad to match original length
    final_tokens = (perturbation_tokens + cell_tokens)[:original_length]  # Slice if too long
    final_tokens += [0] * (original_length - len(final_tokens))            # Pad if too short

    return final_tokens



In [4]:
### Preliminaries v
###

if torch.cuda.is_available(): 
    print("CUDA is available! Devices: ", torch.cuda.device_count()) 
    print("Current CUDA device: ", torch.cuda.current_device()) 
    print("Device name: ", torch.cuda.get_device_name(torch.cuda.current_device())) 
else: print("CUDA is not available")

model = gtu.load_model(model_path)
print('model loaded!')

token_df = pd.read_csv(token_data_path)
token_df.head()

###
### Preliminaries ^


Some weights of BertForMaskedLM were not initialized from the model checkpoint at /scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/240715_geneformer_cellClassifier_no_induced/ksplit1/ and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


CUDA is available! Devices:  1
Current CUDA device:  0
Device name:  NVIDIA A100 80GB PCIe MIG 3g.40gb
model loaded!


Unnamed: 0,gene_id,token_id,gene_name,nonzero_median,gene_version,gene_biotype,Chromosome,Start,End,scenic_tf
0,<pad>,0,,,,,,,,False
1,<mask>,1,,,,,,,,False
2,ENSG00000000003,2,TSPAN6,2.001186,15.0,protein_coding,X,100627107.0,100639991.0,False
3,ENSG00000000005,3,TNMD,3.228213,6.0,protein_coding,X,100584935.0,100599885.0,False
4,ENSG00000000419,4,DPM1,2.218874,14.0,protein_coding,20,50934866.0,50959140.0,False


# Convert to the right format

In [5]:
### Format raw data, print messages to check v
###

# Load from pre-trained data
raw_data = load_from_disk(data_path)

# Convert to DataFrame for filtering
df = raw_data.to_pandas()
print("\nOriginal Dataset:")
print(f"  - Number of samples: {df.shape[0]:,}")
print(f"  - Number of columns: {df.shape[1]:,}")

# Filtering
fb_df = df[df['standardized_cell_type'] == initial_cell_type]

############### important! ###############
# sampling (REMOVED TO TEST ALL CELLS!)
#fb_df = fb_df.sample(num_initial_cells)
################  (2/2)  #################

fb_df = fb_df.reset_index(drop=True)

# add a cell id
fb_df['cell_id'] = [f"cell_{i+1}" for i in range(len(fb_df))]
fb_df['recipe'] = 'raw'  # as opposed to having a speciofic ;-separated recipe list. other entries will have this.
fb_df['type'] = 'initial' # this dataframe no longer includes 'target'

print("\nFiltered Dataset:")
print(f"  - Number of samples: {fb_df.shape[0]:,}")   # Nicer formatting with commas
print(f"  - Number of columns: {fb_df.shape[1]:,}")

# Value counts with sorting
print("\nCell Type Distribution (Filtered):")
print(fb_df['standardized_cell_type'].value_counts().sort_index())  # Sort for readability

# Convert back to Dataset
fb_data = Dataset.from_pandas(fb_df)
print(f"\nDataset converted back: {fb_data}")

###
### Format raw data, print messages to check ^


Original Dataset:
  - Number of samples: 214,715
  - Number of columns: 8

Filtered Dataset:
  - Number of samples: 15,308
  - Number of columns: 11

Cell Type Distribution (Filtered):
standardized_cell_type
Fibroblast    15308
Name: count, dtype: int64

Dataset converted back: Dataset({
    features: ['input_ids', 'cell_type', 'dataset', 'length', 'ignore', 'standardized_cell_type', 'broad_type', '__index_level_0__', 'cell_id', 'recipe', 'type'],
    num_rows: 15308
})


In [6]:
# ### Get the embeddings into an Anndata object v
# ###

# reload(gtu)
# torch.cuda.empty_cache()

# fb_embs = gtu.extract_embedding_in_mem(
#     model, 
#     fb_data, 
#     layer_to_quant=-1,
#     forward_batch_size=100,
# )
# print(f"{fb_embs.shape=}")

# # translate into an anndata object
# fb_adata = gtu.embedding_to_adata(fb_embs)
# fb_adata.obs = fb_df.copy()
# fb_adata.obs.head()


# ###
# ### Get the embeddings into an Anndata object ^


## Running, for all sublists

### Get a df of the initial cells ( i have a pen...)

In [7]:

### Perform the perturbation v
###

reprogramming_df = [
    fb_df
]

perturb = fb_df.copy()
recipe = ";".join(gene_list)
perturb['recipe'] = recipe
perturb['type'] = 'reprogrammed'
perturb['input_ids'] = perturb['input_ids'].apply(lambda x: add_perturbations_to_cell(x, get_tokens_list(gene_list)))

reprogramming_df.append(perturb)

reprogramming_df = pd.concat(reprogramming_df, ignore_index=True)

print(f"{reprogramming_df.shape=}")
reprogramming_df.sample(10)

###
### Perform the perturbation ^


reprogramming_df.shape=(30616, 11)


Unnamed: 0,input_ids,cell_type,dataset,length,ignore,standardized_cell_type,broad_type,__index_level_0__,cell_id,recipe,type
19500,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,121244,cell_4193,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
12557,"[9009, 16916, 18367, 376, 14127, 454, 2592, 47...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,166062,cell_12558,raw,initial
27075,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,164323,cell_11768,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
24693,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,159284,cell_9386,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
4903,"[10945, 2649, 6252, 5675, 10292, 5080, 8578, 3...",fibroblast,TS_Vasculature,1372,fibroblast,Fibroblast,fibroblast,123172,cell_4904,raw,initial
19160,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,120350,cell_3853,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
28400,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,167173,cell_13093,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
16221,"[14409, 11599, 12698, 5806, 10804, 5675, 15641...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,112309,cell_914,GATA2;GFI1B;FOS;STAT5A;REL;FOSB;IKZF1;RUNX3;ME...,reprogrammed
6490,"[16916, 16876, 1732, 19437, 12198, 4923, 10973...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,153069,cell_6491,raw,initial
6588,"[16916, 877, 303, 8346, 9009, 9103, 1404, 1029...",Fibroblasts,TS_Fat,2048,Fibroblasts,Fibroblast,fibroblast,153280,cell_6589,raw,initial


In [9]:
reload(gtu)
torch.cuda.empty_cache()

reprogramming_data = Dataset.from_pandas(reprogramming_df)

reprogramming_embs = gtu.extract_embedding_in_mem(
    model, 
    reprogramming_data, 
    layer_to_quant=-1,
    forward_batch_size=100,
)
print(f"{reprogramming_embs.shape=}")

# translate into an anndata object and plot
reprogramming_adata = gtu.embedding_to_adata(reprogramming_embs)
reprogramming_adata.obs = reprogramming_df.copy()

# sc.tl.pca(reprogramming_adata, n_comps=25)
# sc.pp.neighbors(reprogramming_adata, n_neighbors=200)
# sc.tl.umap(reprogramming_adata, 
#            min_dist=0.75,
#           )

reprogramming_adata.obs.head()





100%|██████████| 307/307 [19:27<00:00,  3.80s/it]

reprogramming_embs.shape=(30616, 512)





Unnamed: 0,input_ids,cell_type,dataset,length,ignore,standardized_cell_type,broad_type,__index_level_0__,cell_id,recipe,type
0,"[16345, 9009, 13048, 6489, 10292, 9508, 12698,...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,109770,cell_1,raw,initial
1,"[12119, 9190, 16876, 3396, 8654, 5298, 4692, 1...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,109771,cell_2,raw,initial
2,"[3878, 9009, 4115, 1950, 376, 16281, 16916, 14...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,109774,cell_3,raw,initial
3,"[3878, 16916, 18367, 9009, 1950, 12119, 16876,...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,109776,cell_4,raw,initial
4,"[6196, 16916, 10920, 11828, 8133, 16166, 8163,...",fibroblast,TS_Vasculature,2048,fibroblast,Fibroblast,fibroblast,109777,cell_5,raw,initial


In [10]:
reprogramming_adata.X

array([[-0.02931164,  3.703356  ,  1.1786633 , ...,  0.5361397 ,
        -0.20584026, -0.69771415],
       [ 0.1409362 ,  2.17202   ,  1.195514  , ...,  0.5794728 ,
        -0.19384083, -0.4654613 ],
       [ 0.15692821,  3.1697617 ,  1.6777157 , ..., -0.2944637 ,
        -0.08776124, -0.47131673],
       ...,
       [ 0.01560115,  3.1300998 ,  1.5892096 , ...,  0.14690934,
        -0.6328564 , -0.67557067],
       [-0.0322585 ,  3.0563653 ,  1.1636658 , ...,  0.0259171 ,
        -0.39277261, -0.45292988],
       [ 0.15233386,  3.4132004 ,  1.2214968 , ...,  0.16492385,
        -0.50729525, -0.7311505 ]], dtype=float32)

In [None]:
## Testing string naming
from datetime import datetime

    # Get the current date and time
now = datetime.now()
    
    # Format the date and time
date_time_str = now.strftime("%Y-%m-%d_%H-%M-%S")


jobNumber = 10
filename = f"{date_time_str}_job_number_{jobNumber}"
filepath = '/home/oliven/test_trash' + filename + '.h5ad'
reprogramming_adata.write(filepath)

print(filepath)

In [None]:
from datetime import datetime

# Assuming reprogramming_adata is already defined and is a valid AnnData object

# Get the current date and time
now = datetime.now()
    
# Format the date and time
date_time_str = now.strftime("%Y-%m-%d_%H-%M-%S")

# Define jobNumber
jobNumber = 10

# Create the filename with "_job_number_" between date and jobNumber
filename = f"{date_time_str}_job_number_{jobNumber}.h5ad"

# Define the filepath
filepath = f'/home/oliven/test_trash/{filename}'

# Write the AnnData object to file
reprogramming_adata.write(filepath)

print(filepath)


In [None]:
from datetime import datetime

# Assuming reprogramming_adata is already defined and is a valid AnnData object

# Get the current date and time
now = datetime.now()
    
# Format the date and time
date_time_str = now.strftime("%Y-%m-%d_%H-%M-%S")

# Define jobNumber
jobNumber = 10

# Create the filename with "_job_number_" between date and jobNumber
filename = f"{date_time_str}_job_number_{jobNumber}.h5ad"

# Define the filepath
filepath = f'/home/oliven/test_trash/{filename}'

# Check and ensure all data is of correct type
def check_and_convert(data):
    if isinstance(data, pd.DataFrame):
        for col in data.columns:
            if not pd.api.types.is_string_dtype(data[col]):
                data[col] = data[col].astype(str)
    return data

reprogramming_adata.obs = check_and_convert(reprogramming_adata.obs)
reprogramming_adata.var = check_and_convert(reprogramming_adata.var)

# Write the AnnData object to file
try:
    reprogramming_adata.write(filepath)
    print(f"File successfully written to {filepath}")
except Exception as e:
    print(f"Error occurred: {e}")

print(filepath)


In [None]:
break

In [None]:
initial = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'initial']
target = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'target']
repro = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'reprogrammed']

# precompute all distances
metric = 'cosine'
D = squareform(pdist(reprogramming_adata.X, metric=metric))
print(f"{D.shape=}")

inital_to_target = D[initial.index, target.index].mean()
print(f"{inital_to_target=:.4f}")

result = []

for i, (recipe, group) in enumerate(repro.groupby('recipe')):

    # might want to 
    if i % 25 == 0:
        print(f"Recipe {i}/{len(inputs)}...")
    
    # compute group to intial
    recipe_to_initial = D[group.index, initial.index].mean() # average over all cells
    
    # compute group to target
    recipe_to_target = D[group.index, target.index].mean() # average over all cells
    
    row = {
        'recipe' : recipe,
        'recipe_to_initial' : recipe_to_initial,
        'recipe_to_target' : recipe_to_target,
        'recipe_diff' : inital_to_target - recipe_to_target,
    }
    result.append(row)
    
result = pd.DataFrame(result)
result.head(10)