In [1]:
from __future__ import print_function
from __future__ import division
import os
import torch
import sys
import tqdm
import pdb
import numpy as np
import platform
import hashlib
import pytorch_transformer
import re
import argparse
import torch.nn.functional as F
from transformProtein import transformProtein
from ProteinDataset_uid import ProteinDataset
from torch.utils.data import Dataset, DataLoader
import pickle
import time
import matplotlib.pyplot as plt
GPU = torch.cuda.is_available()
#if GPU:
#    torch.cuda.empty_cache()
# if you don't want to use GPU (but you have it) modify in pytorch_transformer.py module the related variable
print(GPU)

True


In [2]:
load_model_path = 'ckpt/training_ckpt_4/' # just the folder itself  
curr_model_path = load_model_path+'model_only_state_dict_v0Last_lr0001.pth'
seq_length = 511
embedding_dim = 1280
num_layers = 36
# GENERATION parameters
temperature = 0.9
penalty = 0
top_p = 0.5
np.random.seed(1337)
torch.manual_seed(1337)


vocab_loc = 'mapping_files/vocab.txt'
use_py3 = platform.python_version()[0] == '3'
vocab = open(vocab_loc).readlines() if not use_py3 else open(vocab_loc, encoding='utf-8').read().split('\n')[:-1]
vocab = list(map(lambda x: x.split(' ')[0], vocab))
vocab_size = len(vocab)
print('-----vocab size',vocab_size,'------')

-----vocab size 129407 ------


In [3]:
class TiedEmbeddingSoftmax(torch.nn.Module):
    def __init__(self, vocab_size=vocab_size, embedding_size=embedding_dim, **kwargs):
        super(TiedEmbeddingSoftmax, self).__init__()
        self.w = torch.nn.Parameter(torch.normal(0., 1e-2, size=(vocab_size, embedding_size)))
        self.b = torch.nn.Parameter(torch.zeros(vocab_size))

    def forward(self, inputs, embed=True):
        if embed:
            return torch.nn.functional.embedding(inputs, self.w)
        else:
            return torch.tensordot(inputs, self.w.t(), 1) + self.b

class CTRLmodel(torch.nn.Module):
    def __init__(self):
        super(CTRLmodel,self).__init__()
        self.tied_embedding_softmax = TiedEmbeddingSoftmax()
        self.encoder = pytorch_transformer.Encoder()

    def forward(self, inputs):
        x = self.tied_embedding_softmax(inputs, embed=True)
        x = self.encoder(x)
        x = self.tied_embedding_softmax(x, embed=False)
        return x

    def loadCheckpoint(self, model_path, num_layers):
        if os.path.exists(model_path):
            print('Found PyTorch checkpoint at ', model_path)
            print('Loading checkpoint...')
            checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
            if 'epoch' in checkpoint.keys():
                print(checkpoint.keys())
                checkpoint = checkpoint['model_state_dict']
            self.tied_embedding_softmax.load_state_dict({
                'w': checkpoint.pop('tied_embedding_softmax.w', None),
                'b': checkpoint.pop('tied_embedding_softmax.b', None)
            })
            self.encoder.load_state_dict({key.replace("encoder.", ""): value for key, value in checkpoint.items()})
        else:
            print('Could not find PyTorch checkpoint')
            sys.exit()

model = CTRLmodel()
print('model initialized')
print('loading model from: ', curr_model_path)
reader = model.loadCheckpoint(model_path=curr_model_path, num_layers = num_layers)
print('previous checkpoint loaded')
if GPU:
    model = model.cuda()
    print('previous checkpoint loaded in GPU')
optimizer = torch.optim.Adam(model.parameters()) #lr, betas
model.eval()

with open(os.path.join('mapping_files/','taxa_to_lineage.p'),'rb') as handle:
    taxa_to_lineage = pickle.load(handle)
with open('mapping_files/taxa_to_ctrl_idx.p','rb') as handle:
    taxa_to_ctrl_idx = pickle.load(handle)
with open('mapping_files/kw_to_ctrl_idx.p','rb') as handle:
    kw_to_ctrl_idx = pickle.load(handle)
with open('mapping_files/aa_to_ctrl_idx.p','rb') as handle:
    aa_to_ctrl_idx = pickle.load(handle)
with open('mapping_files/kw_to_name.p2','rb') as handle:
    kw_to_name = pickle.load(handle)
    
def flipdict(my_map):
    return {v: k for k, v in my_map.items()}
ctrl_idx_to_aa = flipdict(aa_to_ctrl_idx)
ctrl_idx_to_kw = flipdict(kw_to_ctrl_idx)
ctrl_idx_to_taxa = flipdict(taxa_to_ctrl_idx)

def predict_fn(inputs):
    with torch.no_grad():
        inputs = torch.tensor(inputs)
        if GPU:
            inputs = inputs.cuda()            
        output = model(inputs)
        stop_token = output[:, :, 1] # the stop token logits
        output = output[:,:,-26:-1] # remove non-AA token logits
        return output, stop_token
        

MODEL SIZE: 
1280
model initialized
loading model from:  ckpt/training_ckpt_4/model_only_state_dict_v0Last_lr0001.pth
Found PyTorch checkpoint at  ckpt/training_ckpt_4/model_only_state_dict_v0Last_lr0001.pth
Loading checkpoint...
previous checkpoint loaded
previous checkpoint loaded in GPU


In [4]:
def save_sequences(predicted, predicted_stopped):
    data_dir = "Generation_PF00959_model_lr_00001_LAST"
    query = 'After_0.50_percent_penalty_0_key_only_lyso'
    os.makedirs(data_dir, exist_ok=True)
    
    predicted_data_file = os.path.join(data_dir, "predicted_data_" + query + ".p")
    with open(predicted_data_file, "wb") as file:
        pickle.dump(predicted, file)

    predicted_stopped_data_file = os.path.join(data_dir, "predicted_stopped_data_" + query + ".p")
    with open(predicted_stopped_data_file, "wb") as file:
        pickle.dump(predicted_stopped, file)
        

In [5]:
def generation_complete_sequence(input_sequence, after_n, tax_lineage):
    res = ""
    res_stopped = []
    # tokens_prob = []
    key_len = len(tax_lineage) # len(kw_lineage+tax_lineage)
    i = after_n # if we have a sequence of amminoacids in input, we can add some in input as seed sequence
    iteration_input_prefix = input_sequence[:i]
    seed_seq = [aa_to_ctrl_idx[ii] for ii in iteration_input_prefix]
    generate_num = key_len + (len(input_sequence)) # how many tokens to generate
    if generate_num < 511:
        generate_num = 511
    seq_length = min(generate_num, 511)
    text = tax_lineage + seed_seq # tax_lineage + kw_lineage + seed_seq
    padded_text = text + [0] * (generate_num - len(text))
    tokens_generated = np.tile(padded_text, (1,1))
    for token in range(len(text)-1, generate_num-1):
        
        # prediction
        prompt_logits, stop_token = predict_fn(tokens_generated[:, :seq_length])
        prompt_logits = prompt_logits.squeeze()  / (temperature if temperature>0 else 1.)
        stop_token = stop_token.squeeze()  / (temperature if temperature>0 else 1.)
        
        _token = token if token < seq_length else -1
        prompt_logits = prompt_logits.cpu().detach().numpy()
        stop_token = stop_token.cpu().detach().numpy()
        
        # penalty
        if (penalty>0) and (token >= key_len + 3):
            penalized_so_far = set()
            for _ in range(token-3,token+1):
                generated_token = tokens_generated[0][_] - (vocab_size-26) # added
                if generated_token in penalized_so_far:
                    continue
                penalized_so_far.add(generated_token)
                prompt_logits[_token][generated_token] /= penalty  

        # compute probabilities from logits
        prompt_probs = np.exp(prompt_logits[_token])
        prompt_probs = prompt_probs / sum(prompt_probs)

        # ESTRARRE TOKEN 1: the stop token, softmax con le probabilità degli amminoacidi
        logits_and_stop = np.concatenate((prompt_logits[_token], [stop_token[_token]]))
        logits_and_stop_prob = np.exp(logits_and_stop)
        logits_and_stop_prob = logits_and_stop_prob / sum(logits_and_stop_prob)

        if logits_and_stop_prob[-1] >= 0.50:
            tokens_generated_stopped = tokens_generated[0][len(seed_seq) + key_len:_token + 1]
            tokens_generated_stopped = ''.join([ctrl_idx_to_aa[c] for c in tokens_generated_stopped])
            res_stopped.append(tokens_generated_stopped)
        
        pruned_list = np.argsort(prompt_probs)[::-1]
        # tokens_prob.append([prompt_probs.tolist()])

        if top_p==1:
            idx = pruned_list[0]
        else:
            # Sort the probabilities
            sorted_probs, sorted_indices = torch.sort(torch.tensor(prompt_probs), descending=True)
            # Calculate cumulative probs
            cum_probs = torch.cumsum(sorted_probs, dim=0)
            # Get the set of tokens whose cumulative probability is less than or equal to p
            valid_indices = sorted_indices[cum_probs <= top_p]
            # If no token's cumulative probability is less than the threshold, just select the top token
            if valid_indices.size(0) == 0:
                valid_indices = sorted_indices[:1]
            # Sample from the valid indices
            idx = valid_indices[torch.randint(0, valid_indices.size(0), (1,))].item()
        # assign the token for generation
        idx += (vocab_size-26) # added to convert 0 AA to original ctrl idx
        tokens_generated[0][token+1] = idx
        
    tokens_generated = tokens_generated[0][len(seed_seq) + key_len:]
    tokens_generated = ''.join([ctrl_idx_to_aa[c] for c in tokens_generated])
    return tokens_generated, res_stopped


In [6]:
predicted = []
predicted_stopped = []
# true_tokens_index_in_probs_all = []
# tokens_probs_all = []
# true_value = []
seq_number = 0
#for entry in random_selection:
while seq_number <= 1500:
    # input_seq = entry['sequence']['metadata']['sequence']
    # FOR TESTIG the code:
    # input_seq = input_seq[0:24]
    input_seq = ''
    
    # if no tax keys:
    tax_lineage = [0]

    offset = 0
    print('generating sequences...')
    start_time = time.time()
    # res, tokens_prob = generation_complete_sequence(input_seq, offset, tax_lineage)
    res, tokens_generated_stopped = generation_complete_sequence(input_seq, offset, tax_lineage)
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Done. Time taken: {elapsed_time} seconds.")

    # print("Input: ", input_seq)
    # print("Res: ", res)
    # true_value.append(input_seq)
    # true_tokens_index_in_prob = []
    # true_tokens_index_in_probs_all.append(true_tokens_index_in_prob)
    # tokens_probs_all.append(tokens_prob)
    # print('input_offset should be nothing: ', input_seq[:offset])
    predicted.append(input_seq[:offset] + res)
    if tokens_generated_stopped:
        for tmp_seq in tokens_generated_stopped:
            predicted_stopped.append(input_seq[:offset] + tmp_seq)

    # FOR TESTIG:
    #break
    seq_number += 1
    print('hello, this is seq_number', seq_number)
    print('hello, this is seq_number module: ', seq_number%10)
    if (seq_number%10) == 0:
        print('the model has generated: ', seq_number , ' sequences')
        print('saving sequences')
        save_sequences(predicted, predicted_stopped)
        
print('GENERATION ENDED')
    

generating sequences...
Done. Time taken: 61.839526653289795 seconds.
hello, this is seq_number 1
hello, this is seq_number module:  1
generating sequences...
Done. Time taken: 60.806084871292114 seconds.
hello, this is seq_number 2
hello, this is seq_number module:  2
generating sequences...
Done. Time taken: 60.35079526901245 seconds.
hello, this is seq_number 3
hello, this is seq_number module:  3
generating sequences...
Done. Time taken: 60.46475267410278 seconds.
hello, this is seq_number 4
hello, this is seq_number module:  4
generating sequences...


KeyboardInterrupt: 