In [21]:
import pandas as pd
from icecream import ic
import ast
import torch
from typing import List, Tuple, Dict, Optional, Set
from enum import Enum
import os
import pickle
import gc
import pandas as pd
from Bio import SeqIO
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import esm
from icecream import ic
import random
from collections import Counter
import requests
from urllib.parse import quote
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import sys

sys.path.append('/workspace/protein_lm/tokenizer')

from tokenizer import EsmTokenizer, PTMTokenizer

sys.path.append('/workspace/protein_lm/modeling/scripts')
from infer import PTMMamba

In [22]:
residue_seqs_df = pd.read_csv('/workspace/protein_lm/evaluation/binding_site_prediction/data/residue_seqs.csv')
residue_seqs_df['aligned_labels_with_gaps'] = residue_seqs_df['aligned_labels_with_gaps'].apply(ast.literal_eval)
residue_seqs_df['aligned_labels'] = residue_seqs_df['aligned_labels'].apply(ast.literal_eval)
residue_seqs_df['labels'] = residue_seqs_df['labels'].apply(ast.literal_eval)

residue_seqs_df.head()

Unnamed: 0,AC_ID,wt_seq,ptm_seq,pdb_id_with_chain_name,aligned_labels_with_gaps,aligned_labels,labels,ppbs_seq_alignment,ptm_seq_alignment,is_test,is_val,is_train
0,C4YMW2,MSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPSNPS...,MSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPSNPS...,4esw_A,"[1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, ...",GSHMSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPS...,---MSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPS...,0,1,0
1,P97291,MPERLAETLMDLWTPLIILWITLPSCVYTAPMNQAHVLTTGSPLEL...,MPERLAETLMDLWTPLIILWITLPSCVYTAPMNQAHVLTTGSPLEL...,1zxk_A,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, ...",------------------------S---------------------...,MPERLAETLMDLWTPLIILWITLPSCVYTAPMNQAHVLTTGSPLEL...,0,0,1
2,P93114,MAVPMDTISGPWGNNGGNFWSFRPVNKINQIVISYGGGGNNPIALT...,M<N-acetylalanine>VPMDTISGPWGNNGGNFWSFRPVNKINQ...,1ouw_B,"[-1, -1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1...","[0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, ...","[1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, ...",--VPMDTISGPWGNNGGNFWSFRPVNKINQIVISYGGGGNNPIALT...,MAVPMDTISGPWGNNGGNFWSFRPVNKINQIVISYGGGGNNPIALT...,1,0,0
3,P42262,MQKIMHISVLLSPVLWGLIFGVSSNSIQIGGLFPRGADQEYSAFRV...,MQKIMHISVLLSPVLWGLIFGVSSNSIQIGGLFPRGADQEYSAFRV...,2xhd_A,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",------------------------N---------------------...,MQKIMHISVLLSPVLWGLIFGVSSNSIQIGGLFPRGADQEYSAFRV...,0,0,1
4,O75208,MAAAAVSGALGRAGWRLLQLRCLPVARCRQALVPRAFHASAVGLRS...,MAAAAVSGALGRAGWRLLQLRCLPVARCRQALVPRAFHASAVGLRS...,4rhp_B,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",----------------------------------------------...,MAAAAVSGALGRAGWRLLQLRCLPVARCRQALVPRAFHASAVGLRS...,0,0,1


In [31]:
embeddings_dir = '/workspace/protein_lm/evaluation/binding_site_prediction/data/embeddings'

residue_seqs_df['esm_650m_embedding_path'] = residue_seqs_df['AC_ID'].apply(lambda x: f'{embeddings_dir}/{x}_esm_650m.pt')
residue_seqs_df['esm_650m_embedding_padding_mask_path'] = residue_seqs_df['AC_ID'].apply(lambda x: f'{embeddings_dir}/{x}_esm_650m_padding_mask.pt')

residue_seqs_df['mamba_with_ptms_embedding_path'] = residue_seqs_df['AC_ID'].apply(lambda x: f'{embeddings_dir}/{x}_mamba_with_ptms.pt')
residue_seqs_df['mamba_with_ptms_embedding_padding_mask_path'] = residue_seqs_df['AC_ID'].apply(lambda x: f'{embeddings_dir}/{x}_mamba_with_ptms_padding_mask.pt')

residue_seqs_df['mamba_without_ptms_embedding_path'] = residue_seqs_df['AC_ID'].apply(lambda x: f'{embeddings_dir}/{x}_mamba_without_ptms.pt')
residue_seqs_df['mamba_without_ptms_embedding_padding_mask_path'] = residue_seqs_df['AC_ID'].apply(lambda x: f'{embeddings_dir}/{x}_mamba_without_ptms_padding_mask.pt')

residue_seqs_df['ppbs_embedding_path'] = residue_seqs_df['pdb_id_with_chain_name'].apply(lambda x: f'{embeddings_dir}/{x}_ppbs_embedding_path.pt')

residue_seqs_df.head()

Unnamed: 0,AC_ID,wt_seq,ptm_seq,pdb_id_with_chain_name,aligned_labels_with_gaps,aligned_labels,labels,ppbs_seq_alignment,ptm_seq_alignment,is_test,is_val,is_train,esm_650m_embedding_path,esm_650m_embedding_padding_mask_path,mamba_with_ptms_embedding_path,mamba_with_ptms_embedding_padding_mask_path,mamba_without_ptms_embedding_path,mamba_without_ptms_embedding_padding_mask_path,ppbs_embedding_path
0,C4YMW2,MSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPSNPS...,MSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPSNPS...,4esw_A,"[1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, ...",GSHMSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPS...,---MSTNKITFLLNWEAAPYHIPVYLANIKGYFKDENLDIAILEPS...,0,1,0,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...
1,P97291,MPERLAETLMDLWTPLIILWITLPSCVYTAPMNQAHVLTTGSPLEL...,MPERLAETLMDLWTPLIILWITLPSCVYTAPMNQAHVLTTGSPLEL...,1zxk_A,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, ...",------------------------S---------------------...,MPERLAETLMDLWTPLIILWITLPSCVYTAPMNQAHVLTTGSPLEL...,0,0,1,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...
2,P93114,MAVPMDTISGPWGNNGGNFWSFRPVNKINQIVISYGGGGNNPIALT...,M<N-acetylalanine>VPMDTISGPWGNNGGNFWSFRPVNKINQ...,1ouw_B,"[-1, -1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1...","[0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, ...","[1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, ...",--VPMDTISGPWGNNGGNFWSFRPVNKINQIVISYGGGGNNPIALT...,MAVPMDTISGPWGNNGGNFWSFRPVNKINQIVISYGGGGNNPIALT...,1,0,0,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...
3,P42262,MQKIMHISVLLSPVLWGLIFGVSSNSIQIGGLFPRGADQEYSAFRV...,MQKIMHISVLLSPVLWGLIFGVSSNSIQIGGLFPRGADQEYSAFRV...,2xhd_A,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",------------------------N---------------------...,MQKIMHISVLLSPVLWGLIFGVSSNSIQIGGLFPRGADQEYSAFRV...,0,0,1,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...
4,O75208,MAAAAVSGALGRAGWRLLQLRCLPVARCRQALVPRAFHASAVGLRS...,MAAAAVSGALGRAGWRLLQLRCLPVARCRQALVPRAFHASAVGLRS...,4rhp_B,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",----------------------------------------------...,MAAAAVSGALGRAGWRLLQLRCLPVARCRQALVPRAFHASAVGLRS...,0,0,1,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...,/workspace/protein_lm/evaluation/binding_site_...


In [51]:
residue_seqs_df.to_csv('/workspace/protein_lm/evaluation/binding_site_prediction/data/residue_seqs_processed.csv', index=False)

## Create ESM650M Embeddings

In [25]:
max_sequence_len = 1000

In [26]:
esm_650m_model, esm_650m_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_650m_batch_converter = esm_650m_alphabet.get_batch_converter()
esm_650m_model.eval()
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
esm_650m_model.to(device)
ic(device)
ic(torch.cuda.is_available())

ic| device: device(type='cuda', index=6)
ic| torch.cuda.is_available(): True


True

In [8]:
class ESMProteinDataset(Dataset):
    def __init__(self,
                batch_converter,
                device,
                seqs_df):
        self.batch_converter = batch_converter
        self.device = device
        self.df = seqs_df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        ac_id = row['AC_ID']
        wt_seq = row['wt_seq']
        _, _, batch_tokens = self.batch_converter([(ac_id, wt_seq)])
        return ac_id, batch_tokens[0].to(device)

torch.manual_seed(0)
esm_650m_dataset = ESMProteinDataset(esm_650m_batch_converter, device, residue_seqs_df)
esm_650m_loader = DataLoader(esm_650m_dataset, batch_size=1, shuffle=True)

In [9]:
ac_id_to_index = {ac_id: i for i, ac_id in enumerate(residue_seqs_df['AC_ID'])}

In [13]:
for protein_ids, batch_tokens in tqdm(esm_650m_loader):
	with torch.no_grad():
		results = esm_650m_model(batch_tokens, repr_layers=[33])
		token_embeddings = results["representations"][33]
		for i, protein_id in enumerate(protein_ids):
			index = ac_id_to_index[protein_id]
			row = residue_seqs_df.iloc[index]
			esm_650m_embedding_path = row['esm_650m_embedding_path']
			embedding = token_embeddings[i]
			with open(esm_650m_embedding_path, 'wb') as f:
				torch.save(embedding.cpu(), f)

100%|██████████| 3050/3050 [02:46<00:00, 18.28it/s]


In [14]:
esm_650m_model.cpu()

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): ESM1bLayerNorm(torch.Size([1280]), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): ESM1bLayerNorm(torch.Size([1280]), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_

## Create PTMMamba Embeddings with PTM sequences

In [16]:
class MambaProteinDataset(Dataset):
    def __init__(self,
                device,
                seqs_df,
                use_ptms=True):
        self.device = device
        self.df = seqs_df
        self.use_ptms = use_ptms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        ac_id = row['AC_ID']
        if self.use_ptms:
            seq = row['ptm_seq']
        else:
            seq = row['wt_seq']
        return ac_id, seq

torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mamba_with_ptms_dataset = MambaProteinDataset(device, residue_seqs_df)
mamba_with_ptms_loader = DataLoader(mamba_with_ptms_dataset, batch_size=1, shuffle=False)

In [17]:
model_checkpoint_path = '/workspace/ckpt/bi_mamba-esm-ptm_token_input/best.ckpt'
mamba = PTMMamba(ckpt_path=model_checkpoint_path, device=device)

<All keys matched successfully>


In [18]:
ptm_tokenizer = PTMTokenizer()

with torch.no_grad():
    for protein_id, ptm_seq in tqdm(mamba_with_ptms_loader):
        ptm_seq = ptm_seq[0]
        protein_id = protein_id[0]
        index = ac_id_to_index[protein_id]
        mamba_with_ptms_embedding_path = residue_seqs_df.iloc[index]['mamba_with_ptms_embedding_path']
        mamba_with_ptms_embedding_padding_mask_path = residue_seqs_df.iloc[index]['mamba_with_ptms_embedding_padding_mask_path']
        tokenized_output = ptm_tokenizer(ptm_seq, return_tensor=True)
        output = mamba(ptm_seq)
        embedding = output.hidden_states.squeeze(dim=0)
        with open(mamba_with_ptms_embedding_path, 'wb') as f:
            torch.save(embedding.cpu(), f)

100%|██████████| 3050/3050 [04:14<00:00, 11.97it/s]


In [19]:
mamba_without_ptms_dataset = MambaProteinDataset(device, residue_seqs_df, use_ptms=False)
mamba_without_ptms_loader = DataLoader(mamba_without_ptms_dataset, batch_size=1, shuffle=False)

In [20]:
ptm_tokenizer = PTMTokenizer()

with torch.no_grad():
    for protein_id, wt_seq in tqdm(mamba_without_ptms_loader):
        wt_seq = wt_seq[0]
        protein_id = protein_id[0]
        index = ac_id_to_index[protein_id]
        mamba_without_ptms_embedding_path = residue_seqs_df.iloc[index]['mamba_without_ptms_embedding_path']
        tokenized_output = ptm_tokenizer(wt_seq, return_tensor=True)
        output = mamba(wt_seq)
        embedding = output.hidden_states.squeeze(dim=0)
        with open(mamba_without_ptms_embedding_path, 'wb') as f:
            torch.save(embedding.cpu(), f)

100%|██████████| 3050/3050 [04:16<00:00, 11.88it/s]


## Create PPBS ESM Embeddings

In [54]:
class PPBSESMDataset(Dataset):
	def __init__(self, batch_converter, device, seqs_df):
		self.df = seqs_df
		self.batch_converter = batch_converter
		self.device = device

	def __len__(self):
		return len(self.df)
	
	def __getitem__(self, idx):
		row = self.df.iloc[idx]
		pdb_id_with_chain_name = row['pdb_id_with_chain_name']
		ppbs_embedding_path = row['ppbs_embedding_path']
		ppbs_seq = row['ppbs_seq_alignment']
		ppbs_seq = ppbs_seq.replace('-', '')
		_, _, batch_tokens = self.batch_converter([(pdb_id_with_chain_name, ppbs_seq)])
		return pdb_id_with_chain_name, batch_tokens[0].to(device), ppbs_embedding_path, ppbs_seq
	
ppbs_dataset = PPBSESMDataset(esm_650m_batch_converter, device, residue_seqs_df)
ppbs_loader = DataLoader(ppbs_dataset, batch_size=1, shuffle=True)

# pdb_id, batch_tokens, ppbs_embedding_path = next(iter(ppbs_dataset))

# ic(pdb_id)
# ic(batch_tokens)
# ic(ppbs_embedding_path)

In [61]:
for pdb_ids, batch_tokens, ppbs_embedding_paths, ppbs_seq in tqdm(ppbs_loader):
	with torch.no_grad():
		results = esm_650m_model(batch_tokens, repr_layers=[33])
		token_embeddings = results["representations"][33]
		for pdb_id, ppbs_embedding_path in zip(pdb_ids, ppbs_embedding_paths):
			embedding = token_embeddings[i][1:-1]
			# ic(embedding.shape)
			# ic(token_embeddings.shape)
			# ic(len(ppbs_seq[0]))
			# ic(ppbs_seq[0])
			with open(ppbs_embedding_path, 'wb') as f:
				torch.save(embedding.cpu(), f)

		# 	break
		# break

100%|██████████| 3050/3050 [02:02<00:00, 24.88it/s]
