In [1]:
import torch
import esm
import gc
import pandas as pd
import numpy as np
import tqdm
from cuml import PCA


In [2]:
DATASET_NAME = "all_v2_2"
INPUT_DATASET = f"../data/main_dataset_creation/outputs/{DATASET_NAME}/dataset_with_alphafold_paths.csv"
OUTPUT_DATASET = f"../data/main_dataset_creation/outputs/{DATASET_NAME}/dataset_with_esm_features.csv"
MAX_CUDA_SEQ_LEN = 6000 # out of memory w/ the 3070 after this


In [3]:
# https://www.kaggle.com/code/kaggleqrdl/esm-quick-start-lb237

token_map = {'L': 0, 'A': 1, 'G': 2, 'V': 3, 'S': 4, 'E': 5, 'R': 6, 'T': 7, 'I': 8, 'D': 9, 'P': 10,
             'K': 11, 'Q': 12, 'N': 13, 'F': 14, 'Y': 15, 'M': 16, 'H': 17, 'W': 18, 'C': 19}
t_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
t_model.eval()  # disables dropout for deterministic results
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
t_model.to(device)
print(device)

cuda


#### Embeddings

We input each train and test wildtype into our transformer and extract the last hidden layers activations. For each protein, this has shape (1, len_protein_seq, 1280). We will save the full embeddings and the pooled embeddings for use later. Additionally we will save the MLM pretrain task amino acid prediction which indicates mutation probability and mutation entropy. This has shape (1, len_protein_seq, 33) but we extract to (len_protein_seq, 20) where 20 is number of common amino acids.


In [5]:
df = pd.read_csv(INPUT_DATASET)
df.columns.tolist()

['uniprot',
 'wild_aa',
 'mutated_chain',
 'mutation_position',
 'mutated_aa',
 'pH',
 'sequence',
 'length',
 'chain_start',
 'chain_end',
 'AlphaFoldDB',
 'Tm',
 'ddG',
 'dTm',
 'dataset_source',
 'infos_found',
 'alphafold_path']

In [6]:
# TRAIN AND TEST WILDTYPES
from scipy.special import softmax
from scipy.stats import entropy
PCA_CT = 16  # random sample size per protein to fit PCA with
all_sequences = df.sequence.unique()
all_pdb_embed_pool = np.zeros((len(all_sequences)+1, 1280))
all_pdb_embed_local = []
all_pdb_embed_tmp = []

all_pdb_prob = []

# EXTRACT TRANSFORMER EMBEDDINGS FOR TRAIN AND TEST WILDTYPES
print('Extracting embeddings from proteins...')
for i, seq in tqdm.tqdm(enumerate(all_sequences)):
    # EXTRACT EMBEDDINGS, MUTATION PROBABILITIES, ENTROPY
    if len(seq)>MAX_CUDA_SEQ_LEN:
        continue
    data = [("protein1", seq)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
        results = t_model(batch_tokens, repr_layers=[33])
    logits = (results['logits'].detach().cpu().numpy()[0, ].T)[4:24, 1:-1]
    all_pdb_prob.append(softmax(logits, axis=0))
    results = results["representations"][33].detach().cpu().numpy()

    # SAVE EMBEDDINGS
    all_pdb_embed_local.append(results)
    all_pdb_embed_pool[i, ] = np.mean(results[0, :, :], axis=0)

    # TEMPORARILY SAVE LOCAL MUTATION EMBEDDINGS
    tmp = df.loc[df.sequence == seq, 'mutation_position'].unique()
    # if len(tmp) > PCA_CT:
    #     tmp = np.random.choice(tmp, PCA_CT, replace=False)
    for j in tmp:
        all_pdb_embed_tmp.append(results[0, int(j), :])

    del batch_tokens, results
    gc.collect()
    torch.cuda.empty_cache()

all_pdb_embed_tmp = np.stack(all_pdb_embed_tmp)


Extracting embeddings from proteins...


477it [01:17,  6.14it/s]
