In [1]:
import torch
import pandas as pd
from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats
from geneformer import EmbExtractor
from geneformer import in_silico_perturber_stats as stats

  from .autonotebook import tqdm as notebook_tqdm


### in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state

In [2]:
torch.cuda.empty_cache()

# first obtain start, goal, and alt embedding positions
# this function was changed to be separate from perturb_data
# to avoid repeating calcuations when parallelizing perturb_data
cell_states_to_model={"state_key": "cell_type", 
                      "start_state": "FB", 
                      "goal_state": "HSC", 
                      "alt_states": ["iHSC"]}

embex = EmbExtractor(model_type="Pretrained",
                     num_classes=3,
                     max_ncells=1000,
                     emb_layer=0,
                     summary_stat="exact_mean",
                     forward_batch_size=20,
                     nproc=1)


model = "/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer-12L-30M/"
data_path = "/scratch/indikar_root/indikar1/cstansbu/HSC/geneformer_inputs/iHSC.dataset"
outpath = "/scratch/indikar_root/indikar1/cstansbu/geneformer"

state_embs_dict = embex.get_state_embs(cell_states_to_model,
                                       model,
                                       data_path,
                                       outpath,
                                       "preturb")

print('done')

FileNotFoundError: Directory /scratch/indikar_root/indikar1/cstansbu/HSC/geneformer_inputs/iHSC.dataset not found

In [3]:
# break

In [4]:
genes = [
    'ENSG00000162924', # GATA2 
    'ENSG00000179348', # GFI1B
    'ENSG00000165702', # FOS
    'ENSG00000170345', # STAT5A
    'ENSG00000126561', # REL
]

isp = InSilicoPerturber(perturb_type="overexpress",
                        perturb_rank_shift=None,
                        genes_to_perturb=genes,
                        combos=0, # individually, or in pairs `1'
                        anchor_gene=None,
                        model_type="Pretrained",
                        num_classes=0, # pretrained model
                        emb_mode="cell",
                        cell_emb_style="mean_pool",
                        cell_states_to_model=cell_states_to_model,
                        state_embs_dict=state_embs_dict,
                        max_ncells=1000,
                        emb_layer=-1,
                        forward_batch_size=10,
                        nproc=1)

In [5]:
# break

In [6]:
# outputs intermediate files from in silico perturbation
isp.perturb_data(model,
                 data_path,
                 outpath,
                 "preturb")

Map: 100%|██████████| 1000/1000 [00:18<00:00, 54.23 examples/s]
Flattening the indices: 100%|██████████| 1000/1000 [00:00<00:00, 53639.72 examples/s]
Map: 100%|██████████| 1000/1000 [00:01<00:00, 869.96 examples/s]
  0%|          | 0/100 [00:00<?, ?it/s]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  1.60it/s][A
                                             [A
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  1.64it/s][A
  1%|          | 1/100 [00:01<02:05,  1.27s/it]A
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  1.60it/s][A
                                             [A
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  1.63it/s][A
  2%|▏         | 2/100 [00:02<02:03,  1.26s/it]A
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:00<00:00,  1.60it/s][A
                                             [A
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|████

In [7]:
# break

In [8]:
token_dictionary_file = "/nfs/turbo/umms-indikar/shared/projects/geneformer/token_dictionary.pkl"
gene_name_id_dictionary_file = "/nfs/turbo/umms-indikar/shared/projects/geneformer/geneformer/gene_name_id_dict.pkl"


ispstats = InSilicoPerturberStats(mode="goal_state_shift",
                                  genes_perturbed=genes,
                                  combos=0,
                                  anchor_gene=None,
                                  token_dictionary_file=token_dictionary_file,
                                  gene_name_id_dictionary_file=gene_name_id_dictionary_file,
                                  cell_states_to_model=cell_states_to_model)

In [9]:
# extracts data from intermediate files and processes stats to output in final .csv

result_path = "/scratch/indikar_root/indikar1/cstansbu/geneformer/test"

perturb_path = "/scratch/indikar_root/indikar1/cstansbu/geneformer/"
ispstats.get_stats(perturb_path,
                   None,
                   outpath,
                   result_path)

100%|██████████| 4/4 [00:00<00:00, 445.79it/s]


In [10]:
df = pd.read_csv(f"{result_path}.csv")
df.head()


Unnamed: 0.1,Unnamed: 0,Shift_to_goal_end,Shift_to_alt_end_iHSC
0,0,0.0022,0.002093
