## in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state

## Imports

In [1]:
from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats
from geneformer import EmbExtractor

import os
import gc
import torch

  def twobit_to_dna(twobit: int, size: int) -> str:
  def dna_to_twobit(dna: str) -> int:
  def twobit_1hamming(twobit: int, size: int) -> List[int]:
2024-01-19 07:57:28.285395: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-19 07:57:28.285439: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-19 07:57:28.285457: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-19 07:57:28.292412: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following i

## To clear CUDA memory in PyTorch

In [None]:
gc.collect()
torch.cuda.empty_cache()
with torch.no_grad():
    torch.cuda.empty_cache()

## Path settings

In [None]:
path_to_model = "/home/domino/geneformer_workflow/Geneformer/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224"

path_to_input_data = "/home/domino/geneformer_workflow/input/data/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset"

path_to_output_directory = "/home/domino/geneformer_workflow/results/in_silico_perturbation/"
os.makedirs(path_to_output_directory, exist_ok = True)

output_prefix = "in_silico_perturbation_human_dcm_hcm_nf"


## Obtain start, goal, and alt embedding positions

In [3]:
# this function was changed to be separate from perturb_data
# to avoid repeating calcuations when parallelizing perturb_data
cell_states_to_model={"state_key": "disease", 
                      "start_state": "dcm", 
                      "goal_state": "nf", 
                      "alt_states": ["hcm"]}

filter_data_dict={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]}

embex = EmbExtractor(model_type="CellClassifier",
                     num_classes=3,
                     filter_data=filter_data_dict,
                     max_ncells=50,#1000,
                     emb_layer=0,
                     summary_stat="exact_mean",
                     forward_batch_size=4,#256,
                     nproc=16)

state_embs_dict = embex.get_state_embs(cell_states_to_model,
                                       path_to_model,
                                       path_to_input_data,
                                       path_to_output_directory,
                                       output_prefix)

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

In [10]:
state_embs_dict

{'dcm': tensor([-8.3769e-01,  1.0012e+00, -3.1126e-02,  1.4559e-01, -3.9568e-01,
         -8.0872e-01,  9.1786e-01, -1.7633e-01,  3.3798e-01, -5.4110e-02,
         -7.7939e-02,  9.3652e-01, -5.4399e-01,  4.2775e-01, -9.8639e-01,
         -7.8825e-01,  2.0219e+00, -3.0223e-01,  3.3856e-01,  7.0369e-01,
         -8.8809e-01,  1.5726e-01,  1.2491e-01,  3.6452e-01,  4.3313e-01,
         -6.0097e-01, -5.8703e-01, -5.1814e-01, -1.6179e-01, -6.7949e-01,
         -6.5791e-01,  1.0861e-01, -7.7618e-01, -9.7280e-01, -2.1697e-01,
         -4.0138e-01,  7.0656e-01, -4.5728e-03,  6.8355e-02, -1.5676e-01,
         -4.5434e-01, -3.0257e-01,  4.9416e-01, -3.2642e-01,  1.0418e-02,
         -8.8666e-01, -7.5822e-02, -1.7595e+00,  3.5294e-01, -1.0471e+00,
          1.1386e+00, -2.7888e-01, -4.4763e-01, -6.3617e-01,  3.5590e-02,
         -3.0585e-01, -1.0241e-01, -3.7912e-01,  8.0731e-01, -8.1154e-01,
          6.2536e-01, -1.9132e-01,  5.5784e-01,  4.6323e-02, -1.5198e+00,
         -5.9301e-01,  1.1929e+

## Initialize in silico perturber<br>

  - Bellow settings will work on Amazon EC2 G5.x4large instances 
  - Large `forward_batch_size` will need stronger GPUs; otherwise, causes CUDA OutOfMemory Error. Changing the batch size may affect the training efficacy.
  - Large `max_ncell` will increase the script's run time  
  

In [5]:
isp = InSilicoPerturber(perturb_type="delete",
                        perturb_rank_shift=None,
                        genes_to_perturb="all",
                        combos=0,
                        anchor_gene=None,
                        model_type="CellClassifier",
                        num_classes=3,
                        emb_mode="cell",
                        cell_emb_style="mean_pool",
                        filter_data=filter_data_dict,
                        cell_states_to_model=cell_states_to_model,
                        state_embs_dict=state_embs_dict,
                        max_ncells=50,#2000,
                        emb_layer=0,
                        forward_batch_size=1,#400,
                        nproc=16)

## Outputs intermediate files from in silico perturbation

In [6]:
isp.perturb_data(path_to_model,
                 path_to_input_data,
                 path_to_output_directory,
                 output_prefix)

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s]

  0%|          | 0/2048 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/1850 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/1850 [00:00<?, ? examples/s]

  0%|          | 0/1850 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/1789 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/1789 [00:00<?, ? examples/s]

  0%|          | 0/1789 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/1761 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/1761 [00:00<?, ? examples/s]

  0%|          | 0/1761 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/1470 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/1470 [00:00<?, ? examples/s]

  0%|          | 0/1470 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/1448 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/1448 [00:00<?, ? examples/s]

  0%|          | 0/1448 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/1432 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/1432 [00:00<?, ? examples/s]

  0%|          | 0/1432 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/1352 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/1352 [00:00<?, ? examples/s]

  0%|          | 0/1352 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/1134 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/1134 [00:00<?, ? examples/s]

  0%|          | 0/1134 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/976 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/976 [00:00<?, ? examples/s]

  0%|          | 0/976 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/841 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/841 [00:00<?, ? examples/s]

  0%|          | 0/841 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/757 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/757 [00:00<?, ? examples/s]

  0%|          | 0/757 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/716 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/716 [00:00<?, ? examples/s]

  0%|          | 0/716 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/689 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/689 [00:00<?, ? examples/s]

  0%|          | 0/689 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/670 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/670 [00:00<?, ? examples/s]

  0%|          | 0/670 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/643 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/643 [00:00<?, ? examples/s]

  0%|          | 0/643 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/609 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/609 [00:00<?, ? examples/s]

  0%|          | 0/609 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Map (num_proc=16):   0%|          | 0/404 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/404 [00:00<?, ? examples/s]

  0%|          | 0/404 [00:00<?, ?it/s]

## In silico perturber stats generator

In [7]:
ispstats = InSilicoPerturberStats(mode="goal_state_shift",
                                  genes_perturbed="all",
                                  combos=0,
                                  anchor_gene=None,
                                  cell_states_to_model={"disease":(["dcm"],["nf"],["hcm"])},
                                  token_dictionary_file='/home/domino/geneformer_workflow/Geneformer/geneformer/token_dictionary.pkl',
                                  gene_name_id_dictionary_file='/home/domino/geneformer_workflow/Geneformer/geneformer/gene_name_id_dict.pkl'
)

The single value dictionary for cell_states_to_model will be replaced with a dictionary with named keys for start, goal, and alternate states. Please specify state_key, start_state, goal_state, and alt_states in the cell_states_to_model dictionary for future use. For example, cell_states_to_model={'state_key': 'disease', 'start_state': 'dcm', 'goal_state': 'nf', 'alt_states': ['hcm', 'other1', 'other2']}


## Extracts data from intermediate files and processes stats to output in final .csv

In [8]:
ispstats.get_stats(input_data_directory = path_to_output_directory,
                   null_dist_data_directory = None,
                   output_directory = path_to_output_directory,
                   output_prefix = output_prefix)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/10334 [00:00<?, ?it/s]

  0%|          | 0/10334 [00:00<?, ?it/s]

  cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
