<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Embedding/TensorFlow/Advanced/ProtXLNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3> Extracting protein sequences' features using XLNet pretrained-model <h3>

<b>1. Load necessry libraries including huggingface transformers<b>

In [1]:
!pip install -q transformers sentencepiece

[K     |████████████████████████████████| 2.6 MB 4.2 MB/s 
[K     |████████████████████████████████| 1.2 MB 64.0 MB/s 
[K     |████████████████████████████████| 3.3 MB 45.2 MB/s 
[K     |████████████████████████████████| 636 kB 66.4 MB/s 
[K     |████████████████████████████████| 895 kB 60.8 MB/s 
[?25h

In [2]:
import tensorflow as tf
from transformers import TFXLNetModel, XLNetTokenizer,XLNetConfig
import re
import numpy as np

<b>2. Load the vocabulary and XLNet Model</b>

In [3]:
tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False )

Downloading:   0%|          | 0.00/238k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

In [4]:
xlnet_men_len = 512

In [5]:
model = TFXLNetModel.from_pretrained("Rostlab/prot_xlnet", mem_len=xlnet_men_len, from_pt=True)

Downloading:   0%|          | 0.00/1.64G [00:00<?, ?B/s]

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFXLNetModel: ['lm_loss.bias', 'lm_loss.weight']
- This IS expected if you are initializing TFXLNetModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFXLNetModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFXLNetModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFXLNetModel for predictions without further training.


<b>3. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)<b>

In [6]:
sequences_Example = ["A E T C Z A O","S K T Z P"]

In [7]:
sequences_Example = [re.sub(r"[UZOBX]", "<unk>", sequence) for sequence in sequences_Example]

<b>4. Tokenize, encode sequences and load it into the GPU if possibile<b>

In [8]:
ids = tokenizer.batch_encode_plus(sequences_Example, add_special_tokens=True, padding=True, return_tensors="tf")

In [9]:
input_ids = ids['input_ids']
attention_mask = ids['attention_mask']

<b>5. Extracting sequences' features and load it into the CPU if needed<b>

In [10]:
output = model(input_ids,attention_mask=attention_mask,mems=None)
embedding = output.last_hidden_state
memory = output.mems

In [11]:
embedding = np.asarray(embedding)

In [12]:
attention_mask = np.asarray(attention_mask)

<b>6. Remove padding ([PAD]) and special tokens that is added by XLNet model<b>

In [13]:
features = [] 
for seq_num in range(len(embedding)):
    seq_len = (attention_mask[seq_num] == 1).sum()
    padded_seq_len = len(attention_mask[seq_num])
    seq_emd = embedding[seq_num][padded_seq_len-seq_len:padded_seq_len-2]
    features.append(seq_emd)

In [14]:
print(features)

[array([[ 0.48745227, -0.77087843,  0.990017  , ..., -0.37356126,
        -1.0589763 ,  0.95599836],
       [ 0.21325974, -0.54789084,  0.6115488 , ..., -0.05471374,
        -0.8787888 ,  0.24645172],
       [ 0.3789134 , -0.63965833,  0.672244  , ..., -0.14891191,
        -0.67695737,  0.3459818 ],
       ...,
       [ 0.09265004, -0.6810122 ,  0.5181865 , ..., -0.32756284,
        -0.56731236, -0.1644126 ],
       [-0.08533108, -0.7438239 ,  0.29890803, ..., -0.24376416,
        -0.11153103, -0.7260992 ],
       [-0.48225644, -0.8381694 ,  0.08214393, ..., -0.25196418,
        -0.03577495, -0.5348944 ]], dtype=float32), array([[ 1.0403047 , -0.9504923 ,  0.33534396, ..., -0.23747206,
        -0.27550477,  0.47948766],
       [ 0.50853604, -0.9508453 ,  1.023512  , ..., -0.05892992,
        -0.9752786 ,  0.08713093],
       [ 0.6562658 , -0.7540612 ,  0.43234774, ...,  0.36707556,
        -0.7063242 , -0.5553413 ],
       [ 0.5903383 ,  0.14151943,  0.29578954, ...,  0.20601095,
     