In [None]:
#srun -p gpu --gres=gpu:1 --cpus-per-task=24 --mem=128G  --time=4200 --pty /bin/bash
import sys
import os
sys.path.append(os.getcwd())
from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats
from geneformer import EmbExtractor

In [None]:
storage_dir = '/mnt/vstor/***/***' #just your storage dir
output_prefix="Donor_AJKQ118_IEL_TCRab_CD8ab_TRM" # one example, there are other 3. I only changed name for each run
# For how to prepare the data for Geneformer, please see the expanded_DEG.ipynb
vanilla_model = "/home/***/Geneformer/gf-12L-95M-i4096"

In [None]:
from geneformer import TranscriptomeTokenizer
tk = TranscriptomeTokenizer({"top10_or_not": "top10_or_not", "activation": "activation"}, nproc=15)
tk.tokenize_data(f"{storage_dir}/{output_prefix}", 
                 f"{storage_dir}/{output_prefix}",
                 "tokenized", 
                 file_format="h5ad")

In [None]:
from geneformer import Classifier
cc = Classifier(classifier="cell",
                cell_state_dict = {"state_key": "top10_or_not", "states": "all"},
                max_ncells=None,
                freeze_layers = 6,
                num_crossval_splits = 1,
                split_sizes = {"train": 0.6, "valid": 0.2, "test": 0.2},
                forward_batch_size=150,
                nproc=47)


cc.prepare_data(input_data_file=f"{storage_dir}/{output_prefix}/tokenized.dataset",
                output_directory=f"{storage_dir}/{output_prefix}/",
                output_prefix=output_prefix)

all_metrics = cc.validate(model_directory=vanilla_model,
                          prepared_input_data_file=f"{storage_dir}/{output_prefix}/{output_prefix}_labeled_train.dataset",
                          id_class_dict_file=f"{storage_dir}/{output_prefix}/{output_prefix}_id_class_dict.pkl",
                          output_directory=f"{storage_dir}/{output_prefix}/",
                          output_prefix=output_prefix,
                          #n_hyperopt_trials=1,
                          predict_eval=True)


In [None]:
model = f"{storage_dir}/{output_prefix}/geneformer_cellClassifier_{output_prefix}/ksplit1/"

In [None]:
all_metrics = cc.validate(model_directory=model,
                          prepared_input_data_file=f"{storage_dir}/{output_prefix}/{output_prefix}_labeled_train.dataset",
                          id_class_dict_file=f"{storage_dir}/{output_prefix}/{output_prefix}_id_class_dict.pkl",
                          output_directory=f"{storage_dir}/{output_prefix}/hyparam_test",
                          output_prefix=output_prefix,
                          n_hyperopt_trials=30,
                          predict_eval=True)

In [None]:
embex = EmbExtractor(model_type="CellClassifier",
                     num_classes=2, 
                     max_ncells=1000,
                     emb_layer=-1, 
                     emb_label=["top10_or_not"],
                     labels_to_plot=["top10_or_not"],
                     forward_batch_size=128,
                     nproc=80)


embs = embex.extract_embs(model,
                          f"{storage_dir}/{output_prefix}/tokenized.dataset",
                          f"{storage_dir}/{output_prefix}/",
                          "top10_or_not_embeddings_labeled")

embex.plot_embs(embs=embs,
                plot_style="heatmap",
                output_directory=f"{storage_dir}/{output_prefix}/",
                output_prefix="embeddings_heatmap")


all_metrics_test = cc.evaluate_saved_model(
        model_directory=model,
        id_class_dict_file=f"{storage_dir}/{output_prefix}/{output_prefix}_id_class_dict.pkl",
        test_data_file=f"{storage_dir}/{output_prefix}/{output_prefix}_labeled_test.dataset",
        output_directory=f"{storage_dir}/{output_prefix}/",
        output_prefix=output_prefix + 'top10_or_not',
    )

cc.plot_conf_mat(
        conf_mat_dict={"Geneformer": all_metrics_test["conf_matrix"]},
        output_directory=f"{storage_dir}/{output_prefix}/",
        output_prefix=output_prefix + 'top10_or_not'
)

In [None]:
cell_states_to_model = {
    "state_key": "top10_or_not", 
    "start_state": "False", 
    "goal_state": "True",
}

embex = EmbExtractor(model_type="CellClassifier",
                     num_classes=2, 
                     max_ncells=1000,
                     emb_layer=-1, 
                     summary_stat="exact_mean",  # I don't want this stat
                     forward_batch_size=128,
                     nproc=80)

state_embs_dict = embex.get_state_embs(
    cell_states_to_model,
    model,
    f"{storage_dir}/{output_prefix}/tokenized.dataset",
    f"{storage_dir}/{output_prefix}",
    "state_emb"
)

In [None]:
isp = InSilicoPerturber(perturb_type="overexpress",
                        genes_to_perturb="all",
                        combos=0,
                        anchor_gene=None,
                        model_type="CellClassifier",
                        num_classes=2,
                        emb_mode="cls",                     
                        cell_states_to_model=cell_states_to_model,
                        state_embs_dict=state_embs_dict,
                        max_ncells=1000,
                        emb_layer=0,
                        forward_batch_size=158,
                        nproc=80)

isp.perturb_data(
    model,
    f"{storage_dir}/{output_prefix}/tokenized.dataset",
    f"{storage_dir}/{output_prefix}",
    "T_expand"
)

In [None]:
ispstats = InSilicoPerturberStats(mode="goal_state_shift",
                                  genes_perturbed="all",
                                  combos=0,
                                  anchor_gene=None,
                                  cell_states_to_model=cell_states_to_model)

ispstats.get_stats(
    f"{storage_dir}",
    None,
    f"{storage_dir}",
    "T_expand"
)