In [1]:
from __future__ import print_function
from __future__ import division
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import sys
import torch
import tqdm
import pdb
import numpy as np
import platform
import hashlib
import pytorch_transformer
import re
import argparse
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import torch.nn.functional as F
#from torch.utils.tensorboard import SummaryWriter
from transformProtein import transformProtein
from ProteinDataset_uid import ProteinDataset
from torch.utils.data import Dataset, DataLoader
import pickle
import time
import matplotlib.pyplot as plt

load_model_path = 'ckpt/' # just the folder itself

seq_length = 511
embedding_dim = 1280
num_layers = 36
vocab_loc = 'mapping_files/vocab.txt'

use_py3 = platform.python_version()[0] == '3'
vocab = open(vocab_loc).readlines() if not use_py3 else open(vocab_loc, encoding='utf-8').read().split('\n')[:-1]
vocab = list(map(lambda x: x.split(' ')[0], vocab))
vocab_size = len(vocab)
print('-----vocab size',vocab_size,'------')

class TiedEmbeddingSoftmax(torch.nn.Module):

  def __init__(self, vocab_size=vocab_size, embedding_size=embedding_dim, **kwargs):
    super(TiedEmbeddingSoftmax, self).__init__()
    self.w = torch.nn.Parameter(torch.normal(0., 1e-2, size=(vocab_size, embedding_size)))
    self.b = torch.nn.Parameter(torch.zeros(vocab_size))

  def forward(self, inputs, embed=True):
    if embed:
      return torch.nn.functional.embedding(inputs, self.w)
    else:
      return torch.tensordot(inputs, self.w.t(), 1) + self.b

class CTRLmodel(torch.nn.Module):
  def __init__(self):
    super(CTRLmodel,self).__init__()
    self.tied_embedding_softmax = TiedEmbeddingSoftmax()
    self.encoder = pytorch_transformer.Encoder()

  def forward(self, inputs):
    x = self.tied_embedding_softmax(inputs, embed = True)
    x = self.encoder(x)
    x = self.tied_embedding_softmax(x, embed = False)
    return x

  def loadCheckpoint(self, model_path, num_layers):
    pytorch_model_hash = hashlib.md5(model_path.encode('utf-8')).hexdigest()

    if os.path.exists(pytorch_model_hash):
      print('Found PyTorch checkpoint @', pytorch_model_hash)
      print('Loading instead of converting from TensorFlow')
      checkpoint = torch.load(pytorch_model_hash)
      self.tied_embedding_softmax.load_state_dict(checkpoint['softmax'])
      self.encoder.load_state_dict(checkpoint['encoder'])

      self.tied_embedding_softmax.to('cuda')
      self.encoder.to('cuda')

    else:
      print('Could not find PyTorch checkpoint')
      print('Converting weights and will store the PyTorch checkpoint as ', pytorch_model_hash)
      chkpt_for_reader = model_path # '.'.join(model_path.split('.')[:-1])
      reader = pywrap_tensorflow.NewCheckpointReader(chkpt_for_reader)

      self.tied_embedding_softmax.w = torch.nn.Parameter(torch.tensor(reader.get_tensor('w')).to('cuda'))
      self.tied_embedding_softmax.b = torch.nn.Parameter(torch.tensor(reader.get_tensor('b')).to('cuda'))

      list_of_variables = list(filter(lambda x: ('Adagrad' not in x) and ('Adam' not in x), reader.get_variable_to_shape_map().keys()))

      str2parameter = lambda x: torch.nn.Parameter(torch.tensor(reader.get_tensor(x)).t().to('cuda'))
      
      self.encoder.layernorm.weight = str2parameter('encoder/layer_normalization_'+str(int(num_layers*2))+'/gamma')
      self.encoder.layernorm.bias = str2parameter('encoder/layer_normalization_'+str(int(num_layers*2))+'/beta')
      for i in tqdm.tqdm(range(num_layers)):
        if i==0:
          layer_variables = sorted(filter(lambda x: 'layer/' in x, list_of_variables))
        else:
          layer_variables = sorted(filter(lambda x: 'layer_'+str(i)+'/' in x, list_of_variables))
        
        current_layer = getattr(self.encoder, 'layer'+str(i))
        
        current_layer.layernorm1.bias = str2parameter(layer_variables[0])
        current_layer.layernorm1.weight = str2parameter(layer_variables[1])

        current_layer.layernorm2.bias = str2parameter(layer_variables[2])
        current_layer.layernorm2.weight = str2parameter(layer_variables[3])


        current_layer.multi_head_attention.Wq.bias = str2parameter(layer_variables[4])
        current_layer.multi_head_attention.Wq.weight = str2parameter(layer_variables[5])
        current_layer.multi_head_attention.Wk.bias = str2parameter(layer_variables[6])
        current_layer.multi_head_attention.Wk.weight = str2parameter(layer_variables[7])
        current_layer.multi_head_attention.Wv.bias = str2parameter(layer_variables[8])
        current_layer.multi_head_attention.Wv.weight = str2parameter(layer_variables[9])
        current_layer.multi_head_attention.dense.bias = str2parameter(layer_variables[10])
        current_layer.multi_head_attention.dense.weight = str2parameter(layer_variables[11])
        current_layer.ffn[0].bias = str2parameter(layer_variables[12])
        current_layer.ffn[0].weight = str2parameter(layer_variables[13])
        current_layer.ffn[2].bias = str2parameter(layer_variables[14])
        current_layer.ffn[2].weight = str2parameter(layer_variables[15])

model = CTRLmodel()
print('model initialized')

ckptnum = '1000000'
curr_model_path = load_model_path+'model.ckpt-'+ckptnum
reader = model.loadCheckpoint(model_path=curr_model_path, num_layers = num_layers)
print('previous checkpoint loaded')
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters()) #lr, betas

model.eval()

with open(os.path.join('mapping_files/','taxa_to_lineage.p'),'rb') as handle:
    taxa_to_lineage = pickle.load(handle)
with open('mapping_files/taxa_to_ctrl_idx.p','rb') as handle:
    taxa_to_ctrl_idx = pickle.load(handle)
with open('mapping_files/kw_to_ctrl_idx.p','rb') as handle:
    kw_to_ctrl_idx = pickle.load(handle)
with open('mapping_files/aa_to_ctrl_idx.p','rb') as handle:
    aa_to_ctrl_idx = pickle.load(handle)
    
with open('mapping_files/kw_to_name.p2','rb') as handle:
    kw_to_name = pickle.load(handle)
with open('mapping_files/taxid_to_name.p2','rb') as handle:
    taxid_to_name = pickle.load(handle)
    
def flipdict(my_map):
    return {v: k for k, v in my_map.items()}
ctrl_idx_to_aa = flipdict(aa_to_ctrl_idx)
ctrl_idx_to_kw = flipdict(kw_to_ctrl_idx)
ctrl_idx_to_taxa = flipdict(taxa_to_ctrl_idx)

-----vocab size 129407 ------
model initialized
Could not find PyTorch checkpoint
Converting weights and will store the PyTorch checkpoint as  b2f624153dd999ec4146b962f290e7bf


100%|██████████| 36/36 [00:37<00:00,  1.03s/it]


previous checkpoint loaded


In [None]:
def predict_fn(inputs):
    with torch.no_grad():
        inputs = torch.tensor(inputs).cuda()
        output = model(inputs)
        output = output[:,:,-26:-1] # remove non-AA token logits
        return output

In [3]:
taxid = 9606 # homo sapiens taxonomy id from NCBI: https://www.ncbi.nlm.nih.gov/taxonomy
tax_lineage = taxa_to_lineage[taxid] # make lineage in ncbi ids
print(tax_lineage)
tax_lineage = [taxa_to_ctrl_idx[ite] for ite in tax_lineage] # now translated as ctrl code indices
print(tax_lineage)

kw_lineage = [677,9] # UniprotKB keywords from https://www.uniprot.org/docs/keywlist
print(kw_lineage)
kw_lineage = [kw_to_ctrl_idx[ite] for ite in kw_lineage] # now translated to ctrl code indices
print(kw_lineage)

[33208, 7711, 40674, 9443, 9604, 9605, 9606]
[11177, 5756, 14034, 6957, 7068, 7069, 7070]
[677, 9]
[46, 258]


In [4]:
example_seq = 'YMIQEEEWDRDLLLDPAWEKQQRKTFTAWCNSHLRKAGTQIENIEEDFRNGLKLMLLLEVISGERLPKPDRGKMRFHKIANVNKALDYIASKGVKLVSIGAEEIVDGNVKMTLGMIWTIILRFAIQDISVEETSAKEGLLLWCQRKTAPYRNVNIQNFHTSWKDGLGLCALIHRHRPDLIDYSKLNKDDPIGNINLAMEIAEKHLDIPKMLDAEDIVNTPKPDERAIMTYVSCFYHAFAGAEQAETAANRICKVLAVNQENERLMEEYERLASELLEWIRRTIPWLENRTPAATMQAMQKKLEDFRDYRRKHKPPKVQEKCQLEINFNTLQTKLRISNRPAFMPSEGKMVSDIAGAWQRLEQAEKGYEEWLLNEIRRLERLEHLAEKFRQKASTHETWAYGKEQILLQKDYESASLTEVRALLRKHEAFESDLAAHQDRVEQIAAIAQELNELDYHDAVNVNDRCQKICDQWDRLGTLTQKRREALERMEKLLETIDQLHLEFAKRAAPFNNWMEGAMEDLQDMFIVHSIEEIQSLITAHEQFKATLPEADGERQSIMAIQNEVEKVIQSYNIRISSSNPYSTVTMDELRTKWDKVKQLVPIRDQSLQEELARQHANERLRRQFAAQANAIGPWIQNKMEEIARSSIQITGALEDQMNQLKQYEHNIINYKNNIDKLEGDHQLIQEALVFDNKHTNYTMEHIRVGWELLLTTIARTINEVETQILTRDAKGITQEQMNEFRASFNHFDRRKNGLMDHEDFRACLISMGYDLGEAEFARIMTLVDPNGQGTVTFQSFIDFMTRETADTDTAEQVIASFRILASDKPYILAEELRRELPPDQAQYCIKRMPAYSGPGSVPGALDYAAFSSALYGESDL'
prefix = example_seq[:150]
ref = example_seq[150:200]
penalty = 1.2
topk = 1

seed_seq = [aa_to_ctrl_idx[ii] for ii in prefix]
generate_num = len(kw_lineage+tax_lineage)+len(prefix+ref)
seq_length = min(generate_num, 511)

text = tax_lineage+kw_lineage+seed_seq
padded_text = text + [0] * (generate_num - len(text))
tokens_generated = np.tile(padded_text, (1,1))

for token in range(len(text)-1, generate_num-1):
    prompt_logits = predict_fn(tokens_generated[:, :seq_length]).squeeze()
    _token = token if token < seq_length else -1
    prompt_logits = prompt_logits.cpu().detach().numpy()

    if penalty>0:
        penalized_so_far = set()
        for _ in range(token-3,token+1):
            generated_token = tokens_generated[0][_] - (vocab_size-26) # added
            if generated_token in penalized_so_far:
                continue
            penalized_so_far.add(generated_token)
            prompt_logits[_token][generated_token] /= penalty

    # compute probabilities from logits
    prompt_probs = np.exp(prompt_logits[_token])
    prompt_probs = prompt_probs / sum(prompt_probs)
    pruned_list = np.argsort(prompt_probs)[::-1]

    if topk==1:
        idx = pruned_list[0]
    else:
        pruned_list = pruned_list[:topk]
        chosen_idx = torch.distributions.categorical.Categorical(logits=torch.tensor(np.expand_dims(prompt_logits[_token][pruned_list],0))).sample().numpy()[0]
        idx = pruned_list[chosen_idx]

    # assign the token for generation
    idx += (vocab_size-26) # added to convert 0 AA to original ctrl idx
    tokens_generated[0][token+1] = idx


tokens_generated_so_far = tokens_generated[0].squeeze()[:token+2]
tokens_generated_so_far = tokens_generated_so_far[(tokens_generated_so_far>=(vocab_size-26)) & (tokens_generated_so_far<(vocab_size-1))]
tokens_generated_so_far = ''.join([ctrl_idx_to_aa[c] for c in tokens_generated_so_far])

query = tokens_generated_so_far[len(seed_seq):]

In [5]:
print(ref)
print(query)

RNVNIQNFHTSWKDGLGLCALIHRHRPDLIDYSKLNKDDPIGNINLAMEI
RNVNIQNFHTSWKDGLALNALIHRHRPDLIDYAKLRKDDPIGNLNTAFEV
