In [1]:
from pathlib import Path
from gears import PertData, GEARS
import scanpy as sc
import json
import gzip
import mygene
from Bio import Entrez, SeqIO
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np

#import pandas as pd
# import gseapy as gp


In [2]:
data_dir = Path('/Users/djemec/data/jepa/v0_3')
raw_dir = data_dir / 'raw_files'
pert_dir = data_dir / 'pert_embd'
protein_dir = data_dir / 'uniprot'


#protein_dir = data_dir / 'uniprot'
#train_dir = data_dir / 'training'
#pt_dir = data_dir / 'pretraining'
splits= ['train','val','test']


n_genes = 8192 # 2**13
count_normalize_target = 1e4

**Breakdown of perturbation**

* Replogle K562 Ess - CRISPRi (Repression)
* Replogle K562 GW - CRISPRi (Repression)
* Replogle RPE1 - CRISPRi (Repression)
* Adamson - CRISPRi (Repression)
* Norman - CRISPRa (Activation) + Combinatorial
* SciPlex 3 - Drug Treatment

Because of this breakdown, for now we'll just use the CRISPRi datasets.   

In [3]:
datasets= {
    'k562e':raw_dir / 'k562e',
    'k562gw':raw_dir /'k562gw'/'replogle_k562_gw_expanded_8k.h5ad',
    'rep1e':raw_dir /'rep1e'/'perturb_processed.h5ad',
    'adamson':raw_dir /'adamson'/'perturb_processed.h5ad'
}
total_cells = 0

## Extract Anchor File - K562 essential
we'll build out 4 important variables from our key dataset.  k562 will be the GEARS dataset we test against and so we need to be extra careful with data leakage on this set. 
1. `all_genes` - all genes with expression we'll have in our model
2. `anchor_genes` - minimum set of genes based on k562-essential. We have to have these to be able to do validation
3. `all_perts`- all perturbations we've seen that we'll train on minus the forbidden perturbations
4. `forbidden_perturbations` - perturbations we cannot include from the k562 dataset for our validation. If we do, we'll "leak" for our validation

We'll also track how many cells we're training on in general.  note that this will include all the cells in test/train/val so we'll actually only train on a subset. 

In [4]:
pert_data = PertData(datasets['k562e']) 
pert_data.load(data_name='replogle_k562_essential')

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['C7orf26+ctrl' 'C14orf178+ctrl' 'RPS10-NUDT3+ctrl' 'SEM1+ctrl' 'FAU+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


In [5]:
pert_data.prepare_split(split='simulation', seed=1) 
adata = pert_data.adata 

Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:272
Done!


here1


In [6]:
def clean_gears_name(name):
    # GEARS format is 'Gene+ctrl' -> We want 'Gene'
    if name.endswith('+ctrl'):
        return name.replace('+ctrl', '')
    if name.startswith('ctrl+'):
        return name.replace('ctrl+', '')
    elif name == 'ctrl':
        return 'control'
    return str.strip(name)

In [7]:
hold_out_perts = set(pert_data.set2conditions['test'])
hold_out_perts = set([clean_gears_name(i) for i in hold_out_perts])
len(hold_out_perts)

272

In [8]:
sc.pp.normalize_total(adata, target_sum=count_normalize_target)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=n_genes, subset=True)

In [9]:
anchor_genes = adata.var.gene_name.tolist()
all_genes = sorted(list(set(anchor_genes)))
len(anchor_genes), len(all_genes)

(5000, 4999)

In [10]:
all_perts = adata.obs['condition'].unique()
all_perts = [clean_gears_name(p) for p in all_perts if p != 'ctrl']
all_perts = sorted(list(set(all_perts)))
len(all_perts)

1087

In [11]:
k562e_total_cells = adata.n_obs
total_cells += k562e_total_cells
k562e_total_cells, total_cells

(162246, 162246)

## Add K562 genome wide
we'd previously already filtered down the K562 genome wide dataset based on forbidden perturbations and our target gene count so for this dataset it should be a matter of reading the file and updating our variables.  we'll be adding to:
1. `all_genes` until we reach our limit
2. `all_perts` since this is a k562 dataset we'll make sure to not add any forbidden perturbations

In [12]:
adata_gw = sc.read_h5ad(datasets['k562gw'], backed='r')

In [13]:
k652gw_total_cells = adata_gw.n_obs
total_cells += k652gw_total_cells
k652gw_total_cells, total_cells

(1943010, 2105256)

**Unify Genes** Now we'll merge the genes we extracted with the anchor_genes.  Since we preprocessed the genome-wide dataset we know that this will come up to exactly our target gene count.  in the future we'd have to be careful to only add up to the target with each dataset to not explode our model size. 

In [14]:
gw_genes = adata_gw.var_names.tolist()
gw_genes = sorted(list(set(gw_genes)))
len(gw_genes), gw_genes[:10]

(8192,
 ['A1BG',
  'AAAS',
  'AACS',
  'AAGAB',
  'AAK1',
  'AAMDC',
  'AAMP',
  'AAR2',
  'AARS',
  'AARS2'])

In [15]:
all_genes = set(all_genes).union(set(gw_genes))
len(all_genes), len(all_genes) <= n_genes

(8192, True)

**Unify Perturbations** 

Now we'll merge the perturbations we extracted with the previous perturbations.  We need to make sure to flag perturbations we don't want 

In [16]:
gw_perts = adata_gw.obs['perturbation'].unique()
gw_perts = [clean_gears_name(p) for p in gw_perts if p != 'control']
gw_perts = sorted(list(set(gw_perts)))
len(gw_perts)

9594

check that no forbidden perturbations leaked before merging the perturbations in. 

In [17]:
set(hold_out_perts) & set(gw_perts)

set()

In [18]:
all_perts = set(all_perts).union(set(gw_perts))
len(all_perts)

9869

In [19]:
f'total cells: {total_cells} | total genes {len(all_genes)} | total perturbations {len(all_perts)}'

'total cells: 2105256 | total genes 8192 | total perturbations 9869'

## rep1e Dataset

In [20]:
ds_adata = sc.read_h5ad(datasets['rep1e'], backed='r')

In [21]:
ds_total_cells = ds_adata.n_obs
total_cells += ds_total_cells
ds_total_cells, total_cells

(162733, 2267989)

In [22]:
ds_genes = ds_adata.var_names.tolist()
ds_genes = sorted(list(set(gw_genes)))
len(ds_genes), ds_genes[:10]

(8192,
 ['A1BG',
  'AAAS',
  'AACS',
  'AAGAB',
  'AAK1',
  'AAMDC',
  'AAMP',
  'AAR2',
  'AARS',
  'AARS2'])

In [23]:
if len(set(all_genes).union(set(ds_genes))) <= n_genes:
    all_genes = set(all_genes).union(set(ds_genes))
    print('including all genes')
else:
    skipped_genes = set(ds_genes) - set(all_genes)
    print(f'skipping {skipped_genes}')

including all genes


In [24]:
ds_perts = ds_adata.obs['condition'].unique()
ds_perts = [clean_gears_name(p) for p in ds_perts if p != 'ctrl' and p != 'control']
ds_perts = sorted(list(set(ds_perts)))
remaining_ds_perts = sorted(list(set(ds_perts) - set(hold_out_perts)))
len(ds_perts), len(remaining_ds_perts)

(1543, 1329)

In [25]:
all_perts = set(all_perts).union(set(remaining_ds_perts))
len(all_perts)

9869

In [26]:
f'total cells: {total_cells} | total genes {len(all_genes)} | total perturbations {len(all_perts)}'

'total cells: 2267989 | total genes 8192 | total perturbations 9869'

## adamson Dataset

In [27]:
ds_adata = sc.read_h5ad(datasets['adamson'], backed='r')

In [28]:
ds_total_cells = ds_adata.n_obs
total_cells += ds_total_cells
ds_total_cells, total_cells

(68603, 2336592)

In [29]:
ds_genes = ds_adata.var_names.tolist()
ds_genes = sorted(list(set(gw_genes)))
len(ds_genes), ds_genes[:10]

(8192,
 ['A1BG',
  'AAAS',
  'AACS',
  'AAGAB',
  'AAK1',
  'AAMDC',
  'AAMP',
  'AAR2',
  'AARS',
  'AARS2'])

In [30]:
if len(set(all_genes).union(set(ds_genes))) <= n_genes:
    all_genes = set(all_genes).union(set(ds_genes))
    print('including all genes')
else:
    skipped_genes = set(ds_genes) - set(all_genes)
    print(f'skipping {skipped_genes}')

including all genes


In [31]:
ds_perts = ds_adata.obs['condition'].unique()
ds_perts = [clean_gears_name(p) for p in ds_perts if p != 'ctrl' and p != 'control']
ds_perts = sorted(list(set(ds_perts)))
remaining_ds_perts = sorted(list(set(ds_perts) - set(hold_out_perts)))
len(ds_perts), len(remaining_ds_perts)

(86, 75)

In [32]:
all_perts = set(all_perts).union(set(remaining_ds_perts))
len(all_perts)

9875

In [33]:
f'total cells: {total_cells} | total genes {len(all_genes)} | total perturbations {len(all_perts)}'

'total cells: 2336592 | total genes 8192 | total perturbations 9875'

## Convert Perts to Embeddings

In [34]:
len(list(all_perts)),list(all_perts)[:2]

(9875, ['DOT1L', 'NR3C1'])

**load uniprot cache**

In [35]:
all_prots = 'https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/reference_proteomes/Eukaryota/UP000005640/UP000005640_9606.fasta.gz'
#! wget -P {protein_dir} {all_prots}
fasta = [i for i in protein_dir.iterdir()][0]
fasta

PosixPath('/Users/djemec/data/jepa/v0_3/uniprot/UP000005640_9606.fasta.gz')

In [36]:
uniprot_db = {}
with gzip.open(fasta, 'rt') as handle:
    for record in SeqIO.parse(handle, 'fasta'):
        # Header: >sp|P12345|GENE_HUMAN Description GN=Symbol PE=1 ...
        desc = record.description
        
        # Robustly extract Gene Name (GN)
        if 'GN=' in desc:
            parts = desc.split('GN=')
            if len(parts) > 1:
                # Symbol is usually the first string after GN=, before the next space
                symbol = parts[1].split(' ')[0]
                uniprot_db[symbol] = str(record.seq)

**Normalize Gene Perturbation Names**

In [37]:
mg = mygene.MyGeneInfo()
mg_results = mg.querymany(all_perts, scopes='symbol,alias', fields='symbol', species='human')
len(mg_results),mg_results[:2]

Input sequence provided is already in string format. No operation performed
Input sequence provided is already in string format. No operation performed
557 input query terms found dup hits:	[('MEMO1', 2), ('CT45A5', 4), ('SMAD1', 2), ('LIG1', 2), ('LEPR', 2), ('MAGEA6', 2), ('PLD1', 2), ('
2 input query terms found no hit:	['AC015871.1', 'AC118549.1']


(10522,
 [{'query': 'DOT1L', '_id': '84444', '_score': 17.593918, 'symbol': 'DOT1L'},
  {'query': 'NR3C1', '_id': '2908', '_score': 17.616997, 'symbol': 'NR3C1'}])

In [38]:
final_map = {}
missing_genes = []

In [39]:
for res in mg_results:
    original_query = res['query']
    
    if 'symbol' in res:
        modern_symbol = res['symbol']
    else:
        print(f'failed to find {res}')
        modern_symbol = original_query

    if modern_symbol in uniprot_db:
        final_map[original_query] = uniprot_db[modern_symbol]
    else:
        missing_genes.append(original_query)

# final cleanup of duplicates
missing_genes = [g for g in missing_genes if not final_map.get(g)]

missing_genes

failed to find {'query': 'AC015871.1', 'notfound': True}
failed to find {'query': 'AC118549.1', 'notfound': True}


['HIST1H4B',
 'GAS8',
 'FUT11',
 'B3GNTL1',
 'UROS',
 'TRMT12',
 'C12orf43',
 'C18orf21',
 'C18orf21',
 'AC015871.1',
 'ALG1L',
 'TMEM241',
 'HIST1H2AK',
 'TMCO3',
 'HBA2',
 'AC118549.1',
 'RPL17-C18orf32',
 'MAGEA9B',
 'HIST1H3J',
 'ZNF720',
 'C11orf58',
 'C22orf46',
 'MFSD5',
 'NEPRO',
 'HIST1H3F',
 'MFSD3',
 'FAM210A',
 'HIST1H4H',
 'RGPD6',
 'MFSD10',
 'FAM210B',
 'HIST1H2AI',
 'MFSD4B',
 'HIST1H4I',
 'HIST1H4J',
 'NME1-NME2',
 'CCDC169-SOHLH2',
 'HIST1H2AM',
 'C1orf56',
 'SLC35C2',
 'NEDD8-MDP1',
 'C11orf54',
 'RBAK-RBAKDN',
 'HIST1H2AL',
 'C21orf91',
 'H3F3B',
 'FUT10',
 'HIST2H4A',
 'C1orf43',
 'NDUFA4',
 'NDUFA4',
 'RBM14-RBM4',
 'C19orf48',
 'ILVBL',
 'RPS10-NUDT3',
 'ST20-MTHFS',
 'HIST1H2AE',
 'SERF1B',
 'HIST2H3D',
 'SMN2',
 'HIST1H4D',
 'HIST1H3H',
 'C5orf15',
 'C1orf35',
 'EFCAB2',
 'HIST1H4E',
 'ARPC4-TTLL3',
 'KRBA1',
 'C10orf88',
 'HIST1H4C',
 'MFSD9',
 'MDM1',
 'C2orf88',
 'FAM156B']

**Search NCBI for missing protein sequences**

In [40]:
Entrez.email = 'test@test.com'
results = {}
for gene in tqdm(missing_genes.copy()):
    try:
        # Specific search for RefSeq proteins in Humans
        term = f'{gene}[All Fields] AND Homo sapiens[Organism] AND srcdb_refseq[PROP]'
        
        # 1. Search
        search_handle = Entrez.esearch(db="protein", term=term, retmax=1)
        record = Entrez.read(search_handle)
        
        if record['IdList']:
            # 2. Fetch
            uid = record['IdList'][0]
            fetch_handle = Entrez.efetch(db="protein", id=uid, rettype="fasta", retmode="text")
            seq_record = SeqIO.read(fetch_handle, "fasta")
            results[gene] = str(seq_record.seq)
            missing_genes.remove(gene)
        else:
            results[gene] = "M" # True Ghost Gene
    except Exception as e:
        print(f'Failed Entrez lookup for {gene}: {e}')
        results[gene] = "M"
        
final_map.update(results)

missing_genes

100%|███████████████████████████████████████████████████████████████████████████| 74/74 [01:03<00:00,  1.16it/s]


['AC015871.1', 'ALG1L', 'AC118549.1', 'C22orf46', 'C19orf48']

**Manually handle unmapped**

In [41]:
clone_map = {
    'AC118549.1': 'ZZZ3',
    'AC015871.1': 'LINC01587', # Will fail protein search (Correctly)
}
results = {}
for gene in tqdm(missing_genes.copy()):
    
    term = clone_map.get(gene, gene)
    try:
        # Specific search for RefSeq proteins in Humans
        term = f'{term}[All Fields] AND Homo sapiens[Organism]'
        
        # 1. Search
        search_handle = Entrez.esearch(db="protein", term=term, retmax=1)
        record = Entrez.read(search_handle)
        
        if record['IdList']:
            # 2. Fetch
            uid = record['IdList'][0]
            fetch_handle = Entrez.efetch(db="protein", id=uid, rettype="fasta", retmode="text")
            seq_record = SeqIO.read(fetch_handle, "fasta")
            results[gene] = str(seq_record.seq)
            missing_genes.remove(gene)
        else:
            results[gene] = "M" # True Ghost Gene
    except Exception as e:
        print(f'Failed Entrez lookup for {gene}: {e}')
        results[gene] = "M"
        
final_map.update(results)

missing_genes

100%|█████████████████████████████████████████████████████████████████████████████| 5/5 [00:03<00:00,  1.27it/s]


[]

**remove final missing pieces for now**

In [42]:
len(all_perts)

9875

In [43]:
all_perts -= set(missing_genes)

for gene in missing_genes:
    del final_map[gene]

len(all_perts), len(final_map.keys())

(9875, 9875)

In [44]:
# check that we have protein sequences
[g for g in final_map.keys() if final_map[g] == 'M']
# check all hold_out perts still encoded
[i for i in list(hold_out_perts) if i not in all_perts]

[]

In [45]:
with open(pert_dir /'perturbation_seq.json', 'w') as f:
    json.dump({str(k): str(v) for k, v in final_map.items()}, f)

**Get Model For Embeddings**
for protein embeddings, we'll use the [ESM-2](https://huggingface.co/facebook/esm2_t6_8M_UR50D) for now given it's been a common workhorse. 

In [46]:
model_name = 'facebook/esm2_t6_8M_UR50D' 
def get_device():
    device = 'cpu'
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1337)
        device = 'cuda'
    print(f'using {device}')
    return device

DEVICE = get_device()

using cpu


In [47]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [48]:
model.eval()
model.to(DEVICE)

EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 320, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-5): 6 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=320, out_features=320, bias=True)
            (key): Linear(in_features=320, out_features=320, bias=True)
            (value): Linear(in_features=320, out_features=320, bias=True)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=320, out_features=320, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
          (dense): Linear(in_features=320, out_features=1280, bias=True)
        )
        (output): EsmOutput(
        

**Get Embeddings**

In [49]:
sorted_genes = sorted(list(final_map.keys()))
    
embeddings = []
pert_to_id = {}

In [50]:
with torch.no_grad():
    for idx, gene in enumerate(tqdm(sorted_genes)):
        seq = final_map[gene]
        
        # Map Gene Name -> Integer ID
        pert_to_id[gene] = idx
        
        # Tokenize
        # Truncate to 1024 AA (covers 95% of human proteins)
        inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        
        # Forward Pass
        outputs = model(**inputs)
        
        # Mean Pooling: Compress [Seq_Len, 320] -> [320]
        # We take the mean of all amino acids to get the "whole protein" vector
        phys_vector = outputs.last_hidden_state[0].mean(dim=0).cpu().numpy()
        embeddings.append(phys_vector)

100%|███████████████████████████████████████████████████████████████████████| 9875/9875 [05:42<00:00, 28.82it/s]


In [51]:
pert_to_id['control'] = len(embeddings)

**Save Pert Mapping/Embeddings**

In [52]:
final_bank = np.vstack(embeddings)

In [53]:
np.save(pert_dir / 'action_embeddings_esm2.npy', final_bank)

In [54]:
with open(pert_dir / 'pert_to_id.json', 'w') as f:
        json.dump(pert_to_id, f)

## Save Genes and perturbations

In [55]:
genes_and_perts = {
    'all_genes':list(all_genes),
    'anchor_genes': list(anchor_genes),
    'all_perts': list(all_perts),
    'hold_out_perts': list(hold_out_perts)
}

with open(raw_dir / 'gene_and_perts.json', 'w') as f:
    json.dump(genes_and_perts, f)