In [1]:
# Generate ESM embedding (an embedding method of protein sequences)
from esm import FastaBatchedDataset, pretrained 
import torch
import pandas as pd
from Bio.PDB import PDBParser

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# convert the residue names of length to a string of only a character
three_to_one = {'ALA':	'A',
'ARG':	'R',
'ASN':	'N',
'ASP':	'D',
'CYS':	'C',
'GLN':	'Q',
'GLU':	'E',
'GLY':	'G',
'HIS':	'H',
'ILE':	'I',
'LEU':	'L',
'LYS':	'K',
'MET':	'M',
'MSE':  'M', # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
'PHE':	'F',
'PRO':	'P',
'PYL':	'O',
'SER':	'S',
'SEC':	'U',
'THR':	'T',
'TRP':	'W',
'TYR':	'Y',
'VAL':	'V',
'ASX':	'B',
'GLX':	'Z',
'XAA':	'X',
'XLE':	'J'}

In [3]:
# set the "nan" elements as None in a list
def set_nones(l):
    return [s if str(s) != 'nan' else None for s in l]

In [4]:
# get the protein residue sequences from a pdb file
def get_sequences_from_pdbfile(file_path):
    # pdb parser
    biopython_parser = PDBParser()
    # load a protein from a pdb file, and get its structure
    structure = biopython_parser.get_structure('random_id', file_path)
    structure = structure[0]
    # obtain the sequence of protein residues
    sequence = None
    # enumerate the chains of the protein
    for i, chain in enumerate(structure):
        # seq: the generated protein sequence of each chain
        seq = ''
        # enumerate the residues and their id 
        for res_idx, residue in enumerate(chain):
            # pass the water residue
            if residue.get_resname() == 'HOH':
                continue
            # judge whether the residue is an amino acid
            # only if c_alpha, n, c exist at the same time
            # will the residue be an amino acid 
            c_alpha, n, c = False, False, False
            for atom in residue:
                if atom.name == 'CA':
                    c_alpha = True
                elif atom.name == 'N':
                    n = True
                elif atom.name == 'C':
                    c = True
            if c_alpha and n and c:  # only append residue if it is an amino acid
                try:
                    # if this residue is an amino acid
                    # get its name(3 characters) and convert it to a new name(1 character)
                    seq += three_to_one[residue.get_resname()]
                except Exception as e:
                    # If encounter an unknown amino acid(not in the three_to_one dictionary)
                    # Then represent it as - in the protein sequence string
                    seq += '-'
                    print("encountered unknown AA: ", residue.get_resname(), ' in the complex. Replacing it with a dash - .')

        if sequence is None:
            sequence = seq
        else:
            # sequences of different chains are seperated by a ":"
            sequence += (":" + seq)

    return sequence

In [5]:
# get protein sequences from multiple protein pdb files
"""
protein_files: 
"""
def get_sequences(protein_files, protein_sequences):
    new_sequences = []
    # each element of new_sequences is a sequence string of a complete protein
    for i in range(len(protein_files)):
        if protein_files[i] is not None:
            new_sequences.append(get_sequences_from_pdbfile(protein_files[i]))
        else:
            new_sequences.append(protein_sequences[i])
    return new_sequences

In [6]:
# 1 Load pretrained ESM model
model_location = "esm2_t33_650M_UR50D"
# load a pretrained ESM model from the Internet
model, alphabet = pretrained.load_model_and_alphabet(model_location)
# evaluate mode
model.eval()
if torch.cuda.is_available():
    model = model.cuda()

In [7]:
# print out the structure of this pretrained ESM deep learning model
print(model)

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bia

In [8]:
# 2 get protein names, sequences and ligand descriptions
example_csv_path = "data/protein_ligand_example_csv.csv"
# read a csv file containing paths, complex names and protein sequences of complexes, as a dataframe
df = pd.read_csv(example_csv_path)
complex_name_list = set_nones(df['complex_name'].tolist())
complex_name_list = [name if name is not None else f"complex_{i}" for i, name in enumerate(complex_name_list)]
protein_path_list = set_nones(df['protein_path'].tolist())
protein_sequence_list = set_nones(df['protein_sequence'].tolist())
ligand_description_list = set_nones(df['ligand_description'].tolist())
protein_sequences = get_sequences(protein_files=protein_path_list, protein_sequences=protein_sequence_list)

In [9]:
# let's see the relevant format and content in the dataframe
print(df['complex_name'].tolist())
print(df['protein_path'].tolist())
print(df['protein_sequence'].tolist())
print(df['ligand_description'].tolist())

[nan, nan]
['data/1a0q/1a0q_protein_processed.pdb', 'data/1a0q/1a0q_protein_processed.pdb']
[nan, nan]
['data/1a0q/1a0q_ligand.sdf', 'COc(cc1)ccc1C#N']


In [10]:
# comparison: after "set_nones" conversion of the above lists, nan is replaced by None
print(complex_name_list)
print(protein_path_list)
print(protein_sequence_list)
print(ligand_description_list)

['complex_0', 'complex_1']
['data/1a0q/1a0q_protein_processed.pdb', 'data/1a0q/1a0q_protein_processed.pdb']
[None, None]
['data/1a0q/1a0q_ligand.sdf', 'COc(cc1)ccc1C#N']


In [11]:
# see the protein sequences, where each amino acid is represented by a character
protein_sequences

['IELTQSPSSLSASLGGKVTITCKASQDIKKYIGWYQHKPGKQPRLLIHYTSTLLPGIPSRFRGSGSGRDYSFSISNLEPEDIATYYCLQYYNLRTFGGGTKLEIKRADAAPTVSIFPPSSEQLTSGGASVVCFLNNFYSKDINVKWKIDGSERQNGVLNSWTDQDSKDSTYSMSSTLTLTKDEYERHNSYTCEATHKTSTSPIVKSFNRNE:VQLQESDAELVKPGASVKISCKASGYTFTDHVIHWVKQKPEQGLEWIGYISPGNGDIKYNEKFKGKATLTADKSSSTAYMQLNSLTSEDSAVYLCKRGYYVDYWGQGTTLTVSSAKTTPPSVYPLAPSMVTLGCLVKGYFPEPVTVTWNSGSLSSGVHTFPAVLQSDLYTLSSSVTVPSSTWPSETVTCNVAHPASSTKVDKKIE',
 'IELTQSPSSLSASLGGKVTITCKASQDIKKYIGWYQHKPGKQPRLLIHYTSTLLPGIPSRFRGSGSGRDYSFSISNLEPEDIATYYCLQYYNLRTFGGGTKLEIKRADAAPTVSIFPPSSEQLTSGGASVVCFLNNFYSKDINVKWKIDGSERQNGVLNSWTDQDSKDSTYSMSSTLTLTKDEYERHNSYTCEATHKTSTSPIVKSFNRNE:VQLQESDAELVKPGASVKISCKASGYTFTDHVIHWVKQKPEQGLEWIGYISPGNGDIKYNEKFKGKATLTADKSSSTAYMQLNSLTSEDSAVYLCKRGYYVDYWGQGTTLTVSSAKTTPPSVYPLAPSMVTLGCLVKGYFPEPVTVTWNSGSLSSGVHTFPAVLQSDLYTLSSSVTVPSSTWPSETVTCNVAHPASSTKVDKKIE']

In [12]:
len(protein_sequences[0]), len(protein_sequences[1])

(417, 417)

In [13]:
# 3 Protein amino acids: from simple characters to information abundant tensor embeddings
# To represent the proteins in a from rich of information
# We use the pretrained ESM model to convert the protein sequences
# from a string of letters 
# into a group of tensor

In [14]:
# 3.1 Convert the protein labels and sequences into a form recognized by the pretrained ESM lanaguage model
labels, sequences = [], []
for i in range(len(protein_sequences)):
    s = protein_sequences[i].split(':')
    sequences.extend(s)
    labels.extend([complex_name_list[i] + '_chain_' + str(j) for j in range(len(s))])

In [15]:
print(labels) # standardized labels

['complex_0_chain_0', 'complex_0_chain_1', 'complex_1_chain_0', 'complex_1_chain_1']


In [16]:
def compute_ESM_embeddings(model, alphabet, labels, sequences):
    # settings used
    toks_per_batch = 4096
    repr_layers = [33]
    include = "per_tok"
    truncation_seq_length = 1022

    dataset = FastaBatchedDataset(labels, sequences)
    batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
    )

    assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]
    embeddings = {}

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)")
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)
            representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}

            for i, label in enumerate(labels):
                truncate_len = min(truncation_seq_length, len(strs[i]))
                embeddings[label] = representations[33][i, 1: truncate_len + 1].clone()
    return embeddings

In [17]:
lm_embeddings = compute_ESM_embeddings(model, alphabet, labels, sequences)

Processing 1 of 1 batches (4 sequences)


In [24]:
# each residue is embedded as a vector of 1280 elements
for i, key in enumerate(lm_embeddings):
    print(lm_embeddings[key].shape)

torch.Size([205, 1280])
torch.Size([205, 1280])
torch.Size([211, 1280])
torch.Size([211, 1280])
