In [1]:
import os
import shutil
import json
import numpy as np
import pandas as pd
import scanpy as sc
import gseapy as gp
from tqdm import tqdm
from pathlib import Path
from gears import PertData, GEARS
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dir = Path('/home/ubuntu/data/decoder')
eval_dir = data_dir / 'tokenized'
splits= ['train','val','test']

chunk_size = 20000        # How many cells per file
n_pathways = 1024          # Number of pathway "tokens" per cell
n_genes = 4096
count_normalize_target = 1e4 # normalize each cell to this count

# Download Data

We'll start by downloading the GEARS version of Replogle K562 Data. Since we're benchmarking we want to ensure we use the EXACT same data/splits as the SOTA models

In [3]:
pert_data = PertData(eval_dir) 
pert_data.load(data_name='replogle_k562_essential')

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['C7orf26+ctrl' 'C14orf178+ctrl' 'RPS10-NUDT3+ctrl' 'SEM1+ctrl' 'FAU+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


**Create Splits**
Now we'll create the same splits that GEARS used: 
* Train: Seen perturbations
* Val: Seen perturbations (held out cells)
* Test: Unseen perturbations (The real challenge)

In [4]:
pert_data.prepare_split(split='simulation', seed=1) 
adata = pert_data.adata 

Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:272
Done!


here1


## Data Preprocessing
Now we need to do the same preprocessing we did before for our model. 

**Total Counts**

We'll start by adding the log of total counts. 

In [5]:
total_counts = np.array(adata.X.sum(axis=1)).flatten()
adata.obs['log_total_counts'] = np.log1p(total_counts)

**Count Normalization and filtering** 
Now we'll normalize our counts to account for differing read depths. After normalizing we'll take the log to squeeze the order of magnitude differences and then do our gene based filtering.

In [6]:
sc.pp.normalize_total(adata, target_sum=count_normalize_target)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=n_genes, subset=True)

Let's see the gene list that we ended up keeping. Also we'll save the list. 

In [7]:
genes = adata.var.gene_name.tolist()
print(f'Final Feature Space: {len(genes)} Genes')

with open(eval_dir / 'gene_names.json', 'w') as f:
    json.dump(genes, f)

Final Feature Space: 4096 Genes


## Pathway Mask

In [8]:
# We use DSigDB or Reactome to get the gene sets
gs_res = gp.get_library(name='DSigDB', organism='Human')

In [9]:
valid_pathways = {k: v for k, v in gs_res.items() if 80 <= len(v) <= 1400}
pathway_names = list(valid_pathways.keys())[:n_pathways]
# Save Pathway Names
with open(eval_dir / 'pathway_names.json', 'w') as f:
    json.dump(pathway_names, f)

In [10]:
binary_mask = np.zeros((len(genes), len(pathway_names)), dtype=np.float32)
gene_to_idx = {gene: i for i, gene in enumerate(genes)}

In [11]:
for p_idx, p_name in enumerate(pathway_names):
    hit_count = 0
    genes_in_pathway = valid_pathways[p_name]
    for g in genes_in_pathway:
        if g in gene_to_idx:
            binary_mask[gene_to_idx[g], p_idx] = 1.0
            hit_count += 1

    if hit_count <= 1:
    	print(f'pathway {p_name} had {hit_count} gene hits')

In [12]:
np.save(eval_dir / 'binary_pathway_mask.npy', binary_mask)


## Prepare Controls 

In [13]:
control_mask = adata.obs['condition'] == 'ctrl'
control_indices = np.where(control_mask)[0]

In [14]:
# Format: List of (Gene_Vector, Total_Count_Scalar)
control_bank = {
    'X': adata.X[control_indices].toarray().astype(np.float32),
    'total': adata.obs['log_total_counts'].values[control_indices].astype(np.float32)
}

In [15]:
print(f'Found {len(control_indices)} control cells.')

Found 10691 control cells.


## Generate Shards

In [16]:
# Let's extract the official splits from the GEARS object
# split_dict = {'train': [pert1, pert2...], 'val': [...], 'test': [...]}
split_map = pert_data.set2conditions 

**Perturbation Map** we actually need to load our training perturbation map so that we reuse the ID for any overlapping perturbations and generate new IDs for new ones.  

Also, the perturbation names are slightly different so we'll need to modify the names of the GEARS dataset

In [19]:
original_pert_map = eval_dir / 'perturbation_map_original.json'
with open(original_pert_map, 'r') as f:
    training_map = json.load(f)

# invert map to find ID:gene name 
max_pert_id = max([v for k, v in training_map.items()])
max_pert_id

2057

In [20]:
def clean_gears_name(name):
    # GEARS format is 'Gene+ctrl' -> We want 'Gene'
    if name.endswith('+ctrl'):
        return name.replace('+ctrl', '')
    elif name == 'ctrl':
        return 'control'
    return name

In [21]:
all_perts = adata.obs['condition'].unique()
overlap = 0
new = 0
id_to_pert = {v:k for k, v in training_map.items()}
for a in all_perts:
    clean_name = clean_gears_name(a)
    if clean_name in training_map:
        overlap += 1
    else:
        max_pert_id += 1
        new += 1
        id_to_pert[max_pert_id] = clean_name
f'overlap {overlap} | new additions {new}'

'overlap 1088 | new additions 0'

In [22]:
pert_to_id = {v: k for k, v in id_to_pert.items()}
with open(eval_dir / 'perturbation_map.json', 'w') as f:
    json.dump({str(k): int(v) for k, v in pert_to_id.items()}, f)

In [23]:
def write_shards(split_name, condition_list):
    """
    Iterates through cells belonging to the given conditions, 
    pairs them with random controls, and saves .npz shards.
    """
    print(f'Split: {split_name.upper()}')
    
    # Filter cells belonging to these perturbations
    # Note: We exclude 'ctrl' from the 'Treated' side of the pair
    mask = adata.obs['condition'].isin(condition_list) & (adata.obs['condition'] != 'ctrl')
    indices = np.where(mask)[0]
    
    # Shuffle for randomness
    np.random.shuffle(indices)
    
    # Buffer for current shard
    buffer = {
        'control_x': [], 
        'control_total': [],
        'case_x': [], 
        'case_total': [],
        'action_ids': []
    }
    
    shard_count = 0
    save_path = eval_dir / split_name
    
    for idx in tqdm(indices):
        # 1. Get Case Data
        case_x = adata.X[idx].toarray().flatten().astype(np.float32)
        case_tot = adata.obs['log_total_counts'].iloc[idx].astype(np.float32)
        pert_name = adata.obs['condition'].iloc[idx]
        
        # 2. Get Random Control Pair
        # Ideally we match batch, but Replogle K562 is often batch-corrected or single batch.
        # For simplicity/speed here, we sample global control.
        # (Improvement: dictionary mapping batch_id -> control_indices)
        rand_idx = np.random.randint(len(control_bank['X']))
        ctrl_x = control_bank['X'][rand_idx]
        ctrl_tot = control_bank['total'][rand_idx]
        
        # 3. Add to Buffer
        buffer['control_x'].append(ctrl_x)
        buffer['control_total'].append(ctrl_tot)
        buffer['case_x'].append(case_x)
        buffer['case_total'].append(case_tot)
        buffer['action_ids'].append(pert_to_id[clean_gears_name(pert_name)])
        
        # 4. Save if buffer full
        if len(buffer['case_x']) >= chunk_size:
            np.savez(
                save_path / f'shard_{split_name}_{shard_count:04d}.npz',
                control=np.array(buffer['control_x']),
                control_total=np.array(buffer['control_total']),
                case=np.array(buffer['case_x']),
                case_total=np.array(buffer['case_total']),
                action_ids=np.array(buffer['action_ids'], dtype=np.int16)
            )
            # Reset
            buffer = {k: [] for k in buffer}
            shard_count += 1
            
    # Save leftovers
    if len(buffer['case_x']) > 0:
        np.savez(
            save_path / f'shard_{split_name}_{shard_count:04d}.npz',
            control=np.array(buffer['control_x']),
            control_total=np.array(buffer['control_total']),
            case=np.array(buffer['case_x']),
            case_total=np.array(buffer['case_total']),
            action_ids=np.array(buffer['action_ids'], dtype=np.int16)
        )

In [24]:
write_shards('train', split_map['train'])

Split: TRAIN


100%|████████████████████████████████████████████████████████████████| 101682/101682 [00:06<00:00, 15660.38it/s]


In [25]:
write_shards('val', split_map['val'])

Split: VAL


100%|██████████████████████████████████████████████████████████████████| 11044/11044 [00:00<00:00, 23390.93it/s]


In [26]:
write_shards('test', split_map['test']) # Unseen perturbations

Split: TEST


100%|██████████████████████████████████████████████████████████████████| 38829/38829 [00:02<00:00, 18542.89it/s]
