In [63]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
torch.cuda.is_available()

False

In [64]:
from __future__ import print_function
from __future__ import division
import os
GPU = False
if GPU:
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
import sys
import torch
import tqdm
import pdb
import numpy as np
import platform
import hashlib
import pytorch_transformer
import re
import argparse
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import torch.nn.functional as F
#from torch.utils.tensorboard import SummaryWriter
from transformProtein import transformProtein
from ProteinDataset_uid import ProteinDataset
from torch.utils.data import Dataset, DataLoader
import pickle
import time
import matplotlib.pyplot as plt

load_model_path = 'ckpt/' # just the folder itself

seq_length = 511
embedding_dim = 1280
num_layers = 36
vocab_loc = 'mapping_files/vocab.txt'

use_py3 = platform.python_version()[0] == '3'
vocab = open(vocab_loc).readlines() if not use_py3 else open(vocab_loc, encoding='utf-8').read().split('\n')[:-1]
vocab = list(map(lambda x: x.split(' ')[0], vocab))
vocab_size = len(vocab)
print('-----vocab size',vocab_size,'------')

class TiedEmbeddingSoftmax(torch.nn.Module):

    def __init__(self, vocab_size=vocab_size, embedding_size=embedding_dim, **kwargs):
        super(TiedEmbeddingSoftmax, self).__init__()
        self.w = torch.nn.Parameter(torch.normal(0., 1e-2, size=(vocab_size, embedding_size)))
        self.b = torch.nn.Parameter(torch.zeros(vocab_size))

    def forward(self, inputs, embed=True):
        if embed:
            return torch.nn.functional.embedding(inputs, self.w)
        else:
            return torch.tensordot(inputs, self.w.t(), 1) + self.b

class CTRLmodel(torch.nn.Module):
    def __init__(self):
        super(CTRLmodel,self).__init__()
        self.tied_embedding_softmax = TiedEmbeddingSoftmax()
        self.encoder = pytorch_transformer.Encoder()

    def forward(self, inputs):
        x = self.tied_embedding_softmax(inputs, embed = True)
        # print(x.shape)
        x = self.encoder(x)
        x = self.tied_embedding_softmax(x, embed = False)
        return x

    def loadCheckpoint(self, model_path, num_layers):
        if os.path.exists(model_path):
            print('Found PyTorch checkpoint at ', model_path)
            print('Loading instead of converting from TensorFlow')
            if GPU:
                checkpoint = torch.load(model_path, map_location=torch.device('cuda:0'))
            else: 
                checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
            self.tied_embedding_softmax.load_state_dict({
                'w': checkpoint.pop('tied_embedding_softmax.w', None),
                'b': checkpoint.pop('tied_embedding_softmax.b', None)
            })
            self.encoder.load_state_dict({key.replace("encoder.", ""): value for key, value in checkpoint.items()})
            if GPU:
                self.tied_embedding_softmax.to('cuda')
                self.encoder.to('cuda')
        else:
            print('Could not find PyTorch checkpoint')
            sys.exit()

model = CTRLmodel()
print('model initialized')

curr_model_path = load_model_path+'pretrain_progen_full.pth'
reader = model.loadCheckpoint(model_path=curr_model_path, num_layers = num_layers)
if GPU: 
    model = model.cuda()
print('previous checkpoint loaded')

optimizer = torch.optim.Adam(model.parameters()) #lr, betas

model.eval()

with open(os.path.join('mapping_files/','taxa_to_lineage.p'),'rb') as handle:
    taxa_to_lineage = pickle.load(handle)
with open('mapping_files/taxa_to_ctrl_idx.p','rb') as handle:
    taxa_to_ctrl_idx = pickle.load(handle)
with open('mapping_files/kw_to_ctrl_idx.p','rb') as handle:
    kw_to_ctrl_idx = pickle.load(handle)
with open('mapping_files/aa_to_ctrl_idx.p','rb') as handle:
    aa_to_ctrl_idx = pickle.load(handle)
    
with open('mapping_files/kw_to_name.p2','rb') as handle:
    kw_to_name = pickle.load(handle)
#with open('mapping_files/taxid_to_name.p2','rb') as handle:
#    taxid_to_name = pickle.load(handle)
    
def flipdict(my_map):
    return {v: k for k, v in my_map.items()}
ctrl_idx_to_aa = flipdict(aa_to_ctrl_idx)
ctrl_idx_to_kw = flipdict(kw_to_ctrl_idx)
ctrl_idx_to_taxa = flipdict(taxa_to_ctrl_idx)

def predict_fn(inputs):
    with torch.no_grad():
        if GPU:
            inputs = torch.tensor(inputs).cuda()
        else: 
            inputs = torch.tensor(inputs)
        output = model(inputs)
        output = output[:,:,-26:-1] # remove non-AA token logits
        return output

-----vocab size 129407 ------
MODEL SIZE: 
1280
model initialized
Found PyTorch checkpoint at  ckpt/pretrain_progen_full.pth
Loading instead of converting from TensorFlow
previous checkpoint loaded


In [66]:
# paths to the saved .p files
query = "PF16754"
random_selection_file = os.path.join("data", "random_selection_" + query + ".p")
data_file = os.path.join("data", "filtered_data_" + query + ".p")

# Reload random_selection from the .p file
random_selection = False
with open(random_selection_file, "rb") as file:
    random_selection = pickle.load(file)

# Reload filtered_data from the .p file
#filtered_data = False
#with open(data_file, "rb") as file:
#    filtered_data = pickle.load(file)

# Print or use the reloaded datasets as needed
print("Random Selection:")
for entry in random_selection:
    print(entry["sequence"])
    break

#print("---")

#print("Filtered Data:")
#for entry in filtered_data:
#    print(entry["sequence"])'''


Random Selection:
{'metadata': {'accession': 'A0A6I1N9T3', 'id': 'A0A6I1N9T3_9PSED', 'source_organism': {'taxId': '2651296', 'scientificName': 'Pseudomonas sp. MWU12-2323', 'fullName': 'Pseudomonas sp. MWU12-2323'}, 'name': 'Pesticin C-terminal domain-containing protein', 'description': None, 'length': 211, 'sequence': 'MITIMMNFHNYKIITHQKPGQITFNAEGNDIPGSPYYSRHIHWPGNDLSGVTIGRGYDMGTRSQSEIHNHMLAAGIPHAQASRLAEAAGLKGSQAAQFVNNYRTSIGDITHQQEQALFDLIYPFYIDRAIANYNKWTENLAGRQPWESLHPIIRDILVDFVYQGFTAGPNPMNAGMKNNFSELISYIENTPAISQYEPGRQRANYLRKYQQ', 'proteome': None, 'gene': 'GC387_04435', 'go_terms': [{'identifier': 'GO:0003796', 'name': 'lysozyme activity', 'category': {'code': 'F', 'name': 'molecular_function'}}], 'protein_evidence': 4, 'source_database': 'unreviewed', 'is_fragment': False, 'ida_accession': 'bee0eb4d73215280428d1e3207c1319e119aec22', 'counters': {'domain_architectures': 504, 'entries': 5, 'isoforms': 0, 'proteomes': 0, 'sets': 2, 'structures': 0, 'taxa': 1, 'dbEntries': {'cdd': 1, 'pfa

<p>To evaluate the generated sequence using teacher forcing and 3-grams and 5-grams:

1. Generate the Sequence: Use your transformer model with teacher forcing to generate the complete sequence. 
At each time step, provide the true input sequence (+ 1 versus previous step), as input to the model.

2. Extract n-grams: Once you have the generated sequence, extract the 3-grams and 5-grams from it. 
With slide a window of size 3 or 5 along the generated sequence and extract the corresponding n-grams.

3. Compare with Test 3-grams and 5-grams: Retrieve the test 3-grams and 5-grams from your test dataset. 
These are the n-grams that you want to compare the generated sequence against.

4. Calculate Metrics: Compare the generated n-grams with the test n-grams to evaluate their similarity. 
You can use various metrics such as precision, recall, or the SAE (Sum of Absolute Errors) to quantify 
the similarity or dissimilarity.

- For precision and recall, you can calculate how many of the generated n-grams match the test
n-grams and divide it by the total number of generated or test n-grams.

- For the SAE, you can calculate the absolute difference between the frequency of each n-gram 
in the generated sequence and the frequency of the same n-gram in the test sequence. 
Sum up these absolute differences to obtain the final SAE value.
</p>

In [68]:
# Da implementare
def top_p_sampling():
    pass

In [69]:
def teacher_forcing_generation(input_sequence, penalty, topk):
    key_len = 0
    res = ""
    for i in range(1, len(input_sequence)):
        iteration_input_prefix = input_sequence[:i]
        seed_seq = [aa_to_ctrl_idx[ii] for ii in iteration_input_prefix]
        generate_num = len(iteration_input_prefix) + 1 # how many tokens to generate
        padded_text = seed_seq + [0] * (generate_num - len(seed_seq))
        tokens_generated = np.tile(padded_text, (1,1))
        for token in range(len(seed_seq)-1, generate_num-1):
            # print(tokens_generated[:, :seq_length].shape)
            prompt_logits = predict_fn(tokens_generated[:, :seq_length]).squeeze()
            _token = token if token < seq_length else -1
            if GPU:
                prompt_logits = prompt_logits.gpu().detach().numpy()
            else:
                prompt_logits = prompt_logits.cpu().detach().numpy()
                
            if penalty>0:
                penalized_so_far = set()
                # variable token_flag for first amminoacids (to count them if they are less that 4)
                if token >= key_len + 3:
                    token_flag = 3  
                elif token - key_len - 3 <= 0:
                    token_flag = 0
                else:
                    token_flag = token
                #print(key_len)
                #print(token)
                #print(token_flag)
                for _ in range(token-token_flag,token+1):
                    generated_token = tokens_generated[0][_] - (vocab_size-26) # added
                    if generated_token in penalized_so_far:
                        continue
                    penalized_so_far.add(generated_token)
                    prompt_logits[_token][generated_token] /= penalty
            
            # compute probabilities from logits
            prompt_probs = np.exp(prompt_logits[_token])
            prompt_probs = prompt_probs / sum(prompt_probs)
            pruned_list = np.argsort(prompt_probs)[::-1]
            
            if topk==1:
                idx = pruned_list[0]
            else:
                pruned_list = pruned_list[:topk]
                chosen_idx = torch.distributions.categorical.Categorical(logits=torch.tensor(np.expand_dims(prompt_logits[_token][pruned_list],0))).sample().numpy()[0]
                idx = pruned_list[chosen_idx]
            # assign the token for generation
            idx += (vocab_size-26) # added to convert 0 AA to original ctrl idx
            tokens_generated[0][token+1] = idx
        tokens_generated_so_far = tokens_generated[0].squeeze()[:token+2]
        tokens_generated_so_far = tokens_generated_so_far[(tokens_generated_so_far>=(vocab_size-26)) & (tokens_generated_so_far<(vocab_size-1))]
        tokens_generated_so_far = ''.join([ctrl_idx_to_aa[c] for c in tokens_generated_so_far])
        query = tokens_generated_so_far[len(seed_seq):]
        res += query
    return res


In [88]:
import pickle
print("on data: ", query)
penalty = 1.2
topk = 3
predicted = []
true_value = []
for entry in random_selection:
    input_seq = entry['sequence']['metadata']['sequence']
    print("Input: ", input_seq)
    res = teacher_forcing_generation(input_seq, penalty, topk)
    print("Res: ", res)
    true_value.append(input_seq)
    predicted.append(input_seq[0] + res)
    break


# Create a directory named "ID_test_data" in the current working directory if it doesn't exist
data_dir = "ID_test_data"
os.makedirs(data_dir, exist_ok=True)

# Save ID_test_data as a .p file
predicted_data_file = os.path.join(data_dir, "predicted_data_" + query + ".p")
with open(predicted_data_file, "wb") as file:
    pickle.dump(predicted, file)
    
# Save ID_test_data as a .p file
true_data_file = os.path.join(data_dir, "true_data_" + query + ".p")
with open(true_data_file, "wb") as file:
    pickle.dump(true_value, file)

on data:  PF16754
Input:  MITIMMNFHNYKIITHQKPGQITFNAEGNDIPGSPYYSRHIHWPGNDLSGVTIGRGYDMGTRSQSEIHNHMLAAGIPHAQASRLAEAAGLKGSQAAQFVNNYRTSIGDITHQQEQALFDLIYPFYIDRAIANYNKWTENLAGRQPWESLHPIIRDILVDFVYQGFTAGPNPMNAGMKNNFSELISYIENTPAISQYEPGRQRANYLRKYQQ
Res:  AKARSLKSLIIISSFKNNSIIVSSKKIKVYSLLLLTVINLESGLSISTGNTLSSSTLVSTVLILILALLLGSIRITVEGVLELIGQIGLSPAVVVLLLLDAGALGELLSAALAALLRDQGAADGRDGVLSLGRDAQDAPSDKSILSKAADDTDEVANRAPKQNHLASYLSALGLSLSIPGGLIAAVKKGGDGNLLLALQIALILLSEIID


<h3>Evaluations</h3>

In [101]:
# Da implementare compute_perplexity
from nltk import ngrams
from sklearn.metrics import accuracy_score
import math
import blosum as bl
#  BLOSUM62 is the matrix built using sequences with less than 62% similarity 
# (sequences with ≥ 62% identity were clustered)

def compute_hard_accuracy(true_string, predicted_string):
    # percentage of predictions that exactly match the ground truth labels
    total_characters = len(true_string)
    assert total_characters == len(predicted_string)
    # Convert strings to lists of characters
    true_list = list(true_string)
    predicted_list = list(predicted_string)
    # Compute hard accuracy using accuracy_score function
    hard_accuracy = accuracy_score(true_list, predicted_list)
    return hard_accuracy

def compute_all_hard_accuracies(true_sequences, predicted_sequences):
    num_seq = len(true_sequences)
    assert num_seq == len(predicted_sequences)
    accuracy_all = []
    for i in range(num_seq):
        tmp_acc = compute_hard_accuracy(true_sequences[i], predicted_sequences[i])
        accuracy_all.append(tmp_acc)
    result = sum(accuracy_all) / num_seq
    return result

def compute_soft_accuracies(true_sequences, predicted_sequences, threshold = 0):
    # soft accuracy takes into account partial matches or similarities 
    # between the predictions and the ground truth labels
    # we need to define what is a partial match!!
    bl_matrix = bl.BLOSUM(62)
    num_seq = len(true_sequences)
    assert num_seq == len(predicted_sequences)
    accuracy_all = []
    for i in range(num_seq):
        tmp_acc = calculate_soft_accuracy(true_sequences[i], predicted_sequences[i], bl_matrix, threshold)
        accuracy_all.append(tmp_acc)
    result = sum(accuracy_all) / num_seq
    return result

def calculate_soft_accuracy(true_sequence, predicted_sequence, bl_matrix, threshold = 0):
    total_characters = len(true_sequence)
    assert total_characters == len(predicted_sequence)
    partial_matches = 0
    for true_aa, predicted_aa in zip(true_sequence, predicted_sequence):
        similarity_score = bl_matrix[true_aa][predicted_aa]
        # print(similarity_score, true_aa, predicted_aa)
        if similarity_score >= threshold:
            partial_matches += 1
    soft_accuracy = partial_matches / total_characters
    return soft_accuracy


def compute_ngrams(sequence, n):
    return list(ngrams(sequence, n))

def compute_perplexity(true_sequence, predicted_sequence, n=3):
    true_ngrams = compute_ngrams(true_sequence, n)
    predicted_ngrams = compute_ngrams(predicted_sequence, n)
    total_ngrams = len(true_ngrams)
    assert total_ngrams == len(predicted_ngrams)
    
    total_log_probability = 0.0
    print(predicted_ngrams)
    print(true_ngrams)ù
    for ngram in predicted_ngrams:
        if ngram in true_ngrams:
            print("we")
            probability = 1.0 / total_ngrams
            total_log_probability += math.log(probability)
        else:
            total_log_probability += float("-inf")  # Set the log probability to negative infinity
    
    perplexity = math.exp(-total_log_probability / total_ngrams)
    return perplexity


In [102]:
# DEVELOPMENT
from nltk import ngrams

def compute_ngrams(sequence, n):
    return list(ngrams(sequence, n))

# Example usage
true_sequence =      "MSRILLL"
predicted_sequence = "MSLLLLL"
perplexity = compute_perplexity(true_sequence, predicted_sequence, 3)
print("perplexity:", perplexity)


we
we
we
perplexity: inf


In [75]:
# DEVELOPMENT
true_sequences =      ['LS', 'LL', 'ASDFGHJ']
predicted_sequences = ['MS', 'HO', 'ASDFGHJ']
compute_soft_accuracies(true_sequences, predicted_sequences)

0.6666666666666666

In [76]:
# DEVELOPMENT
true_sequences =      ['LS', 'LLPL', 'ASDFGHJ'] # List of true sequences MITIMMNFHN
predicted_sequences = ['MS', 'LHKO', 'ASDFGHJ']  # List of predicted sequences

# Compute hard accuracy
accuracy = compute_all_hard_accuracies(true_sequences, predicted_sequences)

# Compute perplexity
perplexity = compute_perplexity(true_sequences, predicted_sequences, n=3)

# Compute soft accuracy
soft_accuracy = compute_soft_accuracies(true_sequences, predicted_sequences)

print(perplexity, accuracy, soft_accuracy)

None 0.5833333333333334 0.75


In [92]:
# paths to the saved .p files
query = "PF16754"
predicted_data_file = os.path.join("ID_test_data", "predicted_data_" + query + ".p")
true_data_file = os.path.join("ID_test_data", "true_data_" + query + ".p")

# Reload predicted_data from the .p file
predicted_data = False
with open(predicted_data_file, "rb") as file:
    predicted_data = pickle.load(file)

# Reload true_data from the .p file
true_data = False
with open(true_data_file, "rb") as file:
    true_data = pickle.load(file)

# Compute hard accuracy
accuracy = compute_all_hard_accuracies(true_data, predicted_data)

# Compute perplexity
perplexity = compute_perplexity(true_data, predicted_data, n=3)

# Compute soft accuracy
soft_accuracy = compute_soft_accuracies(true_data, predicted_data)

print(perplexity, accuracy, soft_accuracy)

None 0.08530805687203792 0.36018957345971564


In [90]:
print(true_data)
print(predicted_data)

['MITIMMNFHNYKIITHQKPGQITFNAEGNDIPGSPYYSRHIHWPGNDLSGVTIGRGYDMGTRSQSEIHNHMLAAGIPHAQASRLAEAAGLKGSQAAQFVNNYRTSIGDITHQQEQALFDLIYPFYIDRAIANYNKWTENLAGRQPWESLHPIIRDILVDFVYQGFTAGPNPMNAGMKNNFSELISYIENTPAISQYEPGRQRANYLRKYQQ']
['MAKARSLKSLIIISSFKNNSIIVSSKKIKVYSLLLLTVINLESGLSISTGNTLSSSTLVSTVLILILALLLGSIRITVEGVLELIGQIGLSPAVVVLLLLDAGALGELLSAALAALLRDQGAADGRDGVLSLGRDAQDAPSDKSILSKAADDTDEVANRAPKQNHLASYLSALGLSLSIPGGLIAAVKKGGDGNLLLALQIALILLSEIID']


<h3>Old Progen Generation Code:</h3>

In [34]:
print(kw_to_name[9])
taxid = 9606 # homo sapiens taxonomy id from NCBI: https://www.ncbi.nlm.nih.gov/taxonomy
tax_lineage = taxa_to_lineage[taxid] # make lineage in ncbi ids
print(tax_lineage)
tax_lineage = [taxa_to_ctrl_idx[ite] for ite in tax_lineage] # now translated as ctrl code indices
print(tax_lineage)

kw_lineage = [677,9] # UniprotKB keywords from https://www.uniprot.org/docs/keywlist
print(kw_lineage)
kw_lineage = [kw_to_ctrl_idx[ite] for ite in kw_lineage] # now translated to ctrl code indices
print(kw_lineage)


example_seq = 'YMIQEEEWDRDLLLDPAWEKQQRKTFTAWCNSHLRKAGTQIENIEEDFRNGLKLMLLLEVISGERLPKPDRGKMRFHKIANVNKALDYIASKGVKLVSIGAEEIVDGNVKMTLGMIWTIILRFAIQDISVEETSAKEGLLLWCQRKTAPYRNVNIQNFHTSWKDGLGLCALIHRHRPDLIDYSKLNKDDPIGNINLAMEIAEKHLDIPKMLDAEDIVNTPKPDERAIMTYVSCFYHAFAGAEQAETAANRICKVLAVNQENERLMEEYERLASELLEWIRRTIPWLENRTPAATMQAMQKKLEDFRDYRRKHKPPKVQEKCQLEINFNTLQTKLRISNRPAFMPSEGKMVSDIAGAWQRLEQAEKGYEEWLLNEIRRLERLEHLAEKFRQKASTHETWAYGKEQILLQKDYESASLTEVRALLRKHEAFESDLAAHQDRVEQIAAIAQELNELDYHDAVNVNDRCQKICDQWDRLGTLTQKRREALERMEKLLETIDQLHLEFAKRAAPFNNWMEGAMEDLQDMFIVHSIEEIQSLITAHEQFKATLPEADGERQSIMAIQNEVEKVIQSYNIRISSSNPYSTVTMDELRTKWDKVKQLVPIRDQSLQEELARQHANERLRRQFAAQANAIGPWIQNKMEEIARSSIQITGALEDQMNQLKQYEHNIINYKNNIDKLEGDHQLIQEALVFDNKHTNYTMEHIRVGWELLLTTIARTINEVETQILTRDAKGITQEQMNEFRASFNHFDRRKNGLMDHEDFRACLISMGYDLGEAEFARIMTLVDPNGQGTVTFQSFIDFMTRETADTDTAEQVIASFRILASDKPYILAEELRRELPPDQAQYCIKRMPAYSGPGSVPGALDYAAFSSALYGESDL'
prefix = example_seq[:3]
# prefix = ""
print("Prefix: ", prefix)
ref = example_seq[0:20]

print("Ref: ", ref)
penalty = 1.2
topk = 3

seed_seq = [aa_to_ctrl_idx[ii] for ii in prefix]
print("seed sequence: ", seed_seq)
# generate_num = len(kw_lineage+tax_lineage)+len(prefix+ref)
key_len = len(kw_lineage+tax_lineage)
generate_num = key_len +len(prefix+ref)
seq_length = min(generate_num, 511)

text = tax_lineage + kw_lineage + seed_seq
padded_text = text + [0] * (generate_num - len(text))
print(padded_text)
tokens_generated = np.tile(padded_text, (1,1))

#i = 1
for token in range(len(text)-1, generate_num-1):
    #print("counter: ", i)
    #i += 1
    print(tokens_generated[:, :seq_length].shape)
    prompt_logits = predict_fn(tokens_generated[:, :seq_length]).squeeze()
    _token = token if token < seq_length else -1
    if GPU:
        prompt_logits = prompt_logits.cpu().detach().numpy()
    else:
        prompt_logits = prompt_logits.gpu().detach().numpy()

    if penalty>0:
        penalized_so_far = set()
        # variable token_flag for first amminoacids (to count them if they are less that 4)
        if token >= key_len + 3:
            token_flag = 3  
        elif token - key_len - 3 <= 0:
            token_flag = 0
        else:
            token_flag = token
        #print(key_len)
        #print(token)
        #print(token_flag)
        for _ in range(token-token_flag,token+1):
            generated_token = tokens_generated[0][_] - (vocab_size-26) # added
            if generated_token in penalized_so_far:
                continue
            penalized_so_far.add(generated_token)
            prompt_logits[_token][generated_token] /= penalty

    # compute probabilities from logits
    prompt_probs = np.exp(prompt_logits[_token])
    prompt_probs = prompt_probs / sum(prompt_probs)
    pruned_list = np.argsort(prompt_probs)[::-1]

    if topk==1:
        idx = pruned_list[0]
    else:
        pruned_list = pruned_list[:topk]
        chosen_idx = torch.distributions.categorical.Categorical(logits=torch.tensor(np.expand_dims(prompt_logits[_token][pruned_list],0))).sample().numpy()[0]
        idx = pruned_list[chosen_idx]

    # assign the token for generation
    idx += (vocab_size-26) # added to convert 0 AA to original ctrl idx
    tokens_generated[0][token+1] = idx


tokens_generated_so_far = tokens_generated[0].squeeze()[:token+2]
tokens_generated_so_far = tokens_generated_so_far[(tokens_generated_so_far>=(vocab_size-26)) & (tokens_generated_so_far<(vocab_size-1))]
tokens_generated_so_far = ''.join([ctrl_idx_to_aa[c] for c in tokens_generated_so_far])

query = tokens_generated_so_far[len(seed_seq):]

print(prefix)
print(query)

Actin-binding
[33208, 7711, 40674, 9443, 9604, 9605, 9606]
[11177, 5756, 14034, 6957, 7068, 7069, 7070]
[677, 9]
[46, 258]
Prefix:  YMI
Ref:  YMIQEEEWDRDLLLDPAWEK
seed sequence:  [129404, 129392, 129389]
[11177, 5756, 14034, 6957, 7068, 7069, 7070, 46, 258, 129404, 129392, 129389, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
(1, 32)


RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got CPUBoolType instead (while checking arguments for embedding)