<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Embedding/PyTorch/Advanced/ProtXLNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3> Extracting protein sequences' features using ProtXLNet pretrained-model <h3>

<b>1. Load necessry libraries including huggingface transformers<b>

In [1]:
!pip install -q transformers sentencePiece

[K     |████████████████████████████████| 2.6 MB 6.7 MB/s 
[K     |████████████████████████████████| 1.2 MB 55.7 MB/s 
[K     |████████████████████████████████| 636 kB 53.2 MB/s 
[K     |████████████████████████████████| 895 kB 34.8 MB/s 
[K     |████████████████████████████████| 3.3 MB 44.1 MB/s 
[?25h

In [2]:
import torch
from transformers import XLNetModel, XLNetTokenizer
import re
import os
import requests
from tqdm.auto import tqdm

<b>2. Load the vocabulary and ProtXLNet Model<b>

In [3]:
tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False)

Downloading:   0%|          | 0.00/238k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

In [4]:
xlnet_men_len = 512

In [5]:
model = XLNetModel.from_pretrained("Rostlab/prot_xlnet",mem_len=xlnet_men_len)

Downloading:   0%|          | 0.00/1.64G [00:00<?, ?B/s]

Some weights of the model checkpoint at Rostlab/prot_xlnet were not used when initializing XLNetModel: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<b>3. Load the model into the GPU if avilabile and switch to inference mode<b>

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [7]:
model = model.to(device)
model = model.eval()

<b>4. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (unk)<b>

In [8]:
sequences_Example = ["A E T C Z A O","S K T Z P"]

In [9]:
sequences_Example = [re.sub(r"[UZOBX]", "<unk>", sequence) for sequence in sequences_Example]

<b>5. Tokenize, encode sequences and load it into the GPU if possibile<b>

In [10]:
ids = tokenizer.batch_encode_plus(sequences_Example, add_special_tokens=True, pad_to_max_length=True)



In [11]:
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

<b>6. Extracting sequences' features and load it into the CPU if needed<b>

In [12]:
with torch.no_grad():
    output = model(input_ids=input_ids,attention_mask=attention_mask,mems=None)
    embedding = output.last_hidden_state
    mems = output.mems

In [13]:
embedding = embedding.cpu().numpy()

<b>7. Remove padding ([PAD]) and special tokens ([CLS],[SEP]) that is added by ProtXLNet model<b>

In [14]:
features = [] 
for seq_num in range(len(embedding)):
    seq_len = (attention_mask[seq_num] == 1).sum()
    padded_seq_len = len(attention_mask[seq_num])
    seq_emd = embedding[seq_num][padded_seq_len-seq_len:padded_seq_len-2]
    features.append(seq_emd)

In [15]:
print(features)

[array([[ 0.4874513 , -0.7708785 ,  0.9900176 , ..., -0.37356144,
        -1.0589752 ,  0.9559978 ],
       [ 0.21325906, -0.54788995,  0.611549  , ..., -0.05471378,
        -0.8787889 ,  0.24645127],
       [ 0.37891322, -0.63965786,  0.67224425, ..., -0.14891209,
        -0.6769571 ,  0.34598204],
       ...,
       [ 0.09265018, -0.68101174,  0.5181867 , ..., -0.32756284,
        -0.56731117, -0.1644125 ],
       [-0.08533233, -0.7438228 ,  0.29890773, ..., -0.24376427,
        -0.1115306 , -0.7260989 ],
       [-0.48225865, -0.8381691 ,  0.08214396, ..., -0.25196335,
        -0.03577377, -0.5348947 ]], dtype=float32), array([[ 1.040305  , -0.95049226,  0.33534572, ..., -0.23747356,
        -0.2755075 ,  0.4794891 ],
       [ 0.50853604, -0.95084566,  1.0235122 , ..., -0.05893029,
        -0.97527987,  0.08713222],
       [ 0.65626603, -0.7540609 ,  0.43234858, ...,  0.367075  ,
        -0.70632464, -0.55534136],
       [ 0.5903382 ,  0.14152074,  0.295789  , ...,  0.2060117 ,
     