## V0_4 Updates
1. Change Perturbation Encoding

In [1]:
import json
import numpy as np
import pandas as pd
import scanpy as sc

from tqdm import tqdm
from pathlib import Path
from gears import PertData
from pathlib import Path
from Bio import Entrez, SeqIO
import mygene
import gzip
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
import torch
from sklearn.model_selection import train_test_split


In [2]:
data_dir = Path('/home/ubuntu')
tok_dir = data_dir / 'training'
raw_dir = data_dir / 'raw_data'
protein_dir = data_dir / 'uniprot'
pert_dir = data_dir / 'pert_embd'

splits= ['train', 'val', 'test']

crispri_seq = raw_dir / 'crispri_seq.csv'
dataset_name = 'k562e'

In [4]:
chunk_size = 20000        # How many cells per file
n_genes = 8192 # 2**13
count_normalize_target = 1e4 

In [5]:
def get_device():
    device = 'cpu'
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1337)
        device = 'cuda'
    print(f'using {device}')
    return device

device = get_device()

using cuda


In [6]:
# Constants for BioJEPA
modality_to_id = {
    'protein': 0,
    'chemical': 1,
    'dna': 2
}

mode_to_id = {
    'crispri': 0,
    'crispra': 1,
    'overexpression': 2,
    'knockout': 3,
    'inhibitor': 4,
    'agonist': 5,
    'degrader': 6,
    'binder': 7,
    'control': 8,
    'unknown': 9
}

## Data Download

In [7]:
pert_data = PertData(data_dir / dataset_name) 
pert_data.load(data_name='replogle_k562_essential')

Downloading...
100%|███████████████████████████████████████████████████████████████| 9.46M/9.46M [00:00<00:00, 120MiB/s]
Downloading...
100%|████████████████████████████████████████████████████████████████| 670M/670M [00:25<00:00, 26.4MiB/s]
Extracting zip file...
Done!
Downloading...
100%|████████████████████████████████████████████████████████████████| 559k/559k [00:00<00:00, 23.7MiB/s]
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['C7orf26+ctrl' 'C14orf178+ctrl' 'RPS10-NUDT3+ctrl' 'SEM1+ctrl' 'FAU+ctrl']
Creating pyg object for each cell in the data...
Creating dataset file...
100%|████████████████████████████████████████████████████████████████| 1088/1088 [07:56<00:00,  2.28it/s]
Done!
Saving new dataset pyg object at /home/ubuntu/k562e/replogle_k562_essential/data_pyg/cell_graphs.pkl
Done!


In [8]:
pert_data.prepare_split(split='simulation', seed=1) 
adata = pert_data.adata 

Creating new splits....
Saving new splits at /home/ubuntu/k562e/replogle_k562_essential/splits/replogle_k562_essential_simulation_1_0.75.pkl
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:272
Done!


In [9]:
adata.obs

Unnamed: 0_level_0,condition,cell_type,cov_drug_dose_name,dose_val,control,condition_name,split
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
AAACCCAAGAAGCCAC-34,UBL5+ctrl,K562,K562_UBL5+ctrl_1+1,1+1,0,K562_UBL5+ctrl_1+1,train
AAAGGATTCTCTCGAC-42,UBL5+ctrl,K562,K562_UBL5+ctrl_1+1,1+1,0,K562_UBL5+ctrl_1+1,train
AACGGGAGTAATGATG-25,UBL5+ctrl,K562,K562_UBL5+ctrl_1+1,1+1,0,K562_UBL5+ctrl_1+1,train
AAGAACAAGCTAGATA-35,UBL5+ctrl,K562,K562_UBL5+ctrl_1+1,1+1,0,K562_UBL5+ctrl_1+1,train
AAGACTCTCTATTGTC-33,UBL5+ctrl,K562,K562_UBL5+ctrl_1+1,1+1,0,K562_UBL5+ctrl_1+1,train
...,...,...,...,...,...,...,...
TTATTGCCACGTGAGA-26,RPS2+ctrl,K562,K562_RPS2+ctrl_1+1,1+1,0,K562_RPS2+ctrl_1+1,train
TTGTTGTTCATGGATC-47,RPS2+ctrl,K562,K562_RPS2+ctrl_1+1,1+1,0,K562_RPS2+ctrl_1+1,train
TTTCACATCTCTTCAA-41,RPS2+ctrl,K562,K562_RPS2+ctrl_1+1,1+1,0,K562_RPS2+ctrl_1+1,train
TTTCAGTGTAGAGTTA-18,RPS2+ctrl,K562,K562_RPS2+ctrl_1+1,1+1,0,K562_RPS2+ctrl_1+1,train


## Add in sgRNA sequence

In [10]:
seq_df = pd.read_csv(crispri_seq)
seq_df.set_index('cell_barcode', inplace=True)
seq_df.head()

Unnamed: 0_level_0,gene,gem_group,gene_id,UMI_count,sgID_AB,guide_seq_a,guide_seq_b
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
AAACCCAAGAAATCCA-27,NAF1,27,ENSG00000145414,11438.0,NAF1_+_164087918.23-P1P2|NAF1_-_164087674.23-P1P2,GGAGCCGTGAGCTTGTCCAG,GCCGCGACGGCGTTCAGAAC
AAACCCAAGAACTTCC-31,BUB1,31,ENSG00000169679,5342.0,BUB1_-_111435363.23-P1P2|BUB1_-_111435372.23-P1P2,GGACAAGCGCCGGGCCTCAG,GCGGGCCTCAGCGGAACCCA
AAACCCAAGAAGCCAC-34,UBL5,34,ENSG00000198258,17305.0,UBL5_-_9938639.23-P1P2|UBL5_+_9938801.23-P1P2,GGGTGAGGAGCTGGTGGCGT,GCCCAGGGCCGCGAACCCCG
AAACCCAAGAATAGTC-43,C9orf16,43,ENSG00000171159,30244.0,C9orf16_+_130922603.23-P1P2|C9orf16_+_13092264...,GGCCGGCGCCGGATGGAAGG,GGCCGCGCGACGATGGAACG
AAACCCAAGACAGCGT-28,TIMM9,28,ENSG00000100575,8407.0,TIMM9_-_58893843.23-P1P2|TIMM9_-_58893848.23-P1P2,GGGGACGGTTGAGCCTTGGG,GGGTTGAGCCTTGGGAGGGA


In [11]:
common_idx = adata.obs.index.intersection(seq_df.index)
adata = adata[common_idx]
seq_df = seq_df.loc[common_idx]

In [12]:
cols_to_add = ['gene','gene_id','gem_group', 'sgID_AB', 'guide_seq_a', 'guide_seq_b']
for c in cols_to_add:
    if c in seq_df.columns:
        adata.obs[c] = seq_df[c].values

In [13]:
adata.obs.head()

Unnamed: 0_level_0,condition,cell_type,cov_drug_dose_name,dose_val,control,condition_name,split,gene,gene_id,gem_group,sgID_AB,guide_seq_a,guide_seq_b
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
AAACCCAAGAAGCCAC-34,UBL5+ctrl,K562,K562_UBL5+ctrl_1+1,1+1,0,K562_UBL5+ctrl_1+1,train,UBL5,ENSG00000198258,34,UBL5_-_9938639.23-P1P2|UBL5_+_9938801.23-P1P2,GGGTGAGGAGCTGGTGGCGT,GCCCAGGGCCGCGAACCCCG
AAAGGATTCTCTCGAC-42,UBL5+ctrl,K562,K562_UBL5+ctrl_1+1,1+1,0,K562_UBL5+ctrl_1+1,train,UBL5,ENSG00000198258,42,UBL5_-_9938639.23-P1P2|UBL5_+_9938801.23-P1P2,GGGTGAGGAGCTGGTGGCGT,GCCCAGGGCCGCGAACCCCG
AACGGGAGTAATGATG-25,UBL5+ctrl,K562,K562_UBL5+ctrl_1+1,1+1,0,K562_UBL5+ctrl_1+1,train,UBL5,ENSG00000198258,25,UBL5_-_9938639.23-P1P2|UBL5_+_9938801.23-P1P2,GGGTGAGGAGCTGGTGGCGT,GCCCAGGGCCGCGAACCCCG
AAGAACAAGCTAGATA-35,UBL5+ctrl,K562,K562_UBL5+ctrl_1+1,1+1,0,K562_UBL5+ctrl_1+1,train,UBL5,ENSG00000198258,35,UBL5_-_9938639.23-P1P2|UBL5_+_9938801.23-P1P2,GGGTGAGGAGCTGGTGGCGT,GCCCAGGGCCGCGAACCCCG
AAGACTCTCTATTGTC-33,UBL5+ctrl,K562,K562_UBL5+ctrl_1+1,1+1,0,K562_UBL5+ctrl_1+1,train,UBL5,ENSG00000198258,33,UBL5_-_9938639.23-P1P2|UBL5_+_9938801.23-P1P2,GGGTGAGGAGCTGGTGGCGT,GCCCAGGGCCGCGAACCCCG


**Check that all non-control rows have been updated**

In [14]:
is_case = adata.obs['condition'] != 'ctrl'
cols_to_check = ['guide_seq_a', 'guide_seq_b']

In [15]:
has_missing_seq = (
    adata.obs[cols_to_check].isnull() | 
    (adata.obs[cols_to_check] == '')
).any(axis=1)

len(adata.obs[is_case & has_missing_seq])

0

## Data Processing

In [16]:
total_counts = np.array(adata.X.sum(axis=1)).flatten()
adata.obs['log_total_counts'] = np.log1p(total_counts)

In [17]:
sc.pp.normalize_total(adata, target_sum=count_normalize_target)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=n_genes, subset=True)

In [18]:
genes = adata.var.gene_name.tolist()
print(f'Final Feature Space: {len(genes)} Genes')

with open(data_dir / 'gene_names.json', 'w') as f:
    json.dump(genes, f)

Final Feature Space: 5000 Genes


## Perturbation Handling

In [19]:
def clean_gears_name(name):
    # GEARS format is 'Gene+ctrl' -> We want 'Gene'
    if name.endswith('+ctrl'):
        return name.replace('+ctrl', '')
    if name.startswith('ctrl+'):
        return name.replace('ctrl+', '')
    elif name == 'ctrl':
        return 'control'
    return str.strip(name)

In [20]:
all_perts = adata.obs['condition'].unique()
all_perts = [clean_gears_name(p) for p in all_perts if p != 'ctrl']
unique_conditions = sorted(list(set(all_perts)))
len(all_perts), len(unique_conditions)

(1087, 1087)

**Extract ensembl IDs**

In [21]:
target_ids = adata.obs['gene_id'].unique().tolist()
target_ids = [t for t in target_ids if isinstance(t, str) and t.startswith('ENSG')]
len(target_ids)

1087

**Query to get uniprot IDs**

In [22]:
mg = mygene.MyGeneInfo()
mg_results = mg.querymany(target_ids, scopes='ensembl.gene,symbol,alias', fields='uniprot,symbol', species='human')

Input sequence provided is already in string format. No operation performed
Input sequence provided is already in string format. No operation performed


In [23]:
ensg_to_acc = {}
for res in mg_results:
    query_ensg = res['query']
    if 'uniprot' in res:
        # Prefer Swiss-Prot (Reviewed) -> TrEMBL (Unreviewed)
        if 'Swiss-Prot' in res['uniprot']:
            val = res['uniprot']['Swiss-Prot']
            # Handle list vs string
            ensg_to_acc[query_ensg] = val[0] if isinstance(val, list) else val
        elif 'TrEMBL' in res['uniprot']:
            val = res['uniprot']['TrEMBL']
            ensg_to_acc[query_ensg] = val[0] if isinstance(val, list) else val

**Read uniprot FASTA**

In [24]:
all_prots = 'https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/reference_proteomes/Eukaryota/UP000005640/UP000005640_9606.fasta.gz'
#! wget -P {protein_dir} {all_prots}

fasta = [i for i in protein_dir.iterdir()][0]
needed_accessions = set(ensg_to_acc.values())
acc_to_seq = {}
fasta

PosixPath('/home/ubuntu/uniprot/UP000005640_9606.fasta.gz')

In [25]:
with gzip.open(fasta, 'rt') as handle:
    for record in SeqIO.parse(handle, 'fasta'):
        # Header format: >sp|P12345|GENE_HUMAN ...
        # We split by '|' to get P12345 safely
        header_parts = record.id.split('|')
        
        # Check if valid format >db|Accession|...
        if len(header_parts) >= 2:
            accession = header_parts[1]
            
            # Only keep if linked to our genes
            if accession in needed_accessions:
                acc_to_seq[accession] = str(record.seq)

In [26]:
id_to_seq = {}
missing_targets = []

In [27]:
for ensg in target_ids:
    acc = ensg_to_acc.get(ensg)
    if acc and acc in acc_to_seq:
        id_to_seq[ensg] = acc_to_seq[acc]
    else:
        missing_targets.append(ensg)

In [28]:
missing_targets

['ENSG00000196531']

**Search NCBI for missing genes**

In [29]:
Entrez.email = 'gptomics@gmail.com'
results = {}

In [30]:
for ensg_id in tqdm(missing_targets.copy()):
    try:
        search_handle = Entrez.esearch(db="gene", term=ensg_id, retmax=1)
        search_record = Entrez.read(search_handle)
        
        if search_record['IdList']:
            ncbi_gene_id = search_record['IdList'][0]
            
            # Step 2: Link Gene -> Protein (RefSeq)
            # This finds the "representative" protein for this gene
            link_handle = Entrez.elink(dbfrom="gene", db="protein", id=ncbi_gene_id, linkname="gene_protein_refseq")
            link_record = Entrez.read(link_handle)
            
            # Check if links exist
            if link_record and link_record[0]['LinkSetDb']:
                # Get the first linked Protein ID
                protein_id = link_record[0]['LinkSetDb'][0]['Link'][0]['Id']
                
                # Step 3: Fetch Sequence
                fetch_handle = Entrez.efetch(db="protein", id=protein_id, rettype="fasta", retmode="text")
                seq_record = SeqIO.read(fetch_handle, "fasta")
                
                results[ensg_id] = str(seq_record.seq)
                missing_targets.remove(ensg_id)
            else:
                results[ensg_id] = 'M' # Found gene, but no RefSeq protein linked
        else:
            results[ensg_id] = 'M' 
    except Exception as e:
        print(f'Failed Entrez lookup for {ensg_id}: {e}')
        results[ensg_id] = 'M'
id_to_seq.update(results)

100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]


In [31]:
found = len([v for v in id_to_seq.values() if v != 'M'])

f'Initial: {len(target_ids)} | Found {found} | Missing {len(missing_targets)}'

'Initial: 1087 | Found 1087 | Missing 0'

**Saving Perturbations**

In [32]:
with open(pert_dir /'perturbation_seq.json', 'w') as f:
    json.dump({str(k): str(v) for k, v in id_to_seq.items()}, f)

### Get Embeddings for Perturbations

**Get Protein Embeddings**
for protein embeddings, we'll use the [ESM-2](https://huggingface.co/facebook/esm2_t6_8M_UR50D) for now given it's been a common workhorse. 

In [33]:
prot_model_name = 'facebook/esm2_t6_8M_UR50D' 

In [34]:
prot_tokenizer = AutoTokenizer.from_pretrained(prot_model_name)
prot_model = AutoModel.from_pretrained(prot_model_name)
prot_model.to(device).eval()

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/31.4M [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 320, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-5): 6 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=320, out_features=320, bias=True)
            (key): Linear(in_features=320, out_features=320, bias=True)
            (value): Linear(in_features=320, out_features=320, bias=True)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=320, out_features=320, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
          (dense): Linear(in_features=320, out_features=1280, bias=True)
        )
        (output): EsmOutput(
        

**Get Embeddings**

In [35]:
prot_vectors = []
prot_index_map = {}

In [36]:
with torch.no_grad():
    for idx, (ensg, seq) in enumerate(tqdm(id_to_seq.items())):
        inputs = prot_tokenizer(seq, return_tensors='pt', truncation=True, max_length=1024).to(device)
        out = prot_model(**inputs)
        emb = out.last_hidden_state[0].mean(dim=0).cpu().numpy()
        
        prot_vectors.append(emb)
        prot_index_map[ensg] = idx # Mapping is now ENSG -> Index

100%|███████████████████████████████████████████████████████████████| 1087/1087 [00:08<00:00, 126.62it/s]


**Save Protein Mapping/Embeddings**

In [37]:
final_bank = np.vstack(prot_vectors)

In [38]:
np.save(pert_dir / 'action_embeddings_esm2.npy', final_bank)

In [39]:
with open(pert_dir / 'anchor_to_id.json', 'w') as f:
    json.dump(prot_index_map, f)

**Get sgRNA embeddings**

In [40]:
dna_model_name = 'InstaDeepAI/NTv3_650M_pre' 

In [42]:
dna_tokenizer = AutoTokenizer.from_pretrained(dna_model_name, trust_remote_code=True)
dna_model = AutoModelForMaskedLM.from_pretrained(dna_model_name, trust_remote_code=True)
dna_model.to(device).eval()

tokenizer_config.json:   0%|          | 0.00/1.48k [00:00<?, ?B/s]

tokenization_ntv3.py:   0%|          | 0.00/7.85k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/ntv3_base_model:
- tokenization_ntv3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


vocab.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/149 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.75k [00:00<?, ?B/s]

configuration_ntv3_pretrained.py:   0%|          | 0.00/8.09k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/ntv3_base_model:
- configuration_ntv3_pretrained.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_ntv3_pretrained.py:   0%|          | 0.00/35.2k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/ntv3_base_model:
- modeling_ntv3_pretrained.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/2.61G [00:00<?, ?B/s]

NTv3PreTrained(
  (core): Core(
    (embed_layer): Embedding(11, 16)
    (stem): Stem(
      (conv): Conv1d(16, 1536, kernel_size=(15,), stride=(1,), padding=same)
    )
    (conv_tower_blocks): ModuleList(
      (0-6): 7 x ConvTowerBlock(
        (conv): ConvBlock(
          (conv): Conv1d(1536, 1536, kernel_size=(5,), stride=(1,), padding=same)
          (layer_norm): LayerNormFP32((np.int64(1536),), eps=1e-05, elementwise_affine=True)
        )
        (res_conv): ResidualConvBlock(
          (conv_block): ConvBlock(
            (conv): Conv1d(1536, 1536, kernel_size=(1,), stride=(1,), padding=same)
            (layer_norm): LayerNormFP32((np.int64(1536),), eps=1e-05, elementwise_affine=True)
          )
        )
        (avg_pool): AvgPool1d(kernel_size=(2,), stride=(2,), padding=(0,))
      )
    )
    (transformer_blocks): ModuleList(
      (0-11): 12 x SelfAttentionBlock(
        (self_attention_layer_norm): LayerNormFP32((1536,), eps=1e-05, elementwise_affine=True)
        (fi

In [43]:
guide_pairs = adata.obs[['sgID_AB', 'guide_seq_a', 'guide_seq_b']].drop_duplicates().dropna()
dna_vectors = []
dna_index_map = {}

In [44]:
def get_ntv3_embedding(sequence, tokenizer, model, device):
    # Tokenize with PADDING to 128
    inputs = tokenizer(
        [sequence], 
        return_tensors="pt",
        add_special_tokens=False, 
        padding="max_length",    
        max_length=128,          
    )
    
    # Manual Mask Generation (Fixes KeyError)
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    attention_mask = (inputs['input_ids'] != pad_id).long()
    inputs = {k: v.to(device) for k, v in inputs.items() if k != 'attention_mask'} 
    attention_mask = attention_mask.to(device)
    
    # Forward Pass
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        
    # Masked Mean Pooling
    token_embeddings = outputs.hidden_states[-1] # Shape: [1, 128, 1536]
    
    # Expand mask to match embedding dimensions 
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    
    # Sum valid tokens / Count valid tokens
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    
    # Avoid division by zero
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
    # Result is [1, 1536] -> squeeze to [1536]
    return (sum_embeddings / sum_mask).squeeze(0)

In [45]:
with torch.no_grad():
    for i, (barcode, row) in enumerate(tqdm(guide_pairs.iterrows(), total=len(guide_pairs))):
        sg_id = row['sgID_AB']
        seq_a = row['guide_seq_a'].upper()
        seq_b = row['guide_seq_b'].upper()
        
        # Embed A & B using the helper
        emb_a = get_ntv3_embedding(seq_a, dna_tokenizer, dna_model, device)
        emb_b = get_ntv3_embedding(seq_b, dna_tokenizer, dna_model, device)
        
        # Average (still on GPU)
        combined_tensor = (emb_a + emb_b) / 2
        
        # Move to CPU/Numpy for storage
        combined = combined_tensor.cpu().numpy()
        
        dna_vectors.append(combined)
        dna_index_map[sg_id] = i

100%|████████████████████████████████████████████████████████████████| 1250/1250 [00:53<00:00, 23.17it/s]


**Save sgRNA embeddings**

In [46]:
np.save(pert_dir / 'input_embeddings_dna.npy', np.vstack(dna_vectors))
with open(pert_dir / 'input_to_id.json', 'w') as f:
    json.dump(dna_index_map, f)

## Create Pertrubation Dataset

we end up with a longer list than targets since we have different cripri targets to the same genes

In [47]:
valid_pairs = []
pert_metadata = {} 

In [48]:
for i, (barcode, row) in enumerate(tqdm(guide_pairs.iterrows(), total=len(guide_pairs))):
    sg_id = row['sgID_AB']
    matching = adata.obs[adata.obs['sgID_AB'] == sg_id]
    if len(matching) == 0: continue
    
    target_ensg = matching['gene_id'].iloc[0]
    
    if target_ensg in prot_index_map and sg_id in dna_index_map:
        this_mod = modality_to_id['dna']
        this_mode = mode_to_id['crispri']
        
        # 1. Add to Alignment Training Pairs
        valid_pairs.append((
            dna_index_map[sg_id], 
            prot_index_map[target_ensg], 
            this_mod, 
            this_mode
        ))
        
        # 2. Register for Sharding Step
        pert_metadata[sg_id] = (this_mod, this_mode)
        
len(valid_pairs), len(target_ids)

100%|███████████████████████████████████████████████████████████████| 1250/1250 [00:05<00:00, 219.97it/s]


(1153, 1087)

In [49]:
def save_alignment(pairs, name):
    if not pairs: return
    inp, anch, mod, mode = zip(*pairs)
    np.savez(pert_dir / name / f'pert_pairs_cripri_{name}.npz', 
             input_idx=np.array(inp, dtype=np.int32), 
             anchor_idx=np.array(anch, dtype=np.int32),
             modality=np.array(mod, dtype=np.int8), 
             mode=np.array(mode, dtype=np.int8))

In [50]:
train_pairs, val_pairs = train_test_split(valid_pairs, test_size=0.1, random_state=42)
save_alignment(train_pairs, 'train')
save_alignment(val_pairs, 'val')

## Prepare Controls

In [51]:
control_groups = {}
ctrl_inds = np.where(adata.obs['condition'] == 'ctrl')[0]



In [52]:
for idx in ctrl_inds:
    grp = adata.obs.iloc[idx]['gem_group']
    if grp not in control_groups: control_groups[grp] = {'X': [], 'total': []}
    control_groups[grp]['X'].append(adata.X[idx].toarray().flatten())
    control_groups[grp]['total'].append(adata.obs['log_total_counts'].iloc[idx])

In [53]:
for g in control_groups:
    control_groups[g]['X'] = np.array(control_groups[g]['X'], dtype=np.float32)
    control_groups[g]['total'] = np.array(control_groups[g]['total'], dtype=np.float32)

In [54]:
print(f'Found {len(control_groups.keys())} control cells.')

Found 48 control cells.


## Shard Save

In [55]:
def write_shards(split_name, conditions, ds_name):
    """
    Iterates through cells belonging to the given conditions, 
    pairs them with random controls, and saves .npz shards.
    """
    print(f'Split: {split_name.upper()}')
    
    # Filter cells belonging to these perturbations
    mask = adata.obs['condition'].isin(conditions) & (adata.obs['condition'] != 'ctrl')
    indices = np.where(mask)[0]
    np.random.shuffle(indices)
    
    # Buffer for current shard
    buffer = {k: [] for k in ['control_x', 'control_total', 'case_x', 'case_total', 'pert_idx', 'pert_mod', 'pert_mode']}
    
    shard_count = 0
    save_path = tok_dir / split_name
    
    for idx in tqdm(indices):
        # get group/pert info
        grp = adata.obs.iloc[idx]['gem_group']
        sg_id = adata.obs.iloc[idx]['sgID_AB']

        # validate exists in controls and perturbations
        if grp not in control_groups or sg_id not in pert_metadata: 
            print(f'skipping {idx} |grp {grp} |sg_id {sg_id}')
            continue
        
        # extract perturbation metadata
        p_mod, p_mode = pert_metadata[sg_id]

        # Sample Control
        c_bank = control_groups[grp]
        ri = np.random.randint(len(c_bank['X']))
        
        # 3. Add to Buffer
        buffer['control_x'].append(c_bank['X'][ri])
        buffer['control_total'].append(c_bank['total'][ri])
        buffer['case_x'].append(adata.X[idx].toarray().flatten())
        buffer['case_total'].append(adata.obs['log_total_counts'].iloc[idx])
        buffer['pert_idx'].append(dna_index_map[sg_id])
        buffer['pert_mod'].append(p_mod)  
        buffer['pert_mode'].append(p_mode)
        
        # 4. Save if buffer full
        if len(buffer['case_x']) >= chunk_size:
            np.savez(save_path/ f'shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
                     control=np.array(buffer['control_x'], dtype=np.float32),
                     control_total=np.array(buffer['control_total'], dtype=np.float32),
                     case=np.array(buffer['case_x'], dtype=np.float32),
                     case_total=np.array(buffer['case_total'], dtype=np.float32),
                     pert_idx=np.array(buffer['pert_idx'], dtype=np.int32),
                     pert_modality=np.array(buffer['pert_mod'], dtype=np.int8),
                     pert_mode=np.array(buffer['pert_mode'], dtype=np.int8))
            buffer = {k: [] for k in buffer}
            shard_count += 1
            
    # Save leftovers
    if buffer['case_x']:
         np.savez(save_path / f'shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
                     control=np.array(buffer['control_x'], dtype=np.float32),
                     control_total=np.array(buffer['control_total'], dtype=np.float32),
                     case=np.array(buffer['case_x'], dtype=np.float32),
                     case_total=np.array(buffer['case_total'], dtype=np.float32),
                     pert_idx=np.array(buffer['pert_idx'], dtype=np.int32),
                     pert_modality=np.array(buffer['pert_mod'], dtype=np.int8),
                     pert_mode=np.array(buffer['pert_mode'], dtype=np.int8))

In [56]:
split_map = pert_data.set2conditions 

In [57]:
write_shards('train', split_map['train'], dataset_name)

Split: TRAIN


100%|██████████████████████████████████████████████████████████| 101682/101682 [00:36<00:00, 2814.64it/s]


In [58]:
write_shards('val', split_map['val'], dataset_name)

Split: VAL


100%|████████████████████████████████████████████████████████████| 11044/11044 [00:03<00:00, 3009.56it/s]


In [59]:
write_shards('test', split_map['test'], dataset_name)

Split: TEST


100%|████████████████████████████████████████████████████████████| 38829/38829 [00:13<00:00, 2900.80it/s]


## Write Pretraining

In [60]:
pt_chunk_size = 50000     
pt_tok_dir = data_dir / 'pretraining'

In [61]:
def write_pt_shards(split_name, condition_list, ds_name):
    """
    Iterates through cells belonging to the given conditions, 
    pairs them with random controls, and saves .npz shards.
    """
    print(f'Split: {split_name.upper()}')
    
    # Filter cells belonging to these split
    mask = adata.obs['condition'].isin(condition_list)
    indices = np.where(mask)[0]
    
    # Shuffle for randomness
    np.random.shuffle(indices)
    
    # Buffer for current shard
    buffer = {
        'x': [], 
        'total': []
    }
    
    shard_count = 0
    save_path = pt_tok_dir / split_name
    
    for idx in tqdm(indices):
        # 1. Get Data
        x = adata.X[idx].toarray().flatten().astype(np.float32)
        total = adata.obs['log_total_counts'].iloc[idx].astype(np.float32)
        
        # 3. Add to Buffer
        buffer['x'].append(x)
        buffer['total'].append(total)
        
        # 4. Save if buffer full
        if len(buffer['x']) >= pt_chunk_size:
            np.savez(
                save_path / f'pt_shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
                x=np.array(buffer['x']),
                total=np.array(buffer['total'])
            )
            # Reset
            buffer = {k: [] for k in buffer}
            shard_count += 1
            
    # Save leftovers
    if len(buffer['x']) > 0:
        np.savez(
                save_path / f'pt_shard_{ds_name}_{split_name}_{shard_count:04d}.npz',
                x=np.array(buffer['x']),
                total=np.array(buffer['total'])
        )

In [62]:
write_pt_shards('train', split_map['train'], dataset_name)

Split: TRAIN


100%|█████████████████████████████████████████████████████████| 112373/112373 [00:04<00:00, 23088.77it/s]


In [63]:
write_pt_shards('val', split_map['val'], dataset_name)

Split: VAL


100%|███████████████████████████████████████████████████████████| 11044/11044 [00:00<00:00, 30547.30it/s]
