In [66]:
from Bio import SeqIO
import numpy as np
import torch.nn as nn
import torch
cos = nn.CosineSimilarity(dim=0, eps=1e-6)
import torch.nn.functional as F
from Bio import AlignIO
import pandas as pd
import torch
import esm
import matplotlib.pyplot as plt
import seaborn as sns
from Bio.PDB import PDBParser, PPBuilder
import warnings
from Bio.PDB.PDBExceptions import PDBConstructionWarning


In [3]:
desktop = '/Users/williamharrigan/Desktop/'
mahdi = '/Users/williamharrigan/Desktop/mahdi_meeting_12_6/'
full_len_sequences = desktop+"mi_1834_seqs.fasta"
clusters_75 = mahdi+"mi_1834_seqs_cdhit_75.fasta"
trunc_sequences = desktop+"trunc_P17693_PF00129.fasta"
embed_dir = desktop+"mi_1834_embeds/"

In [4]:
seqs = SeqIO.to_dict(SeqIO.parse(full_len_sequences, "fasta"))
for k,v in seqs.items():
    seqs[k] = str(v.seq)
    
print(list(seqs.keys())[:5])

['P17693', 'G3HQN3_CRIGR', 'F4NBY0_HUMAN', 'Q860N9_HORSE', 'W6ANG0_MACMU']


In [2]:
# 1. Load the ESM Model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model.eval()

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [5]:
# Assuming you're using a single GPU
if torch.cuda.is_available():
    model = model.cuda()

In [6]:
# 2. Prepare the Input
protein_sequence = seqs['P17693']
data = [(0, protein_sequence)]
batch_converter = alphabet.get_batch_converter()
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_tokens = batch_tokens.cuda() if torch.cuda.is_available() else batch_tokens

In [7]:
# 4. Extract Attention
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
        

In [96]:
warnings.simplefilter('ignore', PDBConstructionWarning)

# Load the structure
pdb_id = '3kyn'
pdb_filename = "/Users/williamharrigan/Downloads/3kyn.pdb"
parser = PDBParser()
structure = parser.get_structure(pdb_id, f"{pdb_filename}")

# Assuming we're interested in the first model
model = structure[0]

In [93]:
chain = model['A']
count = 0
for residue1 in chain:
    for residue2 in chain:
        if residue1 != residue2:
            # compute distance between CA atoms
            try:
                distance = abs(residue1['CA'] - residue2['CA'])
            except KeyError:
#                 print(f"No CA for {residue2.id[1], residue2.resname}")
                continue
            if distance < 5:
                if residue1.id[1] - residue2.id[1] > 2 or residue2.id[1] - residue1.id[1] > 2 :
                    print(residue1.id[1], residue1.resname, residue2.id[1], residue2.resname, distance)
                    count+=1
                else:
                    continue

2 SER 104 GLY 4.9678664
4 SER 102 ASP 4.415316
5 MET 28 VAL 4.433877
6 ARG 100 GLY 4.228853
7 TYR 26 GLY 4.294602
8 PHE 98 MET 4.5094566
9 SER 24 ALA 4.2879615
10 ALA 96 GLN 4.3907695
11 ALA 22 PHE 4.4628735
12 VAL 94 THR 4.5524282
13 SER 20 PRO 4.034341
20 PRO 13 SER 4.034341
22 PHE 11 ALA 4.4628735
23 ILE 37 ASP 4.5875163
24 ALA 9 SER 4.2879615
25 MET 35 ARG 4.385258
26 GLY 7 TYR 4.294602
27 TYR 32 GLN 4.385709
28 VAL 5 MET 4.433877
32 GLN 27 TYR 4.385709
33 PHE 48 ARG 4.304177
33 PHE 49 ALA 4.588333
34 VAL 47 PRO 4.845175
35 ARG 25 MET 4.385258
36 PHE 45 MET 4.460783
37 ASP 23 ILE 4.5875163
45 MET 36 PHE 4.460783
45 MET 64 THR 4.930895
47 PRO 34 VAL 4.845175
48 ARG 33 PHE 4.304177
49 ALA 33 PHE 4.588333
62 GLU 65 ARG 4.8181634
64 THR 45 MET 4.930895
65 ARG 62 GLU 4.8181634
67 THR 70 HIS 4.8471193
68 LYS 71 ALA 4.987451
70 HIS 67 THR 4.8471193
70 HIS 73 THR 4.8729544
71 ALA 68 LYS 4.987451
73 THR 70 HIS 4.8729544
74 ASP 77 ASN 4.6174245
77 ASN 74 ASP 4.6174245
77 ASN 80 THR 4.9075685

In [99]:
for i in chain:
    print(i)

<Residue MET het=  resseq=1 icode= >
<Residue SER het=  resseq=2 icode= >
<Residue HIS het=  resseq=3 icode= >
<Residue SER het=  resseq=4 icode= >
<Residue MET het=  resseq=5 icode= >
<Residue ARG het=  resseq=6 icode= >
<Residue TYR het=  resseq=7 icode= >
<Residue PHE het=  resseq=8 icode= >
<Residue SER het=  resseq=9 icode= >
<Residue ALA het=  resseq=10 icode= >
<Residue ALA het=  resseq=11 icode= >
<Residue VAL het=  resseq=12 icode= >
<Residue SER het=  resseq=13 icode= >
<Residue ARG het=  resseq=14 icode= >
<Residue PRO het=  resseq=15 icode= >
<Residue GLY het=  resseq=16 icode= >
<Residue ARG het=  resseq=17 icode= >
<Residue GLY het=  resseq=18 icode= >
<Residue GLU het=  resseq=19 icode= >
<Residue PRO het=  resseq=20 icode= >
<Residue ARG het=  resseq=21 icode= >
<Residue PHE het=  resseq=22 icode= >
<Residue ILE het=  resseq=23 icode= >
<Residue ALA het=  resseq=24 icode= >
<Residue MET het=  resseq=25 icode= >
<Residue GLY het=  resseq=26 icode= >
<Residue TYR het=  re

In [169]:
file = '/Users/williamharrigan/Downloads/rcsb_pdb_3KYN.fasta'
sequence_dict = {}
sequence = []

for record in SeqIO.parse(file, "fasta"):
    sequence.append(str(record.seq))
sequence_dict['3KYN'] = ''.join(sequence)
#     sequence_dict[record.id.split('|')[0]] = str(record.seq)

In [177]:
sequence_dict['3KYN'][275]

'M'

In [157]:
from Bio.SeqUtils import IUPACData
def simple_aa(three_letter_code):
    return IUPACData.protein_letters_3to1.get(three_letter_code.capitalize())

In [178]:
count = 0

for i in model:
    print(i)
    print(count)
    for residue1 in i:
        try:
            print(residue1.id[1], simple_aa(residue1.resname),sequence_dict['3KYN'][count], residue1['CA'])
            count+=1
        except:
            continue
#         print(residue1.id[1], residue1.resname)



<Chain id=A>
0
1 M M <Atom CA>
2 S S <Atom CA>
3 H H <Atom CA>
4 S S <Atom CA>
5 M M <Atom CA>
6 R R <Atom CA>
7 Y Y <Atom CA>
8 F F <Atom CA>
9 S S <Atom CA>
10 A A <Atom CA>
11 A A <Atom CA>
12 V V <Atom CA>
13 S S <Atom CA>
14 R R <Atom CA>
15 P P <Atom CA>
16 G G <Atom CA>
17 R R <Atom CA>
18 G G <Atom CA>
19 E E <Atom CA>
20 P P <Atom CA>
21 R R <Atom CA>
22 F F <Atom CA>
23 I I <Atom CA>
24 A A <Atom CA>
25 M M <Atom CA>
26 G G <Atom CA>
27 Y Y <Atom CA>
28 V V <Atom CA>
29 D D <Atom CA>
30 D D <Atom CA>
31 T T <Atom CA>
32 Q Q <Atom CA>
33 F F <Atom CA>
34 V V <Atom CA>
35 R R <Atom CA>
36 F F <Atom CA>
37 D D <Atom CA>
38 S S <Atom CA>
39 D D <Atom CA>
40 S S <Atom CA>
41 A A <Atom CA>
42 S S <Atom CA>
43 P P <Atom CA>
44 R R <Atom CA>
45 M M <Atom CA>
46 E E <Atom CA>
47 P P <Atom CA>
48 R R <Atom CA>
49 A A <Atom CA>
50 P P <Atom CA>
51 W W <Atom CA>
52 V V <Atom CA>
53 E E <Atom CA>
54 Q Q <Atom CA>
55 E E <Atom CA>
56 G G <Atom CA>
57 P P <Atom CA>
58 E E <Atom CA>
59 Y Y <