In [1]:
import json
import numpy as np
import pandas as pd
import scanpy as sc
import gseapy as gp
from gears import PertData
from pathlib import Path
from tqdm import tqdm
import scipy.sparse as sp
import anndata as ad

In [2]:
data_dir = Path('/Users/djemec/data/jepa/v0_2')
tok_dir = data_dir / 'tokenized'
splits= ['train','val','test']

temp_dir = tok_dir / 'temp_shards'


datasets= {
    'k562e':tok_dir / 'k562e',
    'k562gw':tok_dir /'k562gw'/'perturb_processed.h5ad',
    'norman':tok_dir /'norman'/'perturb_processed.h5ad',
    'rep1e':tok_dir /'rep1e'/'perturb_processed.h5ad',
    'sciplex':tok_dir /'sciplex'/'perturb_processed.h5ad',
    'adamson':tok_dir /'adamson'/'perturb_processed.h5ad'
}

n_genes = 8192

## Key Dataset
Our model is evaluated against k562 essential, so we have to be really careful about these datasets and make sure they don't leak.  So we'll treak that special. 

In [3]:
key_ds = 'k562e'
pert_data = PertData(datasets[key_ds]) 
pert_data.load(data_name='replogle_k562_essential')

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['C7orf26+ctrl' 'C14orf178+ctrl' 'RPS10-NUDT3+ctrl' 'SEM1+ctrl' 'FAU+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


create split

In [4]:
pert_data.prepare_split(split='simulation', seed=1) 

Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:272
Done!


here1


**Get Mandatory Genes**

In [5]:
anchor_genes = pert_data.adata.var.gene_name.tolist()
len(anchor_genes), anchor_genes[:10]

(5000,
 ['LINC01409',
  'HES4',
  'ISG15',
  'B3GALT6',
  'ACAP3',
  'MXRA8',
  'CCNL2',
  'MRPL20-AS1',
  'MRPL20',
  'ATAD3B'])

**forbidden perturbations in K562**

In [8]:
def clean_gears_name(name):
    # GEARS format is 'Gene+ctrl' -> We want 'Gene'
    if name.endswith('+ctrl'):
        return name.replace('+ctrl', '')
    elif name == 'ctrl':
        return 'control'
    return str.strip(name)

In [9]:
test_perturbations = set(pert_data.set2conditions['test'])
test_perturbations = set([clean_gears_name(i) for i in test_perturbations])
len(test_perturbations), test_perturbations

(272,
 {'ACTR8',
  'ALG13',
  'ANKRD11',
  'AP2S1',
  'ARGLU1',
  'ARIH1',
  'ARPC4',
  'ATP6AP1',
  'BGLAP',
  'BMS1',
  'BRIX1',
  'BYSL',
  'C1QBP',
  'CCDC86',
  'CDC123',
  'CDC23',
  'CDC40',
  'CENPC',
  'CENPE',
  'CHERP',
  'CHMP6',
  'CNIH4',
  'COG2',
  'CPSF3',
  'CPSF4',
  'CSTF1',
  'CSTF3',
  'CTR9',
  'CUL1',
  'CWC15',
  'CWC22',
  'DAD1',
  'DBR1',
  'DDOST',
  'DDX41',
  'DDX49',
  'DDX5',
  'DDX52',
  'DDX54',
  'DERL2',
  'DHX37',
  'DLD',
  'DNAJC8',
  'DNTTIP2',
  'DONSON',
  'DPH2',
  'DYNC1H1',
  'DYNLRB1',
  'E4F1',
  'EBNA1BP2',
  'ECD',
  'EIF2B3',
  'EIF2S1',
  'EIF3E',
  'EIF3F',
  'EIF3G',
  'EIF3J',
  'EIF3M',
  'EIF5A',
  'ELP5',
  'EXOC7',
  'EXOSC10',
  'EXOSC2',
  'FAM50A',
  'FASTKD5',
  'FCF1',
  'FTSJ3',
  'GFER',
  'GFM1',
  'GNL3',
  'GNL3L',
  'GPKOW',
  'GRPEL1',
  'GTF2E1',
  'GTF2E2',
  'HAUS5',
  'HEATR1',
  'HGS',
  'HIRA',
  'HNRNPH1',
  'HNRNPM',
  'HNRNPU',
  'HSCB',
  'HSD17B12',
  'HSPA8',
  'HSPE1',
  'IGBP1',
  'ISCU',
  'KRI1',
  '

**Build Streamed expression variance**

doing this via streaming

In [10]:
gw_ds = 'k562gw'
gw_path = datasets[gw_ds]
adata_gw = sc.read_h5ad(gw_path, backed='r')

In [11]:
total_cells = adata_gw.n_obs
gw_genes = adata_gw.var_names.tolist()
total_cells, len(gw_genes), gw_genes[:10]

(1989578,
 8248,
 ['LINC01409',
  'LINC01128',
  'NOC2L',
  'KLHL17',
  'HES4',
  'ISG15',
  'SDF4',
  'B3GALT6',
  'UBE2J2',
  'ACAP3'])

In [12]:
# We need to calculate Variance = E[X^2] - (E[X])^2 column-wise
# Accumulators
sum_x = np.zeros(adata_gw.n_vars)
sum_sq_x = np.zeros(adata_gw.n_vars)
n_cells_seen = 0
CHUNK_SIZE = 50000 

In [13]:
for i in tqdm(range(0, total_cells, CHUNK_SIZE), desc="Calculating Stats"):
    chunk = adata_gw[i : i + CHUNK_SIZE]
    
    # Load chunk into memory (dense or sparse)
    X_chunk = chunk.X
    
    # Handle sparse vs dense
    if sp.issparse(X_chunk):
        # Sparse sum along axis 0
        sum_x += np.array(X_chunk.sum(axis=0)).flatten()
        # Sparse square (element-wise) then sum
        X_sq = X_chunk.copy()
        X_sq.data **= 2
        sum_sq_x += np.array(X_sq.sum(axis=0)).flatten()
    else:
        sum_x += X_chunk.sum(axis=0)
        sum_sq_x += (X_chunk ** 2).sum(axis=0)
        
    n_cells_seen += chunk.n_obs


Calculating Stats: 100%|██████████████████████████████████████████████████████████████████████████| 40/40 [02:09<00:00,  3.24s/it]


In [14]:
# Compute Variance
mean_x = sum_x / n_cells_seen
mean_sq_x = sum_sq_x / n_cells_seen
variance = mean_sq_x - (mean_x ** 2)

In [17]:
# unique genes
gw_genes_unique = sorted(list(set(adata_gw.var_names))) 

In [19]:
# Rank genes by Variance
gene_stats = pd.DataFrame({'var': variance}, index=gw_genes)
gene_stats = gene_stats.sort_values('var', ascending=False)
gene_stats = gene_stats[~gene_stats.index.duplicated(keep='first')]

In [22]:
# Build Vocabulary
candidates = [g for g in gene_stats.index if g not in set(anchor_genes)]
final_vocab = list(anchor_genes)
slots = n_genes - len(final_vocab)
slots

3192

In [23]:
final_vocab.extend(candidates[:slots+1])
final_vocab = sorted(list(set(final_vocab)))
len(final_vocab)

8192

In [26]:
# validate unique
seen = set()
unique_vocab = []
for g in final_vocab:
    if g not in seen:
        unique_vocab.append(g)
        seen.add(g)
final_vocab = unique_vocab
len(set(final_vocab))

8192

In [27]:
with open(tok_dir / 'gene_names.json', 'w') as f:
    json.dump(final_vocab, f)

**Filter GW down to only eligible genes**

In [28]:
del adata_gw

In [32]:
adata_gw = sc.read_h5ad(gw_path, backed='r')
total_cells = adata_gw.n_obs

In [33]:
shard_paths = []
shard_idx = 0

In [34]:
for i in tqdm(range(0, total_cells, CHUNK_SIZE), desc="Sharding"):
    # 1. Load raw chunk
    chunk = adata_gw[i : i + CHUNK_SIZE].to_memory()
    chunk.obs_names_make_unique()
    chunk.var_names_make_unique()
    
    # 2. Fix Metadata
    chunk.obs['perturbation_name'] = [clean_gears_name(i) for i in chunk.obs['perturbation']]
    
    # 3. Filter Leakage
    mask_keep = ~chunk.obs['perturbation_name'].isin(test_perturbations)
    if mask_keep.sum() == 0:
        continue # Skip empty chunks
    chunk = chunk[mask_keep].copy()
    
    # 4. Reindex to 8192
    df = chunk.to_df()
    df = df.reindex(columns=final_vocab, fill_value=0.0)
    
    # 5. Create Clean AnnData
    chunk_final = sc.AnnData(df, obs=chunk.obs)
    # Important: Convert to sparse if mostly zeros to save disk space
    chunk_final.X = sp.csr_matrix(chunk_final.X)
    
    # 6. Save Shard
    save_path = temp_dir / f'shard_{shard_idx}.h5ad'
    chunk_final.write(save_path)
    shard_paths.append(save_path)
    shard_idx += 1

Sharding: 100%|███████████████████████████████████████████████████████████████████████████████████| 40/40 [08:23<00:00, 12.58s/it]


**Merge shards**

In [35]:
target_file = tok_dir /'k562gw' / 'replogle_k562_gw_expanded_8k.h5ad'
shard_paths

[PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_0.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_1.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_2.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_3.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_4.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_5.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_6.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_7.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_8.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_9.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_10.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/temp_shards/shard_11.h5ad'),
 PosixPath('/Users/djemec/data/jepa/v0_2/tokenized

In [36]:
del adata_gw
del pert_data

In [38]:
ad.experimental.concat_on_disk(
    shard_paths,
    target_file
)