In [1]:
import json
import numpy as np
import pandas as pd
import scanpy as sc
import gseapy as gp
from tqdm import tqdm
from pathlib import Path
from gears import PertData, GEARS
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dir = Path('/Users/djemec/data/jepa/v0_2')
eval_dir = data_dir / 'tokenized'
splits= ['train','val','test']
dataset_name = 'norman'

In [3]:
chunk_size = 25000        # How many cells per file
n_pathways = 1024          # Number of pathway "tokens" per cell
n_genes = 8192 # 2**13
count_normalize_target = 1e4 # normalize each cell to this count

## Data Download

In [4]:
eval_dir / dataset_name

PosixPath('/Users/djemec/data/jepa/v0_2/tokenized/norman')

In [5]:
pert_data = PertData(eval_dir / dataset_name) 
pert_data.load(data_name='norman')

Downloading...
100%|███████████████████████████████████████████████████████████████████████████████████████| 9.46M/9.46M [00:00<00:00, 12.9MiB/s]
Downloading...
100%|█████████████████████████████████████████████████████████████████████████████████████████| 169M/169M [00:23<00:00, 7.06MiB/s]
Extracting zip file...
Done!
Downloading...
100%|█████████████████████████████████████████████████████████████████████████████████████████| 559k/559k [00:00<00:00, 2.48MiB/s]
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['RHOXF2BB+ctrl' 'LYL1+IER5L' 'ctrl+IER5L' 'KIAA1804+ctrl' 'IER5L+ctrl'
 'RHOXF2BB+ZBTB25' 'RHOXF2BB+SET']
Creating pyg object for each cell in the data...
Creating dataset file...
100%|███████████████████████████████████████████████████████████████████████████████████████████| 277/277 [01:02<00:00,  4.44it/s]
Done!
Saving new dataset pyg object at /Users/djemec/data/jepa/v0_2/tokenized/norman/norman/data_pyg/cell_graphs.pkl
Done!


In [6]:
pert_data.prepare_split(split='simulation', seed=1) 
adata = pert_data.adata 

Creating new splits....
Saving new splits at /Users/djemec/data/jepa/v0_2/tokenized/norman/norman/splits/norman_simulation_1_0.75.pkl
Simulation split test composition:
combo_seen0:9
combo_seen1:43
combo_seen2:19
unseen_single:36
Done!


## Data Processing

In [7]:
total_counts = np.array(adata.X.sum(axis=1)).flatten()
adata.obs['log_total_counts'] = np.log1p(total_counts)

In [8]:
sc.pp.normalize_total(adata, target_sum=count_normalize_target)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=n_genes, subset=True)

In [9]:
genes = adata.var.gene_name.tolist()
print(f'Final Feature Space: {len(genes)} Genes')

with open(eval_dir / 'gene_names.json', 'w') as f:
    json.dump(genes, f)

Final Feature Space: 5035 Genes


## Prepare Controls

In [10]:
control_mask = adata.obs['condition'] == 'ctrl'
control_indices = np.where(control_mask)[0]

In [11]:
# Format: List of (Gene_Vector, Total_Count_Scalar)
control_bank = {
    'X': adata.X[control_indices].toarray().astype(np.float32),
    'total': adata.obs['log_total_counts'].values[control_indices].astype(np.float32)
}

In [12]:
print(f'Found {len(control_indices)} control cells.')

Found 7353 control cells.


## Perturbations

In [13]:
original_pert_map = data_dir / 'perturbation_map.json'
with open(original_pert_map, 'r') as f:
    training_map = json.load(f)

# invert map to find ID:gene name 
max_pert_id = max([v for k, v in training_map.items()])
max_pert_id

1807

In [14]:
def clean_gears_name(name):
    # GEARS format is 'Gene+ctrl' -> We want 'Gene'
    if name.endswith('+ctrl'):
        return name.replace('+ctrl', '')
    elif name == 'ctrl':
        return 'control'
    return name

In [15]:
all_perts = adata.obs['condition'].unique()
overlap = 0
new = 0
id_to_pert = {v:k for k, v in training_map.items()}
for a in all_perts:
    clean_name = clean_gears_name(a)
    if clean_name in training_map:
        overlap += 1
    else:
        max_pert_id += 1
        new += 1
        id_to_pert[max_pert_id] = clean_name
f'overlap {overlap} | new additions {new}'

'overlap 6 | new additions 271'

In [16]:
pert_to_id = {v: k for k, v in id_to_pert.items()}

In [17]:
with open(data_dir/'perturbation_map.json', 'w') as f:
    json.dump({str(k): int(v) for k, v in pert_to_id.items()}, f)

## Shard Save

In [18]:
split_map = pert_data.set2conditions 

In [19]:
def write_shards(split_name, condition_list, ds_name):
    """
    Iterates through cells belonging to the given conditions, 
    pairs them with random controls, and saves .npz shards.
    """
    print(f'Split: {split_name.upper()}')
    
    # Filter cells belonging to these perturbations
    # Note: We exclude 'ctrl' from the 'Treated' side of the pair
    mask = adata.obs['condition'].isin(condition_list) & (adata.obs['condition'] != 'ctrl')
    indices = np.where(mask)[0]
    
    # Shuffle for randomness
    np.random.shuffle(indices)
    
    # Buffer for current shard
    buffer = {
        'control_x': [], 
        'control_total': [],
        'case_x': [], 
        'case_total': [],
        'action_ids': []
    }
    
    shard_count = 0
    save_path = eval_dir / split_name
    
    for idx in tqdm(indices):
        # 1. Get Case Data
        case_x = adata.X[idx].toarray().flatten().astype(np.float32)
        case_tot = adata.obs['log_total_counts'].iloc[idx].astype(np.float32)
        pert_name = adata.obs['condition'].iloc[idx]
        
        # 2. Get Random Control Pair
        # Ideally we match batch, but Replogle K562 is often batch-corrected or single batch.
        # For simplicity/speed here, we sample global control.
        # (Improvement: dictionary mapping batch_id -> control_indices)
        rand_idx = np.random.randint(len(control_bank['X']))
        ctrl_x = control_bank['X'][rand_idx]
        ctrl_tot = control_bank['total'][rand_idx]
        
        # 3. Add to Buffer
        buffer['control_x'].append(ctrl_x)
        buffer['control_total'].append(ctrl_tot)
        buffer['case_x'].append(case_x)
        buffer['case_total'].append(case_tot)
        buffer['action_ids'].append(pert_to_id[clean_gears_name(pert_name)])
        
        # 4. Save if buffer full
        if len(buffer['case_x']) >= chunk_size:
            np.savez(
                save_path / f'shard_{split_name}_{shard_count:04d}.npz',
                control=np.array(buffer['control_x']),
                control_total=np.array(buffer['control_total']),
                case=np.array(buffer['case_x']),
                case_total=np.array(buffer['case_total']),
                action_ids=np.array(buffer['action_ids'], dtype=np.int16)
            )
            # Reset
            buffer = {k: [] for k in buffer}
            shard_count += 1
            
    # Save leftovers
    if len(buffer['case_x']) > 0:
        np.savez(
            save_path / f'shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
            control=np.array(buffer['control_x']),
            control_total=np.array(buffer['control_total']),
            case=np.array(buffer['case_x']),
            case_total=np.array(buffer['case_total']),
            action_ids=np.array(buffer['action_ids'], dtype=np.int16)
        )

In [20]:
write_shards('train', split_map['train'], dataset_name)

Split: TRAIN


100%|████████████████████████████████████████████████████████████████████████████████████| 42496/42496 [00:01<00:00, 32743.71it/s]


In [21]:
write_shards('val', split_map['val'], dataset_name)

Split: VAL


100%|████████████████████████████████████████████████████████████████████████████████████| 10754/10754 [00:00<00:00, 46413.62it/s]


In [22]:
write_shards('test', split_map['test'], dataset_name)

Split: TEST


100%|████████████████████████████████████████████████████████████████████████████████████| 28754/28754 [00:00<00:00, 29609.94it/s]
