# ESM Embeddings for NetsurfP

This notebook will try using embeddings from facebook researchs esm pre-trained model that uses a deep representation of 650M Unirep50 sequences.

https://github.com/facebookresearch/esm

## Libraries and pre-trained model

**Download pretrained model** 

**Load libraries**

In [1]:
import torch
import math
import esm
import numpy as np
import pdb

#pip install https://github.com/facebookresearch/esm

## Data

The data is loaded and converted back to aminoacids for the embedding

In [2]:
data_dir = "../../data/nsp2/training_data/"

datasets = [
    ("train_hhblits", np.load(data_dir + "Train_HHblits_small.npz")),
    #("CB513_hhblits", np.load(dir_path + "CB513_HHblits.npz")),
    #("TS115_hhblits", np.load(dir_path + "TS115_HHblits.npz")),
    ("CASP12_HHblits", np.load(data_dir + "CASP12_HHblits.npz")),
    
    #("Train_MMseqs", np.load(dir_path + "Train_MMseqs.npz")),
    #("CB513_MMseqs", np.load(dir_path + "CB513_MMseqs.npz")),
    #("TS115_MMseqs", np.load(dir_path + "TS115_MMseqs.npz")),
    #("CASP12_MMseqs", np.load(dir_path + "CASP12_MMseqs.npz")),
]

A function converts sparse encoding back to amino acid sequence.

In [4]:
def sparse_to_sequence(dataset):
    data = []

    aa_decode = np.array(["N","A","C","D","E","F","G","H","I","K","L","M","N","P","Q","R","S","T","V","W","Y"])

    # get the amino acid encoding and apply decode mask
    for seq_id in range(dataset.shape[0]):
        seq_mask = dataset[seq_id, dataset[seq_id, :, 50] == 1, :20]
        aa_idx = np.argmax(seq_mask, axis=1)

        aa_sequence = str()
        for idx in aa_idx:
            aa_sequence += aa_decode[idx]

        # store decoded sequence
        data.append(("protein" + str(seq_id), aa_sequence))
    
    # remove later
    data.append(("protein_x", "N"*1632))
        
    return data

Display first sequence to check conversion

In [5]:
sparse_to_sequence(datasets[0][1]['data'])[0]

('protein0',
 'LISNWHNIPQPHRETIRGERQPKDDQKFKHDTPNNHKRQTFCFSPCMKRFNDINTPTITINKNCNPEDTTGRKNIVIQPSKFPGCERNFDFKWSGLINKQNCDCQKRNKGRTWTCPVCVDQTLFCFDQPERSKIRSTDNHVNFHINSDNNTRDDEFKNNEKNCPHGETGRPDKKRQWNCKCNIFQDQNHNICKFNTEKTFHFFIKRCFGQGCTQNNCWCCVRSNRDKFGNFKMFCHKTVMNTKDCNEDKRRLFHQTCNCSKIGPKNKSFCDCQKDKDVGPNKKQFDLNPSHFFFHFPRQKSLKKKPKNGHFPTPNFTVNNNTQDRTNRKK')

## ESM1b Embedding

The pre-trained model is instantiated with the batch converter that converts the sequences to batch tokens.

In [6]:
# load 34 layer model
model_path = "../../models/esm1b_t33_650M_UR50S.pt"

model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_path)
batch_converter = alphabet.get_batch_converter()



In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Try to move model to GPU if possible
try:
    model = model.to(device)
except RuntimeError:
    device = 'cpu'
    model = model.to(device)
    
model = model.eval()

Convert the sequences to batch tokens then to token embeddings: https://github.com/facebookresearch/esm/issues/21

Since the pretrained models only is able todo embeddings with max 1024 residues, then an overlap method is implemented to concatenate the longer sequences. Thus, if a sequence is larger than 1024 residues, then the next iter of residues are concatenated by overlapping X residues which then takes X/2 from both of the concatenating sequences.

In [8]:
def overlap_method(data, max_embedding = 1024, offset = 200, batches = 30):
    # prepare data
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    # calculate how many iterations for the overlaps
    sequence_iter = math.floor(batch_tokens.shape[1]/max_embedding)
    sequence_iter = math.floor((max_embedding*(sequence_iter+1)+offset*sequence_iter)/max_embedding)
    
    # extract sequences in batches to avoid RAM issues
    result = torch.tensor([])
    
    for batch in range(0, batch_tokens.shape[0], batches):
        # extract per-residue embeddings (on CPU)
        batch_result = None
        for i in range(sequence_iter):
            embedding = None
            with torch.no_grad():
                if i == 0:
                    # no overlap on first iter
                    tmp_embedding = batch_tokens[batch:batch+batches, i:max_embedding+i]
                    embedding = model(tmp_embedding.to(device), repr_layers=[33])["representations"][33].cpu()
                else:
                    tmp_embedding = batch_tokens[batch:batch+batches, (max_embedding*i-offset*i):(max_embedding*(i+1)-offset*i)]
                    embedding = model(tmp_embedding.to(device), repr_layers=[33])["representations"][33].cpu()
                    
                    # concatenate by overlap for > max_embedding sequences
                    overlap = int(offset/2)
                    embedding = torch.cat([batch_result[:, :-overlap, :], embedding[:, overlap:, :]], dim=1)
                    
                batch_result = embedding
            
        # concatenate finished sequences
        result = torch.cat([result, batch_result], dim=0)
            
        print("Finished embedding batch {} of {}".format(batch+batches, batch_tokens.shape[0]))

    # add extrapolated zeros
    for idx_seq in range(len(batch_strs)):
        result[idx_seq, len(batch_strs[idx_seq]):, :] = 0
        
    return result

**Save embeddings**

The embeddings are merged with the labels from the original datasets

In [9]:
def add_embedding(name, dataset):
    # create embedding
    result = overlap_method(sparse_to_sequence(dataset['data']))
    
    dataset = torch.tensor(dataset['data'])
    
    #remove start and end token
    result = result[:-1, 1:result.shape[1]-1, :]
    
    #merge labels from original dataset and save
    result = torch.cat([dataset, result], dim=2).numpy()
    np.savez_compressed(data_dir.replace("nsp2", "nsp3") + "esm1b_" + name + ".npz", data=result)
    
    print(name + " saved")

**Add embeddings to datasets and save to file**

In [None]:
for name, data in datasets:
    add_embedding(name, data)