<h3> Extracting protein sequences' features using Bert pretrained-model <h3>

<b>1. Load necessry libraries including huggingface transformers<b>

In [1]:
import torch
from transformers import BertModel, BertTokenizer
import re

<b>2. Set the file location of Bert and the vocabulary file<b>

In [2]:
modelPath = '/media/agemagician/Disk2/share_files/summit/uniref100/bert/30_layers/'
vocabPath = '/media/agemagician/Disk2/share_files/summit/uniref100/bert/30_layers/vocab.txt'

<b>3. Load the vocabulary and Bert Model<b>

In [3]:
vocab = BertTokenizer(vocabPath, do_lower_case=False )

In [4]:
model = BertModel.from_pretrained(modelPath)

<b>4. Load the model into the GPU if avilabile and switch to inference mode<b>

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [6]:
model = model.to(device)
model = model.eval()

<b>5. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)<b>

In [7]:
sequences_Example = ["A E T C Z A O","S K T Z P"]

In [8]:
sequences_Example = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences_Example]

<b>6. Tokenize, encode sequences and load it into the GPU if possibile<b>

In [9]:
ids = vocab.batch_encode_plus(sequences_Example, add_special_tokens=True, pad_to_max_length=True)

In [10]:
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

<b>7. Extracting sequences' features and load it into the CPU if needed<b>

In [11]:
with torch.no_grad():
    embedding = model(input_ids=input_ids,attention_mask=attention_mask)[0]

In [12]:
embedding = embedding.cpu().numpy()

<b>8. Remove padding ([PAD]) and special tokens ([CLS],[SEP]) that is added by Bert model<b>

In [13]:
features = [] 
for seq_num in range(len(embedding)):
    seq_len = (attention_mask[seq_num] == 1).sum()
    seq_emd = embedding[seq_num][1:seq_len-1]
    features.append(seq_emd)

In [14]:
print(features)

[array([[ 0.3434632 , -0.11853541,  0.11670393, ...,  0.12733918,
        -0.33663505, -0.4125227 ],
       [ 0.37897065, -0.2309679 ,  0.00771705, ..., -0.44364253,
        -0.56715494, -0.63721484],
       [ 0.28404582, -0.5655926 ,  0.06742072, ..., -0.80574137,
        -0.64692605, -0.66759783],
       ...,
       [ 0.38815227, -0.2500954 ,  0.1672174 , ..., -0.53784174,
        -0.37415683, -0.5270342 ],
       [ 0.69124544, -0.44460136,  0.00450774, ..., -0.3540218 ,
        -0.46115145, -0.7682186 ],
       [ 0.59799904, -0.23531727,  0.09535453, ..., -0.5865146 ,
        -0.61789775, -0.80608463]], dtype=float32), array([[ 0.36907712, -0.5154992 ,  0.2521861 , ..., -0.5068797 ,
        -0.5616515 , -0.6157087 ],
       [ 0.32650244, -0.5830976 ,  0.19571821, ..., -0.57145095,
        -0.25643727, -0.47911468],
       [ 0.0239465 , -0.10554147,  0.27909502, ..., -0.5566634 ,
        -0.3446969 , -0.42556283],
       [ 0.467527  , -0.51940775,  0.0586937 , ..., -0.28577682,
     