In [1]:
import torch
import esm
import gc
import pandas as pd
import numpy as np
import tqdm
from cuml import PCA
from biopandas.pdb import PandasPdb
from scipy.special import softmax
from scipy.stats import entropy


#### Transformer ESM Features

In order to convert amino acid sequences aka proteins into meaningful features, we will use embeddings from SOTA protein transformer. We use Facebook"s pretrained protein transformer ESM (Evolutionary Scale Modeling) with research paper here and GitHub here. Kaggleqrdl provided a starter notebook here. In version 15+, we also extract mutation probabilties and mutation entropy from ESM!


In [1]:
DATASET_NAME = "all_v3"
INPUT_DATASET = f"./data/main_dataset_creation/outputs/{DATASET_NAME}/dataset_with_alphafold_paths.csv"
OUTPUT_DATASET = f"./data/main_dataset_creation/outputs/{DATASET_NAME}/dataset_with_esm_features.csv"
MAX_CUDA_SEQ_LEN = 7000 # out of memory w/ the 3070 after this
PCA_CT = 16  # random sample size per protein to fit PCA with
SUBSET_DUPLICATES_NO_PH = ["uniprot", "wild_aa", "mutation_position",
                           "mutated_aa", "sequence"]
ONLY_DDG = True


In [3]:
# https://www.kaggle.com/code/kaggleqrdl/esm-quick-start-lb237

np.random.seed(42)

token_map = {'L': 0, 'A': 1, 'G': 2, 'V': 3, 'S': 4, 'E': 5, 'R': 6, 'T': 7, 'I': 8, 'D': 9, 'P': 10,
             'K': 11, 'Q': 12, 'N': 13, 'F': 14, 'Y': 15, 'M': 16, 'H': 17, 'W': 18, 'C': 19}
t_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
t_model.eval()  # disables dropout for deterministic results
print("loaded model")

loaded model


#### Embeddings

We input each train and test wildtype into our transformer and extract the last hidden layers activations. For each protein, this has shape (1, len_protein_seq, 1280). We will save the full embeddings and the pooled embeddings for use later. Additionally we will save the MLM pretrain task amino acid prediction which indicates mutation probability and mutation entropy. This has shape (1, len_protein_seq, 33) but we extract to (len_protein_seq, 20) where 20 is number of common amino acids.


In [4]:
df = pd.read_csv(INPUT_DATASET)
print(len(df))
df.drop_duplicates(subset=SUBSET_DUPLICATES_NO_PH, inplace=True)
print(len(df))
# df.columns.tolist()
new_columns = [f"esm_pca_pool_{k}" for k in range(32)]
new_columns += [f"esm_pca_wild_{k}" for k in range(16)]
new_columns += [f"esm_pca_mutant_{k}" for k in range(16)]
new_columns += [f"esm_pca_local_{k}" for k in range(16)]
new_columns += ["esm_mutation_probability", "esm_mutation_entropy"]

for col in new_columns:
    df[col] = np.nan

if ONLY_DDG:
    print(len(df))
    df = df[~(df.ddG.isna())]
    print(len(df))


checking coherence between 477 pairs of sequence-atom(pdb) files


10it [00:00, 21.49it/s]

error for ./data/main_dataset_creation/3D_structures/alphafold/P28335.pdb at position 22: C instead of S


82it [00:03, 26.07it/s]

error for ./data/main_dataset_creation/3D_structures/alphafold/P00749.pdb at position 140: P instead of L


477it [00:21, 21.70it/s]

found 2 non coherent sequence-atom(pdb) pairs
11262
8734





In [5]:
def extract_embeddings(all_sequences, embeddings, t_model, device):
    # EXTRACT TRANSFORMER EMBEDDINGS FOR TRAIN AND TEST WILDTYPES
    print("Extracting embeddings from proteins...")
    
    all_seq_embed_pool = embeddings["all_seq_embed_pool"]
    all_seq_embed_local = embeddings["all_seq_embed_local"]
    all_seq_embed_by_position = embeddings["all_seq_embed_by_position"]
    all_seq_prob = embeddings["all_seq_prob"]
    
    sequences_too_big_for_cuda = []

    for i, seq in tqdm.tqdm(enumerate(all_sequences)):
        # EXTRACT EMBEDDINGS, MUTATION PROBABILITIES, ENTROPY

        # check the device is coherent with protein length
        if (str(device) == "cuda" and len(seq) > MAX_CUDA_SEQ_LEN):
            # if the protein is too big, don't try (we will do it with a cpu later)
            sequences_too_big_for_cuda.append(seq)
            continue
        elif (str(device) == "cpu" and len(seq) <= MAX_CUDA_SEQ_LEN):
            continue

        data = [("_", seq)]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        batch_tokens = batch_tokens.to(device)
        with torch.no_grad():
            results = t_model(batch_tokens, repr_layers=[33])
        # go from 33 to 20 (1 per amino acid)
        logits = (results["logits"].detach().cpu().numpy()[0, ].T)[4:24, 1:-1]
        all_seq_prob[i] = softmax(logits, axis=0)
        results = results["representations"][33].detach().cpu().numpy()

        # SAVE EMBEDDINGS
        all_seq_embed_local[i] = results
        all_seq_embed_pool[i, ] = np.mean(results[0, :, :], axis=0)

        # TEMPORARILY SAVE LOCAL MUTATION EMBEDDINGS
        mutation_positions = df.loc[df.sequence == seq,
                                    "mutation_position"].unique().astype(int)

        # the goal here is to fit the pca on the concat of all embeddings,
        # therefore if one protein has 1000 single mutation it will appear 1000 times
        # and we will overfit the pca to this protein
        # => we choose max PCA_CT single mutations
        if len(mutation_positions) > PCA_CT:
            mutation_positions = np.random.choice(
                mutation_positions, PCA_CT, replace=False)
        for j in mutation_positions:
            all_seq_embed_by_position[i] = results[0, j+1, :]

        del batch_tokens, results
        gc.collect()
        torch.cuda.empty_cache()

    embeddings = {
        "all_seq_embed_pool": all_seq_embed_pool,
        "all_seq_embed_local": all_seq_embed_local,
        "all_seq_embed_by_position": all_seq_embed_by_position,
        "all_seq_prob": all_seq_prob,
    }

    return embeddings, sequences_too_big_for_cuda



In [6]:
all_sequences = df.sequence.unique()
embeddings = {
    "all_seq_embed_pool": np.zeros((len(all_sequences), 1280)),
    "all_seq_embed_local": [None]*len(all_sequences),
    "all_seq_embed_by_position": [None]*len(all_sequences),
    "all_seq_prob": [None]*len(all_sequences),
}

# first we do 'small' proteins with cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t_model.to(device)
print(device)
embeddings, sequences_too_big_for_cuda = extract_embeddings(
    all_sequences, embeddings, t_model, device)

# then we do the biggest proteins with cpu, if needed
if len(sequences_too_big_for_cuda) > 0:
    device = torch.device("cpu")
    t_model.to(device)
    print(device)
    print(len(sequences_too_big_for_cuda))
    embeddings, sequences_too_big_for_cuda = extract_embeddings(
        all_sequences, embeddings, t_model, device)



cuda
Extracting embeddings from proteins...


471it [01:20,  5.83it/s]


cpu
0
Extracting embeddings from proteins...


471it [00:00, 380199.61it/s]


#### RAPIDS PCA

The transformer embeddings have dimension 1280. Since we only have a few thousand rows of train data, that is too many features to include all of them in our XGB model. Furthermore, we want to use local, pooling, and delta embeddings. Which would be 3x1280. To prevent our model from overfitting as a result of the "curse of dimensionality", we reduce the dimension of embeddings using RAPIDS PCA.


In [7]:
# set sequence_to_embed_mapping
sequence_to_embed_mapping = {seq: i for i, seq in enumerate(all_sequences)}
# create stack
all_seq_embed_by_position = np.stack(
    embeddings.pop("all_seq_embed_by_position"))
pca_pool = PCA(n_components=32)
pca_embeds = pca_pool.fit_transform(
    embeddings.pop("all_seq_embed_pool").astype("float32"))
pca_local = PCA(n_components=16)
pca_local.fit(all_seq_embed_by_position.astype("float32"))

# we delete all_seq_embed_by_position: we only used it to fit the pca_local
del all_seq_embed_by_position
_ = gc.collect()


In [8]:
print(pca_embeds.shape)
print(len(sequences_too_big_for_cuda))


(471, 32)
0


In [11]:
def add_embbeddings(row, sequence_to_embed_mapping, embeddings, t_model, device, errors):
    try:
        ##################
        # ROW - IS ROW FROM DOWNLOADED TRAIN CSV
        ##################
        # pdb_map = {x: y for x, y in zip(all_pdb, range(len(all_pdb)))}
        atom_df = PandasPdb().read_pdb(row.alphafold_path)
        atom_df = atom_df.df['ATOM']

        residue_atoms = atom_df.loc[(
            atom_df.residue_number == row.mutation_position)].reset_index(drop=True)

        # FEATURE ENGINEER
        if len(residue_atoms) > 0:

            # check the device is coherent with protein length
            if (str(device) == "cuda" and len(row.sequence) > MAX_CUDA_SEQ_LEN):
                # if the protein is too big, don't try (we will do it with a cpu later)
                return row
            elif (str(device) == "cpu" and len(row.sequence) <= MAX_CUDA_SEQ_LEN):
                return row
                
            # GET MUTANT EMBEDDINGS
            mutated_sequence = (row.sequence[:row.mutation_position] +
                                row.mutated_aa+row.sequence[row.mutation_position+1:])
            data = [("_", mutated_sequence)]
            batch_labels, batch_strs, batch_tokens = batch_converter(data)
            batch_tokens = batch_tokens.to(device)
            with torch.no_grad():
                results = t_model(batch_tokens, repr_layers=[33])
            results = results["representations"][33].cpu().numpy()
            mutant_local = pca_local.transform(
                results[:1, row.mutation_position+1, :])[0, ]
            mutant_pool = np.mean(results[:1, :, :], axis=1)
            mutant_pool = pca_pool.transform(mutant_pool)[0, ]

            # TRANSFORMER ESM EMBEDDINGS
            wild_local = pca_local.transform(
                embeddings["all_seq_embed_local"][sequence_to_embed_mapping[row.sequence]][:1, row.mutation_position+1, :])[0, ]
            wild_pool = pca_embeds[sequence_to_embed_mapping[row.sequence], ]
            for k in range(32):
                row[f"esm_pca_pool_{k}"] = mutant_pool[k] - wild_pool[k]
                if k >= 16:
                    continue
                row[f"esm_pca_wild_{k}"] = wild_local[k]
                row[f"esm_pca_mutant_{k}"] = mutant_local[k]
                row[f"esm_pca_local_{k}"] = mutant_local[k] - wild_local[k]

            # TRANSFORMER MUTATION PROBS AND ENTROPY
            row["esm_mutation_probability"] = embeddings["all_seq_prob"][sequence_to_embed_mapping[row.sequence]
                                        ][token_map[row.mutated_aa], row.mutation_position+1]
            row["esm_mutation_entropy"] = entropy(
                embeddings["all_seq_prob"][sequence_to_embed_mapping[row.sequence]][:, row.mutation_position+1])

            del batch_tokens, results, mutant_local, mutant_pool, wild_local, wild_pool
            gc.collect()
            torch.cuda.empty_cache()
    except Exception as e:
        errors.append(
            f"error occured for {row.uniprot} {row.mutation_position} {row.mutated_aa}")


    return row


In [12]:
df.mutation_position = df.mutation_position.astype(int)
errors = []
# first we do 'small' proteins with cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t_model.to(device)
print(device)
df = df.apply(lambda row: add_embbeddings(
    row, sequence_to_embed_mapping, embeddings, t_model, device, errors), axis=1)

if len(sequences_too_big_for_cuda)>0:
    # then we do the biggest proteins with cpu if needed
    device = torch.device("cpu")
    t_model.to(device)
    print(device)
    print(len(sequences_too_big_for_cuda))
    df = df.apply(lambda row: add_embbeddings(
        row, sequence_to_embed_mapping, embeddings, t_model, device, errors), axis=1)


cuda
cpu
0


In [13]:
print(errors)


['error occured for P00648 156 A', 'error occured for P00644 230 G', 'error occured for P41016 65 A', 'error occured for P60848 96 A', 'error occured for P60848 96 E', 'error occured for P23540 190 A']


In [14]:
df.to_csv(OUTPUT_DATASET, index=False)
print(df.isna().sum())
df.head(10)


#### Embeddings for SUBMISSION

Unnamed: 0,uniprot,wild_aa,mutated_chain,mutation_position,mutated_aa,pH,sequence,length,chain_start,chain_end,...,esm_pca_local_8,esm_pca_local_9,esm_pca_local_10,esm_pca_local_11,esm_pca_local_12,esm_pca_local_13,esm_pca_local_14,esm_pca_local_15,esm_mutation_probability,esm_mutation_entropy
33,P06654,M,A,0,A,6.5,MEKEKKVKYFLRKSAFGLASVSAAFLVGSTVFAVDSPIEDTPIIRN...,448.0,33.0,416.0,...,,,,,,,,,,
34,P06654,M,A,0,D,6.5,MEKEKKVKYFLRKSAFGLASVSAAFLVGSTVFAVDSPIEDTPIIRN...,448.0,33.0,416.0,...,,,,,,,,,,
35,P06654,M,A,0,E,6.5,MEKEKKVKYFLRKSAFGLASVSAAFLVGSTVFAVDSPIEDTPIIRN...,448.0,33.0,416.0,...,,,,,,,,,,
36,P06654,M,A,0,F,6.5,MEKEKKVKYFLRKSAFGLASVSAAFLVGSTVFAVDSPIEDTPIIRN...,448.0,33.0,416.0,...,,,,,,,,,,
37,P06654,M,A,0,G,6.5,MEKEKKVKYFLRKSAFGLASVSAAFLVGSTVFAVDSPIEDTPIIRN...,448.0,33.0,416.0,...,,,,,,,,,,
38,P06654,M,A,0,H,6.5,MEKEKKVKYFLRKSAFGLASVSAAFLVGSTVFAVDSPIEDTPIIRN...,448.0,33.0,416.0,...,,,,,,,,,,
39,P06654,M,A,0,I,6.5,MEKEKKVKYFLRKSAFGLASVSAAFLVGSTVFAVDSPIEDTPIIRN...,448.0,33.0,416.0,...,,,,,,,,,,
40,P06654,M,A,0,K,6.5,MEKEKKVKYFLRKSAFGLASVSAAFLVGSTVFAVDSPIEDTPIIRN...,448.0,33.0,416.0,...,,,,,,,,,,
41,P06654,M,A,0,L,6.5,MEKEKKVKYFLRKSAFGLASVSAAFLVGSTVFAVDSPIEDTPIIRN...,448.0,33.0,416.0,...,,,,,,,,,,
42,P06654,T,A,227,A,6.5,MEKEKKVKYFLRKSAFGLASVSAAFLVGSTVFAVDSPIEDTPIIRN...,448.0,33.0,416.0,...,0.704942,0.08935,-0.832198,0.975513,-0.80409,0.781593,0.383693,-0.447136,7.2e-05,0.020758
