# v0.3 Updates
Merging multiple datasets together. Perturbations were already turned into embeddings. Since we're merging multiple datasets together, we have to handle the fact that some datasets are missing expression counts for some of the genes in our map and so we need to handle the difference between missing and 0. 

In [1]:
import json
import numpy as np
import pandas as pd
import scanpy as sc
from tqdm import tqdm
from pathlib import Path
import gc

In [2]:
data_dir = Path('/Users/djemec/data/jepa/v0_3')
raw_dir = data_dir / 'raw_files'
pert_dir = data_dir / 'pert_embd'
training_dir = data_dir / 'training'
pretraining_dir = data_dir / 'pretraining'
splits= ['train','val','test']

In [3]:
datasets= {
    'rep1e':raw_dir /'rep1e'/'perturb_processed.h5ad',
    'adamson':raw_dir /'adamson'/'perturb_processed.h5ad',
    'k562gw':raw_dir /'k562gw'/'replogle_k562_gw_expanded_8k.h5ad',
    'k562e':raw_dir / 'k562e' / 'ReplogleWeissman2022_K562_essential.h5ad',
}
big_ds = ['k562gw']

In [4]:
chunk_size = 10000
pt_chunk_size = 15000
count_normalize_target = 1e4
val_split_pct = 0.05
train_cells = 0
test_cells = 0
val_cells = 0
pt_train_cells = 0
pt_val_cells = 0

**Load genes/perturbations previously identified**

In [5]:
with open(raw_dir / 'gene_and_perts.json', 'r') as f:
    meta = json.load(f)
    all_genes = meta['all_genes']
    hold_out_perts = set(meta['hold_out_perts'])

n_genes = len(all_genes)
n_genes

8192

In [6]:
with open(pert_dir / 'pert_to_id.json', 'r') as f:
    pert_to_id = json.load(f)
n_perts = len(pert_to_id.keys())
n_perts, len(hold_out_perts)

(9876, 272)

**Create All Gene Index Map**

In [7]:
gene_to_id = {g: i for i, g in enumerate(all_genes)}
n_genes = len(all_genes)

In [8]:
def clean_gears_name(name):
    # GEARS format is 'Gene+ctrl' -> We want 'Gene'
    if name.endswith('+ctrl'):
        return name.replace('+ctrl', '')
    if name.startswith('ctrl+'):
        return name.replace('ctrl+', '')
    elif name == 'ctrl':
        return 'control'
    return str.strip(name)

## Save file shards

In [9]:
def save_shard_split(ds_key, split_name, indices, adata, ctrl_bank_X, ctrl_bank_totals, 
                     local_indices, global_indices, dataset_valid_mask,
                     batch_col, batch_to_ctrl_indices, condition_col):
    if len(indices) == 0:
        return
    
    save_path = training_dir / split_name

    indices = np.sort(indices)
    saved_mask = dataset_valid_mask.astype(np.int8)
    meta_conds = adata.obs[condition_col].values
    meta_sf = adata.obs['size_factor'].values
    meta_log_tot = adata.obs['log_total_counts'].values
    meta_batches = adata.obs[batch_col].values if batch_col else None

    shard_count = 0
    
    for i in tqdm(range(0, len(indices), chunk_size), desc=f'Processing {split_name}'):
        idx_chunk = indices[i : i + chunk_size]
        current_batch_len = len(idx_chunk)

        # 1. Vectorized Read & Norm (Case)
        raw_batch = adata.X[idx_chunk]
        if hasattr(raw_batch, 'toarray'):
            raw_batch = raw_batch.toarray() # Convert sparse to dense immediately

        sf_chunk = meta_sf[idx_chunk].reshape(-1, 1)
        norm_batch = np.log1p(raw_batch * sf_chunk).astype(np.float32)

        # 2. Map to Global
        case_global = np.zeros((current_batch_len, n_genes), dtype=np.float32)
        case_global[:, global_indices] = norm_batch[:, local_indices]

        # 3. Metadata
        cond_chunk = meta_conds[idx_chunk]
        case_tot_chunk = meta_log_tot[idx_chunk].astype(np.float32)

        # 4. Control Sampling
        batch_ctrl_indices = np.zeros(current_batch_len, dtype=int)

        if batch_col:
            batch_id_chunk = meta_batches[idx_chunk]
            unique_batches, inverse_indices = np.unique(batch_id_chunk, return_inverse=True)
            
            for b_idx, b_id in enumerate(unique_batches):
                mask = (inverse_indices == b_idx)
                count = np.count_nonzero(mask)
                
                # Get candidates
                candidates = batch_to_ctrl_indices.get(b_id, [])
                if len(candidates) > 0:
                    choices = np.random.choice(candidates, size=count)
                    batch_ctrl_indices[mask] = choices
                else:
                    batch_ctrl_indices[mask] = np.random.randint(len(ctrl_bank_X), size=count)
        else:
            batch_ctrl_indices = np.random.randint(len(ctrl_bank_X), size=current_batch_len)

        ctrl_vecs = ctrl_bank_X[batch_ctrl_indices]
        ctrl_tots = ctrl_bank_totals[batch_ctrl_indices]
        act_ids = np.array([pert_to_id[clean_gears_name(c)] for c in cond_chunk], dtype=np.int16)

        # shuffle then save
        perm = np.random.permutation(current_batch_len)
        
        np.savez(
            save_path / f'shard_{ds_key}_{split_name}_{shard_count:04d}.npz',
            control=ctrl_vecs[perm],
            control_total=ctrl_tots[perm],
            case=case_global[perm],
            case_total=case_tot_chunk[perm],
            action_ids=act_ids[perm],
            valid_mask=saved_mask
        )
        shard_count += 1




In [10]:
def save_pretrain_shard(ds_key, split_name, indices, adata, 
                        local_indices, global_indices, dataset_valid_mask):
    
    save_path = pretraining_dir / split_name

    indices = np.sort(indices)
    saved_mask = dataset_valid_mask.astype(np.int8)
    meta_sf = adata.obs['size_factor'].values
    meta_log_tot = adata.obs['log_total_counts'].values

    shard_count = 0
    
    for i in tqdm(range(0, len(indices), pt_chunk_size), desc=f'PT {split_name}'):
        idx_chunk = indices[i : i + pt_chunk_size]
        current_batch_len = len(idx_chunk)

        # 1. Vectorized Read
        raw_batch = adata.X[idx_chunk]
        if hasattr(raw_batch, 'toarray'):
            raw_batch = raw_batch.toarray()

        # 2. Vectorized Norm
        sf_chunk = meta_sf[idx_chunk].reshape(-1, 1)
        norm_batch = np.log1p(raw_batch * sf_chunk).astype(np.float32)

        # 3. Map to Global
        x_global = np.zeros((current_batch_len, n_genes), dtype=np.float32)
        x_global[:, global_indices] = norm_batch[:, local_indices]

        # 4. Metadata
        tot_chunk = meta_log_tot[idx_chunk].astype(np.float32)

        # 5. shuffle and save
        perm = np.random.permutation(current_batch_len)
        
        np.savez(
            save_path / f'pt_shard_{ds_key}_{split_name}_{shard_count:04d}.npz',
            x=x_global[perm],
            total=tot_chunk[perm],
            valid_mask=saved_mask
        )
        shard_count += 1

## Process datasets


In [11]:
def process_dataset(ds_key, ds_path):
    print(f'Processing {ds_key}')

    # load dataset
    if ds_key in big_ds:
        adata = sc.read_h5ad(ds_path, backed='r')
        total_counts = []
        n_genes_per_cell = []
        batched_size = chunk_size
        
        for i in tqdm(range(0, adata.n_obs, batched_size), desc="Metrics"):
            chunk = adata.X[i:i+batched_size]
            
            # Sum
            batch_sum = chunk.sum(axis=1)
            if hasattr(batch_sum, 'A1'): 
                batch_sum = batch_sum.A1 
            else: 
                batch_sum = np.array(batch_sum).flatten()
            total_counts.append(batch_sum)

            # Count Genes
            batch_genes = (chunk > 0).sum(axis=1)
            if hasattr(batch_genes, 'A1'): 
                batch_genes = batch_genes.A1
            else: 
                batch_genes = np.array(batch_genes).flatten()
            n_genes_per_cell.append(batch_genes)
            
        total_counts = np.concatenate(total_counts)
        n_genes_per_cell = np.concatenate(n_genes_per_cell)
        
    else:
        adata = sc.read_h5ad(ds_path)
        if hasattr(adata.X, 'tocsr'):
            print('Converting to CSR for fast row-slicing')
            adata.X = adata.X.tocsr()
        
        raw_sums = adata.X.sum(axis=1)
        if hasattr(raw_sums, 'A1'):
            total_counts = raw_sums.A1 
        else:
            total_counts = np.array(raw_sums).flatten()

        # 2. N Genes (Count Non-Zero)
        # For sparse matrices, getting nnz per row is extremely fast
        if hasattr(adata.X, 'getnnz'):
            n_genes_per_cell = adata.X.getnnz(axis=1)
        else:
            # Dense fallback
            n_genes_per_cell = np.count_nonzero(adata.X, axis=1)
    #print(adata.var.columns)
    #print(adata.obs.columns)
    
    # Preprocessing
    
    adata.obs['total_counts'] = total_counts
    adata.obs['n_genes'] = n_genes_per_cell

    # Calculate Factors (Allowed on parent object)
    adata.obs['log_total_counts'] = np.log1p(adata.obs['total_counts'])
    adata.obs['size_factor'] = count_normalize_target / adata.obs['total_counts']

    # Filter (Creates View)
    adata = adata[adata.obs['n_genes'] >= 200]
    
    # Global Alignment
    if 'gene_name' in adata.var.columns:
        local_genes = adata.var.gene_name.tolist()
    else:
        local_genes = adata.var_names.tolist()
        
    print(f'genes {local_genes[:5]}')
    valid_map = []
    dataset_valid_mask = np.zeros(n_genes, dtype=np.float32)

    for local_i, gene in enumerate(local_genes):
        if gene in gene_to_id:
            global_i = gene_to_id[gene]
            valid_map.append((local_i, global_i))
            dataset_valid_mask[global_i] = 1.0
            
    print(f'overlapping genes {len(valid_map)} | skipped genes {len(local_genes) - len(valid_map)}')

    local_indices, global_indices = zip(*valid_map)
    local_indices = np.array(local_indices)
    global_indices = np.array(global_indices)

    # Handle Batch & Conditions
    #print(list(adata.obs.columns))
    batch_col = None
    candidates = ['batch', 'batch_id', 'gem_group', 'sequencing_batch']
    for col in candidates:
        if col in adata.obs.columns:
            batch_col = col
            print(f'Batch column: {batch_col}')
            break
    
    if not batch_col:
        print('No batch column. using global pool')

    condition_col = None
    cond_candidates = ['perturbation','condition']
    for col in cond_candidates:
        if col in adata.obs.columns:
            condition_col = col
            print(f'Condition column: {condition_col}')
            break

    # Build Normalized Control Bank
    cntrl_names = ['ctrl', 'control', 'non-targeting']
    is_ctrl = pd.Series(adata.obs[condition_col].isin(cntrl_names))
    ctrl_indices = np.where(is_ctrl)[0]

    if len(ctrl_indices) == 0:
        print('!!!!!no control')
    
    raw_ctrls = adata.X[ctrl_indices]
    if hasattr(raw_ctrls, 'toarray'):
        raw_ctrls = raw_ctrls.toarray()

    ctrl_sf = adata.obs['size_factor'].iloc[ctrl_indices].values.reshape(-1, 1)
    norm_ctrls = np.log1p(raw_ctrls * ctrl_sf).astype(np.float32)
    
    ctrl_bank_X = np.zeros((len(ctrl_indices), n_genes), dtype=np.float32)
    ctrl_bank_X[:, global_indices] = norm_ctrls[:, local_indices]
    ctrl_bank_totals = adata.obs['log_total_counts'].iloc[ctrl_indices].values.astype(np.float32)

    batch_to_ctrl_indices = {}
    if batch_col:
        ctrl_batches = adata.obs[batch_col].iloc[ctrl_indices].values
        for i, batch_id in enumerate(ctrl_batches):
            if batch_id not in batch_to_ctrl_indices:
                batch_to_ctrl_indices[batch_id] = []
            batch_to_ctrl_indices[batch_id].append(i)

    # Categorize Cells
    conditions = pd.Series(adata.obs[condition_col].values)

    train_pool_indices = []
    test_indices = []
    pt_pool_indices = []
    stats = {'hold_out': 0, 'unknown': 0, 'train_pool': 0}

    for idx in range(len(adata)):
        if is_ctrl.iloc[idx]: 
            pt_pool_indices.append(idx)
            continue
            
        clean_cond = clean_gears_name(conditions.iloc[idx])
        if clean_cond not in pert_to_id:
            stats['unknown'] += 1
            pt_pool_indices.append(idx)
            continue
            
        if clean_cond in hold_out_perts:
            test_indices.append(idx)
            stats['hold_out'] += 1
        else:
            train_pool_indices.append(idx)
            pt_pool_indices.append(idx)
            stats['train_pool'] += 1
    print(f'Trainable: {stats['train_pool']} | Test {stats['hold_out']} | Skipped {stats['unknown']}')
    print(f'Pretraining: {len(pt_pool_indices)}')

    # set train and val splits
    np.random.shuffle(train_pool_indices)
    num_val = int(len(train_pool_indices) * val_split_pct)
    
    val_indices = train_pool_indices[:num_val]
    train_indices = train_pool_indices[num_val:]

    save_shard_split(ds_key, 'train', train_indices, adata, ctrl_bank_X, ctrl_bank_totals,
                     local_indices, global_indices, dataset_valid_mask, 
                     batch_col, batch_to_ctrl_indices, condition_col)
    
    save_shard_split(ds_key, 'val', val_indices, adata, ctrl_bank_X, ctrl_bank_totals,
                     local_indices, global_indices, dataset_valid_mask, 
                     batch_col, batch_to_ctrl_indices, condition_col)
    
    save_shard_split(ds_key, 'test', test_indices, adata, ctrl_bank_X, ctrl_bank_totals,
                     local_indices, global_indices, dataset_valid_mask, 
                     batch_col, batch_to_ctrl_indices, condition_col)

    # set and save pretraining
    np.random.shuffle(pt_pool_indices)
    
    # We can use the same Validation Split Percentage
    num_pt_val = int(len(pt_pool_indices) * val_split_pct)
    
    pt_val_indices = pt_pool_indices[:num_pt_val]
    pt_train_indices = pt_pool_indices[num_pt_val:]
    
    save_pretrain_shard(ds_key, 'train', pt_train_indices, adata, 
                        local_indices, global_indices, dataset_valid_mask)
    
    save_pretrain_shard(ds_key, 'val', pt_val_indices, adata, 
                        local_indices, global_indices, dataset_valid_mask)


    del adata
    del ctrl_bank_X
    gc.collect()
    
    print('\n')
    return len(train_indices), len(val_indices), len(test_indices), len(pt_train_indices), len(pt_val_indices)
    
    

In [None]:
for key, path in list(datasets.items()):
    c_tr, c_val, c_tes, pt_t, pt_v = process_dataset(key, path)
    train_cells += c_tr
    val_cells +=  c_val
    test_cells += c_tes
    pt_train_cells += pt_t
    pt_val_cells +=  pt_v
    
train_cells, val_cells, test_cells

Processing rep1e
Converting to CSR for fast row-slicing
genes ['PLEKHN1', 'HES4', 'ISG15', 'AGRN', 'B3GALT6']
overlapping genes 3739 | skipped genes 1261
No batch column. using global pool
Condition column: condition
Trainable: 133225 | Test 18023 | Skipped 0
Pretraining: 144710


Processing train: 100%|█████████████████████████████████████████████████████████| 13/13 [00:12<00:00,  1.07it/s]
Processing val: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.35it/s]
Processing test: 100%|████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.00s/it]
PT train: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:11<00:00,  1.12s/it]
PT val: 100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.47it/s]




Processing adamson
Converting to CSR for fast row-slicing
genes ['AP006222.2', 'RP11-54O7.16', 'RP11-54O7.1', 'RP11-54O7.3', 'SAMD11']
overlapping genes 2089 | skipped genes 2971
No batch column. using global pool
Condition column: condition
Trainable: 39462 | Test 4878 | Skipped 0
Pretraining: 63721


Processing train: 100%|███████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.66it/s]
Processing val: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.41it/s]
Processing test: 100%|████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.16it/s]
PT train: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:02<00:00,  1.69it/s]
PT val: 100%|█████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.17it/s]




Processing k562gw


Metrics: 100%|████████████████████████████████████████████████████████████████| 195/195 [00:39<00:00,  4.95it/s]


genes ['A1BG', 'AAAS', 'AACS', 'AAGAB', 'AAK1']
overlapping genes 8192 | skipped genes 0
Batch column: batch
Condition column: perturbation
