### Computing ESM-2 embeddings with internet access:

In [None]:
import torch
import esm # install via pip install fair-esm
import numpy as np

#when you call this model the following function for the first time, this will download the model from the internet
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval();  

device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")

#This function takes a protein sequence and returns the mean-embedding of the protein
def get_esm_embedding(model, alphabet, sequence):
    sequences = [("P", sequence)]
    batch_labels, batch_strs, batch_tokens = batch_converter(sequences)
    with torch.no_grad():
        results = model(batch_tokens.to(device), repr_layers=[6], return_contacts=False)
    token_representations = results["representations"][6].cpu().numpy()
    #we don't want to use the first and last token:
    protein_representations = np.mean(token_representations[0,1:-1,:], axis=0)
    return protein_representations

#example
sequence = "MAGLQKQK"
embedding = get_esm_embedding(model, alphabet, sequence)
print(embedding.shape)

### Computing ESM-2 emebddings without internet access (on HPC):

You first need to download the model paramaters of the ESM-2 model ("esm2_t6_8M_UR50D-contact-regression.pt", "esm2_t6_8M_UR50D.pt") from data/worksheet4. You need to store those files on the HPC. The following code shows you how to load the model that is stored on a local file without downloading it from the internet:

In [None]:
model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_location = "/path_to_parameter_weights/esm2_t6_8M_UR50D.pt")
batch_converter = alphabet.get_batch_converter()
model.eval();

In [None]:
#The following code is identical to the code above for the downloaded model:

In [None]:
device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
#This function takes a protein sequence and returns the mean-embedding of the protein
def get_esm_embedding(model, alphabet, sequence):
    sequences = [("P", sequence)]
    batch_labels, batch_strs, batch_tokens = batch_converter(sequences)
    with torch.no_grad():
        results = model(batch_tokens.to(device), repr_layers=[6], return_contacts=False)
    token_representations = results["representations"][6].cpu().numpy()
    #we don't want to use the first and last token:
    protein_representations = np.mean(token_representations[0,1:-1,:], axis=0)
    return protein_representations


embedding = get_esm_embedding(model, alphabet, sequence)
print(embedding.shape)