### load modules and download pre-trained ESM models

In [4]:
import torch
import esm

# Load ESM-2 model (Using smaller version here with 6 layers and 8 Million parameters)
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
#model, alphabet = esm.pretrained.esm2_t36_3B_UR50D()
#model, alphabet = esm.pretrained.esm2_t48_15B_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

import pandas as pd
import numpy as np

### read 567 sdabs in fasta file

In [5]:
def read_fasta(fp):
        name, seq = None, []
        for line in fp:
            line = line.rstrip()
            if line.startswith(">"):
                if name: yield (name, ''.join(seq))
                name, seq = line, []
            else:
                seq.append(line)
        if name: yield (name, ''.join(seq))

data = []
with open('../data/sdabs.fasta') as fp:
    for name, seq in read_fasta(fp):
        data.append((name, seq))

### generate the embeddings

In [7]:
sequence_representations_list = []
chunk_size = 25
for i in range(0, len(data), chunk_size):
    chunk = data[i:i+chunk_size]
    print(i+chunk_size)
    batch_labels, batch_strs, batch_tokens = batch_converter(chunk)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[6], return_contacts=True) # ESM 650M
        #results = model(batch_tokens, repr_layers=[36], return_contacts=True) # ESM 3B
        #results = model(batch_tokens, repr_layers=[48], return_contacts=True) # ESM 15B
    token_representations = results["representations"][6] # ESM 650M
    #token_representations = results["representations"][36] # ESM 3B
    #token_representations = results["representations"][48] # ESM 15B

    # Generate per-sequence representations via averaging
    # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
    sequence_representations = []
    for i, tokens_len in enumerate(batch_lens):
        sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

    sequence_representations_list.append(sequence_representations)

25
50
75
100
125
150
175
200
225
250
275
300
325
350
375
400
425
450
475
500
525
550
575


In [8]:
flat_list = [item for sublist in sequence_representations_list for item in sublist]

In [9]:
X = torch.stack(flat_list, dim=0).cpu().detach().numpy()
X.shape

(567, 320)

### save into csv files

In [10]:
np.savetxt("../data/sdab_protein_embeddings_t6.csv", X, delimiter=",")

#np.savetxt("sdab_data_master_list_t6.csv", X, delimiter=",")

#np.savetxt("sdab_data_master_list_650M.csv", X, delimiter=",")
#np.savetxt("sdab_data_master_list_3B.csv", X, delimiter=",")
#np.savetxt("sdab_data_master_list_15B.csv", X, delimiter=",")