In [1]:
import torch
import esm
import gc
import pandas as pd
import numpy as np
import tqdm
from cuml import PCA


#### Transformer ESM Features

In order to convert amino acid sequences aka proteins into meaningful features, we will use embeddings from SOTA protein transformer. We use Facebook's pretrained protein transformer ESM (Evolutionary Scale Modeling) with research paper here and GitHub here. Kaggleqrdl provided a starter notebook here. In version 15+, we also extract mutation probabilties and mutation entropy from ESM!


In [2]:
DATASET_NAME = "all_v2_2"
INPUT_DATASET = f"../data/main_dataset_creation/outputs/{DATASET_NAME}/dataset_with_alphafold_paths.csv"
OUTPUT_DATASET = f"../data/main_dataset_creation/outputs/{DATASET_NAME}/dataset_with_esm_features.csv"
MAX_CUDA_SEQ_LEN = 500 # out of memory w/ the 3070 after this


In [3]:
# https://www.kaggle.com/code/kaggleqrdl/esm-quick-start-lb237

token_map = {'L': 0, 'A': 1, 'G': 2, 'V': 3, 'S': 4, 'E': 5, 'R': 6, 'T': 7, 'I': 8, 'D': 9, 'P': 10,
             'K': 11, 'Q': 12, 'N': 13, 'F': 14, 'Y': 15, 'M': 16, 'H': 17, 'W': 18, 'C': 19}
t_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
t_model.eval()  # disables dropout for deterministic results
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
t_model.to(device)
print(device)

cuda


#### Embeddings

We input each train and test wildtype into our transformer and extract the last hidden layers activations. For each protein, this has shape (1, len_protein_seq, 1280). We will save the full embeddings and the pooled embeddings for use later. Additionally we will save the MLM pretrain task amino acid prediction which indicates mutation probability and mutation entropy. This has shape (1, len_protein_seq, 33) but we extract to (len_protein_seq, 20) where 20 is number of common amino acids.


In [4]:
df = pd.read_csv(INPUT_DATASET)
# df.columns.tolist()

In [5]:
# TRAIN AND TEST WILDTYPES
from scipy.special import softmax
from scipy.stats import entropy
PCA_CT = 16  # random sample size per protein to fit PCA with
all_sequences = df.sequence.unique()
all_seq_embed_pool = np.zeros((len(all_sequences)+1, 1280))
all_seq_embed_local = []
all_seq_embed_by_position = []
all_seq_prob = []
seq_to_embed = {} # {"MV...": (i, [list of all positions])}
seq_to_big_for_cuda = []

# EXTRACT TRANSFORMER EMBEDDINGS FOR TRAIN AND TEST WILDTYPES
print('Extracting embeddings from proteins...')
for i, seq in tqdm.tqdm(enumerate(all_sequences)):
    # EXTRACT EMBEDDINGS, MUTATION PROBABILITIES, ENTROPY
    
    if len(seq)>MAX_CUDA_SEQ_LEN:
        # if the protein is too big, don't try (we will do it with a cpu later)
        seq_to_big_for_cuda.append(seq)
        continue


    data = [("protein1", seq)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
        results = t_model(batch_tokens, repr_layers=[33])
    logits = (results['logits'].detach().cpu().numpy()[0, ].T)[4:24, 1:-1]
    all_seq_prob.append(softmax(logits, axis=0))
    results = results["representations"][33].detach().cpu().numpy()

    # SAVE EMBEDDINGS
    all_seq_embed_local.append(results)
    all_seq_embed_pool[i, ] = np.mean(results[0, :, :], axis=0)

    # TEMPORARILY SAVE LOCAL MUTATION EMBEDDINGS
    mutation_positions = df.loc[df.sequence == seq,
                                'mutation_position'].unique().astype(int)
    # update seq_to_embed mapping
    seq_to_embed[seq] = (i, mutation_positions)
    # if len(tmp) > PCA_CT:
    #     tmp = np.random.choice(tmp, PCA_CT, replace=False)
    for j in mutation_positions:
        all_seq_embed_by_position.append(results[0, j, :])

    del batch_tokens, results
    gc.collect()
    torch.cuda.empty_cache()

all_seq_embed_by_position = np.stack(all_seq_embed_by_position)


Extracting embeddings from proteins...


477it [00:44, 10.66it/s]


#### RAPIDS PCA

The transformer embeddings have dimension 1280. Since we only have a few thousand rows of train data, that is too many features to include all of them in our XGB model. Furthermore, we want to use local, pooling, and delta embeddings. Which would be 3x1280. To prevent our model from overfitting as a result of the "curse of dimensionality", we reduce the dimension of embeddings using RAPIDS PCA.


In [6]:
pca_pool = PCA(n_components=32)
pca_embeds = pca_pool.fit_transform(all_seq_embed_pool.astype('float32'))
pca_local = PCA(n_components=16)
pca_local.fit(all_seq_embed_by_position.astype('float32'))
# del all_seq_embed_by_position
# _ = gc.collect()


PCA()

In [11]:
pca_embeds.shape


(478, 32)

In [None]:
def get_new_row(atom_df, j, row):
    ##################
    # ATOM_DF - IS PDB FILE'S ATOM_DF
    # J - IS RESIDUE NUMBER WHICH IS TRAIN CSV POSITION PLUS OFFSET
    # ROW - IS ROW FROM DOWNLOADED TRAIN CSV
    ##################

    dd = None
    tmp = atom_df.loc[(atom_df.residue_number == j)].reset_index(drop=True)
    prev = atom_df.loc[(atom_df.residue_number == j-1)].reset_index(drop=True)
    post = atom_df.loc[(atom_df.residue_number == j+1)].reset_index(drop=True)

    # FEATURE ENGINEER
    if len(tmp) > 0:

        # GET MUTANT EMBEDDINGS
        data = [("protein1", row.mutant_seq)]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        batch_tokens = batch_tokens.to(device)
        with torch.no_grad():
            results = t_model(batch_tokens, repr_layers=[33])
        results = results["representations"][33].cpu().numpy()
        mutant_local = pca_local.transform(results[:1, row.position, :])[0, ]
        mutant_pool = np.mean(results[:1, :, :], axis=1)
        mutant_pool = pca_pool.transform(mutant_pool)[0, ]

        # MUTATION AND POSITION
        dd = {}
        dd['WT'] = row.wildtype
        dd['WT2'] = tmp.residue_name.map(aa_map)[0]
        dd['MUT'] = row.mutation
        dd['position'] = row.position
        dd['relative_position'] = row.position / len(row.sequence)

        # B_FACTOR
        if USE_B_COLUMN:
            dd['b_factor'] = tmp.b_factor.mean()

        # ANIMO ACID PROPERTIES AND DELTAS
        for c in PROPS:
            dd[f'{c}_1'] = aa_props.loc[row.wildtype, c]
            dd[f'{c}_2'] = aa_props.loc[row.mutation, c]
            dd[f'{c}_delta'] = dd[f'{c}_2']-dd[f'{c}_1']

        # SUBSTITUTION MATRICES
        dd['blosum100'] = sub_mat_b100[(row.wildtype, row.mutation)]
        dd['blosum80'] = sub_mat_b80[(row.wildtype, row.mutation)]
        dd['blosum60'] = sub_mat_b60[(row.wildtype, row.mutation)]
        dd['blosum40'] = sub_mat_b40[(row.wildtype, row.mutation)]
        dd['demask'] = sub_mat_demask[(row.wildtype, row.mutation)]

        # PREVIOUS AND POST AMINO ACID INFO
        if (len(prev) > 0):
            dd['prev'] = prev.residue_name.map(aa_map)[0]
            if USE_B_COLUMN:
                dd['b_factor_prev'] = prev.b_factor.mean()
        else:
            dd['prev'] = 'X'
            if USE_B_COLUMN:
                dd['b_factor_prev'] = -999

        if (len(post) > 0):
            dd['post'] = post.residue_name.map(aa_map)[0]
            if USE_B_COLUMN:
                dd['b_factor_post'] = post.b_factor.mean()
        else:
            dd['post'] = 'X'
            if USE_B_COLUMN:
                dd['b_factor_post'] = -999

        # ANGLE BETWEEN MUTATION AND NEIGHBORS
        if (len(prev) > 0) & (len(post) > 0):
            # BACKBONE ATOMS
            atm = ['N', 'H', 'CA', 'O']
            prev = prev.loc[prev.atom_name.isin(atm)]
            tmp = tmp.loc[tmp.atom_name.isin(atm)]
            post = post.loc[post.atom_name.isin(atm)]
            # VECTORS
            c_prev = np.array(
                [prev.x_coord.mean(), prev.y_coord.mean(), prev.z_coord.mean()])
            c_tmp = np.array(
                [tmp.x_coord.mean(), tmp.y_coord.mean(), tmp.z_coord.mean()])
            c_post = np.array(
                [post.x_coord.mean(), post.y_coord.mean(), post.z_coord.mean()])
            vec_a = c_prev - c_tmp
            vec_b = c_post - c_tmp
            # COMPUTE ANGLE
            norm_a = np.sqrt(vec_a.dot(vec_a))
            norm_b = np.sqrt(vec_b.dot(vec_b))
            dd['cos_angle'] = vec_a.dot(vec_b)/norm_a/norm_b
        else:
            dd['cos_angle'] = -2

        # 3D LOCATION OF MUTATION
        atm = ['N', 'H', 'CA', 'O']
        atoms = atom_df.loc[atom_df.atom_name.isin(atm)]
        centroid1 = np.array(
            [atoms.x_coord.mean(), atoms.y_coord.mean(), atoms.z_coord.mean()])
        tmp = tmp.loc[tmp.atom_name.isin(atm)]
        centroid2 = np.array(
            [tmp.x_coord.mean(), tmp.y_coord.mean(), tmp.z_coord.mean()])
        dist = centroid2 - centroid1
        dd['location3d'] = dist.dot(dist)

        # TRANSFORMER ESM EMBEDDINGS
        wt_local = pca_local.transform(
            all_pdb_embed_local[pdb_map[row.PDB]][:1, row.position, :])[0, ]
        wt_pool = pca_embeds[pdb_map[row.PDB], ]
        for kk in range(32):
            dd[f'pca_pool_{kk}'] = mutant_pool[kk] - wt_pool[kk]
            if kk >= 16:
                continue
            dd[f'pca_wt_{kk}'] = wt_local[kk]
            dd[f'pca_mutant_{kk}'] = mutant_local[kk]
            dd[f'pca_local_{kk}'] = mutant_local[kk] - wt_local[kk]

        # TRANSFORMER MUTATION PROBS AND ENTROPY
        dd['mut_prob'] = all_pdb_prob[pdb_map[row.PDB]
                                      ][token_map[dd['MUT']], dd['position']-1]
        dd['mut_entropy'] = entropy(
            all_pdb_prob[pdb_map[row.PDB]][:, dd['position']-1])

        # SURFACE AREA FEATURES
        PATH = '../input/nesp-kaggle-train-surface-area/'
        if row.CIF:
            nm = f'{row.CIF}-model_v3.csv'
        elif row.PDB != 'kaggle':
            PATH = '../input/nesp-jin-external-surface-area/'
            nm = f'{row.PDB}.csv'
        else:
            nm = 'wildtype_structure_prediction_af2_SASA.csv'
        try:
            area = pd.read_csv(f'{PATH}{nm}')
            rw = area.loc[area.Residue_number == j].iloc[0]
            dd['sa_total'] = rw.Total
            dd['sa_apolar'] = rw.Apolar
            dd['sa_backbone'] = rw.Backbone
            dd['sa_sidechain'] = rw.Sidechain
            dd['sa_ratio'] = rw.Ratio
            dd['sa_in/out'] = -1
            if rw['In/Out'] == 'i':
                dd['sa_in/out'] = 1
            elif rw['In/Out'] == 'o':
                dd['sa_in/out'] = 0
        except:
            print('### NEED SURFACE AREA for PDB:',
                  row.PDB, 'residue_number:', j)
            return None

        # LABEL ENCODE AMINO ACIDS
        dd['AA1'] = aa_map_2[dd['WT']]
        dd['AA2'] = aa_map_2[dd['MUT']]
        dd['AA3'] = aa_map_2[dd['prev']]
        dd['AA4'] = aa_map_2[dd['post']]

        # TARGETS AND SOURCES
        dd['ddG'] = row.ddG
        dd['dTm'] = row.dTm
        dd['pdb'] = row.PDB
        dd['source'] = row.source

        del batch_tokens, results, mutant_local, mutant_pool, wt_local, wt_pool
        gc.collect()
        torch.cuda.empty_cache()

    return dd
