In [1]:
import torch
import pandas as pd
from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats
from geneformer import EmbExtractor
from geneformer import in_silico_perturber_stats as stats

  from .autonotebook import tqdm as notebook_tqdm


### in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state

In [2]:
torch.cuda.empty_cache()

# first obtain start, goal, and alt embedding positions
# this function was changed to be separate from perturb_data
# to avoid repeating calcuations when parallelizing perturb_data
cell_states_to_model={"state_key": "cell_type", 
                      "start_state": "FB", 
                      "goal_state": "HSC", 
                      "alt_states": ["iHSC"]}

embex = EmbExtractor(model_type="Pretrained",
                     num_classes=3,
                     max_ncells=50000,
                     emb_layer=0,
                     summary_stat="exact_mean",
                     forward_batch_size=20,
                     nproc=16)


model = "/nfs/turbo/umms-indikar/shared/projects/HSC/data/geneformer/geneformer-12L-30M/"
data_path = "/scratch/indikar_root/indikar1/cstansbu/HSC/geneformer_inputs/iHSC.dataset"
outpath = "/scratch/indikar_root/indikar1/cstansbu/geneformer"

state_embs_dict = embex.get_state_embs(cell_states_to_model,
                                       model,
                                       data_path,
                                       outpath,
                                       "preturb")

print('done')

Filter (num_proc=16): 100%|██████████| 54347/54347 [00:15<00:00, 3603.81 examples/s]


OSError: Incorrect path_or_model_id: '/nfs/turbo/umms-indikar/shared/projects/HSC/data/geneformer/geneformer-12L-30M/'. Please provide either the path to a local folder or the repo_id of a model on the Hub.

In [None]:
# break

In [None]:
genes = [
    'ENSG00000162924', # GATA2 
    'ENSG00000179348', # GFI1B
    'ENSG00000165702', # FOS
    'ENSG00000170345', # STAT5A
    'ENSG00000126561', # REL
]

isp = InSilicoPerturber(perturb_type="overexpress",
                        perturb_rank_shift=None,
                        genes_to_perturb='all',
                        combos=0, # individually, or in pairs `1'
                        anchor_gene=None,
                        model_type="Pretrained",
                        num_classes=0, # pretrained model
                        emb_mode="cell",
                        cell_emb_style="mean_pool",
                        cell_states_to_model=cell_states_to_model,
                        state_embs_dict=state_embs_dict,
                        max_ncells=3000,
                        emb_layer=-1,
                        forward_batch_size=10,
                        nproc=1)

In [None]:
break

In [None]:
# outputs intermediate files from in silico perturbation
isp.perturb_data(model,
                 data_path,
                 outpath,
                 "preturb")

In [None]:
ispstats = InSilicoPerturberStats(mode="goal_state_shift",
                                  genes_perturbed='all',
                                  combos=0,
                                  anchor_gene=None,
                                  token_dictionary_file="/nfs/turbo/umms-indikar/shared/projects/HSC/data/geneformer/token_dictionary.pkl",
                                  gene_name_id_dictionary_file="/nfs/turbo/umms-indikar/shared/projects/HSC/data/geneformer/geneformer/gene_name_id_dict.pkl",
                                  cell_states_to_model=cell_states_to_model)

In [None]:
# extracts data from intermediate files and processes stats to output in final .csv

result_path = "/scratch/indikar_root/indikar1/cstansbu/geneformer/test"

perturb_path = "/scratch/indikar_root/indikar1/cstansbu/geneformer/"
ispstats.get_stats(perturb_path,
                   None,
                   outpath,
                   result_path)

In [None]:
df = pd.read_csv(f"{result_path}.csv")
df.head()
