<h3> Extracting protein sequences' features using Albert pretrained-model <h3>

<b>1. Load necessry libraries including huggingface transformers<b>

In [1]:
import torch
from transformers import AlbertModel, AlbertTokenizer
import re

<b>2. Set the file location of Albert and the vocabulary file<b>

In [2]:
modelPath = '/media/agemagician/Disk2/projects/bio_google_tpu_project/ahmed/models_pytorch/albert'
vocabPath = '/media/agemagician/Disk2/projects/bio_google_tpu_project/ahmed/models_pytorch/albert/albert_vocab_model.model'

<b>3. Load the vocabulary and Albert Model<b>

In [3]:
vocab = AlbertTokenizer(vocabPath, do_lower_case=False )

In [4]:
model = AlbertModel.from_pretrained(modelPath)

<b>4. Load the model into the GPU if avilabile and switch to inference mode<b>

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [6]:
model = model.to(device)
model = model.eval()

<b>5. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)<b>

In [7]:
sequences_Example = ["A E T C U Z A O","S K T Z P"]

In [8]:
sequences_Example = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences_Example]

<b>6. Tokenize, encode sequences and load it into the GPU if possibile<b>

In [9]:
ids = vocab.batch_encode_plus(sequences_Example, add_special_tokens=True, pad_to_max_length=True)

In [10]:
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

<b>7. Extracting sequences' features and load it into the CPU if needed<b>

In [11]:
with torch.no_grad():
    embedding = model(input_ids=input_ids,attention_mask=attention_mask)[0]

In [12]:
embedding = embedding.cpu().numpy()

<b>8. Remove padding ([PAD]) and special tokens ([CLS],[SEP]) that is added by Albert model<b>

In [13]:
features = [] 
for seq_num in range(len(embedding)):
    seq_len = (attention_mask[seq_num] == 1).sum()
    seq_emd = embedding[seq_num][1:seq_len-1]
    features.append(seq_emd)

In [14]:
print(features)

[array([[-0.08361891,  0.03106401, -0.04620098, ..., -0.10921817,
        -0.03090221,  0.0826586 ],
       [-0.34378687, -0.01706827,  0.10808687, ..., -0.2354121 ,
         0.16600646,  0.16465808],
       [-0.19779804,  0.03071388, -0.08432098, ..., -0.18379787,
         0.11854444,  0.15428072],
       ...,
       [-0.1822614 , -0.20597872, -0.12547323, ..., -0.17655548,
         0.14813331, -0.24152306],
       [-0.09552477, -0.21021341, -0.12658282, ..., -0.19666438,
         0.10713978, -0.05355697],
       [-0.10423692, -0.20334783, -0.11877518, ..., -0.10915546,
         0.12062257, -0.208226  ]], dtype=float32), array([[-0.12822573,  0.10716011, -0.01129843, ..., -0.11677988,
         0.0390483 ,  0.10009992],
       [-0.36793005,  0.03902509, -0.11569419, ..., -0.22101341,
         0.10926705,  0.13324389],
       [-0.26249614,  0.11969303, -0.05417374, ..., -0.20998071,
         0.11928424,  0.21757752],
       [-0.03641684, -0.08373144,  0.0571348 , ..., -0.12186526,
     