## V0_2 Updates
1. Properly handling the test/train/val split based on gears
2. Adding in ESM-2 based embeddings for the actions. 

In [1]:
import json
import numpy as np
import pandas as pd
import scanpy as sc
import gseapy as gp
from tqdm import tqdm
from pathlib import Path
from gears import PertData, GEARS
from pathlib import Path
from Bio import Entrez, SeqIO
import mygene
import gzip
from transformers import AutoTokenizer, AutoModel
import torch

In [2]:
data_dir = Path('/home/ubuntu')
tok_dir = data_dir / 'training'
splits= ['train','val','test']

protein_dir = data_dir / 'uniprot'
pert_dir = data_dir / 'pert_embd'

In [3]:
chunk_size = 50000        # How many cells per file
n_pathways = 1024          # Number of pathway "tokens" per cell
n_genes = 8192 # 2**13
count_normalize_target = 1e4 # normalize each cell to this count
dataset_name = 'k562e'

## Data Download

In [4]:
pert_data = PertData(data_dir / dataset_name) 
pert_data.load(data_name='replogle_k562_essential')

Downloading...
100%|██████████████████████████████████████████████████████████████████████| 9.46M/9.46M [00:00<00:00, 134MiB/s]
Downloading...
100%|███████████████████████████████████████████████████████████████████████| 670M/670M [00:25<00:00, 26.5MiB/s]
Extracting zip file...
Done!
Downloading...
100%|███████████████████████████████████████████████████████████████████████| 559k/559k [00:00<00:00, 30.2MiB/s]
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['C7orf26+ctrl' 'C14orf178+ctrl' 'RPS10-NUDT3+ctrl' 'SEM1+ctrl' 'FAU+ctrl']
Creating pyg object for each cell in the data...
Creating dataset file...
100%|███████████████████████████████████████████████████████████████████████| 1088/1088 [07:54<00:00,  2.29it/s]
Done!
Saving new dataset pyg object at /home/ubuntu/k562e/replogle_k562_essential/data_pyg/cell_graphs.pkl
Done!


In [5]:
pert_data.prepare_split(split='simulation', seed=1) 
adata = pert_data.adata 

Creating new splits....
Saving new splits at /home/ubuntu/k562e/replogle_k562_essential/splits/replogle_k562_essential_simulation_1_0.75.pkl
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:272
Done!


## Data Processing

In [6]:
total_counts = np.array(adata.X.sum(axis=1)).flatten()
adata.obs['log_total_counts'] = np.log1p(total_counts)

In [7]:
sc.pp.normalize_total(adata, target_sum=count_normalize_target)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=n_genes, subset=True)

In [8]:
genes = adata.var.gene_name.tolist()
print(f'Final Feature Space: {len(genes)} Genes')

with open(data_dir / 'gene_names.json', 'w') as f:
    json.dump(genes, f)

Final Feature Space: 5000 Genes


## Perturbation Handling

In [9]:
def clean_gears_name(name):
    # GEARS format is 'Gene+ctrl' -> We want 'Gene'
    if name.endswith('+ctrl'):
        return name.replace('+ctrl', '')
    elif name == 'ctrl':
        return 'control'
    return name

In [10]:
all_perts = adata.obs['condition'].unique()
all_perts = [clean_gears_name(p) for p in all_perts if p != 'ctrl']
unique_perts = sorted(list(set(all_perts)))
len(all_perts), len(unique_perts)

(1087, 1087)

**Uniprot Cache**

In [11]:
all_prots = 'https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/reference_proteomes/Eukaryota/UP000005640/UP000005640_9606.fasta.gz'
#! wget -P {protein_dir} {all_prots}
fasta = [i for i in protein_dir.iterdir()][0]

In [12]:
uniprot_db = {}
with gzip.open(fasta, 'rt') as handle:
    for record in SeqIO.parse(handle, 'fasta'):
        # Header: >sp|P12345|GENE_HUMAN Description GN=Symbol PE=1 ...
        desc = record.description
        
        # Robustly extract Gene Name (GN)
        if 'GN=' in desc:
            parts = desc.split('GN=')
            if len(parts) > 1:
                # Symbol is usually the first string after GN=, before the next space
                symbol = parts[1].split(' ')[0]
                uniprot_db[symbol] = str(record.seq)

**Normalize Gene Perturbation Names**

In [13]:
mg = mygene.MyGeneInfo()
mg_results = mg.querymany(unique_perts, scopes='symbol,alias', fields='symbol', species='human')
len(mg_results),mg_results[:2]

Input sequence provided is already in string format. No operation performed
Input sequence provided is already in string format. No operation performed
85 input query terms found dup hits:	[('ACTB', 2), ('ALDOA', 2), ('ALG2', 2), ('ARPC4', 2), ('ATP1A1', 3), ('ATR', 4), ('BAP1', 3), ('BDP


(1184,
 [{'query': 'AAMP', '_id': '14', '_score': 17.595827, 'symbol': 'AAMP'},
  {'query': 'AARS', '_id': '16', '_score': 16.989502, 'symbol': 'AARS1'}])

In [14]:
final_map = {}
missing_genes = []

In [15]:
for res in mg_results:
    original_query = res['query']
    
    if 'symbol' in res:
        modern_symbol = res['symbol']
    else:
        print(f'failed to find {res}')
        modern_symbol = original_query

    if modern_symbol in uniprot_db:
        final_map[original_query] = uniprot_db[modern_symbol]
    else:
        missing_genes.append(original_query)

# final cleanup of duplicates
missing_genes = [g for g in missing_genes if not final_map.get(g)]

missing_genes

['NEPRO', 'SMN2']

**Search NCBI for missing protein sequences**

In [16]:
Entrez.email = 'test@test.com'
results = {}
for gene in tqdm(missing_genes.copy()):
    try:
        # Specific search for RefSeq proteins in Humans
        term = f'{gene}[Gene Name] AND Homo sapiens[Organism] AND srcdb_refseq[PROP]'
        
        # 1. Search
        search_handle = Entrez.esearch(db="protein", term=term, retmax=1)
        record = Entrez.read(search_handle)
        
        if record['IdList']:
            # 2. Fetch
            uid = record['IdList'][0]
            fetch_handle = Entrez.efetch(db="protein", id=uid, rettype="fasta", retmode="text")
            seq_record = SeqIO.read(fetch_handle, "fasta")
            results[gene] = str(seq_record.seq)
            missing_genes.remove(gene)
        else:
            results[gene] = "M" # True Ghost Gene
    except Exception as e:
        print(f'Failed Entrez lookup for {gene}: {e}')
        results[gene] = "M"
        
final_map.update(results)

missing_genes

100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.53it/s]


[]

In [17]:
found = len([v for v in final_map.values() if v != 'M'])
total_perts = len(unique_perts)
f'Initial: {total_perts} | Found {found} | Missing {total_perts - found}'

'Initial: 1087 | Found 1087 | Missing 0'

**Saving Perturbations**

In [18]:
with open(pert_dir /'perturbation_seq.json', 'w') as f:
    json.dump({str(k): str(v) for k, v in final_map.items()}, f)

**Get Model For Embeddings**
for protein embeddings, we'll use the [ESM-2](https://huggingface.co/facebook/esm2_t6_8M_UR50D) for now given it's been a common workhorse. 

In [19]:
model_name = 'facebook/esm2_t6_8M_UR50D' 
def get_device():
    device = 'cpu'
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1337)
        device = 'cuda'
    print(f'using {device}')
    return device

DEVICE = get_device()

using cuda


In [20]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/31.4M [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
model.eval()
model.to(DEVICE)

EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 320, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-5): 6 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=320, out_features=320, bias=True)
            (key): Linear(in_features=320, out_features=320, bias=True)
            (value): Linear(in_features=320, out_features=320, bias=True)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=320, out_features=320, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
          (dense): Linear(in_features=320, out_features=1280, bias=True)
        )
        (output): EsmOutput(
        

**Get Embeddings**

In [22]:
sorted_genes = sorted(list(final_map.keys()))
    
embeddings = []
pert_to_id = {}

In [23]:
with torch.no_grad():
    for idx, gene in enumerate(tqdm(sorted_genes)):
        seq = final_map[gene]
        
        # Map Gene Name -> Integer ID
        pert_to_id[gene] = idx
        
        # Tokenize
        # Truncate to 1024 AA (covers 95% of human proteins)
        inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        
        # Forward Pass
        outputs = model(**inputs)
        
        # Mean Pooling: Compress [Seq_Len, 320] -> [320]
        # We take the mean of all amino acids to get the "whole protein" vector
        phys_vector = outputs.last_hidden_state[0].mean(dim=0).cpu().numpy()
        embeddings.append(phys_vector)

100%|██████████████████████████████████████████████████████████████████████| 1087/1087 [00:08<00:00, 133.17it/s]


In [24]:
pert_to_id['control'] = len(embeddings)

**Save Pert Mapping/Embeddings**

In [25]:
final_bank = np.vstack(embeddings)

In [26]:
np.save(pert_dir / 'action_embeddings_esm2.npy', final_bank)

In [27]:
with open(pert_dir / 'pert_to_id.json', 'w') as f:
        json.dump(pert_to_id, f)

## Pathway Mask

In [28]:
gs_res = gp.get_library(name='DSigDB', organism='Human')

In [29]:
valid_pathways = {k: v for k, v in gs_res.items() if 80 <= len(v) <= 1400}
pathway_names = list(valid_pathways.keys())[:n_pathways]
# Save Pathway Names
with open(data_dir / 'pathway_names.json', 'w') as f:
    json.dump(pathway_names, f)

In [30]:
binary_mask = np.zeros((len(genes), len(pathway_names)), dtype=np.float32)
gene_to_idx = {gene: i for i, gene in enumerate(genes)}

In [31]:
for p_idx, p_name in enumerate(pathway_names):
    hit_count = 0
    genes_in_pathway = valid_pathways[p_name]
    for g in genes_in_pathway:
        if g in gene_to_idx:
            binary_mask[gene_to_idx[g], p_idx] = 1.0
            hit_count += 1

    if hit_count <= 1:
    	print(f'pathway {p_name} had {hit_count} gene hits')

In [32]:
np.save(data_dir / 'binary_pathway_mask.npy', binary_mask)


## Prepare Controls

In [33]:
control_mask = adata.obs['condition'] == 'ctrl'
control_indices = np.where(control_mask)[0]

In [34]:
# Format: List of (Gene_Vector, Total_Count_Scalar)
control_bank = {
    'X': adata.X[control_indices].toarray().astype(np.float32),
    'total': adata.obs['log_total_counts'].values[control_indices].astype(np.float32)
}

In [35]:
print(f'Found {len(control_indices)} control cells.')

Found 10691 control cells.


## Shard Save

In [36]:
split_map = pert_data.set2conditions 

In [37]:
def write_shards(split_name, condition_list, ds_name):
    """
    Iterates through cells belonging to the given conditions, 
    pairs them with random controls, and saves .npz shards.
    """
    print(f'Split: {split_name.upper()}')
    
    # Filter cells belonging to these perturbations
    # Note: We exclude 'ctrl' from the 'Treated' side of the pair
    mask = adata.obs['condition'].isin(condition_list) & (adata.obs['condition'] != 'ctrl')
    indices = np.where(mask)[0]
    
    # Shuffle for randomness
    np.random.shuffle(indices)
    
    # Buffer for current shard
    buffer = {
        'control_x': [], 
        'control_total': [],
        'case_x': [], 
        'case_total': [],
        'action_ids': []
    }
    
    shard_count = 0
    save_path = tok_dir / split_name
    
    for idx in tqdm(indices):
        # 1. Get Case Data
        case_x = adata.X[idx].toarray().flatten().astype(np.float32)
        case_tot = adata.obs['log_total_counts'].iloc[idx].astype(np.float32)
        pert_name = adata.obs['condition'].iloc[idx]
        
        # 2. Get Random Control Pair
        # Ideally we match batch, but Replogle K562 is often batch-corrected or single batch.
        # For simplicity/speed here, we sample global control.
        # (Improvement: dictionary mapping batch_id -> control_indices)
        rand_idx = np.random.randint(len(control_bank['X']))
        ctrl_x = control_bank['X'][rand_idx]
        ctrl_tot = control_bank['total'][rand_idx]
        
        # 3. Add to Buffer
        buffer['control_x'].append(ctrl_x)
        buffer['control_total'].append(ctrl_tot)
        buffer['case_x'].append(case_x)
        buffer['case_total'].append(case_tot)
        buffer['action_ids'].append(pert_to_id[clean_gears_name(pert_name)])
        
        # 4. Save if buffer full
        if len(buffer['case_x']) >= chunk_size:
            np.savez(
                save_path / f'shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
                control=np.array(buffer['control_x']),
                control_total=np.array(buffer['control_total']),
                case=np.array(buffer['case_x']),
                case_total=np.array(buffer['case_total']),
                action_ids=np.array(buffer['action_ids'], dtype=np.int16)
            )
            # Reset
            buffer = {k: [] for k in buffer}
            shard_count += 1
            
    # Save leftovers
    if len(buffer['case_x']) > 0:
        np.savez(
            save_path / f'shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
            control=np.array(buffer['control_x']),
            control_total=np.array(buffer['control_total']),
            case=np.array(buffer['case_x']),
            case_total=np.array(buffer['case_total']),
            action_ids=np.array(buffer['action_ids'], dtype=np.int16)
        )

In [38]:
write_shards('train', split_map['train'], dataset_name)

Split: TRAIN


100%|████████████████████████████████████████████████████████████████| 101682/101682 [00:06<00:00, 14758.92it/s]


In [39]:
write_shards('val', split_map['val'], dataset_name)

Split: VAL


100%|██████████████████████████████████████████████████████████████████| 11044/11044 [00:00<00:00, 21644.71it/s]


In [40]:
write_shards('test', split_map['test'], dataset_name)

Split: TEST


100%|██████████████████████████████████████████████████████████████████| 38829/38829 [00:01<00:00, 21438.19it/s]
