# 13_diffPerturbGene_QC

In [1]:
######## Load modules ########
from __future__ import annotations #default now for name.error issue
import os
## ensure model trains on one GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
os.environ["WORLD_SIZE"] = "1"
import pickle
from datetime import datetime
import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
import scipy
from tqdm import tqdm

# Cell2Sentence imports
import cell2sentence as cs
# from cell2sentence.utils import benchmark_expression_conversion, reconstruct_expression_from_cell_sentence
from cell2sentence import utils
from cell2sentence.tasks import embed_cells, predict_cell_types_of_data
from cell2sentence.prompt_formatter import get_cell_sentence_str, PromptFormatter #for custom prompt

# Hugging Face
from transformers import TrainingArguments, AutoModelForCausalLM
from datasets import Dataset # Arrow

# Single-cell libraries
import scanpy as sc
import anndata as ad
from collections import Counter, defaultdict #count table

sc.set_figure_params(dpi=300, color_map="viridis_r", facecolor="white", )
sc.settings.verbosity = 3  # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_header()

# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

  from .autonotebook import tqdm as notebook_tqdm


scanpy==1.9.8 anndata==0.9.2 umap==0.5.7 numpy==1.24.4 scipy==1.10.1 pandas==2.0.3 scikit-learn==1.3.2 statsmodels==0.14.1 pynndescent==0.5.13
cpu


<br>

`PerturbationPromptFormatter` 

1. control cells = number of ‘non-target’ (11742); unique perturbations is found by finding the `key` length (2393)

2. When creating the control-perturbed pairs -> 245670 pairs (created by every perturb samples within perturb type, e.g.:
    ```
    pert_to_indices = {
      "BRCA1": [3, 9, 10, 25], #3 samples
      "TP53":  [4, 5], #2 samples
      "MYC":   [7, 8, 12]
    }
    ```

3. train-val-test total number = total number of pairs (245670)
 

In [3]:
######## Re-load Perturbation Data ########
DATA_PATH = "/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/GSE264667_jurkat_processed.h5ad"
adata = ad.read_h5ad(DATA_PATH)
print(adata)             # check dimension
# print(adata.var_names)   # will be used to create the __cell sentences__
# print(adata.obs.columns) # check colnames (for next step)
print(adata.X.max())     # check max value (log10 transformation expects a maximum value somewhere around 3 or 4)

target_gene_counter = Counter(adata.obs['target_gene'])
print(f"{len(target_gene_counter)} unique perturbations")
print(target_gene_counter.most_common(20)) #('non-targeting', 11742)!


######## AnnData -> Arrow ########
# We'll keep all relevant columns for our new task
adata_obs_cols_to_keep = ['batch_var','cell_type','target_gene','gene_id','mitopercent','UMI_count']

# Create Arrow dataset and vocabulary
arrow_ds, vocabulary = cs.CSData.adata_to_arrow(
    adata=adata, 
    random_state=SEED, 
    sentence_delimiter=' ',
    label_col_names=adata_obs_cols_to_keep
)
print(arrow_ds)

# # Check single-cell info
# sample_idx = 0
# print(arrow_ds[sample_idx])
# ##   Check cell sentence length
# print(len(arrow_ds[sample_idx]["cell_sentence"].split(" ")))  # Cell 0 has 4016 nonzero expressed genes
# ##   Check feature info
# print(type(vocabulary))
# print(len(vocabulary))
print(list(vocabulary.items())[:10]) #fist 10, also contains the number of cells 'that gene' was expressed in.


######## Custom Prompt Formatting for Perturbation Prediction ########

#  The input provides the {control cell} and the {perturbation}, asking for the {perturbed result}.
custom_input_prompt_template = """Given the following cell sentence of {num_genes} expressed genes representing a cell's basal state, predict the cell sentence after applying the perturbation: {perturbation_name}.
Control cell sentence: {control_cell_sentence}.

Perturbed cell sentence:"""

# The answer is simply the target cell sentence.
answer_template = "{perturbed_cell_sentence}."


#### Create PerturbationPromptFormatter (format_hf_ds outputs: formatted HF Arrow DataSet) ####
class PerturbationPromptFormatter(PromptFormatter):
    def __init__(self,
        task_name,
        input_prompt,
        answer_template,
        top_k_genes, 
        perturbation_col='target_gene',
        control_label='non-targeting'
    ):
        """
        Initializes the custom prompt formatter.

        Args:
            task_name (str): The name for this task.
            input_prompt (str): The template for the model's input.
            answer_template (str): The template for the model's expected response.
            top_k_genes (int): The number of top genes to include in the cell sentence.
            perturbation_col (str): The column name in the dataset that contains perturbation info.
            control_label (str): The label used to identify control cells in the perturbation_col.
        """
        super().__init__()
        self.task_name = task_name
        self.input_prompt = input_prompt
        self.answer_template = answer_template
        self.top_k_genes = top_k_genes
        self.perturbation_col = perturbation_col
        self.control_label = control_label
        assert top_k_genes > 0, "'top_k_genes' must be an integer > 0"

    def format_hf_ds(self, hf_ds):
        """
        Custom formatting function for perturbation prediction. This function creates pairs of
        (control, perturbed) cells by sampling from a global pool of control cells.
        """
        model_inputs_list = []
        responses_list = []
        
        # 1. Separate all cells into a global control pool and a dict of perturbed cells
        control_indices = []
        pert_to_indices = defaultdict(list)

        print("Grouping cells by perturbation and identifying global controls...")
        for i, sample in enumerate(hf_ds):
            if sample[self.perturbation_col] == self.control_label:
                control_indices.append(i)
            else:
                pert_to_indices[sample[self.perturbation_col]].append(i)

            # For each cell (sample) in the dataset:
            # If it's a control cell (e.g., target_gene == 'non-targeting'): add its index to control_indices
            # If it's perturbed (e.g., target_gene == 'BRCA1'): add its index to the pert_to_indices dictionary under that perturbation name
        
        assert len(control_indices) > 0, "No control cells found. Cannot create pairs."
        print(f"Found {len(control_indices)} control cells.")
        print(f"Found {len(pert_to_indices)} unique perturbations.")

        # 2. Create prompt-response pairs by iterating through perturbed cells
        print("Creating control-perturbed pairs...")
        for pert_name, perturbed_indices in tqdm(pert_to_indices.items()):
            for perturbed_idx in perturbed_indices:
                # Pair each perturbed cell with a random control cell from the global pool
                control_idx = random.choice(control_indices)
                
                control_sample = hf_ds[control_idx]
                perturbed_sample = hf_ds[perturbed_idx]

                # Format control cell sentence
                control_sentence, num_genes_str = get_cell_sentence_str(#https://github.com/vandijklab/cell2sentence/blob/a6efaf079f98491d4723ced44b929936b94368aa/src/cell2sentence/prompt_formatter.py#L31
                    control_sample,
                    num_genes=self.top_k_genes  #list slicing would NOT give out-of-bounds error -> can use total-gene-length from adata
                                                #cap, but may be > actual nnz genes
                )
                # Compute the true number of genes actually used in the sentence
                num_genes_str = str(len(control_sentence.split(" ")))

                # Format perturbed cell sentence
                perturbed_sentence, _ = get_cell_sentence_str(
                    perturbed_sample,
                    num_genes=self.top_k_genes
                )
                
                #### Matches the template fstring ####
                # Format the model input string using the perturbation name
                model_input_str = self.input_prompt.format(
                    num_genes=num_genes_str,
                    perturbation_name=pert_name,
                    control_cell_sentence=control_sentence
                )
                # Format the response string
                response_str = self.answer_template.format(
                    perturbed_cell_sentence=perturbed_sentence
                )

                model_inputs_list.append(model_input_str)
                responses_list.append(response_str)

        # Create the final Hugging Face Dataset
        ds_split_dict = {
            "sample_type": [self.task_name] * len(model_inputs_list),
            "model_input": model_inputs_list,
            "response": responses_list,
        }
        ds = Dataset.from_dict(ds_split_dict)
        return ds

#### Test run to see the results (this will automatically done in `csmodel.fine_tune()`)
# Initiate the formatter
task_name = "perturbation_prediction"
prompt_formatter = PerturbationPromptFormatter(
    task_name=task_name,
    input_prompt=custom_input_prompt_template,
    answer_template=answer_template,
    top_k_genes=200 # Using top 200 genes for this example.
)
# Format the dataset
formatted_ds = prompt_formatter.format_hf_ds(arrow_ds)
print(formatted_ds)
# print(type(formatted_ds)) #datasets.arrow_dataset.Dataset
# Inspect a formatted sample
print("--- Formatted Sample ---")
print("#----Model input:----#")
print(formatted_ds[0]["model_input"], "\n")
print("#----Response:----#")
print(formatted_ds[0]["response"])

AnnData object with n_obs × n_vars = 257412 × 8882
    obs: 'batch_var', 'cell_type', 'target_gene', 'gene_id', 'mitopercent', 'UMI_count', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt'
    var: 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'
    uns: 'batch_var_colors', 'log1p', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'
3.3689852
2394 unique perturbations
[('non-targeting', 11742), ('TFAM', 2506), ('SLC1A5', 1697), ('GFM1', 1349), ('GTF3C4', 1266), ('PSMB5', 1218), ('MRPL36', 1167), ('PPP6C', 1151), ('NBPF12', 1120), ('MRPL35', 977), ('POGLUT3', 970), ('TARDBP', 816), ('MRPL34', 789), ('CCDC6', 776), ('BCAR1', 766), ('GTF2E2', 739), ('GAB2', 673), ('TRNT1', 656), ('HSD17B10', 648), ('THAP1', 623)]


100%|██████████| 257412/257412 [01:09<00:00, 3721.72it/s]


Dataset({
    features: ['cell_name', 'cell_sentence', 'batch_var', 'cell_type', 'target_gene', 'gene_id', 'mitopercent', 'UMI_count'],
    num_rows: 257412
})
[('LINC01409', 34288), ('LINC01128', 60189), ('NOC2L', 142586), ('HES4', 182061), ('ISG15', 227707), ('TNFRSF4', 22688), ('SDF4', 123657), ('B3GALT6', 94370), ('UBE2J2', 140413), ('ACAP3', 45948)]
Grouping cells by perturbation and identifying global controls...
Found 11742 control cells.
Found 2393 unique perturbations.
Creating control-perturbed pairs...


100%|██████████| 2393/2393 [00:43<00:00, 55.08it/s] 


Dataset({
    features: ['sample_type', 'model_input', 'response'],
    num_rows: 245670
})
--- Formatted Sample ---
#----Model input:----#
Given the following cell sentence of 200 expressed genes representing a cell's basal state, predict the cell sentence after applying the perturbation: NELFE.
Control cell sentence: TMSB4X MT-CO3 MT-CO2 MT-CO1 RPL13 RPS3A MT-ATP6 EEF1A1 PTMA RPL10 RPS4X RPS2 ACTB RPS6 RPS18 RPL29 RPL19 MT-ND4 RPL11 HSPD1 RPLP1 RPL6 MT-ND1 RPL37 RPS8 RPL15 RPS7 B2M MALAT1 RPS23 RPS27A RPL3 NCL RPS12 RPS3 HSP90AA1 HSP90AB1 RPS19 RPS24 RPL28 RPL18 RPL37A RPL32 HMGB1 MT-ND2 TUBA1B ACTG1 RPL9 RPLP0 MT-ND5 RPL41 RPL17 RPS16 YBX1 GAPDH RPL30 RPL13A RPS14 RACK1 RPL7 RPS5 MIF HNRNPA1 SOX4 RPS13 RPL18A RPS26 NPM1 H3F3A RPL21 SET RPL7A STMN1 TPT1 RPL14 RPL39 RPL24 UBE2S RPL8 H2AFZ RPL26 NACA RPL10A SNHG29 PFN1 RPL34 HINT1 CHCHD2 CENPF RPL5 RPS11 RPSA CDK6 H3F3B RPS21 HNRNPA3 MT-CYB RPS17 YWHAZ BTF3 SERF2 RPL35 HNRNPA2B1 CFL1 RPS15A RPL23A UBA52 TUBB LDHB HNRNPU RPS20 FTL FTH1 

In [4]:
######## Generating predictions with the Finetuned Model ########
final_ckpt_path = "/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/finetunedModel_2025-12-09-16_44_46_testFinetune_perturbation_prediction/checkpoint-500"
save_dir        = "/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq"
print(final_ckpt_path)
print(save_dir)

#### Load the finetuned model & Save the best checkpoint as new CSModel ####
finetuned_model = cs.CSModel(
    model_name_or_path=final_ckpt_path, # Path is updated after finetuning
    save_dir=save_dir,
    save_name="reload_perturbation_predictor_finetuned_final"
)
print(finetuned_model.save_path)

# Loading the final finetuned model checkpoint into a regular Hugging Face AutoModelForCausalLM and moving it to GPU/CPU
final_model = AutoModelForCausalLM.from_pretrained(
    finetuned_model.save_path,
    cache_dir=os.path.join(save_dir, ".cache"), #where to store / reuse model files
    trust_remote_code=True
).to(device)
print(final_model)

# Load dataset split (done in finetune() function, saved to output directory)
ft_dir="/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/finetunedModel_2025-12-09-16_44_46_testFinetune_perturbation_prediction/"
with open(os.path.join(ft_dir, 'data_split_indices_dict.pkl'), 'rb') as f:
    data_split_indices_dict = pickle.load(f)
print(data_split_indices_dict.keys())
print(len(data_split_indices_dict['train']))
print(len(data_split_indices_dict['val']))
print(len(data_split_indices_dict['test']))

/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/finetunedModel_2025-12-09-16_44_46_testFinetune_perturbation_prediction/checkpoint-500
/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq
Using device: cpu
/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/reload_perturbation_predictor_finetuned_final
GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 2048)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-15): 16 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXSdpaAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_val

## Check all unique perturbations (genes) in the __test set__

In [23]:
# Select all test samples
formatted_test_ds = formatted_ds.select(data_split_indices_dict['test'])
print(len(formatted_test_ds))
print(formatted_test_ds)

#### Find all the perturbations/genes ####
import re

unique_perturbs = []
for sample_idx, formatted_row in enumerate(formatted_test_ds):
    inference_prompt = formatted_test_ds[sample_idx]['model_input']
    match = re.search(r'perturbation:\s*(.*?)\.', inference_prompt) #only match once for everything between
    if match:
        result = match.group(1)
        unique_perturbs.append(result)
        
print("\nAll unique perturbs and occurence:")
print(Counter(unique_perturbs))

# Create the unique set
unique_perturbs = set(unique_perturbs)
print(len(unique_perturbs)) #2339

24567
Dataset({
    features: ['sample_type', 'model_input', 'response'],
    num_rows: 24567
})

All unique perturbs and occurence:
Counter({'TFAM': 260, 'SLC1A5': 170, 'GTF3C4': 143, 'GFM1': 126, 'NBPF12': 120, 'PSMB5': 119, 'MRPL36': 107, 'PPP6C': 107, 'TARDBP': 92, 'MRPL35': 92, 'BCAR1': 84, 'MRPL34': 83, 'POGLUT3': 80, 'CCDC6': 73, 'GAB2': 73, 'GTF2E2': 71, 'ZDHHC7': 66, 'HSD17B10': 62, 'THAP1': 60, 'CCDC78': 59, 'TRNT1': 59, 'TWF1': 56, 'C7orf26': 56, 'TMEM214': 54, 'ZNF236': 53, 'ZBTB17': 53, 'PPP1R37': 49, 'INTS14': 49, 'KRT17': 48, 'PPP2R1A': 46, 'FAM136A': 45, 'DDX55': 44, 'ANAPC15': 44, 'ZNHIT3': 44, 'EPS8L1': 44, 'CLOCK': 43, 'FBXO42': 43, 'EIF4B': 42, 'INTS13': 42, 'SHQ1': 41, 'TSR1': 41, 'ZNF718': 40, 'TRAPPC11': 39, 'TMEM242': 38, 'MIS18BP1': 38, 'TFDP1': 38, 'ADAT3': 38, 'HSCB': 37, 'SMARCB1': 37, 'GPS1': 37, 'PPP2CA': 37, 'DNM1': 37, 'IK': 36, 'PSTK': 36, 'ESPN': 36, 'EIF2B1': 35, 'SLC35G2': 34, 'JAZF1': 34, 'LAMTOR1': 33, 'PMPCB': 33, 'POLR1C': 33, 'MCM3': 33, 'TPT1':

- In total 2339 unique perturbations/genes in the test set (which makes sense when there are many duplicates before train-eval-test split

## Check if 'head' and 'tail' genes are in the test samples

>Check the genes in "well-studied “head” genes (e.g., TP53, EGFR, MKI67, ribosomal RPL/RPS; interferon ISG15/IFI6) acquire rich, stable embeddings, while rare or lineage-restricted “tail” genes (e.g., tuft-cell POU2F3, mTEC AIRE, hair-cell ATOH1".

In [50]:
# set to list
ls_unique_perturbs = list(unique_perturbs)
ls_unique_perturbs_upper = [gene.upper() for gene in ls_unique_perturbs]

# convert adata target gene column to all uppercase
ls_adata_target_gene = adata.obs['target_gene'].tolist()
ls_adata_target_gene_upper = [gene.upper() for gene in ls_adata_target_gene]


# create gene sets
head = ['TP53', 'EGFR', 'MKI67', 'RPL', 'RPS', 'ISG15', 'IFI6']
tail = ['POU2F3', 'AIRE', 'ATOH1']

for head_gene in head:
    print(f"(test samples)     : {head_gene} -> {head_gene.upper() in ls_unique_perturbs_upper}")
    print(f"(original all data): {head_gene} -> {head_gene.upper() in ls_adata_target_gene_upper}\n")
for tail_gene in tail:
    print(f"(test samples)     : {tail_gene} -> {tail_gene.upper() in ls_unique_perturbs_upper}")
    print(f"(original all data): {tail_gene} -> {tail_gene.upper() in ls_adata_target_gene_upper}\n")

(test samples)     : TP53 -> False
(original all data): TP53 -> False

(test samples)     : EGFR -> False
(original all data): EGFR -> False

(test samples)     : MKI67 -> False
(original all data): MKI67 -> False

(test samples)     : RPL -> False
(original all data): RPL -> False

(test samples)     : RPS -> False
(original all data): RPS -> False

(test samples)     : ISG15 -> False
(original all data): ISG15 -> False

(test samples)     : IFI6 -> False
(original all data): IFI6 -> False

(test samples)     : POU2F3 -> False
(original all data): POU2F3 -> False

(test samples)     : AIRE -> False
(original all data): AIRE -> False

(test samples)     : ATOH1 -> False
(original all data): ATOH1 -> False



In [46]:
# check if in the original data
'NELFE' in adata.obs['target_gene'].tolist()

True

## *Save the unique perturbs/genes in the test set

In [63]:
arr_unique_perturbs_testset = np.array(list(unique_perturbs))
# Save
np.save(os.path.join(ft_dir,"unique_perturbs_testset_script13.npy"), arr_unique_perturbs_testset)

# # Later, load
# arr_unique_perturbs_testset = np.load('/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/finetunedModel_2025-12-09-16_44_46_testFinetune_perturbation_prediction/unique_perturbs_testset_script13.npy')
# print(arr_unique_perturbs_testset)