In [21]:
import os
import torch
import numpy as np
import pickle
import time
#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
from model_manager import CTRLmodel
from tokenizer import Tokenizer
from generation_manager import GeneratorManager

In [5]:
GPU = torch.cuda.is_available()
# if you don't want to use GPU (but you have it aviable) modify in pytorch_transformer.py module the related variable
print(GPU)

True


In [6]:
tokenizer = Tokenizer()
print('tokenizer loaded')

curr_model_path = 'ckpt/training_ckpt_4/model_only_state_dict_v0Last_lr0001.pth' # model checkpoint
print('loading model from:', curr_model_path)

model = CTRLmodel()
reader = model.loadCheckpoint(model_path=curr_model_path)

# Saving only state_dict:
#torch.save(model.state_dict(), model_state_dict_path)
#model_state_dict_path = load_model_path + 'model_only_state_dict_v0Last.pth'
#print('Model state dict saved')
            
if GPU:
    model = model.cuda()
    print('previous checkpoint loaded in GPU')
else: 
    print('previous checkpoint loaded in CPU')

optimizer = torch.optim.Adam(model.parameters()) #lr, betas
model.eval()
print('model ready for inference')

tokenizer loaded
loading model from: ckpt/training_ckpt_4/model_only_state_dict_v0Last_lr0001.pth
MODEL SIZE: 
1280
Found PyTorch checkpoint at  ckpt/training_ckpt_4/model_only_state_dict_v0Last_lr0001.pth
Loading instead of converting from TensorFlow
previous checkpoint loaded in GPU
model ready for inference


In [7]:
def save_sequences(predicted, predicted_stopped, molechule_name_and_params):
    data_dir = "Generation_PF00959_model_lr_00001_phage_specific"
    query = molechule_name_and_params
    os.makedirs(data_dir, exist_ok=True)
    
    predicted_data_file = os.path.join(data_dir, "predicted_data_" + query + ".p")
    with open(predicted_data_file, "wb") as file:
        pickle.dump(predicted, file)

    predicted_stopped_data_file = os.path.join(data_dir, "predicted_stopped_data_" + query + ".p")
    with open(predicted_stopped_data_file, "wb") as file:
        pickle.dump(predicted_stopped, file)
        

In [11]:
# GENERATION parameters
temperature = 0.9
penalty = 0
top_p = 0.5
np.random.seed(1337)
torch.manual_seed(1337)

def predict_fn(inputs):
    with torch.no_grad():
        inputs = torch.tensor(inputs)
        if GPU:
            inputs = inputs.cuda()            
        output = model(inputs)
        stop_token = output[:, :, 1] # the stop token logits
        output = output[:,:,-26:-1] # remove non-AA token logits
        return output, stop_token
        
generator = GeneratorManager(predict_fn, penalty=penalty, top_p=top_p, temperature=temperature)

In [12]:
def generate_and_save(top_p, seq_number_to_generate, protein_base, protein, protein_name):
    predicted = []
    predicted_stopped = []
    seq_number = 0
    input_seq = protein
    tax_lineage = [0]
    offset = int(protein_base * len(input_seq))
    # Adjust for zero-based indexing
    offset = offset - 1 if offset > 0 else 0
    #print(offset)
    print('generating sequences...')
    while seq_number < seq_number_to_generate:
        start_time = time.time()
        res, tokens_generated_stopped = generator.generation_complete_sequence(input_seq, offset, tax_lineage)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Done. Time taken: {elapsed_time} seconds.")
        predicted.append(input_seq[:offset] + res)
        if tokens_generated_stopped:
            for tmp_seq in tokens_generated_stopped:
                predicted_stopped.append(input_seq[:offset] + tmp_seq)
        seq_number += 1
        print('hello, this is seq_number', seq_number)

        # FOR TESTING
        break
        
        if ((seq_number%10) == 0) or (seq_number == (seq_number_to_generate-1)):
            # print('the model has generated: ', seq_number , ' sequences')
            print('saving sequences')
            description = protein_name + '_top_p_' + str(top_p) + 'seed_percentage_' + str(protein_base)
            save_sequences(predicted, predicted_stopped, description)
    print('GENERATION BATCH ENDED')
    

In [13]:
def read_fasta(file_path):
    """
    Reads a FASTA file with a single protein sequence.
    Args:
    file_path (str): The path to the FASTA file.
    Returns:
    tuple: A tuple containing the identifier and the protein sequence as strings.
    """
    with open(file_path, 'r') as file:
        lines = file.readlines()
        identifier = lines[0].strip()[4:21]  # Remove '>' and any trailing newline character
        sequence = ''.join(line.strip() for line in lines[1:])  # Concatenate the remaining lines
        return identifier, sequence
    

In [18]:
sequences_and_identifiers = []
fasta_dir = 'results_generation/FASTA_lysozymes_of_interest/'
for file in os.listdir(fasta_dir):
    sequences_and_identifiers.append((read_fasta(fasta_dir + file)))

In [19]:
for i, seq in sequences_and_identifiers:
    print(i)
    print(len(seq))
    print(seq)

Q37875|ENLYS_BPP1
185
MKGKTAAGGGAICAIAVMITIVMGNGNVRTNQAGLELIGNAEGCRRDPYMCPAGVWTDGIGNTHGVTPGVRKTDQQIAADWEKNILIAERCINQHFRGKDMPDNAFSAMTSAAFNMGCNSLRTYYSKARGMRVETSIHKWAQKGEWVNMCNHLPDFVNSNGVPLRGLKIRREKERQLCLTGLVNE
P03706|ENLYS_LAMB
158
MVEINNQRKAFLDMLAWSEGTDNGRQKTRNHGYDVIVGGELFTDYSDHPRKLVTLNPKLKSTGAGRYQLLSRWWDAYRKQLGLKDFSPKSQDAVALQQIKERGALPMIDRGDIRQAIDRCSNIWASLPGAGYGQFEHKADSLIAKFKEAGGTVREIDV
P78285|LYSD_ECOLI
165
MPPSLRKAVAAAIGGGAIAIASVLITGPSGNDGLEGVSYIPYKDIVGVWTVCHGHTGKDIMLGKTYTKAECKALLNKDLATVARQINPYIKVDIPETTRGALYSFVYNVGAGNFRTSTLLRKINQGDIKGACDQLRRWTYAGGKQWKGLMTRREIEREVCLWGQQ
P00720|ENLYS_BPT4
164
MNIFEMLRIDERLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNCNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRCALINMVFQMGETGVAGFTNSLRMLQQKRWDEAAVNLAKSIWYNQTPNRAKRVITTFRTGTWDAYKNL


In [22]:
top_ps = [0.25, 0.5, 0.75]
after_percentage = [0.25, 0.5, 0.75]
number_of_sequences_to_generate = 150 # 1 è un test: poi 200
for protein_name, protein in sequences_and_identifiers:
    for after_p in after_percentage:
        for p in top_ps:
            generate_and_save(p, number_of_sequences_to_generate, after_p, protein, protein_name)
            print(p, number_of_sequences_to_generate, after_p, protein, protein_name)
            print('------')

generating sequences...
Done. Time taken: 41.221362590789795 seconds.
hello, this is seq_number 1
GENERATION BATCH ENDED
0.25 150 0.25 MKGKTAAGGGAICAIAVMITIVMGNGNVRTNQAGLELIGNAEGCRRDPYMCPAGVWTDGIGNTHGVTPGVRKTDQQIAADWEKNILIAERCINQHFRGKDMPDNAFSAMTSAAFNMGCNSLRTYYSKARGMRVETSIHKWAQKGEWVNMCNHLPDFVNSNGVPLRGLKIRREKERQLCLTGLVNE Q37875|ENLYS_BPP1
------
generating sequences...
Done. Time taken: 41.494739294052124 seconds.
hello, this is seq_number 1
GENERATION BATCH ENDED
0.5 150 0.25 MKGKTAAGGGAICAIAVMITIVMGNGNVRTNQAGLELIGNAEGCRRDPYMCPAGVWTDGIGNTHGVTPGVRKTDQQIAADWEKNILIAERCINQHFRGKDMPDNAFSAMTSAAFNMGCNSLRTYYSKARGMRVETSIHKWAQKGEWVNMCNHLPDFVNSNGVPLRGLKIRREKERQLCLTGLVNE Q37875|ENLYS_BPP1
------
generating sequences...
Done. Time taken: 41.818920612335205 seconds.
hello, this is seq_number 1
GENERATION BATCH ENDED
0.75 150 0.25 MKGKTAAGGGAICAIAVMITIVMGNGNVRTNQAGLELIGNAEGCRRDPYMCPAGVWTDGIGNTHGVTPGVRKTDQQIAADWEKNILIAERCINQHFRGKDMPDNAFSAMTSAAFNMGCNSLRTYYSKARGMRVETSIHKWAQKGEWVNMCNHLPDFVNSNGVPLRGLKIRREKE

KeyboardInterrupt: 