In [4]:
from pathlib import Path
import scanpy as sc
import scperturb
import pandas as pd
import numpy as np
import os
import requests

In [31]:
data_dir = Path('/Users/djemec/data/jepa')

raw_path = data_dir / 'replogle_k562.h5ad'
clean_path = data_dir / 'cleaned_replogle_k562.h5ad'

## Data
Downloading [Replogle 2022 (K562 Essential)](https://virtualcellmodels.cziscience.com/dataset/k562-essential-perturb-seq) via scPerturb: [more datasets](https://projects.sanderlab.org/scperturb/datavzrd/scPerturb_vzrd_v1/dataset_info/index_1.html)

The original dataset is part of a large-scale genotype-phenotype map developed by Replogle et. al. in 2022. The dataset specifically includes gene expression profiles from the human chronic myeloid leukemia cell line (K562 cells) after genetic perturbation of essential genes using CRISPR interference. The dataset was processed to benchmark models performing genetic perturbation prediction tasks.

In [5]:
#raw data download
#!wget -P {raw_path} https://zenodo.org/record/7041849/files/ReplogleWeissman2022_K562_essential.h5ad?download=1 

--2025-12-13 20:46:25--  https://zenodo.org/record/7041849/files/ReplogleWeissman2022_K562_essential.h5ad?download=1
Resolving zenodo.org (zenodo.org)... 2001:1458:d00:24::100:245, 2001:1458:d00:17::100:1d9, 2001:1458:d00:61::100:427, ...
Connecting to zenodo.org (zenodo.org)|2001:1458:d00:24::100:245|:443... connected.
HTTP request sent, awaiting response... 301 MOVED PERMANENTLY
Location: /records/7041849/files/ReplogleWeissman2022_K562_essential.h5ad [following]
--2025-12-13 20:46:26--  https://zenodo.org/records/7041849/files/ReplogleWeissman2022_K562_essential.h5ad
Reusing existing connection to [zenodo.org]:443.
HTTP request sent, awaiting response... 200 OK
Length: 1546729675 (1.4G) [application/octet-stream]
Saving to: ‘/Users/djemec/data/jepa/replogle_k562.h5ad/ReplogleWeissman2022_K562_essential.h5ad?download=1’


2025-12-13 20:53:24 (3.53 MB/s) - ‘/Users/djemec/data/jepa/replogle_k562.h5ad/ReplogleWeissman2022_K562_essential.h5ad?download=1’ saved [1546729675/1546729675]



In [7]:
adata = sc.read_h5ad(raw_path)

### Filtering the data

In [9]:
# 1. Filter cells with low counts (standard QC)
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)

### Normalize Data
We log-normalize the counts for the input since our model will try to predict these values in latent space.

In [10]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

### Perturbation Label Inspection
We need to map 'perturbation' (str) to a dense vector for the Action Conditioner.  We Typically see a format of: "GeneA_Knockdown"

In [28]:
adata.obs.columns

Index(['batch', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'guide_id',
       'percent_mito', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor',
       'core_adjusted_UMI_count', 'disease', 'cancer', 'cell_line', 'sex',
       'age', 'perturbation', 'organism', 'perturbation_type', 'tissue_type',
       'ncounts', 'ngenes', 'nperts', 'percent_ribo', 'n_genes'],
      dtype='object')

In [29]:
set(adata.obs.perturbation)

{'MCRS1',
 'ELP2',
 'PHB2',
 'PSMD6',
 'ARMC6',
 'TRAPPC3',
 'OR1E2',
 'DLD',
 'LARS2',
 'UBE2D3',
 'DYNC1H1',
 'PKM',
 'SKA1',
 'OR2A1',
 'SNIP1',
 'URM1',
 'CDC45',
 'UQCRFS1',
 'GUCD1',
 'PALB2',
 'RPL27',
 'TBCD',
 'SPDYE2',
 'ATP6V0D1',
 'KPNB1',
 'RPL39',
 'DMAP1',
 'SEC61B',
 'MED18',
 'ZC3H18',
 'POLR2K',
 'MCM10',
 'TWISTNB',
 'GLMN',
 'PSMC5',
 'ALG1L',
 'ANAPC2',
 'PSME1',
 'RPUSD4',
 'ZNF506',
 'RTCB',
 'PDCD6IP',
 'RPL23',
 'PRELID3B',
 'LARS',
 'SCD',
 'CDK9',
 'TRIM49C',
 'DNMT1',
 'CPSF4',
 'USP5',
 'HCFC1',
 'HIST2H2BE',
 'DDX21',
 'FAM32A',
 'COPE',
 'TRAPPC4',
 'SLC35G6',
 'NBAS',
 'CEP152',
 'POLR3E',
 'RFT1',
 'NFRKB',
 'XPO1',
 'ARHGAP11B',
 'LURAP1',
 'TCOF1',
 'TRA2B',
 'SURF6',
 'CUL7',
 'CNOT2',
 'HAPLN2',
 'CCDC137',
 'MLST8',
 'CHEK1',
 'PRPF8',
 'PSMA1',
 'NPLOC4',
 'RPS15A',
 'PWP1',
 'TSSK3',
 'SYF2',
 'WDR1',
 'HDAC7',
 'OGT',
 'MICOS10',
 'TAF10',
 'SEM1',
 'METTL3',
 'BRF1',
 'PFDN1',
 'YJEFN3',
 'RPTOR',
 'MRPS18B',
 'DR1',
 'UBAP1',
 'YY1',
 'NOP16',

### Write Data

In [30]:
print(f'Data Loaded: {adata.shape[0]} cells x {adata.shape[1]} genes')

Data Loaded: 310385 cells x 8563 genes


In [32]:
adata.write(clean_path)

# Tokenize Data

In [33]:
import gseapy as gp
from tqdm import tqdm
import json

In [79]:
tokenized_dir = data_dir / 'tokenized'
chunk_size = 10000        # How many cells per .npy file
N_PATHWAYS = 1024          # Number of pathway "tokens" per cell
QUANTIZATION_MAX = 20.0   # Cap for outlier clamping (log-normalized expression rarely exceeds this)

### Initiation

In [56]:
gene_mask = None

In [57]:
#uncomment if processing independently
#adata = sc.read_h5ad(clean_path, backed='r')

In [58]:
genes = adata.var_names.tolist()

### Building Pathway Mask
Creates the Gene -> Pathway projection matrix.

Fetch Hallmark (H) or Canonical Pathways (C2)

In [59]:
gs_res = gp.get_library(name='MSigDB_Hallmark_2020', organism='Human')
gs_res

{'TNF-alpha Signaling via NF-kB': ['MARCKS',
  'IL23A',
  'NINJ1',
  'TNFSF9',
  'SIK1',
  'ATF3',
  'SERPINE1',
  'MYC',
  'HES1',
  'CCN1',
  'CCNL1',
  'EGR1',
  'EGR2',
  'EGR3',
  'JAG1',
  'ABCA1',
  'GADD45B',
  'GADD45A',
  'KLF10',
  'PLK2',
  'EIF1',
  'EHD1',
  'FOSL2',
  'FOSL1',
  'GPR183',
  'PLPP3',
  'IFIT2',
  'ICAM1',
  'ZC3H12A',
  'IER2',
  'IL12B',
  'IER5',
  'JUNB',
  'IER3',
  'STAT5A',
  'DUSP5',
  'EDN1',
  'DUSP4',
  'JUN',
  'DUSP1',
  'DUSP2',
  'TSC22D1',
  'CCL20',
  'SPHK1',
  'LIF',
  'IL18',
  'TUBB2A',
  'RHOB',
  'VEGFA',
  'IL1A',
  'PTPRE',
  'TLR2',
  'IL1B',
  'BHLHE40',
  'CLCF1',
  'ID2',
  'REL',
  'FJX1',
  'SGK1',
  'BTG3',
  'BTG2',
  'BTG1',
  'SDC4',
  'LITAF',
  'AREG',
  'SOCS3',
  'PANX1',
  'RIPK2',
  'NFIL3',
  'SERPINB2',
  'GCH1',
  'IFNGR2',
  'G0S2',
  'FOS',
  'F3',
  'SERPINB8',
  'SPSB1',
  'FOSB',
  'PER1',
  'F2RL1',
  'HBEGF',
  'CD44',
  'TRIP10',
  'CDKN1A',
  'PTGER4',
  'PTGS2',
  'IFIH1',
  'NAMPT',
  'OLR1',
  'ICOSLG

Filter top pathways (or just take first N)

In [71]:
len_gs_pathways = len(list(gs_res.keys()))
if N_PATHWAYS > len_gs_pathways:
    N_PATHWAYS = len_gs_pathways
    selected_pathways = list(gs_res.keys())
else:
    selected_pathways = list(gs_res.keys())[:N_PATHWAYS]
N_PATHWAYS, selected_pathways

(50,
 ['TNF-alpha Signaling via NF-kB',
  'Hypoxia',
  'Cholesterol Homeostasis',
  'Mitotic Spindle',
  'Wnt-beta Catenin Signaling',
  'TGF-beta Signaling',
  'IL-6/JAK/STAT3 Signaling',
  'DNA Repair',
  'G2-M Checkpoint',
  'Apoptosis',
  'Notch Signaling',
  'Adipogenesis',
  'Estrogen Response Early',
  'Estrogen Response Late',
  'Androgen Response',
  'Myogenesis',
  'Protein Secretion',
  'Interferon Alpha Response',
  'Interferon Gamma Response',
  'Apical Junction',
  'Apical Surface',
  'Hedgehog Signaling',
  'Complement',
  'Unfolded Protein Response',
  'PI3K/AKT/mTOR  Signaling',
  'mTORC1 Signaling',
  'E2F Targets',
  'Myc Targets V1',
  'Myc Targets V2',
  'Epithelial Mesenchymal Transition',
  'Inflammatory Response',
  'Xenobiotic Metabolism',
  'Fatty Acid Metabolism',
  'Oxidative Phosphorylation',
  'Glycolysis',
  'Reactive Oxygen Species Pathway',
  'p53 Pathway',
  'UV Response Up',
  'UV Response Dn',
  'Angiogenesis',
  'heme Metabolism',
  'Coagulation',
 

#### Building mask

In [72]:
mask = np.zeros((len(genes), len(selected_pathways)), dtype=np.float32)
mask.shape

(8563, 50)

In [73]:
gene_to_idx = {gene: i for i, gene in enumerate(genes)}

In [74]:
for p_idx, (pathway, pathway_genes) in enumerate(gs_res.items()):
    if p_idx >= N_PATHWAYS: break
    
    # Count how many genes from this pathway are actually in our dataset
    hit_count = 0
    for gene in pathway_genes:
        if gene in gene_to_idx:
            mask[gene_to_idx[gene], p_idx] = 1.0
            hit_count += 1
    
    # Normalize column: Divide by number of genes found to get "Average Activity"
    # This prevents pathways with 200 genes from having huge values vs pathways with 10 genes
    if hit_count > 0:
        mask[:, p_idx] /= hit_count
        
gene_mask = mask

In [75]:
gene_mask

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], shape=(8563, 50), dtype=float32)

#### Metadata Save

In [76]:
meta = {
            "pathways": selected_pathways,
            "n_genes": len(genes),
            "quantization_max": QUANTIZATION_MAX
        }
meta

{'pathways': ['TNF-alpha Signaling via NF-kB',
  'Hypoxia',
  'Cholesterol Homeostasis',
  'Mitotic Spindle',
  'Wnt-beta Catenin Signaling',
  'TGF-beta Signaling',
  'IL-6/JAK/STAT3 Signaling',
  'DNA Repair',
  'G2-M Checkpoint',
  'Apoptosis',
  'Notch Signaling',
  'Adipogenesis',
  'Estrogen Response Early',
  'Estrogen Response Late',
  'Androgen Response',
  'Myogenesis',
  'Protein Secretion',
  'Interferon Alpha Response',
  'Interferon Gamma Response',
  'Apical Junction',
  'Apical Surface',
  'Hedgehog Signaling',
  'Complement',
  'Unfolded Protein Response',
  'PI3K/AKT/mTOR  Signaling',
  'mTORC1 Signaling',
  'E2F Targets',
  'Myc Targets V1',
  'Myc Targets V2',
  'Epithelial Mesenchymal Transition',
  'Inflammatory Response',
  'Xenobiotic Metabolism',
  'Fatty Acid Metabolism',
  'Oxidative Phosphorylation',
  'Glycolysis',
  'Reactive Oxygen Species Pathway',
  'p53 Pathway',
  'UV Response Up',
  'UV Response Dn',
  'Angiogenesis',
  'heme Metabolism',
  'Coagulat

In [77]:
with open(data_dir/'tokenize_metadata.json', 'w') as f:
            json.dump(meta, f)

### Write Tokenized Shards

In [80]:
n_cells = adata.shape[0]
n_chunks = (n_cells // chunk_size) + 1
n_cells, n_chunks

(310385, 32)

In [82]:
for i in tqdm(range(0, n_cells, chunk_size), total=n_chunks):
    # 1. Load Chunk into Memory (Dense)
    # Slicing a backed AnnData loads that slice into memory
    end = min(i + chunk_size, n_cells)
    chunk = adata[i:end]
    
    # Get raw expression (check if sparse or dense)
    X = chunk.X
    if hasattr(X, 'toarray'): 
        X = X.toarray() # Convert sparse to dense for matmul
    
    # 2. Project to Pathway Space [Batch, Genes] @ [Genes, Pathways]
    # Result: [Batch, Pathways] (Continuous Floats)
    pathway_activity = np.dot(X, gene_mask)
    
    # 3. Quantize to Fixed Point (uint32)
    # We map range [0, QUANTIZATION_MAX] -> [0, 2^32 - 1]
    
    # Clamp outliers
    pathway_activity = np.clip(pathway_activity, 0, QUANTIZATION_MAX)
    
    # Scale
    scale_factor = (2**32 - 1) / QUANTIZATION_MAX
    quantized = (pathway_activity * scale_factor).astype(np.uint32)
    
    # 4. Save Perturbation Metadata (Actions) if needed
    # We save the obs (metadata) separately or alongside
    # For this MVP, let's just save the tokens
    shard_idx = i // chunk_size
    save_path = tokenized_dir / f'shard_{shard_idx:04d}.npy'
    np.save(save_path, quantized)
    
    # Optional: Save corresponding 'perturbation' labels for this shard
    if 'perturbation' in chunk.obs:
        labels = chunk.obs['perturbation'].values
        np.save(save_path.with_stem(save_path.stem + '_labels'), labels)

100%|███████████████████████████████████████████| 32/32 [00:04<00:00,  6.77it/s]
