<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Embedding/PyTorch/Basic/ProtXLNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3> Extracting protein sequences' features using ProtXLNet pretrained-model <h3>

<b>1. Load necessry libraries including huggingface transformers<b>

In [1]:
!pip install -q transformers sentencepiece

[K     |████████████████████████████████| 2.6 MB 17.9 MB/s 
[K     |████████████████████████████████| 1.2 MB 56.8 MB/s 
[K     |████████████████████████████████| 895 kB 60.5 MB/s 
[K     |████████████████████████████████| 636 kB 67.8 MB/s 
[K     |████████████████████████████████| 3.3 MB 54.7 MB/s 
[?25h

In [2]:
import torch
from transformers import XLNetTokenizer, AutoModel, pipeline
import re
import numpy as np
import os
import requests
from tqdm.auto import tqdm

<b>2. Load the vocabulary and ProtXLNet Model<b>

In [3]:
tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False)

Downloading:   0%|          | 0.00/238k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

In [4]:
model = AutoModel.from_pretrained("Rostlab/prot_xlnet")

Downloading:   0%|          | 0.00/1.64G [00:00<?, ?B/s]

Some weights of the model checkpoint at Rostlab/prot_xlnet were not used when initializing XLNetModel: ['lm_loss.bias', 'lm_loss.weight']
- This IS expected if you are initializing XLNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<b>3. Load the model into the GPU if avilabile<b>

In [5]:
fe = pipeline('feature-extraction', model=model, tokenizer=tokenizer,device=0)

<b>4. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (unk)<b>

In [6]:
sequences_Example = ["A E T C Z A O","S K T Z P"]

In [7]:
sequences_Example = [re.sub(r"[UZOBX]", "<unk>", sequence) for sequence in sequences_Example]

<b>5. Extracting sequences' features and covert the output to numpy if needed<b>

In [8]:
embedding = fe(sequences_Example)

In [9]:
embedding = np.array(embedding)

In [10]:
print(embedding)

[[[ 5.69580197e-01 -8.12229991e-01  1.51267636e+00 ... -3.47371846e-01
   -1.97737467e+00  1.02282596e+00]
  [ 2.76625063e-02 -6.71196759e-01  9.98871863e-01 ...  7.27692693e-02
   -1.62625885e+00 -8.44621472e-03]
  [ 2.20986292e-01 -5.26815653e-01  6.64871275e-01 ...  4.78142016e-02
   -1.39787090e+00  3.08236480e-01]
  ...
  [-3.64926666e-01 -8.19322467e-01  4.81531680e-01 ...  2.35715330e-01
   -6.73881948e-01 -1.06030154e+00]
  [ 4.51355964e-01 -8.96943450e-01  4.00961459e-01 ... -1.93732351e-01
   -5.60827017e-01 -2.78551280e-01]
  [ 3.18279088e-01 -1.61193037e+00  4.94405985e-01 ... -2.51359075e-01
   -1.32739067e-01 -1.23081710e-02]]

 [[ 1.91230476e-01  1.84453130e-02 -1.82836526e-03 ... -4.36504066e-01
    2.18435768e-02 -1.59097388e-01]
  [ 2.63837665e-01 -6.02969602e-02 -1.12777390e-02 ... -2.28306949e-01
   -3.21158886e-01  1.10596269e-01]
  [ 8.65128040e-01 -1.61870003e-01 -1.75776944e-01 ...  3.56552243e-01
   -2.34118134e-01  4.93945815e-02]
  ...
  [ 4.84942645e-01  6.7

<b>Optional: Remove padding ([PAD]) and special tokens ([CLS],[SEP]) that is added by ProtXLNet model<b>

In [11]:
features = [] 
for seq_num in range(len(embedding)):
    seq_len = len(sequences_Example[seq_num].replace(" ", ""))
    padded_seq_len = len(embedding[seq_num])
    start_Idx = padded_seq_len-seq_len-2
    end_Idx = padded_seq_len-2
    seq_emd = embedding[seq_num][start_Idx:end_Idx]
    features.append(seq_emd)

In [12]:
print(features)

[array([[ 0.02766251, -0.67119676,  0.99887186, ...,  0.07276927,
        -1.62625885, -0.00844621],
       [ 0.22098629, -0.52681565,  0.66487128, ...,  0.0478142 ,
        -1.3978709 ,  0.30823648],
       [ 0.98757803, -1.0321219 ,  0.99680513, ..., -0.338559  ,
        -1.51521862,  1.05237138],
       [ 0.70799828, -0.664361  ,  0.85833752, ..., -0.02473303,
        -1.51670814, -0.21759628],
       [-0.14213684, -0.86483902,  0.8144275 , ..., -0.32999068,
        -0.23385352, -1.71954966],
       [-0.36492667, -0.81932247,  0.48153168, ...,  0.23571533,
        -0.67388195, -1.06030154]]), array([], shape=(0, 1024), dtype=float64)]
