In [3]:
from pathlib import Path
import scanpy as sc
import scperturb
import pandas as pd
import numpy as np
import os
import requests

In [7]:
data_dir = Path('/home/ubuntu/data/ac_model')

raw_path = data_dir / 'replogle_k562.h5ad'
clean_path = data_dir / 'cleaned_replogle_k562.h5ad'

## Data
Downloading [Replogle 2022 (K562 Essential)](https://virtualcellmodels.cziscience.com/dataset/k562-essential-perturb-seq) via scPerturb: [more datasets](https://projects.sanderlab.org/scperturb/datavzrd/scPerturb_vzrd_v1/dataset_info/index_1.html)

The original dataset is part of a large-scale genotype-phenotype map developed by Replogle et. al. in 2022. The dataset specifically includes gene expression profiles from the human chronic myeloid leukemia cell line (K562 cells) after genetic perturbation of essential genes using CRISPR interference. The dataset was processed to benchmark models performing genetic perturbation prediction tasks.

In [8]:
#raw data download
#!wget -P {raw_path} https://zenodo.org/record/7041849/files/ReplogleWeissman2022_K562_essential.h5ad?download=1 

In [9]:
adata = sc.read_h5ad(raw_path)

### Filtering the data

In [10]:
# 1. Filter cells with low counts (standard QC)
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)

### (SKIPPED) Normalize Data
We log-normalize the counts for the input since our model will try to predict these values in latent space.

In [11]:
#sc.pp.normalize_total(adata, target_sum=1e4)
#sc.pp.log1p(adata)

### Perturbation Label Inspection
We need to map 'perturbation' (str) to a dense vector for the Action Conditioner.  We Typically see a format of: "GeneA_Knockdown"

In [12]:
adata.obs.columns

Index(['batch', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'guide_id',
       'percent_mito', 'UMI_count', 'z_gemgroup_UMI', 'core_scale_factor',
       'core_adjusted_UMI_count', 'disease', 'cancer', 'cell_line', 'sex',
       'age', 'perturbation', 'organism', 'perturbation_type', 'tissue_type',
       'ncounts', 'ngenes', 'nperts', 'percent_ribo', 'n_genes'],
      dtype='object')

In [13]:
set(adata.obs.perturbation)

{'KARS',
 'NUP214',
 'RRP7A',
 'CNOT2',
 'CENPA',
 'UBE2M',
 'RNF103',
 'SACM1L',
 'SF3A2',
 'RPL24',
 'ALG11',
 'POLR3B',
 'ESF1',
 'NAA50',
 'SRFBP1',
 'RAE1',
 'GTF3C6',
 'H3F3A',
 'MRPL49',
 'MRPL36',
 'GTF2H2C',
 'RPL39',
 'NAA38',
 'SMN2',
 'NBPF3',
 'FOXD4',
 'THOC6',
 'DDOST',
 'SPRTN',
 'BANF1',
 'NOP58',
 'PLK1',
 'EARS2',
 'ZNF658',
 'HCRTR1',
 'PCID2',
 'MFN2',
 'TACC3',
 'DCTN2',
 'CENPT',
 'TIMM22',
 'FGFR1OP',
 'RAD21',
 'CNN2',
 'MCM7',
 'ILF2',
 'GTPBP4',
 'PSMB3',
 'EIF3J',
 'NUP35',
 'BNIP1',
 'OGT',
 'EIF3M',
 'ZNF763',
 'RNF40',
 'CAMLG',
 'MYCBP',
 'NUP98',
 'RPL10A',
 'CNOT1',
 'TBC1D3',
 'CTNNBL1',
 'LCE1E',
 'HSF1',
 'WDR43',
 'ATP6V1H',
 'HIST1H2BN',
 'SCD',
 'GET3',
 'HMGN2',
 'ANAPC13',
 'CACTIN',
 'NRBP1',
 'PSME2',
 'SMG5',
 'MRPL19',
 'POLR1B',
 'PCF11',
 'USP19',
 'RBBP6',
 'COX17',
 'EIF2B5',
 'GAPDH',
 'ZNF131',
 'RPL35A',
 'SART1',
 'MTOR',
 'UQCRQ',
 'POP1',
 'DONSON',
 'MED28',
 'TOPBP1',
 'NDUFAB1',
 'ABCB7',
 'SEC61G',
 'TAF1B',
 'SNRNP25',
 'VPS5

### Write Data

In [14]:
print(f'Data Loaded: {adata.shape[0]} cells x {adata.shape[1]} genes')

Data Loaded: 310385 cells x 8563 genes


In [15]:
adata.write(clean_path)

# (DO NOT RUN) Tokenize Data

In [None]:
import gseapy as gp
from tqdm import tqdm
import json

In [None]:
tokenized_dir = data_dir / 'tokenized'
chunk_size = 10000        # How many cells per .npy file
N_PATHWAYS = 1024          # Number of pathway "tokens" per cell
QUANTIZATION_MAX = 20.0   # Cap for outlier clamping (log-normalized expression rarely exceeds this)

### Initiation

In [None]:
gene_mask = None

In [None]:
#uncomment if processing independently
#adata = sc.read_h5ad(clean_path, backed='r')

In [None]:
genes = adata.var_names.tolist()

### Building Pathway Mask
Creates the Gene -> Pathway projection matrix.

Fetch Hallmark (H) or Canonical Pathways (C2)

In [None]:
gs_res = gp.get_library(name='MSigDB_Hallmark_2020', organism='Human')
gs_res

Filter top pathways (or just take first N)

In [None]:
len_gs_pathways = len(list(gs_res.keys()))
if N_PATHWAYS > len_gs_pathways:
    N_PATHWAYS = len_gs_pathways
    selected_pathways = list(gs_res.keys())
else:
    selected_pathways = list(gs_res.keys())[:N_PATHWAYS]
N_PATHWAYS, selected_pathways

#### Building mask

In [None]:
mask = np.zeros((len(genes), len(selected_pathways)), dtype=np.float32)
mask.shape

In [None]:
gene_to_idx = {gene: i for i, gene in enumerate(genes)}

In [None]:
for p_idx, (pathway, pathway_genes) in enumerate(gs_res.items()):
    if p_idx >= N_PATHWAYS: break
    
    # Count how many genes from this pathway are actually in our dataset
    hit_count = 0
    for gene in pathway_genes:
        if gene in gene_to_idx:
            mask[gene_to_idx[gene], p_idx] = 1.0
            hit_count += 1
    
    # Normalize column: Divide by number of genes found to get "Average Activity"
    # This prevents pathways with 200 genes from having huge values vs pathways with 10 genes
    if hit_count > 0:
        mask[:, p_idx] /= hit_count
        
gene_mask = mask

In [None]:
gene_mask

#### Metadata Save

In [None]:
meta = {
            "pathways": selected_pathways,
            "n_genes": len(genes),
            "quantization_max": QUANTIZATION_MAX
        }
meta

In [None]:
with open(data_dir/'tokenize_metadata.json', 'w') as f:
            json.dump(meta, f)

### Write Tokenized Shards

In [None]:
n_cells = adata.shape[0]
n_chunks = (n_cells // chunk_size) + 1
n_cells, n_chunks

In [None]:
for i in tqdm(range(0, n_cells, chunk_size), total=n_chunks):
    # 1. Load Chunk into Memory (Dense)
    # Slicing a backed AnnData loads that slice into memory
    end = min(i + chunk_size, n_cells)
    chunk = adata[i:end]
    
    # Get raw expression (check if sparse or dense)
    X = chunk.X
    if hasattr(X, 'toarray'): 
        X = X.toarray() # Convert sparse to dense for matmul
    
    # 2. Project to Pathway Space [Batch, Genes] @ [Genes, Pathways]
    # Result: [Batch, Pathways] (Continuous Floats)
    pathway_activity = np.dot(X, gene_mask)
    
    # 3. Quantize to Fixed Point (uint32)
    # We map range [0, QUANTIZATION_MAX] -> [0, 2^32 - 1]
    
    # Clamp outliers
    pathway_activity = np.clip(pathway_activity, 0, QUANTIZATION_MAX)
    
    # Scale
    scale_factor = (2**32 - 1) / QUANTIZATION_MAX
    quantized = (pathway_activity * scale_factor).astype(np.uint32)
    
    # 4. Save Perturbation Metadata (Actions) if needed
    # We save the obs (metadata) separately or alongside
    # For this MVP, let's just save the tokens
    shard_idx = i // chunk_size
    save_path = tokenized_dir / f'shard_{shard_idx:04d}.npy'
    np.save(save_path, quantized)
    
    # Optional: Save corresponding 'perturbation' labels for this shard
    if 'perturbation' in chunk.obs:
        labels = chunk.obs['perturbation'].values
        np.save(save_path.with_stem(save_path.stem + '_labels'), labels)