<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Generate/ProtXLNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3> Generate protein sequences using ProtXLNet pretrained-model <h3>

<b>1. Load necessry libraries including huggingface transformers<b>

In [5]:
import torch
from transformers import XLNetLMHeadModel, XLNetTokenizer
import re
import os
import requests
import pandas as pd
from tqdm.auto import tqdm
from Bio import SeqIO

<b>2. Load the vocabulary and ProtXLNet Model<b>

In [6]:
tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False)

In [7]:
model = XLNetLMHeadModel.from_pretrained("Rostlab/prot_xlnet")

<b>3. Load the model into the GPU if avilabile and switch to inference mode<b>

In [8]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [9]:
model = model.to(device)
model = model.eval()

<b>4. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (unk)<b>

Secuencia de ejemplo, esta secuencia se tomo de la lista de secuencias usadas para entrenar el modelo, se selecciono al azar buscando la etiqueta positiva para secuencias que atacan tumores.

In [10]:
sequences_Example = 'AAVALLPAVLLALLAPQLGKKKHRRRPSKKKRHW' 

In [11]:
sequences_Example = re.sub(r"[UZOB]", "<unk>", sequences_Example) 

<b>5. Tokenize, encode sequences and load it into the GPU if possibile<b>

In [12]:
ids = tokenizer.encode(sequences_Example, add_special_tokens=False)

In [13]:
input_ids = torch.tensor(ids).unsqueeze(0).to(device)

<b>6. Generate Protein Sequence<b>

In [14]:
max_length = 17 #longitud promedio de secuencias que atacan cancer 17.014
temperature = 3.0
k = 0
p = 1.0
repetition_penalty = 1.0
num_return_sequences = 2500

In [15]:
output_ids = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=k,
        top_p=p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        num_return_sequences=num_return_sequences,
    )

This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (-1). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


In [20]:
output_sequences = ["".join("".join(tokenizer.decode(output_id)).split()) for output_id in output_ids]

In [21]:
type(output_sequences)

list

In [22]:
print('Generated Sequences\n')
for output_sequence in output_sequences:
  print(output_sequence)

Generated Sequences

XLDLDDWYTVDRDAMSM
XKEAKEGATEWCPIVIN
XIHID(IYHIPDHN<cls>HH
XIYMYQNPQADYQKTVV
XNNDKCCEIISHNXMFM
XYYIENVMHVAMPMYYK
XYIFYNYYHC<sep>IYTTFY
XDPAMEFDNAEIIDDDD
XRLILLRGRLTFYRRFW
XYMMMDEMCQMYPTSQA
XY<sep>SHGHQXQHNXTITM
XDDDKDDEYMKMLKKKK
XMIQQQLQIRKLEMSKS
X–MVWSTVSLAVMSSRL
XLERGNRVEKRWCCCSR
XTCLQQEGRGRGRGQGQ
XNRQRNSNGVAMSGTAT
XVYV<eod>GMAPMAMGGAQF
XDENNESWCVDEHYSCW
XL<pad>XMDDHNYMNMDMCD
XEKQQVEQLQAGLEHAG
XKNKKDKKLNQVWKCYV
XMKLDD"KFLFDDHSFF
XYYYYIFAMSAIAHITH
XYYPQCRFMRXHPQPQP
XEYDYIPPPHHHRNKNI
XEYYAQIVITTKININV
XKQMNEMNEVMNTNINQ
XPDY)<sep>GYHDDYPVFFN
XYVYEMELAHAYENENA
XDRQDDVRGDHADGRPR
XYAADCDMRCDCDMHCN
XYYKKDDSLHYTDHTHQ
XEEEYKRDHYMFMFHMM
XYNITVVCYDITXA-FD
XHHRTRTRAAAAGVVAV
XSMLSSFAPP<mask>SIDNYY
XMDLCLVMDLWME.HPE
XKVKEAKV<sep>KIMDECFF
XDFKDVFDDTVQDDKTP
XYKYNCHPHQQ<eod>D–GYP
XMEGGHGDHEGGHSMMI
XEDGEEFLLLPP<mask>FFGP
XNETNHVKVNPVATSLN
XDDLLLSSDNDLIFIST
XMXHIMKIVGVTQQ£TS
XYNQKKNGNCNSCKKNL
XEMDDMIGNNGMQGDGV
XYFY<sep>K<sep>YDKWIGSKME
XYDSNMCFDSNMSICFD
XIXTDFNXIYDQLFGYN
XKVEIIMAQGK

In [23]:
import sys
sys.path.insert(0, '../')
from src.seq_cleanup import clean_seq

In [24]:
filtered_df = clean_seq(output_sequences)
filtered_df

0       LDLDDWYTVDRDAMSM
1       KEAKEGATEWCPIVIN
2       IYMYQNPQADYQKTVV
3       YYIENVMHVAMPMYYK
4       DPAMEFDNAEIIDDDD
              ...       
1360    YVYMMYYMYMVRMCHD
1361    DNKHYYDYDTKFNYVV
1362    WEHEQQHDNQDDGKDN
1363    YYCIMNKMTDKHFFAA
1364    ANADVAANVDMPYGYG
Name: 0, Length: 1365, dtype: object

In [26]:
filtered_df.index

RangeIndex(start=0, stop=1365, step=1)

In [25]:
with open('../data/processed/generated_seqs.fasta','w', encoding='UTF8') as f:
    
    for _, row in filtered_df.iterrows():
        f.write(f">{row.index}\n{row['Sequence']}\n")

AttributeError: 'Series' object has no attribute 'iterrows'