<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Generate/ProtXLNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3> Generate protein sequences using ProtXLNet pretrained-model <h3>

<b>1. Load necessry libraries including huggingface transformers<b>

In [1]:
import torch
from transformers import XLNetLMHeadModel, XLNetTokenizer
import re
import os
import requests
import pandas as pd
from tqdm.auto import tqdm
from Bio import SeqIO

<b>2. Load the vocabulary and ProtXLNet Model<b>

In [2]:
tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False, sep_token = '')

In [3]:
model = XLNetLMHeadModel.from_pretrained("Rostlab/prot_xlnet")

<b>3. Load the model into the GPU if avilabile and switch to inference mode<b>

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [5]:
model = model.to(device)
model = model.eval()

<b>4. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (unk)<b>

Secuencia de ejemplo, esta secuencia se tomo de la lista de secuencias usadas para entrenar el modelo, se selecciono al azar buscando la etiqueta positiva para secuencias que atacan tumores.

In [6]:
sequences_Example = ['AAVALLPAVLLALLAPQLGKKKHRRRPSKKKRHW',"RRPKGRGKRRREKQRPSDKPRR", "KGRGKRRREKQRPSDAPAA", "AGRGARAAEAQRPSDKPRR", "KGRGARRREKQRPSDKPRR", "NGRKISLDLRAPLYKKIIKKLLES", "KGRGKRRREKQRPCDKPRR"
                     "KGRGKRAAEKQAPSDKPRR"]

In [7]:
sequences_Example = re.sub(r"[UZOB]", "<unk>", str(sequences_Example)) 

In [8]:
sequences_Example

"['AAVALLPAVLLALLAPQLGKKKHRRRPSKKKRHW', 'RRPKGRGKRRREKQRPSDKPRR', 'KGRGKRRREKQRPSDAPAA', 'AGRGARAAEAQRPSDKPRR', 'KGRGARRREKQRPSDKPRR', 'NGRKISLDLRAPLYKKIIKKLLES', 'KGRGKRRREKQRPCDKPRRKGRGKRAAEKQAPSDKPRR']"

<b>5. Tokenize, encode sequences and load it into the GPU if possibile<b>

In [9]:
ids = tokenizer.encode(sequences_Example, add_special_tokens=False)

In [10]:
input_ids = torch.tensor(ids).unsqueeze(0).to(device)

<b>6. Generate Protein Sequence<b>

In [12]:
max_length = 17 #longitud promedio de secuencias que atacan cancer 17.014
temperature = 3.0 #menor numero mas variedad pero menos exactitud
k = 5
p = 1
repetition_penalty = 5.0
num_return_sequences =3000

In [None]:
output_ids = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=k,
        top_p=p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        num_return_sequences=num_return_sequences,
    )

This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (-1). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


In [16]:
output_sequences = ["".join("".join(tokenizer.decode(output_id)).split()) for output_id in output_ids]

In [17]:
print('Generated Sequences\n')
for output_sequence in output_sequences:
  print(output_sequence)

Generated Sequences

XDHDNEMDQDHAHVQVH
XDEDNKHYNDNNHTHVD
XYKEHVQYTGSAQANGS
XDKENVHVYVCIFLIIF
XDYMKNDDEDVEIVVCF
XEINLFFCICKHYY<s>PK
XDNVIYYEKDSVYYVQP
XYYIHKNDNDFGGVVVG
XDKDEYYHVDENACIQQ
XNDVVEKIDGAVEKMVV
XEYKEDKNILHHQHRRW
XVVYFYMHVDVVIVYYT
XDKEHVYVYVCIFVFLF
XENKILVAYSCRMMIVT
XDEHVNIYMYFCLSWFM
XVVYIHVYV</s><pad><pad><pad><pad><pad><pad><pad>
XEKEIDVNEKNAYDVHQ
XNHDQEHYRKQEEVEEA
XDEKVVDAEMGHYYIVV
XNKEYYDECVEYDYKYH
XDKDEHAHVDKGTYADV
XVYMFMHEHVAHAQQHD
XDKNEHVYVDVIMCMFM
XYDVMWRRHGAMWLPLW
XDKEYAAPPVINKHQEQ
XEDEYVDINECTYTDID
XYYHEKVFFCCCSRGAT
XNEKDIHEKYVQDMCLF
XDYEHENVYVCMSVSMF
XVYYLHIDIVHYYLEFN
XNEKYDDHDGNHVIINK
XYVCYMCFMCYIYNDGS
XNKDHVYYDGADFVHVD
XDYMNK<mask>DYMYRPEQMM
XYYKTQIHVNVCVFVFM
XYKNHTSHCRINQSPLY
XNKEDHNVQGAQAHIPL
XVIVFMLNVVLHVAIVA
XYICTCVRYKYYYHYY.
XNKYDEDGVFHHRHQKN
XVIMIYIINLNINVFVC
XNEYKDVHKCEFQHVYT
XYIKDVYMHATSTHLPQ
XYITNHVQAAAGGSSTT
XYKHEQACWRPEVVVTS
XNKHEYDYHVHEDYIPT
XDHEVLNSHGVDLISAR
XNEKEHEDEYVNEHGHA
XDYVELDAYTEQDPCIN
XDHDENDKNYGNNGVSH
XYINKQHENTCWGSSTA
XVYMVVVAAGTTSNDGV
XDHDQ

In [18]:
import sys
sys.path.insert(0, '../')
from src.seq_cleanup import clean_seq

In [19]:
filtered_df = clean_seq(output_sequences)
filtered_df

0       DHDNEMDQDHAHVQVH
1       DEDNKHYNDNNHTHVD
2       YKEHVQYTGSAQANGS
3       DKENVHVYVCIFLIIF
4       DYMKNDDEDVEIVVCF
              ...       
2431    DYEDNAYGQHSYGHNH
2432    DKDEDKGKNEDMSQDV
2433    NHEKEYGDHVEQHDRV
2434    YKNEDEVVGRCYCCCF
2435    NEDNKVNAYDDHGGRG
Name: 0, Length: 2436, dtype: object

In [20]:
with open('../data/processed/generated_seqs.fasta','w', encoding='UTF8') as f:
    
    for i in range(len(filtered_df)):
        f.write(f">{i}\n{filtered_df[i]}\n")

In [21]:
import subprocess

# Define the WSL command buscar 
command = ["wsl", "cd-hit", "-i", "../data/processed/generated_seqs.fasta", "-o", "../data/processed/generated_seqs_cd_hit.txt", "-c", "0.99"]

# Run the command
result = subprocess.run(command, capture_output=True, text=True)

# Print the output and errors
print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)


Program: CD-HIT, V4.8.1 (+OpenMP), Aug 20 2021, 08:39:56
Command: cd-hit -i ../data/processed/generated_seqs.fasta -o
         ../data/processed/generated_seqs_cd_hit.txt -c 0.99

Started: Wed Feb 19 18:56:47 2025
                            Output                              
----------------------------------------------------------------
total seq: 2436
longest and shortest : 16 and 16
Total letters: 38976
Sequences have been sorted

Approximated minimal memory consumption:
Sequence        : 0M
Buffer          : 1 X 10M = 10M
Table           : 1 X 65M = 65M
Miscellaneous   : 0M
Total           : 76M

Table limit with the given memory limit:
Max number of representatives: 4000000
Max number of word counting entries: 90466733


comparing sequences from          0  to       2436
..
     2436  finished       2436  clusters

Approximated maximum memory consumption: 76M
writing new database
writing clustering information
program completed !

Total CPU time 0.17

STDERR: w s l :   A   l o

In [23]:
from preprocess import pfeature_process

In [24]:
cd_hit_path = '../data/processed/generated_seqs_cd_hit.txt'
generated_seqs = pfeature_process(cd_hit_path,'../data/processed/generated_seqs_pfeature.csv')