### ESM Encode

Use the ESM model to encode custom pdbs, and pull language embeddings. Also try cofolding en masse.

In [20]:
## Run once cell

%load_ext autoreload
%autoreload 2

import os
os.chdir('..')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
## Main import cell

import matplotlib.pyplot as plt
import requests


import esm
import torch
import numpy as np
from Bio import SeqIO
from Bio.PDB import PDBList, PDBParser



In [3]:
%%time



def fetch_pdb_sequences(pdb_ids, output_fasta):
    """Fetch full pdb sequences directly and store in files. 
        Produces a single large fasta files with all proteins.
    """
    pdbl = PDBList()
    sequences = []
    for pdb_id in pdb_ids:
        pdb_file = pdbl.retrieve_pdb_file(pdb_id, pdir='./pdb_files', file_format='pdb')
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure(pdb_id, pdb_file)
        for model in structure:
            for chain in model:
                sequence = "".join([residue.resname for residue in chain.get_residues()])
                sequences.append(f">{pdb_id}_{chain.id}\n{sequence}")
    with open(output_fasta, 'w') as f:
        f.write("\n".join(sequences))

pdb_ids = ["1AIR", "1HIN", "1DAB"]
output_fasta = "sequences.fasta"
fetch_pdb_sequences(pdb_ids, output_fasta)


Downloading PDB structure '1AIR'...
Downloading PDB structure '1HIN'...
Downloading PDB structure '1DAB'...
CPU times: user 406 ms, sys: 16.9 ms, total: 423 ms
Wall time: 1.9 s


In [7]:
lst = [1,2,3]
lst2 = [4,6]
lst += lst2
lst

[1, 2, 3, 4, 6]

### ESMEmbeddings

Here is how to get sequence representations from ESM:

In [10]:
%%time

from typing import List, Tuple

esm_model = esm.pretrained.esm2_t33_650M_UR50D

class ESMEmbeddings:
    def __init__(self, esm_model):

        # Load ESM-2 model
        self.model, self.alphabet = esm_model()
        self.batch_converter = self.alphabet.get_batch_converter()
        self.model.eval()  # disables dropout for deterministic results

        self.data = []
        self.representations = []

    def get_embeddings(self, data: List[Tuple]):
        """Given a list of sequences, return their representations, as predicted
            by ESM
        """
        model, alphabet, batch_converter = self.model, self.alphabet, self.batch_converter
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

        # Extract per-residue representations (on CPU)
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[33], return_contacts=True)
        token_representations = results["representations"][33]

        # Generate per-sequence representations via averaging
        # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
        sequence_representations = []
        for i, tokens_len in enumerate(batch_lens):
            sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
        
        # Make accessible by object
        self.data += data
        self.representations += sequence_representations
        self.batch_lens = batch_lens
        self.results = results

        return sequence_representations

    def show_attention(self):
        """Show embeddings as attention map"""
        for (_, seq), tokens_len, attention_contacts in zip(self.data, self.batch_lens, self.results["contacts"]):
            plt.matshow(attention_contacts[: tokens_len, : tokens_len])
            plt.title(seq)
            plt.show()        

esm_embeddings = ESMEmbeddings(esm_model)

    ## Takes about 3 min to finish this cell from scratch

CPU times: user 14.6 s, sys: 7.05 s, total: 21.7 s
Wall time: 52.2 s


In [11]:

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]

representations = esm_embeddings.get_embeddings(data)

# # This is a long output but potentially useful
# esm_embeddings.show_attention()


### Structure Prediction and PLDDT

Here is how to evaluate sequences by folding them using ESM and calculating plddt:

In [21]:
import biotite.structure.io as bsio


esm_predictions_model = esm.pretrained.esmfold_v1

class ESMPredictions:
    def __init__(self, esm_model):
        model = esm_model()
        self.model = model.eval().cuda()

    def single_sequence(self, sequence, save_to="result"):
        """Predict the structure for a sequence sequence and get its score
        """
        # Optionally, uncomment to set a chunk size for axial attention. This can help reduce memory.
        # Lower sizes will have lower memory requirements at the cost of increased speed.
        # model.set_chunk_size(128)

        # Multimer prediction can be done with chains separated by ':'

        with torch.no_grad():
            output = self.model.infer_pdb(sequence)

        if save_to != "":
            with open(f"{save_to}.pdb", "w") as f:
                f.write(output)
            struct = bsio.load_structure(f"{save_to}.pdb", extra_fields=["b_factor"])
            pLDDT = struct.b_factor.mean()  # this will be the pLDDT

            return struct, pLDDT

        return output

esm_predictions = ESMPredictions(esm_predictions_model)

ModuleNotFoundError: No module named 'openfold'

In [None]:
# 88.3
sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"

struct, pLDDT = esm_predictions.single_sequence(sequence)
print(pLDDT)