In [61]:
from Bio import SeqIO
import numpy as np
import torch.nn as nn
import torch
cos = nn.CosineSimilarity(dim=0, eps=1e-6)
import torch.nn.functional as F
from Bio import AlignIO
import pandas as pd
import torch
import esm
import matplotlib.pyplot as plt
import seaborn as sns
from Bio.PDB import PDBParser, PPBuilder
import warnings
from Bio.PDB.PDBExceptions import PDBConstructionWarning
from Bio.SeqUtils import IUPACData
from collections import defaultdict

In [2]:
base_dir = '/Users/williamharrigan/Desktop/Github/contact_site_classifier/attention_classifier/data_files/'
fasta_file = base_dir + 'rcsb_pdb_3KYN.fasta'
pdb_filename = base_dir + '3kyn.pdb'

In [3]:
## This code is not important right now, will be when adding multiple sequences

# seqs = SeqIO.to_dict(SeqIO.parse(full_len_sequences, "fasta"))
# for k,v in seqs.items():
#     seqs[k] = str(v.seq)
    
# print(list(seqs.keys())[:5])

In [4]:
def simple_aa(three_letter_code):
    return IUPACData.protein_letters_3to1.get(three_letter_code.capitalize())

## Generate ESM-2 Embedding

In [5]:
# 1. Load the ESM Model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model.eval()

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [6]:
# Set model to use cuda GPU
if torch.cuda.is_available():
    model = model.cuda()

In [7]:
# 2. Prepare sequence input
# Load sequence file

seq_chains = []

for record in SeqIO.parse(fasta_file, "fasta"):
    seq_chains.append(str(record.seq))

pdb_id = record.id.split('|')[0].split('_')[0]
protein_sequence = ''.join(seq_chains)
print(pdb_id)
print('Sequence: ', protein_sequence[:15], '\n')

# Index protein sequence as sequence 0 (next sequence would be indexed as 1)
data = [(0, protein_sequence)]
print('Data: ', data, '\n')

# Prepare variables to input sequence into ESM-2 model 
batch_converter = alphabet.get_batch_converter()
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_tokens = batch_tokens.cuda() if torch.cuda.is_available() else batch_tokens

print('batch_tokens: ', '\n\n', batch_tokens)

3KYN
Sequence:  MSHSMRYFSAAVSRP 

Data:  [(0, 'MSHSMRYFSAAVSRPGRGEPRFIAMGYVDDTQFVRFDSDSASPRMEPRAPWVEQEGPEYWEEETRNTKAHAQTDRMNLQTLRGYYNQSEASSHTLQWMIGCDLGSDGRLIRGYERYAYDGKDYLALNEDLRSWTAADTAAQISKRKCEAANVAEQRRAYLEGTCVEWLHRYLENGKEMLQRADPPKTHVTHHPVFDYEATLRCWALGFYPAEIILTWQRDGEDQTQDVELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPEPLMLRWKMIQRTPKIQVYSRHPAENGKSNFLNCYVSGFHPSDIEVDLLKNGERIEKVEHSDLSFSKDWSFYLLYYTEFTPTEKDEYACRVNHVTLSQPKIVKWDRDMKGPPAALTL')] 

batch_tokens:  

 tensor([[ 0, 20,  8, 21,  8, 20, 10, 19, 18,  8,  5,  5,  7,  8, 10, 14,  6, 10,
          6,  9, 14, 10, 18, 12,  5, 20,  6, 19,  7, 13, 13, 11, 16, 18,  7, 10,
         18, 13,  8, 13,  8,  5,  8, 14, 10, 20,  9, 14, 10,  5, 14, 22,  7,  9,
         16,  9,  6, 14,  9, 19, 22,  9,  9,  9, 11, 10, 17, 11, 15,  5, 21,  5,
         16, 11, 13, 10, 20, 17,  4, 16, 11,  4, 10,  6, 19, 19, 17, 16,  8,  9,
          5,  8,  8, 21, 11,  4, 16, 22, 20, 12,  6, 23, 13,  4,  6,  8, 13,  6,
         10,  4, 12, 10,  6, 19,  9, 10, 19,  5, 19, 13,  6

In [8]:
# 4. Input prepared sequence information into model and output as results (contact predictions are included in embedding)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)

print('Model outputs: ', results.keys())

Model outputs:  dict_keys(['logits', 'representations', 'attentions', 'contacts'])


## Double Checking Indexes for Embedding Sequence vs PDB Sequence

In [14]:
## Here I am double checking amino acid signatures from PDB chain to the sequence used in embedding generation

residue_position = 0

for i in protein_model:
    print(i)
    print(residue_position)
    for residue in i:
        try:
            print(residue.id[1], simple_aa(residue.resname),protein_sequence[residue_position], residue['CA'])
            residue_position+=1
        except:
            continue
#         print(residue1.id[1], residue1.resname)



<Chain id=A>
0
1 M M <Atom CA>
2 S S <Atom CA>
3 H H <Atom CA>
4 S S <Atom CA>
5 M M <Atom CA>
6 R R <Atom CA>
7 Y Y <Atom CA>
8 F F <Atom CA>
9 S S <Atom CA>
10 A A <Atom CA>
11 A A <Atom CA>
12 V V <Atom CA>
13 S S <Atom CA>
14 R R <Atom CA>
15 P P <Atom CA>
16 G G <Atom CA>
17 R R <Atom CA>
18 G G <Atom CA>
19 E E <Atom CA>
20 P P <Atom CA>
21 R R <Atom CA>
22 F F <Atom CA>
23 I I <Atom CA>
24 A A <Atom CA>
25 M M <Atom CA>
26 G G <Atom CA>
27 Y Y <Atom CA>
28 V V <Atom CA>
29 D D <Atom CA>
30 D D <Atom CA>
31 T T <Atom CA>
32 Q Q <Atom CA>
33 F F <Atom CA>
34 V V <Atom CA>
35 R R <Atom CA>
36 F F <Atom CA>
37 D D <Atom CA>
38 S S <Atom CA>
39 D D <Atom CA>
40 S S <Atom CA>
41 A A <Atom CA>
42 S S <Atom CA>
43 P P <Atom CA>
44 R R <Atom CA>
45 M M <Atom CA>
46 E E <Atom CA>
47 P P <Atom CA>
48 R R <Atom CA>
49 A A <Atom CA>
50 P P <Atom CA>
51 W W <Atom CA>
52 V V <Atom CA>
53 E E <Atom CA>
54 Q Q <Atom CA>
55 E E <Atom CA>
56 G G <Atom CA>
57 P P <Atom CA>
58 E E <Atom CA>
59 Y Y <

## Calculating Expected Contact Sites from PDB structure file

In [28]:
# Turned off warnings for discontinuous data structures

warnings.simplefilter('ignore', PDBConstructionWarning)

# Load the structure from locally saved file
parser = PDBParser()
structure = parser.get_structure(pdb_id, f"{pdb_filename}")

# Extract desired protein structure from PDB structure (typically only 1 structure to choose from)
protein_structure = structure[0]

In [41]:
# Choose protein chain to compare amino acid residues
# chain = protein_structure['A']

# Checking how many amino acids are compared
count = 0

for chain in protein_structure:
    print(chain)
    for residue1 in chain:
        for residue2 in chain:
            if residue1 != residue2:
                # compute distance between CA atoms
                try:
                    distance = abs(residue1['CA'] - residue2['CA'])
                except KeyError:
    #                 print(f"No CA for {residue2.id[1], residue2.resname}")
                    continue
                if distance < 5:
                    if residue1.id[1] - residue2.id[1] > 2 or residue2.id[1] - residue1.id[1] > 2 :
                        print(residue1.id[1], residue1.resname, residue2.id[1], residue2.resname, distance)
                        count+=1
                    else:
                        continue

## Initialize Contact Dictionaries for Training Random Forest Classifier

In [79]:
contacts = defaultdict(dict)
non_contacts = defaultdict(dict)

# Mapping of chain IDs to starting index
chain_to_index = {'A': -1, 'B': 275, 'P': 375}

count = 0

for chain in protein_structure:
    print(chain.id)
    index = chain_to_index.get(chain.id, 0)  # Default to 0 if chain ID not found

    for residue1 in chain:
        for residue2 in chain:
            if residue1 != residue2:
                try:
                    # Calculate Alpha-Carbon Distance
                    distance = abs(residue1['CA'] - residue2['CA'])
                except KeyError:
                    continue

                if distance < 5:
                    # Calculating Distance of Amino Acids in Sequence (Can't be next to each other to be considered contacts)
                    diff = abs(residue1.id[1] - residue2.id[1])
                    if diff > 2:
                        res1_index = residue1.id[1] + index
                        res2_index = residue2.id[1] + index

            # Saving Residues, Amino Acid Signatures, Distance and Contact (T/F) Into Dicts for RF training
                        
                        print(res1_index, residue1.resname, res2_index, residue2.resname, distance, protein_sequence[res1_index], protein_sequence[res2_index])
                        contacts[(res1_index, res2_index)] = {
                            'aa1': protein_sequence[res1_index],
                            'aa2': protein_sequence[res2_index], 
                            'dist': distance,
                            'contact':True
                        }
                        count += 1
                    else:
                        continue
                else:
                    non_contacts[(res1_index, res2_index)] = {
                        'aa1': protein_sequence[res1_index],
                        'aa2': protein_sequence[res2_index], 
                        'dist': distance,
                        'contact':False
                    }
print(f"Total contacts: {count}")


A
1 SER 103 GLY 4.9678664 S G
3 SER 101 ASP 4.415316 S D
4 MET 27 VAL 4.433877 M V
5 ARG 99 GLY 4.228853 R G
6 TYR 25 GLY 4.294602 Y G
7 PHE 97 MET 4.5094566 F M
8 SER 23 ALA 4.2879615 S A
9 ALA 95 GLN 4.3907695 A Q
10 ALA 21 PHE 4.4628735 A F
11 VAL 93 THR 4.5524282 V T
12 SER 19 PRO 4.034341 S P
19 PRO 12 SER 4.034341 P S
21 PHE 10 ALA 4.4628735 F A
22 ILE 36 ASP 4.5875163 I D
23 ALA 8 SER 4.2879615 A S
24 MET 34 ARG 4.385258 M R
25 GLY 6 TYR 4.294602 G Y
26 TYR 31 GLN 4.385709 Y Q
27 VAL 4 MET 4.433877 V M
31 GLN 26 TYR 4.385709 Q Y
32 PHE 47 ARG 4.304177 F R
32 PHE 48 ALA 4.588333 F A
33 VAL 46 PRO 4.845175 V P
34 ARG 24 MET 4.385258 R M
35 PHE 44 MET 4.460783 F M
36 ASP 22 ILE 4.5875163 D I
44 MET 35 PHE 4.460783 M F
44 MET 63 THR 4.930895 M T
46 PRO 33 VAL 4.845175 P V
47 ARG 32 PHE 4.304177 R F
48 ALA 32 PHE 4.588333 A F
61 GLU 64 ARG 4.8181634 E R
63 THR 44 MET 4.930895 T M
64 ARG 61 GLU 4.8181634 R E
66 THR 69 HIS 4.8471193 T H
67 LYS 70 ALA 4.987451 K A
69 HIS 66 THR 4.847119

In [78]:
print(list(contacts)[:5])
print(list(non_contacts)[:5])

[(1, 103), (3, 101), (4, 27), (5, 99), (6, 25)]
[(369, 354), (1, 103), (3, 101), (4, 27), (5, 99)]


## Implementing Random Forest Classifier (Currently for only 1 sequence)