In [1]:
from __future__ import print_function
from __future__ import division
import os
#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
import sys
import torch
import tqdm
import pdb
import numpy as np
import platform
import hashlib
import pytorch_transformer
import re
import argparse
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import torch.nn.functional as F
#from torch.utils.tensorboard import SummaryWriter
from transformProtein import transformProtein
from ProteinDataset_uid import ProteinDataset
from torch.utils.data import Dataset, DataLoader
import pickle
import time
import matplotlib.pyplot as plt

load_model_path = 'ckpt/' # just the folder itself

seq_length = 511
embedding_dim = 1280
num_layers = 36
vocab_loc = 'mapping_files/vocab.txt'

use_py3 = platform.python_version()[0] == '3'
vocab = open(vocab_loc).readlines() if not use_py3 else open(vocab_loc, encoding='utf-8').read().split('\n')[:-1]
vocab = list(map(lambda x: x.split(' ')[0], vocab))
vocab_size = len(vocab)
print('-----vocab size',vocab_size,'------')

class TiedEmbeddingSoftmax(torch.nn.Module):

  def __init__(self, vocab_size=vocab_size, embedding_size=embedding_dim, **kwargs):
    super(TiedEmbeddingSoftmax, self).__init__()
    self.w = torch.nn.Parameter(torch.normal(0., 1e-2, size=(vocab_size, embedding_size)))
    self.b = torch.nn.Parameter(torch.zeros(vocab_size))

  def forward(self, inputs, embed=True):
    if embed:
      return torch.nn.functional.embedding(inputs, self.w)
    else:
      return torch.tensordot(inputs, self.w.t(), 1) + self.b

class CTRLmodel(torch.nn.Module):
  def __init__(self):
    super(CTRLmodel,self).__init__()
    self.tied_embedding_softmax = TiedEmbeddingSoftmax()
    self.encoder = pytorch_transformer.Encoder()

  def forward(self, inputs):
    x = self.tied_embedding_softmax(inputs, embed = True)
    print(x.shape)
    x = self.encoder(x)
    x = self.tied_embedding_softmax(x, embed = False)
    return x

  def loadCheckpoint(self, model_path, num_layers):
    if os.path.exists(model_path):
      print('Found PyTorch checkpoint at ', model_path)
      print('Loading instead of converting from TensorFlow')
      checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
      self.tied_embedding_softmax.load_state_dict({
        'w': checkpoint.pop('tied_embedding_softmax.w', None),
        'b': checkpoint.pop('tied_embedding_softmax.b', None)
      })
      self.encoder.load_state_dict({key.replace("encoder.", ""): value for key, value in checkpoint.items()})
      self.tied_embedding_softmax.to('cuda')
      self.encoder.to('cuda')
    else:
      print('Could not find PyTorch checkpoint')
      sys.exit()

model = CTRLmodel()
print('model initialized')

curr_model_path = load_model_path+'pretrain_progen_full.pth'
reader = model.loadCheckpoint(model_path=curr_model_path, num_layers = num_layers)
model = model.cuda()
print('previous checkpoint loaded')

optimizer = torch.optim.Adam(model.parameters()) #lr, betas

model.eval()

with open(os.path.join('mapping_files/','taxa_to_lineage.p'),'rb') as handle:
    taxa_to_lineage = pickle.load(handle)
with open('mapping_files/taxa_to_ctrl_idx.p','rb') as handle:
    taxa_to_ctrl_idx = pickle.load(handle)
with open('mapping_files/kw_to_ctrl_idx.p','rb') as handle:
    kw_to_ctrl_idx = pickle.load(handle)
with open('mapping_files/aa_to_ctrl_idx.p','rb') as handle:
    aa_to_ctrl_idx = pickle.load(handle)
    
with open('mapping_files/kw_to_name.p2','rb') as handle:
    kw_to_name = pickle.load(handle)
#with open('mapping_files/taxid_to_name.p2','rb') as handle:
#    taxid_to_name = pickle.load(handle)
    
def flipdict(my_map):
    return {v: k for k, v in my_map.items()}
ctrl_idx_to_aa = flipdict(aa_to_ctrl_idx)
ctrl_idx_to_kw = flipdict(kw_to_ctrl_idx)
ctrl_idx_to_taxa = flipdict(taxa_to_ctrl_idx)

def predict_fn(inputs):
    with torch.no_grad():
        inputs = torch.tensor(inputs).cuda()
        output = model(inputs)
        output = output[:,:,-26:-1] # remove non-AA token logits
        return output

-----vocab size 129407 ------
MODEL SIZE: 
1280
model initialized
Found PyTorch checkpoint at  ckpt/pretrain_progen_full.pth
Loading instead of converting from TensorFlow
previous checkpoint loaded


In [11]:
import pickle
import os

# paths to the saved .p files
random_selection_file = os.path.join("data", "random_selection_PF16754.p")
data_file = os.path.join("data", "filtered_data_PF16754.p")

# Reload random_selection from the .p file
random_selection = False
with open(random_selection_file, "rb") as file:
    random_selection = pickle.load(file)

# Reload filtered_data from the .p file
#filtered_data = False
#with open(data_file, "rb") as file:
#    filtered_data = pickle.load(file)

''' Print or use the reloaded datasets as needed
print("Random Selection:")
for entry in random_selection:
    print(entry["sequence"])

print("---")

print("Filtered Data:")
for entry in filtered_data:
    print(entry["sequence"])'''


' Print or use the reloaded datasets as needed\nprint("Random Selection:")\nfor entry in random_selection:\n    print(entry["sequence"])\n\nprint("---")\n\nprint("Filtered Data:")\nfor entry in filtered_data:\n    print(entry["sequence"])'

In [None]:
'''To evaluate the generated sequence using teacher forcing and 3-grams and 5-grams:

1. Generate the Sequence: Use your transformer model with teacher forcing to generate the complete sequence. 
At each time step, provide the true input sequence (+ 1 versus previous step), as input to the model.

2. Extract n-grams: Once you have the generated sequence, extract the 3-grams and 5-grams from it. 
With slide a window of size 3 or 5 along the generated sequence and extract the corresponding n-grams.

3. Compare with Test 3-grams and 5-grams: Retrieve the test 3-grams and 5-grams from your test dataset. 
These are the n-grams that you want to compare the generated sequence against.

4. Calculate Metrics: Compare the generated n-grams with the test n-grams to evaluate their similarity. 
You can use various metrics such as precision, recall, or the SAE (Sum of Absolute Errors) to quantify 
the similarity or dissimilarity.

- For precision and recall, you can calculate how many of the generated n-grams match the test
n-grams and divide it by the total number of generated or test n-grams.

- For the SAE, you can calculate the absolute difference between the frequency of each n-gram 
in the generated sequence and the frequency of the same n-gram in the test sequence. 
Sum up these absolute differences to obtain the final SAE value.
'''

In [None]:
input_seq = 'YMIQEEEWDRDLLLDPAWEKQQRKTFTAWCNSHLRKAG'
prefix = input_seq[:3]
seed_seq = [aa_to_ctrl_idx[ii] for ii in prefix]
penalty = 0
topk = 1
key_len = 0
generate_num = key_len +len(prefix+ref)
seq_length = min(generate_num, 511)
padded_text = seed_seq + [0] * (generate_num - len(seed_seq))
print(padded_text)
tokens_generated = np.tile(padded_text, (1,1))

for token in range(len(text)-1, generate_num-1):
    prompt_logits = predict_fn(tokens_generated[:, :seq_length]).squeeze()
    _token = token if token < seq_length else -1
    prompt_logits = prompt_logits.cpu().detach().numpy()

    # compute probabilities from logits
    prompt_probs = np.exp(prompt_logits[_token])
    prompt_probs = prompt_probs / sum(prompt_probs)
    pruned_list = np.argsort(prompt_probs)[::-1]

    if topk==1:
        idx = pruned_list[0]
    else:
        pruned_list = pruned_list[:topk]
        chosen_idx = torch.distributions.categorical.Categorical(logits=torch.tensor(np.expand_dims(prompt_logits[_token][pruned_list],0))).sample().numpy()[0]
        idx = pruned_list[chosen_idx]

    # assign the token for generation
    idx += (vocab_size-26) # added to convert 0 AA to original ctrl idx
    tokens_generated[0][token+1] = idx

tokens_generated_so_far = tokens_generated[0].squeeze()[:token+2]
tokens_generated_so_far = tokens_generated_so_far[(tokens_generated_so_far>=(vocab_size-26)) & (tokens_generated_so_far<(vocab_size-1))]
tokens_generated_so_far = ''.join([ctrl_idx_to_aa[c] for c in tokens_generated_so_far])

query = tokens_generated_so_far[len(seed_seq):]
print(prefix)
print(query)

In [3]:
print(kw_to_name[9])
taxid = 9606 # homo sapiens taxonomy id from NCBI: https://www.ncbi.nlm.nih.gov/taxonomy
tax_lineage = taxa_to_lineage[taxid] # make lineage in ncbi ids
print(tax_lineage)
tax_lineage = [taxa_to_ctrl_idx[ite] for ite in tax_lineage] # now translated as ctrl code indices
print(tax_lineage)

kw_lineage = [677,9] # UniprotKB keywords from https://www.uniprot.org/docs/keywlist
print(kw_lineage)
kw_lineage = [kw_to_ctrl_idx[ite] for ite in kw_lineage] # now translated to ctrl code indices
print(kw_lineage)


example_seq = 'YMIQEEEWDRDLLLDPAWEKQQRKTFTAWCNSHLRKAGTQIENIEEDFRNGLKLMLLLEVISGERLPKPDRGKMRFHKIANVNKALDYIASKGVKLVSIGAEEIVDGNVKMTLGMIWTIILRFAIQDISVEETSAKEGLLLWCQRKTAPYRNVNIQNFHTSWKDGLGLCALIHRHRPDLIDYSKLNKDDPIGNINLAMEIAEKHLDIPKMLDAEDIVNTPKPDERAIMTYVSCFYHAFAGAEQAETAANRICKVLAVNQENERLMEEYERLASELLEWIRRTIPWLENRTPAATMQAMQKKLEDFRDYRRKHKPPKVQEKCQLEINFNTLQTKLRISNRPAFMPSEGKMVSDIAGAWQRLEQAEKGYEEWLLNEIRRLERLEHLAEKFRQKASTHETWAYGKEQILLQKDYESASLTEVRALLRKHEAFESDLAAHQDRVEQIAAIAQELNELDYHDAVNVNDRCQKICDQWDRLGTLTQKRREALERMEKLLETIDQLHLEFAKRAAPFNNWMEGAMEDLQDMFIVHSIEEIQSLITAHEQFKATLPEADGERQSIMAIQNEVEKVIQSYNIRISSSNPYSTVTMDELRTKWDKVKQLVPIRDQSLQEELARQHANERLRRQFAAQANAIGPWIQNKMEEIARSSIQITGALEDQMNQLKQYEHNIINYKNNIDKLEGDHQLIQEALVFDNKHTNYTMEHIRVGWELLLTTIARTINEVETQILTRDAKGITQEQMNEFRASFNHFDRRKNGLMDHEDFRACLISMGYDLGEAEFARIMTLVDPNGQGTVTFQSFIDFMTRETADTDTAEQVIASFRILASDKPYILAEELRRELPPDQAQYCIKRMPAYSGPGSVPGALDYAAFSSALYGESDL'
prefix = example_seq[:3]
# prefix = ""
print("Prefix: ", prefix)
ref = example_seq[150:155]

print("Ref: ", ref)
penalty = 1.2
topk = 3

seed_seq = [aa_to_ctrl_idx[ii] for ii in prefix]
print("seed sequence: ", seed_seq)
# generate_num = len(kw_lineage+tax_lineage)+len(prefix+ref)
key_len = len(kw_lineage+tax_lineage)
generate_num = key_len +len(prefix+ref)
seq_length = min(generate_num, 511)

text = tax_lineage + kw_lineage + seed_seq
padded_text = text + [0] * (generate_num - len(text))
print(padded_text)
tokens_generated = np.tile(padded_text, (1,1))

#i = 1
for token in range(len(text)-1, generate_num-1):
    #print("counter: ", i)
    #i += 1
    prompt_logits = predict_fn(tokens_generated[:, :seq_length]).squeeze()
    _token = token if token < seq_length else -1
    prompt_logits = prompt_logits.cpu().detach().numpy()

    if penalty>0:
        penalized_so_far = set()
        # variable token_flag for first amminoacids (to count them if they are less that 4)
        if token >= key_len + 3:
            token_flag = 3  
        elif token - key_len - 3 <= 0:
            token_flag = 0
        else:
            token_flag = token
        #print(key_len)
        #print(token)
        #print(token_flag)
        for _ in range(token-token_flag,token+1):
            generated_token = tokens_generated[0][_] - (vocab_size-26) # added
            if generated_token in penalized_so_far:
                continue
            penalized_so_far.add(generated_token)
            prompt_logits[_token][generated_token] /= penalty

    # compute probabilities from logits
    prompt_probs = np.exp(prompt_logits[_token])
    prompt_probs = prompt_probs / sum(prompt_probs)
    pruned_list = np.argsort(prompt_probs)[::-1]

    if topk==1:
        idx = pruned_list[0]
    else:
        pruned_list = pruned_list[:topk]
        chosen_idx = torch.distributions.categorical.Categorical(logits=torch.tensor(np.expand_dims(prompt_logits[_token][pruned_list],0))).sample().numpy()[0]
        idx = pruned_list[chosen_idx]

    # assign the token for generation
    idx += (vocab_size-26) # added to convert 0 AA to original ctrl idx
    tokens_generated[0][token+1] = idx


tokens_generated_so_far = tokens_generated[0].squeeze()[:token+2]
tokens_generated_so_far = tokens_generated_so_far[(tokens_generated_so_far>=(vocab_size-26)) & (tokens_generated_so_far<(vocab_size-1))]
tokens_generated_so_far = ''.join([ctrl_idx_to_aa[c] for c in tokens_generated_so_far])

query = tokens_generated_so_far[len(seed_seq):]

print(prefix)
print(query)

[33208, 7711, 40674, 9443, 9604, 9605, 9606]
[11177, 5756, 14034, 6957, 7068, 7069, 7070]
[677, 9]
[46, 258]


YMI
