<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Embedding/Advanced/ProtBert_BFD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3> Extracting protein sequences' features using ProtBert-BFD pretrained-model <h3>

<b>1. Load necessry libraries including huggingface transformers<b>

In [1]:
!pip install -q transformers

[K     |████████████████████████████████| 778kB 2.8MB/s 
[K     |████████████████████████████████| 890kB 16.4MB/s 
[K     |████████████████████████████████| 1.1MB 18.0MB/s 
[K     |████████████████████████████████| 3.0MB 25.0MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [2]:
import torch
from transformers import BertModel, BertTokenizer
import re
import os
import requests
from tqdm.auto import tqdm

<b>2. Set the url location of ProtBert-BFD and the vocabulary file<b>

In [3]:
modelUrl = 'https://www.dropbox.com/s/luv2r115bumo90z/pytorch_model.bin?dl=1'
configUrl = 'https://www.dropbox.com/s/33en5mbl4wf27om/bert_config.json?dl=1'
vocabUrl = 'https://www.dropbox.com/s/tffddoqfubkfcsw/vocab.txt?dl=1'

<b>3. Download ProtBert-BFD models and vocabulary files<b>

In [4]:
downloadFolderPath = 'models/ProtBert-BFD/'

In [5]:
modelFolderPath = downloadFolderPath

modelFilePath = os.path.join(modelFolderPath, 'pytorch_model.bin')

configFilePath = os.path.join(modelFolderPath, 'config.json')

vocabFilePath = os.path.join(modelFolderPath, 'vocab.txt')

In [6]:
if not os.path.exists(modelFolderPath):
    os.makedirs(modelFolderPath)

In [7]:
def download_file(url, filename):
  response = requests.get(url, stream=True)
  with tqdm.wrapattr(open(filename, "wb"), "write", miniters=1,
                    total=int(response.headers.get('content-length', 0)),
                    desc=filename) as fout:
      for chunk in response.iter_content(chunk_size=4096):
          fout.write(chunk)

In [8]:
if not os.path.exists(modelFilePath):
    download_file(modelUrl, modelFilePath)

if not os.path.exists(configFilePath):
    download_file(configUrl, configFilePath)

if not os.path.exists(vocabFilePath):
    download_file(vocabUrl, vocabFilePath)

HBox(children=(FloatProgress(value=0.0, description='models/ProtBert-BFD/pytorch_model.bin', max=1684058277.0,…




HBox(children=(FloatProgress(value=0.0, description='models/ProtBert-BFD/config.json', max=313.0, style=Progre…




HBox(children=(FloatProgress(value=0.0, description='models/ProtBert-BFD/vocab.txt', max=81.0, style=ProgressS…




<b>4. Load the vocabulary and ProtBert-BFD Model</b>

In [9]:
tokenizer = BertTokenizer(vocabFilePath, do_lower_case=False )

In [10]:
model = BertModel.from_pretrained(modelFolderPath)

<b>5. Load the model into the GPU if avilabile and switch to inference mode<b>

In [11]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [12]:
model = model.to(device)
model = model.eval()

<b>6. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)<b>

In [13]:
sequences_Example = ["A E T C Z A O","S K T Z P"]

In [14]:
sequences_Example = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences_Example]

<b>7. Tokenize, encode sequences and load it into the GPU if possibile<b>

In [15]:
ids = tokenizer.batch_encode_plus(sequences_Example, add_special_tokens=True, pad_to_max_length=True)

In [16]:
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

<b>8. Extracting sequences' features and load it into the CPU if needed<b>

In [17]:
with torch.no_grad():
    embedding = model(input_ids=input_ids,attention_mask=attention_mask)[0]

In [18]:
embedding = embedding.cpu().numpy()

<b>9. Remove padding ([PAD]) and special tokens ([CLS],[SEP]) that is added by ProtBert-BFD model<b>

In [19]:
features = [] 
for seq_num in range(len(embedding)):
    seq_len = (attention_mask[seq_num] == 1).sum()
    seq_emd = embedding[seq_num][1:seq_len-1]
    features.append(seq_emd)

In [20]:
print(features)

[array([[ 0.05551133, -0.10461304, -0.03253962, ...,  0.05091606,
         0.04318975,  0.10181108],
       [ 0.13895561, -0.046583  ,  0.02193631, ...,  0.06942613,
         0.14762992,  0.06503808],
       [ 0.14610603, -0.08092842, -0.12500416, ..., -0.03651231,
         0.02485525,  0.07977536],
       ...,
       [ 0.02349902, -0.01549769, -0.05685329, ..., -0.01342281,
         0.01704315,  0.06431052],
       [ 0.08129995, -0.1092955 , -0.03022903, ...,  0.08717731,
         0.02061446,  0.05156654],
       [ 0.06197417, -0.06417818, -0.02039655, ..., -0.02796507,
         0.0884005 ,  0.07532689]], dtype=float32), array([[-0.06304268, -0.23687428, -0.07115868, ..., -0.03852162,
        -0.00322069, -0.05244054],
       [ 0.01905588, -0.105173  , -0.02930211, ..., -0.00238627,
        -0.09289714,  0.02722595],
       [ 0.07721861, -0.1703198 , -0.13987812, ..., -0.08390203,
         0.03587941, -0.01317161],
       [ 0.00872737, -0.1771819 , -0.05856298, ..., -0.09918059,
     