In [1]:
import h5py
import pickle
import pandas as pd

import torch

import numpy as np
%matplotlib inline

from translator import SignalTranslator
torch.cuda.synchronize()

with open('../outputs/ctable_token.pkl', 'rb') as f:
    ctable = pickle.load(f)

In [2]:
torch.cuda.device(0)

<torch.cuda.device at 0x7f2cf5b2e978>

In [3]:
# Helper functions for tokenizing new inputs
alphabet = ' .$ACDEFGHIKLMNPQRSTUVWXYZ'
max_len_in = 107 # max length of prot seq (105 aa) + 2 for tokens
max_len_out = 72
n_chars = len(alphabet)

with open('../data/ctable_copies/ctable_token_master.pkl', 'rb') as f:
    ctable = pickle.load(f)

def encode(seqs, max_len, ctable):
    if ctable.one_hot:
        X = np.zeros((len(seqs), max_len, n_chars))
    else:
        X = np.zeros((len(seqs), max_len))
    seqs = ['$' + seq + '.' for seq in seqs]
    seqs = [seq + ' ' * ((max_len) - len(seq))for seq in seqs]
    for i, seq in enumerate(seqs):
        X[i] = ctable.encode(seq, max_len)
    return X

def to_h5py(seqs, fname, ctable):
    chunksize = 500
    with h5py.File(fname, 'w') as f:
        if ctable.one_hot:
            print('true')
            X = f.create_dataset('X', (len(seqs), max_len_in, n_chars))
        else:
            X = f.create_dataset('X', (len(seqs), max_len_in))          
        for i in range(0, len(seqs), chunksize):
            X[i:i + chunksize, :] = encode([seq for seq in seqs[i:i+chunksize]], max_len_in, ctable)
        left = len(seqs) % chunksize
        if left > 0:
            X[-left:, :] = encode([seq for seq in seqs[-left:]], max_len_in, ctable)

In [4]:
# Load sample input, convert to h5py file for generator
df = pd.read_excel('../data/example_test_input.xlsx', engine='openpyxl')
test_seqs = df['protein_sequence'].values
test_seqs = [s[:100] for s in test_seqs]
test_filename = ('../data/example_test_tokens.hdf5')
to_h5py(test_seqs, test_filename, ctable)

In [5]:
# Load a Model Checkpoint
chkpt_name = 'SIM99_550_12500_64_6_5_0.1_64_100_0.0001_-0.03_99'
chkpt = "../outputs/models/model_checkpoints/" + chkpt_name + ".chkpt"
clf = SignalTranslator.load_model(chkpt)

Namespace(cuda=False, d_inner_hid=1100, d_k=64, d_model=550, d_v=64, d_word_vec=550, dropout=0.1, embs_share_weight=True, max_token_seq_len=107, n_head=5, n_layers=6, proj_share_weight=True, src_vocab_size=27, tgt_vocab_size=27) Namespace(beam_size=1, ctable=<tools.CharacterTable object at 0x7f2cedbd9e48>, max_trans_length=72, n_best=1) Namespace(d_model=None, decay_power=-0.03, lr_max=0.0001, n_warmup_steps=12500, optim=<class 'torch.optim.adam.Adam'>)
position_encoding
position_encoding
Initiated Transformer with 27403200 parameters.


In [6]:
# test_gen_data = []
# Generate SPs for Proteins
file = h5py.File(test_filename)
training_data = SignalTranslator.generator_from_h5_noy(file, 64, shuffle=False, use_cuda=False)
src = next(training_data) # src is prot sequence, tgt is signal pep
file.close()
clf_outputs  = clf.translate_batch(src, 5)
decoded, all_hyp, all_scores, enc_outputs, dec_outputs, enc_slf_attns, dec_slf_attns, dec_enc_attn = clf_outputs

for src, dec in zip(src[0], decoded):
#     print(ctable.decode(src.data.cpu().numpy())[:]) # prot sequence from Zach's excel
    print(dec) # model's predictions
    print()
    
    input_seq = ctable.decode(src.data.cpu().numpy())[:]
    output_seq = dec

  result = self.forward(*input, **kwargs)
  out = self.model.prob_projection(dec_output)


MKKPLGKIVASTALLISVAFSSSIASA.

MIKKIPLKTIAVMALSGCTFFVNG.

MRLAKIAGLTASLLFSLWGALA.

MAILVLLFLLAVEINS.

MKMRTGKKGFLSILLAFLLVITSIPFTLVDVEA.

MKLFTATIAVLGAVSATAHA.

MKDLFRLIALLSCCLALFPLTFA.

MKSLLLTAFAAGTALA.                                                       

MNKLFYLFMLGLAAFA.                                                       

MKLKIVFAVAAIAPVLHS.

MKFTQAVLSLLGSAATALA.

MNIRLGALLAGLLLSAMASAVFA.

MKKSLISFLALGLLFGSAFA.

MKFLSIVLLIVGLAYG.

MFKFVLVLSVLAALASARA.

MLTFHRIIRKGWMFLLAFLLTALLFCPTGQPAKA.

MNLKILFALALGVCLAA.

MVSNKRVLALSALFGCCSLASA.

MKNFATLSAVLAGATALA.                                                     

MIRLKRLLAGLLLPLFVTAFG.

MKKTGFIGKTLALVIAAGMAGTAAFA.

MKKTILALALLGSLAA.                                                       

MTLKTTITLFFAALSANAAFA.

MKFQDLTLVLSLSTALA.

MKLLTSFVLIGALAFA.                                                       

MGIQKKVSILVAGLFMATAFATA.

MKKTAAIAALAGLSFAGMAHA.

MVASLWSSILPVLAFLWADLSAGA.

MKKRLLIASVALGSLFSFCA.

MARA.

MKFLILLITLGAIAATALA.

MKL