In [1]:
import json
import numpy as np
import pandas as pd
import scanpy as sc
import gseapy as gp
from tqdm import tqdm
from pathlib import Path
from gears import PertData, GEARS
from pathlib import Path

In [4]:
data_dir = Path('/Users/djemec/data/jepa/v0_2')
tok_dir = data_dir / 'training'
splits= ['train','val','test']

In [5]:
chunk_size = 25000        # How many cells per file
n_pathways = 1024          # Number of pathway "tokens" per cell
n_genes = 8192 # 2**13
count_normalize_target = 1e4 # normalize each cell to this count
dataset_name = 'k562e'

## Data Download

In [6]:
pert_data = PertData(data_dir / dataset_name) 
pert_data.load(data_name='replogle_k562_essential')

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['C7orf26+ctrl' 'C14orf178+ctrl' 'RPS10-NUDT3+ctrl' 'SEM1+ctrl' 'FAU+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


In [7]:
pert_data.prepare_split(split='simulation', seed=1) 
adata = pert_data.adata 

Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:272
Done!


here1


## Data Processing

In [8]:
total_counts = np.array(adata.X.sum(axis=1)).flatten()
adata.obs['log_total_counts'] = np.log1p(total_counts)

In [9]:
sc.pp.normalize_total(adata, target_sum=count_normalize_target)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=n_genes, subset=True)

In [11]:
genes = adata.var.gene_name.tolist()
print(f'Final Feature Space: {len(genes)} Genes')

with open(data_dir / 'gene_names.json', 'w') as f:
    json.dump(genes, f)

Final Feature Space: 5000 Genes


## Pathway Mask

In [12]:
gs_res = gp.get_library(name='DSigDB', organism='Human')

In [13]:
valid_pathways = {k: v for k, v in gs_res.items() if 80 <= len(v) <= 1400}
pathway_names = list(valid_pathways.keys())[:n_pathways]
# Save Pathway Names
with open(data_dir / 'pathway_names.json', 'w') as f:
    json.dump(pathway_names, f)

In [14]:
binary_mask = np.zeros((len(genes), len(pathway_names)), dtype=np.float32)
gene_to_idx = {gene: i for i, gene in enumerate(genes)}

In [15]:
for p_idx, p_name in enumerate(pathway_names):
    hit_count = 0
    genes_in_pathway = valid_pathways[p_name]
    for g in genes_in_pathway:
        if g in gene_to_idx:
            binary_mask[gene_to_idx[g], p_idx] = 1.0
            hit_count += 1

    if hit_count <= 1:
    	print(f'pathway {p_name} had {hit_count} gene hits')

In [16]:
np.save(data_dir / 'binary_pathway_mask.npy', binary_mask)


## Prepare Controls

In [17]:
control_mask = adata.obs['condition'] == 'ctrl'
control_indices = np.where(control_mask)[0]

In [18]:
# Format: List of (Gene_Vector, Total_Count_Scalar)
control_bank = {
    'X': adata.X[control_indices].toarray().astype(np.float32),
    'total': adata.obs['log_total_counts'].values[control_indices].astype(np.float32)
}

In [19]:
print(f'Found {len(control_indices)} control cells.')

Found 10691 control cells.


## Perturbations

In [20]:
def clean_gears_name(name):
    # GEARS format is 'Gene+ctrl' -> We want 'Gene'
    if name.endswith('+ctrl'):
        return name.replace('+ctrl', '')
    elif name == 'ctrl':
        return 'control'
    return name

In [21]:
all_perts = adata.obs['condition'].unique()
pert_to_id = {clean_gears_name(p): i for i, p in enumerate(all_perts)}

In [22]:
with open(data_dir/'perturbation_map.json', 'w') as f:
    json.dump({str(k): int(v) for k, v in pert_to_id.items()}, f)

## Shard Save

In [23]:
split_map = pert_data.set2conditions 

In [24]:
def write_shards(split_name, condition_list, ds_name):
    """
    Iterates through cells belonging to the given conditions, 
    pairs them with random controls, and saves .npz shards.
    """
    print(f'Split: {split_name.upper()}')
    
    # Filter cells belonging to these perturbations
    # Note: We exclude 'ctrl' from the 'Treated' side of the pair
    mask = adata.obs['condition'].isin(condition_list) & (adata.obs['condition'] != 'ctrl')
    indices = np.where(mask)[0]
    
    # Shuffle for randomness
    np.random.shuffle(indices)
    
    # Buffer for current shard
    buffer = {
        'control_x': [], 
        'control_total': [],
        'case_x': [], 
        'case_total': [],
        'action_ids': []
    }
    
    shard_count = 0
    save_path = eval_dir / split_name
    
    for idx in tqdm(indices):
        # 1. Get Case Data
        case_x = adata.X[idx].toarray().flatten().astype(np.float32)
        case_tot = adata.obs['log_total_counts'].iloc[idx].astype(np.float32)
        pert_name = adata.obs['condition'].iloc[idx]
        
        # 2. Get Random Control Pair
        # Ideally we match batch, but Replogle K562 is often batch-corrected or single batch.
        # For simplicity/speed here, we sample global control.
        # (Improvement: dictionary mapping batch_id -> control_indices)
        rand_idx = np.random.randint(len(control_bank['X']))
        ctrl_x = control_bank['X'][rand_idx]
        ctrl_tot = control_bank['total'][rand_idx]
        
        # 3. Add to Buffer
        buffer['control_x'].append(ctrl_x)
        buffer['control_total'].append(ctrl_tot)
        buffer['case_x'].append(case_x)
        buffer['case_total'].append(case_tot)
        buffer['action_ids'].append(pert_to_id[clean_gears_name(pert_name)])
        
        # 4. Save if buffer full
        if len(buffer['case_x']) >= chunk_size:
            np.savez(
                save_path / f'shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
                control=np.array(buffer['control_x']),
                control_total=np.array(buffer['control_total']),
                case=np.array(buffer['case_x']),
                case_total=np.array(buffer['case_total']),
                action_ids=np.array(buffer['action_ids'], dtype=np.int16)
            )
            # Reset
            buffer = {k: [] for k in buffer}
            shard_count += 1
            
    # Save leftovers
    if len(buffer['case_x']) > 0:
        np.savez(
            save_path / f'shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
            control=np.array(buffer['control_x']),
            control_total=np.array(buffer['control_total']),
            case=np.array(buffer['case_x']),
            case_total=np.array(buffer['case_total']),
            action_ids=np.array(buffer['action_ids'], dtype=np.int16)
        )

In [25]:
write_shards('train', split_map['train'], dataset_name)

Split: TRAIN


100%|██████████████████████████████████████████████████████████████████████████████████| 101682/101682 [00:04<00:00, 22299.84it/s]


In [26]:
write_shards('val', split_map['val'], dataset_name)

Split: VAL


100%|████████████████████████████████████████████████████████████████████████████████████| 11044/11044 [00:00<00:00, 37751.85it/s]


In [27]:
write_shards('test', split_map['test'], dataset_name)

Split: TEST


100%|████████████████████████████████████████████████████████████████████████████████████| 38829/38829 [00:01<00:00, 26543.83it/s]
