In [1]:
import json
import numpy as np
import pandas as pd
import scanpy as sc
import gseapy as gp
from tqdm import tqdm
from pathlib import Path
from gears import PertData, GEARS
from pathlib import Path

In [2]:
data_dir = Path('/home/ubuntu')
tok_dir = data_dir / 'pretraining'
splits= ['train','val','test']

In [3]:
chunk_size = 75000        # How many cells per file
n_pathways = 1024          # Number of pathway "tokens" per cell
n_genes = 8192 # 2**13
count_normalize_target = 1e4 # normalize each cell to this count
dataset_name = 'k562e'

## Data Download

In [4]:
pert_data = PertData(data_dir / dataset_name) 
pert_data.load(data_name='replogle_k562_essential')

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['C7orf26+ctrl' 'C14orf178+ctrl' 'RPS10-NUDT3+ctrl' 'SEM1+ctrl' 'FAU+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


In [5]:
pert_data.prepare_split(split='simulation', seed=1) 
adata = pert_data.adata 

Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:272
Done!


here1


## Data Processing

In [6]:
total_counts = np.array(adata.X.sum(axis=1)).flatten()
adata.obs['log_total_counts'] = np.log1p(total_counts)

In [7]:
sc.pp.normalize_total(adata, target_sum=count_normalize_target)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=n_genes, subset=True)

**Load Known Gene Map**

In [8]:
gene_dir = data_dir / 'gene_names.json'
with open(gene_dir, 'r') as f:
    genes = json.load(f)
len(genes)

5000

In [9]:
genes == adata.var.gene_name.tolist()

True

## Prepare Data

## Shard Save

In [10]:
split_map = pert_data.set2conditions 

In [11]:
def write_shards(split_name, condition_list, ds_name):
    """
    Iterates through cells belonging to the given conditions, 
    pairs them with random controls, and saves .npz shards.
    """
    print(f'Split: {split_name.upper()}')
    
    # Filter cells belonging to these split
    mask = adata.obs['condition'].isin(condition_list)
    indices = np.where(mask)[0]
    
    # Shuffle for randomness
    np.random.shuffle(indices)
    
    # Buffer for current shard
    buffer = {
        'x': [], 
        'total': []
    }
    
    shard_count = 0
    save_path = tok_dir / split_name
    
    for idx in tqdm(indices):
        # 1. Get Data
        x = adata.X[idx].toarray().flatten().astype(np.float32)
        total = adata.obs['log_total_counts'].iloc[idx].astype(np.float32)
        
        # 3. Add to Buffer
        buffer['x'].append(x)
        buffer['total'].append(total)
        
        # 4. Save if buffer full
        if len(buffer['x']) >= chunk_size:
            np.savez(
                save_path / f'pt_shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
                x=np.array(buffer['x']),
                total=np.array(buffer['total'])
            )
            # Reset
            buffer = {k: [] for k in buffer}
            shard_count += 1
            
    # Save leftovers
    if len(buffer['x']) > 0:
        np.savez(
                save_path / f'pt_shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
                x=np.array(buffer['x']),
                total=np.array(buffer['total'])
        )

In [12]:
write_shards('train', split_map['train'], dataset_name)

Split: TRAIN


100%|████████████████████████████████████████████████████████████████| 112373/112373 [00:04<00:00, 23475.55it/s]


In [13]:
write_shards('val', split_map['val'], dataset_name)

Split: VAL


100%|██████████████████████████████████████████████████████████████████| 11044/11044 [00:00<00:00, 27423.93it/s]


In [14]:
write_shards('test', split_map['test'], dataset_name)

Split: TEST


100%|██████████████████████████████████████████████████████████████████| 38829/38829 [00:01<00:00, 28955.39it/s]
