# TransProt Embeddings for NetsurfP

This notebook will try using embeddings from ProtTrans that provides state of the art pretrained language models for proteins.

https://github.com/agemagician/ProtTrans

**Load libraries**

In [1]:
import torch
from transformers import T5EncoderModel, T5Tokenizer
import re
import numpy as np
import gc
import pdb

#pip install sentencepiece
#pip install transformers

## Data

The data is loaded and converted back to aminoacids for the embedding

In [2]:
dir_path = "/home/projects/ht3_aim/people/erikie/nsp/data/nsp2/training_data/"

datasets = [
    ("train_hhblits", np.load(dir_path + "Train_HHblits.npz")),
    ("CB513_hhblits", np.load(dir_path + "CB513_HHblits.npz")),
    ("TS115_hhblits", np.load(dir_path + "TS115_HHblits.npz")),
    ("CASP12_HHblits", np.load(dir_path + "CASP12_HHblits.npz")),
    
    ("Train_MMseqs", np.load(dir_path + "Train_MMseqs.npz")),
    ("CB513_MMseqs", np.load(dir_path + "CB513_MMseqs.npz")),
    ("TS115_MMseqs", np.load(dir_path + "TS115_MMseqs.npz")),
    ("CASP12_MMseqs", np.load(dir_path + "CASP12_MMseqs.npz")),
]

A function converts sparse encoding back to amino acid sequence.

In [3]:
def sparse_to_sequence(dataset):
    data = []

    aa_decode = np.array(["N","A","C","D","E","F","G","H","I","K","L","M","N","P","Q","R","S","T","V","W","Y"])

    dataset = dataset['data']
    # get the amino acid encoding and apply decode mask
    for seq_id in range(dataset.shape[0]):
        seq_mask = dataset[seq_id, dataset[seq_id, :, 50] == 1, :20]
        aa_idx = np.argmax(seq_mask, axis=1)

        aa_sequence = str()
        for idx in aa_idx:
            aa_sequence += aa_decode[idx]

        # store decoded sequence
        data.append(aa_sequence)
        
    # map rarely occured amino acids (U,Z,O,B) to (X)
    data = [" ".join(re.sub(r"[UZOB]", "X", sequence)) for sequence in data]
    
    return data

Display part of first sequence to check conversion

In [4]:
sparse_to_sequence(datasets[0][1])[0][:100]

'L I S N W H N I P Q P H R E T I R G E R Q P K D D Q K F K H D T P N N H K R Q T F C F S P C M K R F '

## ProtTrans Embedding

The pre-trained model is instantiated and a tokenizer to convert the sequence to input data for the model

In [None]:
# load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("/home/projects/ht3_aim/people/erikie/nsp/models/prot_t5_xl_bfd", do_lower_case=False)
model = T5EncoderModel.from_pretrained("/home/projects/ht3_aim/people/erikie/nsp/models/prot_t5_xl_bfd")

Load the model to GPU if possible and switch to eval mode to not train on the weights of the model

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Try to move model to GPU if possible
try:
    model = model.to(device)
except RuntimeError:
    device = 'cpu'
    model = model.to(device)
    
model = model.eval()

In [None]:
def embeddings(data, batches = 5):
    result = torch.tensor([])
    
    # tokenize the sequences
    ids = tokenizer.batch_encode_plus(data, add_special_tokens=True, padding=True)

    input_ids = torch.tensor(ids['input_ids'])
    attention_mask = torch.tensor(ids['attention_mask'])
    
    for batch in range(0, len(data), batches):

        # extract the embeddings
        with torch.no_grad():
            embedding = model(input_ids = input_ids[batch:batch+batches].to(device),
                              attention_mask = attention_mask[batch:batch+batches].to(device))

        # keep the last hidden state and move to CPU
        embedding = embedding.last_hidden_state.cpu()
        
        result = torch.cat([result, embedding], dim=0)
        
    # add extrapolated zeros
    for idx_seq in range(len(data)):
        result[idx_seq, len(data[idx_seq]):, :] = 0
        
    return result

**Save embeddings**

The embeddings are merged with the labels from the original datasets

In [None]:
def add_embedding(name, dataset):
    
    # create embedding
    result = embeddings(sparse_to_sequence(dataset))
    
    dataset = torch.tensor(dataset['data'])
    result = result[:, 1:result.shape[1]+1, :]
    
    #merge labels from original dataset and save
    result = torch.cat([dataset, result], dim=2)
    np.savez_compressed(dir_path + "prot_t5_xl_bfd_" + name + ".npz", data=result.numpy())
    
    print(name + " saved")

**Add embeddings to datasets and save to file**

In [None]:
for name, data in datasets:
    add_embedding(name, data)