In [1]:
import sys

import os
import json
import time
import argparse

import pandas as pd
import numpy as np


import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from torchinfo import summary

from sklearn.metrics import f1_score, roc_auc_score, roc_curve, auc, precision_recall_curve
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler

import wandb

sys.path.append('./../../../src/')

from utils import *
from utils_torch import * 
from tqdm import tqdm

In [5]:
# Read MAPP and create flanked peptides
MAPP_df = pd.read_csv('./../../../data/PG/MAPP_AP_benchmark.csv', sep='\t')
MAPP_df['y'] = [1 if i=='HIT' else 0 for i in MAPP_df['HIT/DECOY']]
max_flank_len = 5

for flank in range(max_flank_len+1):
    MAPP_df[str(flank)+'flank_peptide'] = [n[::-1][:flank]+p+c[:flank] 
                                           for n,p,c in zip(MAPP_df['n_flank'],MAPP_df['peptide'],MAPP_df['c_flank'])]

MAPP_df 

Unnamed: 0,peptide,Gene,n_flank,c_flank,length,AP with flanks score,AP without flanks score,HIT/DECOY,y,0flank_peptide,1flank_peptide,2flank_peptide,3flank_peptide,4flank_peptide,5flank_peptide
0,FKDTDYKRH,NT5C2,RTSVD,KDTDY,9,0.084120,0.010553,HIT,1,FKDTDYKRH,DFKDTDYKRHK,DVFKDTDYKRHKD,DVSFKDTDYKRHKDT,DVSTFKDTDYKRHKDTD,DVSTRFKDTDYKRHKDTDY
1,TQIMFETF,ACTB,NREKM,QIMFE,8,0.377213,0.489819,HIT,1,TQIMFETF,MTQIMFETFQ,MKTQIMFETFQI,MKETQIMFETFQIM,MKERTQIMFETFQIMF,MKERNTQIMFETFQIMFE
2,KKSQIFSTASD,HSPA5,TVVPT,KSQIF,11,0.014379,0.029527,HIT,1,KKSQIFSTASD,TKKSQIFSTASDK,TPKKSQIFSTASDKS,TPVKKSQIFSTASDKSQ,TPVVKKSQIFSTASDKSQI,TPVVTKKSQIFSTASDKSQIF
3,KTNHLVTVE,PDHB,EASVM,TNHLV,9,0.120193,0.130120,HIT,1,KTNHLVTVE,MKTNHLVTVET,MVKTNHLVTVETN,MVSKTNHLVTVETNH,MVSAKTNHLVTVETNHL,MVSAEKTNHLVTVETNHLV
4,SDLQLDRISVY,TUBB,TYHGD,DLQLD,11,0.089185,0.069398,HIT,1,SDLQLDRISVY,DSDLQLDRISVYD,DGSDLQLDRISVYDL,DGHSDLQLDRISVYDLQ,DGHYSDLQLDRISVYDLQL,DGHYTSDLQLDRISVYDLQLD
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3169,VRGLLAGRVPP,GPAT2,VGACA,RGLLA,11,0.412522,0.119909,DECOY,0,VRGLLAGRVPP,AVRGLLAGRVPPR,ACVRGLLAGRVPPRG,ACAVRGLLAGRVPPRGL,ACAGVRGLLAGRVPPRGLL,ACAGVVRGLLAGRVPPRGLLA
3170,AQKKDGKKRKR,HIST1H2BI,KAVTK,QKKDG,11,0.065142,0.153942,DECOY,0,AQKKDGKKRKR,KAQKKDGKKRKRQ,KTAQKKDGKKRKRQK,KTVAQKKDGKKRKRQKK,KTVAAQKKDGKKRKRQKKD,KTVAKAQKKDGKKRKRQKKDG
3171,LKQKRMYEQ,CHMP5,KALRV,KQKRM,9,0.072208,0.009945,DECOY,0,LKQKRMYEQ,VLKQKRMYEQK,VRLKQKRMYEQKQ,VRLLKQKRMYEQKQK,VRLALKQKRMYEQKQKR,VRLAKLKQKRMYEQKQKRM
3172,SLVIGSSTLFS,XRCC6,VYPEE,LVIGS,11,0.001822,0.003727,DECOY,0,SLVIGSSTLFS,ESLVIGSSTLFSL,EESLVIGSSTLFSLV,EEPSLVIGSSTLFSLVI,EEPYSLVIGSSTLFSLVIG,EEPYVSLVIGSSTLFSLVIGS


In [3]:
# load model
model, alphabet, model_state = esm.pretrained._load_model_and_alphabet_core_v2(torch.load('./../../../../../esm_models/esm2_t33_650M_UR50D.pt'))
batch_converter = alphabet.get_batch_converter()

model.eval()  # disables dropout for deterministic results



ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [38]:
# Create embeddings
for flank in range(max_flank_len):
    peptides_ls = MAPP_df[str(flank)+'flank_peptide'].to_list()
    embedding_dict = {}

    for range_sec in tqdm(range(0, len(peptides_ls), 100)):
        peptides = peptides_ls[range_sec:range_sec+100]

        # Prepare data 
        data = [(peptide, peptide) for i,peptide in enumerate(peptides)]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)

        embeddings = get_esm_embedding(model, batch_tokens, mean=True)

        for idx, t in enumerate(embeddings):
            embedding_dict[batch_labels[idx]] = t
            
    MAPP_embedded_df = pd.DataFrame(embedding_dict).T.reset_index().rename(columns={'index':'peptide'})
    MAPP_embedded_df['y'] = MAPP_df['y']
    MAPP_embedded_df.to_csv('./../../../data/PG/esm1b/MAPP/MAPP_'+str(flank)+'flank.csv')

    print(len(embedding_dict))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [06:18<00:00, 11.82s/it]


3174


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [07:27<00:00, 13.97s/it]


3174


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [08:18<00:00, 15.58s/it]


3174


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [09:23<00:00, 17.61s/it]


3174


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [10:32<00:00, 19.77s/it]


3174
