<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Embedding/PyTorch/Advanced/ProtT5-XL-BFD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3>Extracting protein sequences' features using ProtT5-XL-BFD pretrained-model</h3>

**1. Load necessry libraries including huggingface transformers**

In [1]:
!pip install -q transformers

[K     |████████████████████████████████| 1.1MB 8.0MB/s 
[K     |████████████████████████████████| 3.0MB 32.0MB/s 
[K     |████████████████████████████████| 1.1MB 52.8MB/s 
[K     |████████████████████████████████| 890kB 52.6MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [2]:
import torch
from transformers import T5Model, T5Tokenizer
import re
import numpy as np
import gc

<b>2. Load the vocabulary and ProtT5-XL-BFD Model<b>

In [3]:
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_bfd", do_lower_case=False )

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=237990.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1786.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=24.0, style=ProgressStyle(description_w…




In [4]:
model = T5Model.from_pretrained("Rostlab/prot_t5_xl_bfd")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=457.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=11276091454.0, style=ProgressStyle(desc…




Some weights of the model checkpoint at Rostlab/prot_t5_xl_bfd were not used when initializing T5Model: ['lm_head.weight']
- This IS expected if you are initializing T5Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing T5Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
gc.collect()

1083

<b>3. Load the model into the GPU if avilabile and switch to inference mode<b>

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [7]:
model = model.to(device)
model = model.eval()

<b>4. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)<b>

In [8]:
sequences_Example = ["A E T C Z A O","S K T Z P"]

In [9]:
sequences_Example = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences_Example]

<b>5. Tokenize, encode sequences and load it into the GPU if possibile<b>

In [10]:
ids = tokenizer.batch_encode_plus(sequences_Example, add_special_tokens=True, padding=True)

In [11]:
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

<b>6. Extracting sequences' features and load it into the CPU if needed<b>

In [12]:
with torch.no_grad():
    embedding = model(input_ids=input_ids,attention_mask=attention_mask,decoder_input_ids=input_ids)

In [13]:
# For feature extraction we recommend to use the encoder embedding
encoder_embedding = embedding[2].cpu().numpy()
decoder_embedding = embedding[0].cpu().numpy()

<b>7. Remove padding (\<pad\>) and special tokens (\</s\>) that is added by ProtT5-XL-BFD model<b>

In [14]:
features = [] 
for seq_num in range(len(encoder_embedding)):
    seq_len = (attention_mask[seq_num] == 1).sum()
    seq_emd = encoder_embedding[seq_num][:seq_len-1]
    features.append(seq_emd)

In [15]:
print(features)

[array([[ 0.29069245, -0.20804921, -0.21934746, ...,  0.338466  ,
         0.4603143 , -0.16942453],
       [ 0.20362847, -0.12995915, -0.1713986 , ...,  0.10642274,
        -0.37738234,  0.05401197],
       [ 0.03736701, -0.08506602, -0.23616499, ...,  0.25152597,
         0.1333861 ,  0.03063781],
       ...,
       [ 0.47787982, -0.23029418, -0.10682343, ...,  0.4057172 ,
         0.5251372 ,  0.19175273],
       [ 0.07990415,  0.07094704, -0.08554687, ...,  0.2494753 ,
         0.36538357, -0.45506278],
       [ 0.09110187,  0.17046836,  0.4207918 , ...,  0.25626653,
         0.02010932, -0.11016774]], dtype=float32), array([[ 0.28857464, -0.11107825, -0.13360332, ..., -0.06594545,
         0.00528727, -0.21770114],
       [ 0.13953976, -0.12703423,  0.06635726, ..., -0.02377458,
        -0.28750974,  0.09930851],
       [ 0.22406411,  0.0280987 ,  0.04245762, ...,  0.14432053,
         0.2020918 , -0.22417578],
       [ 0.5806171 , -0.16730028, -0.14636922, ...,  0.3601615 ,
     