## ErePOC Using Tutorial

Test data in this tutorial contains one pair of HEM-binding proteins, and one pair of ADP-binding proteins.  
  
Data Format:
- PDB_ID - Protein's PDB ID
- Ligand - Name of pocket's binding ligand
- Sequence - Protein's sequence
- Pocket Positions - Residue indices of the pocket in the protein sequence

In [1]:
import pandas as pd
import numpy as np
import torch

import esm

from model.utils import Net_embed_MLP

### Generate Pockets' ESM-2 Data from Protein Sequences and Pocket Positions

In [15]:
# Load example data file
file_path = "./Example-Dataset/4-samples.csv"
sample_data = pd.read_csv(file_path)

pdb_id_list = sample_data['PDB_ID'].tolist()              # PDB ID List
ligand_list = sample_data['Ligand'].tolist()              # Binding Ligand List
sequence_list = sample_data['Sequence'].tolist()          # Protein Sequence List
pocket_list = sample_data['Pocket Positions'].tolist()    # Pocket Position List

print("PDB ID List: ", pdb_id_list)
print("Ligand List: ", ligand_list)

PDB ID List:  ['101m', '19hc', '8shq', '8igw']
Ligand List:  ['HEM', 'HEM', 'ADP', 'ADP']


In [11]:
# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [37]:
# Generate ESM-2 Representations
esm2_representations = []

for index, seq in enumerate(sequence_list):
    pdb_id = pdb_id_list[index]
    int_list = np.asarray([int(num) for num in pocket_list[index].split(',')])

    # tuple: (pdb_id, sequence)
    tuple_data = (pdb_id, seq)
    batch_labels, batch_strs, batch_tokens = batch_converter([tuple_data])
    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
    token_representations = results["representations"][33]
    
    # Generate pockets' representations
    concatenated_tensor = torch.stack(list(token_representations[0, int_list]), dim=0)
    esm2_representations.append(concatenated_tensor.mean(0))

In [40]:
esm2_representations = torch.from_numpy(np.asarray(esm2_representations))
print("ESM-2 Data Shape: ", esm2_representations.shape, ", Pocket's Representation Shape: ", esm2_representations[0].shape)

ESM-2 Data Shape:  torch.Size([4, 1280]) , Pocket's Representation Shape:  torch.Size([1280])


### Generate Pockets' ErePOC Representations

In [41]:
# Load ErePOC Model
ErePOC = Net_embed_MLP(input_dim=esm2_representations[0].shape[0],hidden_dim=512,out_dim=256, drop_prob=0.1)
# No CUDA Version
ErePOC.load_state_dict(torch.load("esm2-mlp-best-epoch-275-normAndzerograd.pt", map_location='cpu'))

ErePOC.eval()
with torch.no_grad():
    # Get ErePOC Repr.
    ErePOC_repr_array = ErePOC(esm2_representations)
print("ErePOC Repr. Array Shape: ", ErePOC_repr_array.shape)

  ErePOC.load_state_dict(torch.load("esm2-mlp-best-epoch-275-normAndzerograd.pt", map_location='cpu'))


ErePOC Repr. Array Shape:  torch.Size([4, 256])


### Calculate Cosine Similarity of Different Pockets

In [42]:
from sklearn.metrics.pairwise import cosine_similarity

# Cosine Similarity: the HEM pair
HEM_pair_sim = cosine_similarity([ErePOC_repr_array[0]], [ErePOC_repr_array[1]])[0][0]
# Cosine Similarity: the ADP pair
ADP_pair_sim = cosine_similarity([ErePOC_repr_array[2]], [ErePOC_repr_array[3]])[0][0]

# Cosine Similairty: HEM and ADP
HEM_ADP_sim = cosine_similarity([ErePOC_repr_array[0]], [ErePOC_repr_array[2]])[0][0]

print("HEM pair similarity: ", HEM_pair_sim)
print("ADP pair similarity: ", ADP_pair_sim)

print("HEM and ADP similarity: ", HEM_ADP_sim)

HEM pair similarity:  0.9990303
ADP pair similarity:  0.9912592
HEM and ADP similarity:  0.07724705
