### Setup

In [1]:
# Notebook setup
import os
import yaml
import pickle
os.chdir('../../')
import logging
import numpy as np
import pandas as pd
import scanpy as sc
logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
%load_ext autoreload
%autoreload 0

# assign number of cpus to use as data loaders
n_cpus = 10
seed = 42

# define label in .obs to classify
cls_label = 'cls_label'
batch_key = 'dataset'



In [2]:
# Set working directory
work_dir = '/home/xlv0877/proj_home/dl/'
# Set output directory for training
data_dir = os.path.join(work_dir, 'data')
# Set name
name = 'model_set.h5ad'
# Set output dir
output_dir = '/home/xlv0877/proj_home/dl/data/merge/8a2a33dfa4d9a069/'
metaset_path = os.path.join(output_dir, 'perturb_metaset_w_prot_emb.h5ad')
emb_p = '/home/xlv0877/p32655/projects/perturbation_prediction/ExPert/resources/gene_embeddings/GenePT_emebdding_v2/GenePT_gene_protein_embedding_model_3_text.pickle'
essential_genes_p = '/home/xlv0877/p32655/projects/perturbation_prediction/ExPert/resources/gene_embeddings/essential_genes.tsv'
config_p = os.path.join(output_dir, 'config.yaml')
with open(config_p, 'r') as f:
    config = yaml.safe_load(f)
filtered_dir = os.path.join(config['cache_dir'], 'filtered')
# Read class embedding
with open(emb_p, 'rb') as file:
    cls_emb = pd.DataFrame(pickle.load(file)).T
# Load list of essential genes
essential_genes = pd.read_csv(essential_genes_p, sep='\t', low_memory=False)
# Subset embedding to essential genes
essential_genes = cls_emb.index.intersection(essential_genes.gene.str.upper())

### Load dataset

In [32]:
# Read dataset
adata = sc.read(metaset_path)

In [33]:
# Use broad cell typing if information is available
if 'celltype_broad' in adata.obs.columns:
    adata.obs['celltype'] = adata.obs['celltype_broad'].values
# define all labels to classify on
if 'perturbation_direction' in adata.obs.columns:
    logging.info('Using perturbation direction to classify')
    cls_labels = ['perturbation_direction', 'perturbation']
else:
    cls_labels = ['celltype', 'perturbation_type', 'perturbation']
# define label in .obs to classify
cls_label = 'cls_label'
batch_key = 'dataset'
# create classification label
adata.obs['cls_label'] = adata.obs[cls_labels].agg(';'.join, axis=1)
# add status label (control or perturbed)
adata.obs['is_ctrl'] = adata.obs['perturbation'] == 'control'

2025-09-05 12:24:51,111 - INFO - Using perturbation direction to classify


In [34]:
# Filter perturbations for essential genes
target_train_perturbations = essential_genes.intersection(adata.obs.perturbation.unique())

### Filter dataset

#### Split into train and test dataset

In [35]:
# Focus training on genome-wide studies and use rest for testing
d = pd.read_csv('resources/datasets/meta/training.csv')
training_datasets = d['publication index'].astype(str) + d['dataset index'].fillna('').apply(lambda x: f'_{x}' if x else '')
d = pd.read_csv('resources/datasets/meta/testing.csv')
testing_datasets = d['publication index'].astype(str) + d['dataset index'].fillna('').apply(lambda x: f'_{x}' if x else '')
# Split training and testing
train_adata = adata[adata.obs.dataset.isin(training_datasets)].copy()
test_adata = adata[~adata.obs.dataset.isin(training_datasets)].copy()
logging.info(f'Training adata: {train_adata.shape}')
logging.info(f'Testing adata: {test_adata.shape}')
del adata

2025-09-05 12:44:44,354 - INFO - Training adata: (819165, 5783)
2025-09-05 12:44:44,392 - INFO - Testing adata: (794560, 5783)


### Check out DEGs for datasets

In [6]:
min_degs = 2
signif_perturbations = {}
deg_file_suffix = '_deg_mask.csv'
for file in os.listdir(filtered_dir):
    if file.endswith(deg_file_suffix):
        bn = file.replace(deg_file_suffix, '')
        if bn in train_adata.obs.dataset.unique():
            logging.info(f'Extracting DEGs for {bn}')
            fn = os.path.join(filtered_dir, file)
            dm = pd.read_csv(fn, index_col=0)
            # Filter for adata genes
            dm = dm.T[train_adata.var_names].copy()
            # Filter perturbations for min degs in training data
            signif_perturbations[bn] = set(dm.index[(dm.sum(axis=1)>=min_degs)])

2025-08-21 10:16:08,971 - INFO - Extracting DEGs for XAtlas2025_HEK293T
2025-08-21 10:16:11,662 - INFO - Extracting DEGs for ReplogleWeissman2022_K562_gwps
2025-08-21 10:16:12,340 - INFO - Extracting DEGs for ReplogleWeissman2022_rpe1
2025-08-21 10:16:13,058 - INFO - Extracting DEGs for XAtlas2025_HCT116
2025-08-21 10:16:15,712 - INFO - Extracting DEGs for ReplogleWeissman2025_Jurkat
2025-08-21 10:16:16,400 - INFO - Extracting DEGs for ReplogleWeissman2025_HepG2
2025-08-21 10:16:17,147 - INFO - Extracting DEGs for ReplogleWeissman2022_K562_essential


In [8]:
# Filter for dominant perturbations only
selected_perturbations = set.intersection(*signif_perturbations.values())
adata = train_adata[train_adata.obs.perturbation.isin(selected_perturbations)].copy()

In [36]:
adata = train_adata

In [37]:
adata.obs.perturbation.nunique()

1655

#### Mixscale score

In [15]:
# Filter for mixscale score
mst = 2
ms_mask = (adata.obs['mixscale_score'].abs() > mst) | (adata.obs['perturbation']=='control')
adata._inplace_subset_obs(ms_mask)
cpp = adata.obs.perturbation.value_counts()
logging.info(f'Filtering for mixscale score > {mst}, mean number of cells / perturbation: {cpp[1:].mean()}')
logging.info(f'Filtered for {np.sum(ms_mask)}/{len(ms_mask)} cells')

2025-09-05 11:51:56,744 - INFO - Filtering for mixscale score > 2, mean number of cells / perturbation: 289.21765417170496
2025-09-05 11:51:56,746 - INFO - Filtered for 530176/819165 cells


In [16]:
adata.obs.perturbation.nunique()

1655

In [44]:
# Subset shared perturbations to essential genes
target_mask = adata.obs.perturbation.isin(target_train_perturbations)
ctrl_mask = adata.obs.perturbation=='control'
adata._inplace_subset_obs(target_mask | ctrl_mask)

#### Subset class embeddings to testing embeddings

In [60]:
training_genes = adata.obs.perturbation.unique()
cls_emb_genes = sorted(list((set(essential_genes) | set(training_genes)) & set(cls_emb.index)))
cls_emb_filtered = cls_emb.loc[cls_emb_genes].copy()
cls_emb_filtered.index = 'neg;' + cls_emb_filtered.index
# Change columns to string
cls_emb_filtered.columns = 'dim_' + cls_emb_filtered.columns.astype(str)

In [61]:
# Add control group to embedding
ctrl_emb = cls_emb_filtered.mean(axis=0)
ctrl_emb = pd.DataFrame(ctrl_emb).T
ctrl_emb.index = ['control']
cls_emb_filtered = pd.concat((cls_emb_filtered, ctrl_emb), axis=0)

In [62]:
# Add gene embedding to adata
adata.uns['cls_embedding'] = cls_emb_filtered

#### Remove perturbations that do not have an embedding

In [63]:
logging.info(f'Found {adata.obs.cls_label.nunique()} classes.')
# Either focus on class embedding or gene embedding, should work either way
if 'cls_embedding' in adata.uns:
    embedding_mask = adata.uns['cls_embedding'].sum(axis=1)!=0
    embedding_mask = adata.obs.cls_label.isin(embedding_mask[embedding_mask].index)
elif 'gene_embedding' in adata.obsm:
    # Select cells that have a match in the embedding
    embedding_mask = adata.obsm['gene_embedding'].sum(axis=1).A1!=0
else:
    embedding_mask = None
    logging.info('No embeddings found.')
if embedding_mask is not None:
    adata._inplace_subset_obs((embedding_mask | adata.obs.is_ctrl))
    # Check number of unique perturbations to classify
    logging.info(f'Initializing dataset with {adata.obs.cls_label.nunique()} classes.')

2025-09-05 12:53:26,787 - INFO - Found 557 classes.
2025-09-05 12:53:31,817 - INFO - Initializing dataset with 557 classes.


#### Remove control cells

In [24]:
logging.info('Removing control groups.')
adata._inplace_subset_obs(adata.obs.perturbation!='control')
logging.info(f'Model set shape: {adata.shape[0]} cells x {adata.shape[1]} genes.')

2025-09-05 11:54:00,645 - INFO - Removing control groups.
2025-09-05 11:54:02,878 - INFO - Model set shape: 149663 cells x 5783 genes.


#### Annotate HVGs and filter (optional)

In [27]:
# Select pre-calculated hvgs
hvg_dir = os.path.join(os.path.dirname(metaset_path), 'hvg')
hvgs = set()
for p in os.listdir(hvg_dir):
    idx = p.split('_hvgs')[0]
    fp = os.path.join(hvg_dir, p)
    _hvg = pd.read_csv(fp, index_col=0)
    hvgs.update(_hvg[_hvg.highly_variable==1].index)
adata.var['highly_variable'] = adata.var.index.isin(hvgs)

#### Filter for minimum amount of cells / perturbation

In [64]:
# remove perturbations with less than minimum amount of cells
min_cells = 100
cpp = adata.obs.groupby(cls_labels[:-1], observed=True)[cls_labels[-1]].value_counts()
valid_perturbations = cpp[cpp >= min_cells].reset_index()[cls_labels[-1]].unique().tolist()
logging.info(f'Found {len(valid_perturbations)}/{cpp.shape[0]} perturbations with >= {min_cells} cells in each cell type.')
adata._inplace_subset_obs(adata.obs[cls_labels[-1]].isin(valid_perturbations))

2025-09-05 12:56:13,020 - INFO - Found 516/557 perturbations with >= 100 cells in each cell type.


#### Remove empty cells or genes

In [65]:
# Remove genes with no counts (zero-padded)
sc.pp.filter_genes(adata, min_counts=1)
sc.pp.filter_cells(adata, min_counts=1)

In [66]:
# Manual save for full dataset
o = os.path.join(data_dir, 'eg_full_model_set_w_ctrl.h5ad')
logging.info(f'Saving train model set to {o}')
adata.write_h5ad(o)

2025-09-05 12:56:37,391 - INFO - Saving train model set to /home/xlv0877/proj_home/dl/data/eg_full_model_set_w_ctrl.h5ad
... storing 'cls_label' as categorical


In [36]:
# Subset perturbations
N = 100
sort_key = 'hm'
#sort_key = 'cpp'
summary = (adata.obs.groupby('perturbation', observed=True)
 .apply(lambda x: pd.Series({'ms': np.log(np.abs(x.mixscale_score).mean()), 'cpp': np.log(x.shape[0]), 'dpp': x.dataset.nunique()})).sort_values('cpp', ascending=False)
).reset_index()
# Compute harmonic mean between number of cells and mean mixscale score
summary['hm'] = (3*summary.ms*summary.dpp*summary.cpp) / (summary.ms+summary.dpp+summary.cpp)
summary['flagged'] = False
summary.loc[summary[sort_key].sort_values(ascending=False).head(N).index,'flagged'] = True
subset_mask = adata.obs.perturbation.isin(summary[summary.flagged].perturbation)
logging.info(f'Filtered adata for top {N} perturbations, got {subset_mask.sum()} cells.')
small_adata = adata[subset_mask].copy()

  .apply(lambda x: pd.Series({'ms': np.log(np.abs(x.mixscale_score).mean()), 'cpp': np.log(x.shape[0]), 'dpp': x.dataset.nunique()})).sort_values('cpp', ascending=False)
2025-08-19 06:29:32,427 - INFO - Filtered adata for top 100 perturbations, got 55811 cells.


### Save training data

In [37]:
o_small = os.path.join(data_dir, 'model_set.h5ad')
logging.info(f'Saving train subset model set to {o_small}')
small_adata.write_h5ad(o_small)
o = os.path.join(data_dir, 'full_model_set.h5ad')
logging.info(f'Saving train model set to {o}')
adata.write_h5ad(o)
test_p = os.path.join(data_dir, 'test_set.h5ad')
logging.info(f'Saving test model set to {test_p}')
test_adata.write_h5ad(test_p)

2025-08-19 06:29:43,782 - INFO - Saving train subset model set to /home/xlv0877/proj_home/dl/data/model_set.h5ad
... storing 'cls_label' as categorical
2025-08-19 06:29:46,927 - INFO - Saving train model set to /home/xlv0877/proj_home/dl/data/full_model_set.h5ad
... storing 'cls_label' as categorical
2025-08-19 06:29:55,430 - INFO - Saving test model set to /home/xlv0877/proj_home/dl/data/test_set.h5ad
... storing 'cls_label' as categorical


In [16]:
# Manual save for full dataset
o = os.path.join(data_dir, 'full_model_set.h5ad')
logging.info(f'Saving train model set to {o}')
adata.write_h5ad(o)

2025-08-30 07:51:14,733 - INFO - Saving train model set to /home/xlv0877/proj_home/dl/data/full_model_set.h5ad
... storing 'cls_label' as categorical


In [50]:
set.union(*signif_perturbations.values()).intersection(paper_perturbations.perturbation.values)

{'DNMT1', 'NAE1', 'SMARCE1', 'UBE2M', 'VHL'}

In [44]:
# Check overlap with paper perturbations
pp = '/home/xlv0877/p32655/projects/perturbation_prediction/benchmark/chef/paper_perturbations.txt'
paper_perturbations = pd.read_csv(pp, header=None)
paper_perturbations.columns = ['perturbation']
paper_perturbations

Unnamed: 0,perturbation
0,BAML2
1,BRPF1
2,CABIN1
3,CASP3
4,CASP6
5,CHAMP1
6,COA7
7,COL9A3
8,DNMT1
9,EMSY


In [58]:
test.obs.perturbation.value_counts()

perturbation
control    9871
TFAM       2536
SLC1A5     1719
GFM1       1355
GTF3C4     1270
           ... 
AURKB         3
NUP93         2
RPL7A         2
SEC13         2
NUP155        1
Name: count, Length: 1821, dtype: int64