<a href="https://colab.research.google.com/github/agemagician/Prot-Transformers/blob/master/Embedding/Advanced/Bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3> Extracting protein sequences' features using ProtBert pretrained-model <h3>

<b>1. Load necessry libraries including huggingface transformers<b>

In [3]:
!pip install -q transformers
!pip install -q git+https://github.com/wkentaro/gdown.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
  Building wheel for gdown (PEP 517) ... [?25l[?25hdone


In [4]:
import torch
from transformers import BertModel, BertTokenizer
import re
import os
import gdown

<b>2. Set the url location of ProtBert and the vocabulary file<b>

In [5]:
modelUrl = 'https://drive.google.com/uc?export=download&id=1mLuVhMuGTSSfVkK1rrmBSTC5yiCpGy--'
configUrl = 'https://drive.google.com/uc?export=download&id=1hg30JtXz6Okl0esJnMC2J_9TNbx5gJpl'
vocabUrl = 'https://drive.google.com/uc?export=download&id=15eFspbhoF5uUZ6xKKTAIGOunXYDeFruL'

<b>3. Download ProtBert models and vocabulary files<b>

In [6]:
downloadFolderPath = 'models/ProtBert/'

In [7]:
modelFolderPath = downloadFolderPath

modelFilePath = os.path.join(modelFolderPath, 'pytorch_model.bin')

configFilePath = os.path.join(modelFolderPath, 'config.json')

vocabFilePath = os.path.join(modelFolderPath, 'vocab.txt')

In [8]:
if not os.path.exists(modelFolderPath):
    os.makedirs(modelFolderPath)

In [9]:
def download_file(url,filename):
  while not os.path.exists(filename):
    gdown.download(url,filename, quiet=False)

In [10]:
if not os.path.exists(modelFilePath):
    download_file(modelUrl, modelFilePath)

if not os.path.exists(configFilePath):
    download_file(configUrl, configFilePath)

if not os.path.exists(vocabFilePath):
    download_file(vocabUrl, vocabFilePath)

Downloading...
From: https://drive.google.com/uc?export=download&id=1mLuVhMuGTSSfVkK1rrmBSTC5yiCpGy--
To: /content/models/ProtBert/pytorch_model.bin
1.68GB [00:19, 85.1MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&id=1hg30JtXz6Okl0esJnMC2J_9TNbx5gJpl
To: /content/models/ProtBert/config.json
100%|██████████| 313/313 [00:00<00:00, 468kB/s]
Downloading...
From: https://drive.google.com/uc?export=download&id=15eFspbhoF5uUZ6xKKTAIGOunXYDeFruL
To: /content/models/ProtBert/vocab.txt
100%|██████████| 81.0/81.0 [00:00<00:00, 133kB/s]


<b>4. Load the vocabulary and ProtBert Model</b>

In [11]:
tokenizer = BertTokenizer(vocabFilePath, do_lower_case=False )

In [12]:
model = BertModel.from_pretrained(modelFolderPath)

<b>5. Load the model into the GPU if avilabile and switch to inference mode<b>

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [14]:
model = model.to(device)
model = model.eval()

<b>6. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)<b>

In [15]:
sequences_Example = ["A E T C Z A O","S K T Z P"]

In [16]:
sequences_Example = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences_Example]

<b>7. Tokenize, encode sequences and load it into the GPU if possibile<b>

In [17]:
ids = tokenizer.batch_encode_plus(sequences_Example, add_special_tokens=True, pad_to_max_length=True)

In [18]:
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

<b>8. Extracting sequences' features and load it into the CPU if needed<b>

In [19]:
with torch.no_grad():
    embedding = model(input_ids=input_ids,attention_mask=attention_mask)[0]

In [20]:
embedding = embedding.cpu().numpy()

<b>9. Remove padding ([PAD]) and special tokens ([CLS],[SEP]) that is added by Bert model<b>

In [21]:
features = [] 
for seq_num in range(len(embedding)):
    seq_len = (attention_mask[seq_num] == 1).sum()
    seq_emd = embedding[seq_num][1:seq_len-1]
    features.append(seq_emd)

In [22]:
print(features)

[array([[ 0.09233951,  0.13914399, -0.05241546, ..., -0.13947977,
        -0.0428006 ,  0.07431364],
       [ 0.11505969,  0.01998261, -0.08627468, ..., -0.00946233,
        -0.1872724 ,  0.13169006],
       [ 0.04012022,  0.05467906, -0.0424584 , ...,  0.02054804,
        -0.03819017,  0.1011811 ],
       ...,
       [ 0.08853199,  0.0228724 , -0.05121122, ..., -0.0673336 ,
        -0.03013714,  0.04913732],
       [ 0.10792391,  0.09771117, -0.05832333, ..., -0.12765658,
        -0.06493297,  0.12894136],
       [ 0.0545795 ,  0.03635893, -0.07821195, ..., -0.03016043,
        -0.06015311,  0.08903892]], dtype=float32), array([[ 0.0880546 ,  0.10885102,  0.06212992, ...,  0.0144938 ,
        -0.03327068, -0.0362049 ],
       [-0.03332657,  0.03442192,  0.0809955 , ...,  0.0008738 ,
        -0.02733596,  0.05783575],
       [ 0.07896441,  0.10343076, -0.01508527, ...,  0.06711898,
        -0.01733149,  0.12912118],
       [ 0.01805276,  0.07072446, -0.03629201, ..., -0.03547218,
     