In [1]:
import json
import numpy as np
import pandas as pd
import scanpy as sc
import gseapy as gp
from tqdm import tqdm
from pathlib import Path
from gears import PertData, GEARS
from pathlib import Path

In [2]:
data_dir = Path('/Users/djemec/data/jepa/v0_2')
tok_dir = data_dir / 'training'
splits= ['train','val','test']

In [3]:
chunk_size = 25000        # How many cells per file
n_pathways = 1024          # Number of pathway "tokens" per cell
n_genes = 8192 # 2**13
count_normalize_target = 1e4 # normalize each cell to this count
dataset_name = 'k562e'

## Data Download

In [4]:
pert_data = PertData(data_dir / dataset_name) 
pert_data.load(data_name='replogle_k562_essential')

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['C7orf26+ctrl' 'C14orf178+ctrl' 'RPS10-NUDT3+ctrl' 'SEM1+ctrl' 'FAU+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


In [5]:
pert_data.prepare_split(split='simulation', seed=1) 
adata = pert_data.adata 

Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:272
Done!


here1


In [17]:
adata.var

Unnamed: 0_level_0,gene_name,chr,start,end,class,strand,length,in_matrix,mean,std,cv,fano,highly_variable,means,dispersions,dispersions_norm
gene_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
ENSG00000237491,LINC01409,chr1,778747,810065,gene_version10,+,31318,True,0.137594,0.380048,2.762105,1.049733,True,0.130939,0.222407,0.028718
ENSG00000188290,HES4,chr1,998962,1000172,gene_version10,-,1210,True,0.249577,0.561933,2.251540,1.265214,True,0.205869,0.322631,0.715487
ENSG00000187608,ISG15,chr1,1001138,1014540,gene_version10,+,13402,True,0.377373,0.787623,2.087120,1.643865,True,0.335591,0.757568,3.695832
ENSG00000176022,B3GALT6,chr1,1232237,1235041,gene_version7,+,2804,True,0.315492,0.603217,1.911989,1.153345,True,0.251509,0.187828,-0.208232
ENSG00000131584,ACAP3,chr1,1292390,1309609,gene_version19,-,17219,True,0.146009,0.391124,2.678769,1.047732,True,0.133338,0.198733,-0.133505
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSG00000198695,MT-ND6,chrM,14149,14673,gene_version2,-,524,True,9.036310,6.860383,0.759202,5.208415,True,2.169510,1.539535,1.636096
ENSG00000278704,BX004987.1,GL000009.2,56140,58376,gene_version1,-,2236,True,0.241213,0.507266,2.102976,1.066768,True,0.215032,0.235268,0.116845
ENSG00000274847,MAFIP,GL000194.1,53594,115055,gene_version1,-,61461,True,0.127525,0.361556,2.835168,1.025072,True,0.116269,0.183864,-0.235394
ENSG00000278384,AL354822.1,GL000218.1,51867,54893,gene_version1,-,3026,True,0.248814,0.516552,2.076062,1.072394,True,0.227783,0.282330,0.439332


## Data Processing

In [None]:
total_counts = np.array(adata.X.sum(axis=1)).flatten()
adata.obs['log_total_counts'] = np.log1p(total_counts)

In [None]:
sc.pp.normalize_total(adata, target_sum=count_normalize_target)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=n_genes, subset=True)

In [None]:
genes = adata.var.gene_name.tolist()
print(f'Final Feature Space: {len(genes)} Genes')

with open(data_dir / 'gene_names.json', 'w') as f:
    json.dump(genes, f)

## Pathway Mask

In [None]:
gs_res = gp.get_library(name='DSigDB', organism='Human')

In [None]:
valid_pathways = {k: v for k, v in gs_res.items() if 80 <= len(v) <= 1400}
pathway_names = list(valid_pathways.keys())[:n_pathways]
# Save Pathway Names
with open(data_dir / 'pathway_names.json', 'w') as f:
    json.dump(pathway_names, f)

In [None]:
binary_mask = np.zeros((len(genes), len(pathway_names)), dtype=np.float32)
gene_to_idx = {gene: i for i, gene in enumerate(genes)}

In [None]:
for p_idx, p_name in enumerate(pathway_names):
    hit_count = 0
    genes_in_pathway = valid_pathways[p_name]
    for g in genes_in_pathway:
        if g in gene_to_idx:
            binary_mask[gene_to_idx[g], p_idx] = 1.0
            hit_count += 1

    if hit_count <= 1:
    	print(f'pathway {p_name} had {hit_count} gene hits')

In [None]:
np.save(data_dir / 'binary_pathway_mask.npy', binary_mask)


## Prepare Controls

In [None]:
control_mask = adata.obs['condition'] == 'ctrl'
control_indices = np.where(control_mask)[0]

In [None]:
# Format: List of (Gene_Vector, Total_Count_Scalar)
control_bank = {
    'X': adata.X[control_indices].toarray().astype(np.float32),
    'total': adata.obs['log_total_counts'].values[control_indices].astype(np.float32)
}

In [None]:
print(f'Found {len(control_indices)} control cells.')

## Perturbations

In [None]:
def clean_gears_name(name):
    # GEARS format is 'Gene+ctrl' -> We want 'Gene'
    if name.endswith('+ctrl'):
        return name.replace('+ctrl', '')
    elif name == 'ctrl':
        return 'control'
    return name

In [None]:
all_perts = adata.obs['condition'].unique()
pert_to_id = {clean_gears_name(p): i for i, p in enumerate(all_perts)}

In [None]:
with open(data_dir/'perturbation_map.json', 'w') as f:
    json.dump({str(k): int(v) for k, v in pert_to_id.items()}, f)

## Shard Save

In [None]:
split_map = pert_data.set2conditions 

In [None]:
def write_shards(split_name, condition_list, ds_name):
    """
    Iterates through cells belonging to the given conditions, 
    pairs them with random controls, and saves .npz shards.
    """
    print(f'Split: {split_name.upper()}')
    
    # Filter cells belonging to these perturbations
    # Note: We exclude 'ctrl' from the 'Treated' side of the pair
    mask = adata.obs['condition'].isin(condition_list) & (adata.obs['condition'] != 'ctrl')
    indices = np.where(mask)[0]
    
    # Shuffle for randomness
    np.random.shuffle(indices)
    
    # Buffer for current shard
    buffer = {
        'control_x': [], 
        'control_total': [],
        'case_x': [], 
        'case_total': [],
        'action_ids': []
    }
    
    shard_count = 0
    save_path = eval_dir / split_name
    
    for idx in tqdm(indices):
        # 1. Get Case Data
        case_x = adata.X[idx].toarray().flatten().astype(np.float32)
        case_tot = adata.obs['log_total_counts'].iloc[idx].astype(np.float32)
        pert_name = adata.obs['condition'].iloc[idx]
        
        # 2. Get Random Control Pair
        # Ideally we match batch, but Replogle K562 is often batch-corrected or single batch.
        # For simplicity/speed here, we sample global control.
        # (Improvement: dictionary mapping batch_id -> control_indices)
        rand_idx = np.random.randint(len(control_bank['X']))
        ctrl_x = control_bank['X'][rand_idx]
        ctrl_tot = control_bank['total'][rand_idx]
        
        # 3. Add to Buffer
        buffer['control_x'].append(ctrl_x)
        buffer['control_total'].append(ctrl_tot)
        buffer['case_x'].append(case_x)
        buffer['case_total'].append(case_tot)
        buffer['action_ids'].append(pert_to_id[clean_gears_name(pert_name)])
        
        # 4. Save if buffer full
        if len(buffer['case_x']) >= chunk_size:
            np.savez(
                save_path / f'shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
                control=np.array(buffer['control_x']),
                control_total=np.array(buffer['control_total']),
                case=np.array(buffer['case_x']),
                case_total=np.array(buffer['case_total']),
                action_ids=np.array(buffer['action_ids'], dtype=np.int16)
            )
            # Reset
            buffer = {k: [] for k in buffer}
            shard_count += 1
            
    # Save leftovers
    if len(buffer['case_x']) > 0:
        np.savez(
            save_path / f'shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
            control=np.array(buffer['control_x']),
            control_total=np.array(buffer['control_total']),
            case=np.array(buffer['case_x']),
            case_total=np.array(buffer['case_total']),
            action_ids=np.array(buffer['action_ids'], dtype=np.int16)
        )

In [None]:
write_shards('train', split_map['train'], dataset_name)

In [None]:
write_shards('val', split_map['val'], dataset_name)

In [None]:
write_shards('test', split_map['test'], dataset_name)