In [1]:
import torch
from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats
from geneformer import EmbExtractor

  from .autonotebook import tqdm as notebook_tqdm


### in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state

In [2]:
torch.cuda.empty_cache()

# first obtain start, goal, and alt embedding positions
# this function was changed to be separate from perturb_data
# to avoid repeating calcuations when parallelizing perturb_data
cell_states_to_model={"state_key": "cell_type", 
                      "start_state": "FB", 
                      "goal_state": "HSC", 
                      "alt_states": ["iHSC"]}

embex = EmbExtractor(model_type="Pretrained",
                     num_classes=3,
                     max_ncells=1000,
                     emb_layer=0,
                     summary_stat="exact_mean",
                     forward_batch_size=20,
                     nproc=16)


model = "/nfs/turbo/umms-indikar/shared/projects/HSC/data/geneformer/geneformer-12L-30M/"
data_path = "/scratch/indikar_root/indikar1/cstansbu/HSC/geneformer_inputs/iHSC.dataset"
outpath = "/scratch/indikar_root/indikar1/cstansbu/geneformer"

state_embs_dict = embex.get_state_embs(cell_states_to_model,
                                       model,
                                       data_path,
                                       outpath,
                                       "preturb")

print('done')

100%|██████████| 50/50 [00:41<00:00,  1.20it/s]
100%|██████████| 50/50 [00:43<00:00,  1.16it/s]
100%|██████████| 50/50 [00:41<00:00,  1.20it/s]

done





In [3]:
# break

In [4]:
isp = InSilicoPerturber(perturb_type="delete",
                        perturb_rank_shift=None,
                        genes_to_perturb="all",
                        combos=0,
                        anchor_gene=None,
                        model_type="Pretrained",
                        num_classes=3,
                        emb_mode="cell",
                        cell_emb_style="mean_pool",
                        cell_states_to_model=cell_states_to_model,
                        state_embs_dict=state_embs_dict,
                        max_ncells=2000,
                        emb_layer=0,
                        forward_batch_size=20,
                        nproc=16)

In [5]:
# break

In [None]:
# outputs intermediate files from in silico perturbation
isp.perturb_data(model,
                 data_path,
                 outpath,
                 "preturb")

  0%|          | 0/2000 [00:00<?, ?it/s]
  0%|          | 0/1 [00:00<?, ?it/s][A
                                     [A
Map (num_proc=16):   0%|          | 0/2048 [00:00<?, ? examples/s][A
Map (num_proc=16):   4%|▎         | 74/2048 [00:00<00:06, 292.35 examples/s][A
Map (num_proc=16):  10%|▉         | 203/2048 [00:00<00:03, 604.52 examples/s][A
Map (num_proc=16):  19%|█▉        | 384/2048 [00:00<00:01, 930.30 examples/s][A
Map (num_proc=16):  29%|██▊       | 585/2048 [00:00<00:01, 1053.25 examples/s][A
Map (num_proc=16):  38%|███▊      | 768/2048 [00:00<00:01, 1201.79 examples/s][A
Map (num_proc=16):  47%|████▋     | 970/2048 [00:00<00:00, 1202.47 examples/s][A
Map (num_proc=16):  56%|█████▋    | 1152/2048 [00:01<00:00, 1304.38 examples/s][A
Map (num_proc=16):  66%|██████▌   | 1352/2048 [00:01<00:00, 1271.45 examples/s][A
Map (num_proc=16):  75%|███████▌  | 1536/2048 [00:01<00:00, 1294.64 examples/s][A
Map (num_proc=16):  88%|████████▊ | 1792/2048 [00:01<00:00, 1411.40 ex

In [None]:
ispstats = InSilicoPerturberStats(mode="goal_state_shift",
                                  genes_perturbed="all",
                                  combos=0,
                                  anchor_gene=None,
                                  cell_states_to_model=cell_states_to_model)

In [None]:
# extracts data from intermediate files and processes stats to output in final .csv
ispstats.get_stats(data_path,
                   None,
                   outpath,
                   "preturb")