<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Generate/ProtXLNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3> Generate protein sequences using ProtXLNet pretrained-model <h3>

<b>1. Load necessry libraries including huggingface transformers<b>

In [1]:
!pip install -q transformers

[K     |████████████████████████████████| 778kB 2.8MB/s 
[K     |████████████████████████████████| 3.0MB 10.0MB/s 
[K     |████████████████████████████████| 890kB 29.5MB/s 
[K     |████████████████████████████████| 1.1MB 39.9MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [2]:
import torch
from transformers import XLNetLMHeadModel, XLNetTokenizer,pipeline
import re
import os
import requests
from tqdm.auto import tqdm

<b>2. Set the url location of ProtXLNet and the vocabulary file<b>

In [3]:
modelUrl = 'https://www.dropbox.com/s/z0i0z01d2wm19ap/pytorch_model.bin?dl=1'
configUrl = 'https://www.dropbox.com/s/to876ivj48wylkj/config.json?dl=1'
tokenizerUrl = 'https://www.dropbox.com/s/mvypdtedpuz0yxg/spm_model.model?dl=1'

<b>3. Download ProtXLNet models and vocabulary files<b>

In [4]:
downloadFolderPath = 'models/ProtXLNet/'

In [5]:
modelFolderPath = downloadFolderPath

modelFilePath = os.path.join(modelFolderPath, 'pytorch_model.bin')

configFilePath = os.path.join(modelFolderPath, 'config.json')

tokenizerFilePath = os.path.join(modelFolderPath, 'spm_model.model')

In [6]:
if not os.path.exists(modelFolderPath):
    os.makedirs(modelFolderPath)

In [7]:
def download_file(url, filename):
  response = requests.get(url, stream=True)
  with tqdm.wrapattr(open(filename, "wb"), "write", miniters=1,
                    total=int(response.headers.get('content-length', 0)),
                    desc=filename) as fout:
      for chunk in response.iter_content(chunk_size=4096):
          fout.write(chunk)

In [8]:
if not os.path.exists(modelFilePath):
    download_file(modelUrl, modelFilePath)

if not os.path.exists(configFilePath):
    download_file(configUrl, configFilePath)

if not os.path.exists(tokenizerFilePath):
    download_file(tokenizerUrl, tokenizerFilePath)

HBox(children=(FloatProgress(value=0.0, description='models/ProtXLNet/pytorch_model.bin', max=1637757076.0, st…




HBox(children=(FloatProgress(value=0.0, description='models/ProtXLNet/config.json', max=1351.0, style=Progress…




HBox(children=(FloatProgress(value=0.0, description='models/ProtXLNet/spm_model.model', max=238192.0, style=Pr…




<b>4. Load the vocabulary and ProtXLNet Model<b>

In [9]:
tokenizer = XLNetTokenizer(tokenizerFilePath, do_lower_case=False)

In [10]:
model = XLNetLMHeadModel.from_pretrained(modelFolderPath)

<b>5. Load the model into the GPU if avilabile and switch to inference mode<b>

In [11]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [12]:
model = model.to(device)
model = model.eval()

<b>6. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)<b>

In [13]:
sequences_Example = "A E T C Z A O"

In [14]:
sequences_Example = re.sub(r"[UZOB]", "X", sequences_Example) 

<b>7. Tokenize, encode sequences and load it into the GPU if possibile<b>

In [15]:
ids = tokenizer.encode(sequences_Example, add_special_tokens=False)

In [16]:
input_ids = torch.tensor(ids).unsqueeze(0).to(device)

<b>8. Generate Protein Sequence<b>

In [17]:
max_length = 100
temperature = 1.0
k = 0
p = 0.9
repetition_penalty = 1.0
num_return_sequences = 3

In [18]:
output_ids = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=k,
        top_p=p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        num_return_sequences=num_return_sequences,
    )

In [19]:
output_sequences = [" ".join(" ".join(tokenizer.decode(output_id)).split()) for output_id in output_ids]

In [20]:
print('Generated Sequences\n')
for output_sequence in output_sequences:
  print(output_sequence)

Generated Sequences

A E T C X A X A A E T C L A K A A E T C L A K A A E T C L A K A A E T C L A K A A E T C L A K A A E T C L A K A A E T C L A K A A E T C L A K A A E T C L A K A A E T C L A K A A E T C L A K A A E T C
A E T C X A X L V S M T E A K G K G P I P R V A V G E E E D V T I L R G R M V D A D L K M D F D A Q R G V L A R K V V E Q I D S L T E V F L E V R I Q T S L D D R G S L G Q K K M V G F V G P A V M F A
A E T C X A X S G K R A D A R L T T A F F A L A V I G P E L D S T P A I A P L S A L G E E E D K G M T Q A G D F R L E T L I L R L V L V F S P V S A L A S S K P A R S I A L D G A S R R G L S P D G D D
