In [20]:
import gseapy as gp
from tqdm import tqdm
import json
from pathlib import Path
import scanpy as sc
import numpy as np
import collections
import pandas as pd

In [2]:
data_dir = Path('/Users/djemec/data/jepa')
clean_path = data_dir / 'cleaned_replogle_k562.h5ad'
tokenized_dir = data_dir / 'tokenized'
chunk_size = 10000        # How many cells per .npy file
n_pathways = 1024          # Number of pathway "tokens" per cell
QUANTIZATION_MAX = 20.0   # Cap for outlier clamping (log-normalized expression rarely exceeds this)

### Initiation and read data

In [3]:
COL_PERTURBATION = 'perturbation'  # The gene being knocked down
COL_CONTROL_LABEL = 'control' # The label string for control cells
COL_BATCH = 'batch' # Batch column (if available) to match controls/treated within same batch

In [4]:
adata = sc.read_h5ad(clean_path)

In [5]:
genes = adata.var_names.tolist()
        
# State storage
gene_mask = None
pathway_names = []

In [6]:
# Prepare Perturbation Dictionary (Str -> Int ID)
all_perturbations = adata.obs[COL_PERTURBATION].unique()
pert_to_id = {p: i for i, p in enumerate(all_perturbations)}

In [7]:
with open(data_dir/'perturbation_map.json', 'w') as f:
            json.dump({str(k): int(v) for k, v in pert_to_id.items()}, f)


### Build Pathway Mask

Using Reactome 2024 pathway dataset since it's updated frequently and  it pairs well with the Replogle et al. (2022) K562 Perturb-seq dataset

In [8]:
gs_res = gp.get_library(name='Reactome_Pathways_2024', organism='Human')
gs_res

{'2-LTR Circle Formation': ['REV',
  'XRCC5',
  'XRCC6',
  'XRCC4',
  'PSIP1',
  'LIG4',
  'HMGA1',
  'GAG',
  'GAG-POL',
  'VIF',
  'VPR',
  'BANF1',
  'VPU'],
 'A Tetrasaccharide Linker Sequence Is Required for GAG Synthesis': ['B3GAT3',
  'B3GAT2',
  'B3GAT1',
  'BGN',
  'B3GALT6',
  'DCN',
  'VCAN',
  'BCAN',
  'HSPG2',
  'SDC1',
  'SDC2',
  'CSPG5',
  'CSPG4',
  'B4GALT7',
  'AGRN',
  'NCAN',
  'SDC3',
  'XYLT1',
  'SDC4',
  'XYLT2',
  'GPC2',
  'GPC1',
  'GPC4',
  'GPC3',
  'GPC6',
  'GPC5'],
 'ABC Transporter Disorders': ['ABCG8',
  'ABCC2',
  'KCNJ11',
  'ABCC8',
  'ABCC9',
  'ABCC6',
  'PSMC6',
  'ABCA12',
  'PSMC4',
  'PSMC5',
  'APOA1',
  'PSMC2',
  'PSMC3',
  'PSMC1',
  'RNF185',
  'ABCG5',
  'PSMD11',
  'ABCB4',
  'PSMD13',
  'PSMD12',
  'PSMD14',
  'SEL1L',
  'ABCB6',
  'RNF5',
  'PSMD7',
  'PSMD8',
  'PSMD6',
  'ERLEC1',
  'PSMD3',
  'PSMD1',
  'PSMD2',
  'RPS27A',
  'ABCA3',
  'ABCA1',
  'ADRM1',
  'ERLIN2',
  'ERLIN1',
  'PSMA6',
  'LMBRD1',
  'PSMA7',
  'PSMA4',
  'AB

In [9]:
# Sort pathways by size (information content) or just take first N
# Let's filter for pathways with 10-500 genes to avoid too small sets

valid_pathways = {k: v for k, v in gs_res.items() if 10 <= len(v) <= 500}
selected_keys = list(valid_pathways.keys())[:n_pathways]
pathway_names = selected_keys
len(pathway_names), pathway_names

(1024,
 ['2-LTR Circle Formation',
  'A Tetrasaccharide Linker Sequence Is Required for GAG Synthesis',
  'ABC Transporter Disorders',
  'ABC Transporters in Lipid Homeostasis',
  'ABC-family Proteins Mediated Transport',
  'ADORA2B Mediated Anti-Inflammatory Cytokines Production',
  'ADP Signalling Through P2Y Purinoceptor 1',
  'ADP Signalling Through P2Y Purinoceptor 12',
  'AKT Phosphorylates Targets in the Cytosol',
  'AKT Phosphorylates Targets in the Nucleus',
  'ALK Mutants Bind TKIs',
  'APC Truncation Mutants Have Impaired AXIN Binding',
  'APC-Cdc20 Mediated Degradation of Nek2A',
  'APC C-mediated Degradation of Cell Cycle Proteins',
  'APC C Cdc20 Mediated Degradation of Cyclin B',
  'APC C Cdc20 Mediated Degradation of Securin',
  'APC C Cdc20 Mediated Degradation of Mitotic Proteins',
  'APC C Cdh1 Mediated Degradation of Cdc20 and Other APC C Cdh1 Targets in Late Mitosis Early G1',
  'APC Cdc20 Mediated Degradation of Cell Cycle Proteins Prior to Satisfation of the Chec

### Build Mask

In [10]:
# Build Mask [Genes, Pathways]
mask = np.zeros((len(genes), len(selected_keys)), dtype=np.float32)
gene_to_idx = {gene: i for i, gene in enumerate(genes)}

In [11]:
for p_idx, p_name in enumerate(selected_keys):
    genes_in_pathway = valid_pathways[p_name]
    hit_count = 0
    for gene in genes_in_pathway:
        if gene in gene_to_idx:
            mask[gene_to_idx[gene], p_idx] = 1.0
            hit_count += 1
    
    # Normalize
    if hit_count > 0:
        mask[:, p_idx] /= hit_count
        
gene_mask = mask

In [12]:
# Save Pathway Names for analysis later
with open(data_dir /'pathway_names.json', 'w') as f:
    json.dump(selected_keys, f)

## Tokenize

1/ Caching Control Cells into RAM.

In [13]:
obs = adata.obs
obs

Unnamed: 0_level_0,batch,gene,gene_id,transcript,gene_transcript,guide_id,percent_mito,UMI_count,z_gemgroup_UMI,core_scale_factor,...,age,perturbation,organism,perturbation_type,tissue_type,ncounts,ngenes,nperts,percent_ribo,n_genes
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCCAAGAAATCCA-27,27,NAF1,ENSG00000145414,P1P2,5449_NAF1_P1P2_ENSG00000145414,NAF1_+_164087918.23-P1P2|NAF1_-_164087674.23-P1P2,0.112083,11438.0,0.013047,0.813253,...,53,NAF1,human,CRISPR,cell_line,11324.0,3332,1,0.225362,3332
AAACCCAAGAACTTCC-31,31,BUB1,ENSG00000169679,P1P2,935_BUB1_P1P2_ENSG00000169679,BUB1_-_111435363.23-P1P2|BUB1_-_111435372.23-P1P2,0.179895,5342.0,-1.522247,0.844107,...,53,BUB1,human,CRISPR,cell_line,5257.0,2192,1,0.129732,2192
AAACCCAAGAAGCCAC-34,34,UBL5,ENSG00000198258,P1P2,9534_UBL5_P1P2_ENSG00000198258,UBL5_-_9938639.23-P1P2|UBL5_+_9938801.23-P1P2,0.105287,17305.0,0.384157,1.091537,...,53,UBL5,human,CRISPR,cell_line,17135.0,4002,1,0.236825,4002
AAACCCAAGAATAGTC-43,43,C9orf16,ENSG00000171159,P1P2,1131_C9orf16_P1P2_ENSG00000171159,C9orf16_+_130922603.23-P1P2|C9orf16_+_13092264...,0.099359,30244.0,3.721912,0.948277,...,53,C9orf16,human,CRISPR,cell_line,29717.0,5358,1,0.246828,5358
AAACCCAAGACAGCGT-28,28,TIMM9,ENSG00000100575,P1P2,8927_TIMM9_P1P2_ENSG00000100575,TIMM9_-_58893843.23-P1P2|TIMM9_-_58893848.23-P1P2,0.137623,8407.0,-0.975371,0.868942,...,53,TIMM9,human,CRISPR,cell_line,8261.0,2944,1,0.183392,2944
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGTCTGTCGTC-45,45,ATP6V1D,ENSG00000100554,P1P2,682_ATP6V1D_P1P2_ENSG00000100554,ATP6V1D_+_67826485.23-P1P2|ATP6V1D_+_67826497....,0.100272,18350.0,0.428227,1.115052,...,53,ATP6V1D,human,CRISPR,cell_line,18044.0,4300,1,0.253547,4300
TTTGTTGTCTGTCTCG-27,27,CNOT3,ENSG00000088038,P1P2,1718_CNOT3_P1P2_ENSG00000088038,CNOT3_+_54641532.23-P1P2|CNOT3_-_54641691.23-P1P2,0.093876,8671.0,-0.633593,0.813253,...,53,CNOT3,human,CRISPR,cell_line,8510.0,3158,1,0.183196,3158
TTTGTTGTCTGTGCGG-44,44,METTL3,ENSG00000165819,P1P2,5004_METTL3_P1P2_ENSG00000165819,METTL3_+_21979431.23-P1P2|METTL3_-_21979084.23...,0.107983,20568.0,1.054624,0.973352,...,53,METTL3,human,CRISPR,cell_line,20355.0,4247,1,0.256595,4247
TTTGTTGTCTTGCAGA-14,14,RPL5,ENSG00000122406,P1P2,7475_RPL5_P1P2_ENSG00000122406,RPL5_+_93297664.23-P1P2|RPL5_-_93297968.23-P1P2,0.128225,23568.0,1.676254,1.050055,...,53,RPL5,human,CRISPR,cell_line,23169.0,4690,1,0.210713,4690


**Find control indices**

In [19]:
is_control = obs[COL_PERTURBATION].astype(str).str.contains(COL_CONTROL_LABEL, case=False, na=False)
control_indices = np.where(is_control)[0]
is_control[is_control] # shows only true

cell_barcode
AAACCCAAGCGAGGAG-47    True
AAACCCAAGCGTCTGC-27    True
AAACCCAAGGAGGGTG-47    True
AAACCCAAGTACCCTA-20    True
AAACCCAAGTGTTCAC-3     True
                       ... 
TTTGTTGGTCGGCCTA-1     True
TTTGTTGGTCGTCTCT-6     True
TTTGTTGGTGCTATTG-36    True
TTTGTTGTCGCAACAT-25    True
TTTGTTGTCTCCATAT-23    True
Name: perturbation, Length: 10691, dtype: bool

**Process Controls**

In [30]:
def quantize(continuous_matrix):
    '''Helper to convert float -> uint32'''
    clamped = np.clip(continuous_matrix, 0, QUANTIZATION_MAX)
    scale_factor = (2**32 - 1) / QUANTIZATION_MAX
    return (clamped * scale_factor).astype(np.uint32)

In [21]:
control_cache = collections.defaultdict(list)

In [23]:
batch_size = 10000

In [31]:
for i in tqdm(range(0, len(control_indices), batch_size), desc='Tokenizing Controls'):
    idx_chunk = control_indices[i : i+batch_size]
    
    # Load Raw Genes
    # Note: Indexing backed anndata with list is slower, but acceptable for 200k cells
    # Optimization: If sorted, it's faster.
    raw_chunk = adata[idx_chunk].X
    raw_chunk = raw_chunk.toarray()
    
    # Tokenize
    tokens_chunk = np.dot(raw_chunk, gene_mask) #(N, Pathways)
    tokens_quant = quantize(tokens_chunk)
    
    # Group by Batch
    batch_ids = adata.obs.iloc[idx_chunk][COL_BATCH].values if COL_BATCH in adata.obs else np.zeros(len(idx_chunk))
    
    for b_id, tok in zip(batch_ids, tokens_quant):
        control_cache[b_id].append(tok)


Tokenizing Controls: 100%|ââââââââââââââââââââââââ| 2/2 [00:00<00:00,  6.63it/s]


In [33]:
# Convert lists to numpy arrays for fast sampling
for k in control_cache:
    control_cache[k] = np.array(control_cache[k])

In [34]:
print(f'Controls Cached. Batches found: {len(control_cache.keys())}')

Controls Cached. Batches found: 48


### Stream Treated Cells & Pair

In [36]:
treated_indices = np.where(~is_control)[0]
n_treated = len(treated_indices)
n_treated

299694

In [37]:
buffer_control = []
buffer_treated = []
buffer_actions = []

shard_counter = 0

In [38]:
for i in tqdm(range(0, n_treated, chunk_size), desc='Processing Pairs'):
    idx_chunk = treated_indices[i : i+chunk_size]
    
    # Load Treated Data
    chunk_adata = adata[idx_chunk]
    X_treated = chunk_adata.X
    if hasattr(X_treated, "toarray"): X_treated = X_treated.toarray()
    
    # Tokenize Treated
    tokens_treated = np.dot(X_treated, gene_mask)
    tokens_treated_q = quantize(tokens_treated)
    
    # Get Metadata
    batch_ids = chunk_adata.obs[COL_BATCH].values if COL_BATCH in obs else np.zeros(len(idx_chunk))
    pert_names = chunk_adata.obs[COL_PERTURBATION].values
    pert_ids = [pert_to_id[p] for p in pert_names]
    
    # Find Pairs
    matched_controls = []
    for j, b_id in enumerate(batch_ids):
        # Try to find control from same batch
        if b_id in control_cache:
            pool = control_cache[b_id]
        else:
            # Fallback to any control if batch missing
            pool = control_cache[list(control_cache.keys())[0]]
        
        # Random Sample
        rand_idx = np.random.randint(len(pool))
        matched_controls.append(pool[rand_idx])
    
    matched_controls = np.array(matched_controls) # [Chunk, Pathways]
    
    # Save to buffers
    # Structure: [Control_Tokens (P), Treated_Tokens (P), Action_ID (1)]
    # We save separate arrays to the shard for clarity
    
    np.savez_compressed(
        tokenized_dir/ f'shard_{shard_counter:04d}.npz',
        control=matched_controls,
        treated=tokens_treated_q,
        action_ids=np.array(pert_ids, dtype=np.uint16)
    )
    shard_counter += 1

Processing Pairs: 100%|âââââââââââââââââââââââââ| 30/30 [01:16<00:00,  2.54s/it]
