In [12]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from torch.autograd import Variable

import pandas as pd
import numpy as np

import pdb
import torch.optim as optim
from tqdm import tqdm_notebook as tqdm

from torch.nn.utils import clip_grad_norm

# import torch.backends.cudnn as cudnn
# cudnn.benchmark = True

import math

use_cuda = True

In [62]:

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_gru_layers=1, bidirectional = True, dropout_p = 0.2):
        super(EncoderRNN, self).__init__()
        
        self.h_layers = num_gru_layers
        self.hidden_size = hidden_size
        self.num_gru_layers = num_gru_layers
        self.h_layers = num_gru_layers
        
        self.bidirectional = bidirectional
        
        self.lin_in = hidden_size
        if bidirectional:
#             self.hidden_size = self.hidden_size* 2
            self.h_layers = self.h_layers * 2
            self.lin_in = hidden_size * 2
        
        self.droput = nn.Dropout(dropout_p)
        
        self.gru = nn.GRU(input_size, hidden_size, num_layers = num_gru_layers, bidirectional = bidirectional, dropout = dropout_p)
#         self.out = nn.Sequential(
#                                  nn.BatchNorm1d(self.lin_in),
#                                  nn.ReLU(True),
#                                  nn.Linear(self.lin_in, self.lin_in),
#                                  nn.BatchNorm1d(self.lin_in),
#                                  nn.ReLU(True),
#                                  nn.Linear(self.lin_in, input_size),
#                                  torch.nn.LogSoftmax(dim = 1)
#                                 )
        
    def forward(self, input, hidden):
        
#         output = self.droput(input)
        
        output, hidden = self.gru(input, hidden)        
#         output = torch.cat([torch.unsqueeze(self.out(out_batch),0) for out_batch in output],0)

        return output, hidden

    def initHidden(self, batch_size = 1):
        result = Variable(torch.zeros(self.h_layers, batch_size, self.hidden_size))
        
        return result
#         if use_cuda:
#             return result.cuda()
#         else:
#             return result
  
class Attention(nn.Module):
    def __init__(self, attn_in_size, hidden_size):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
#         self.attn = nn.Linear(attn_in_size*2, hidden_size)
        
        self.attn = nn.Sequential(
                                    nn.Linear(attn_in_size*2, hidden_size)
                                )
        
        self.v = nn.Parameter(torch.rand(hidden_size))
        stdv = 1. / math.sqrt(self.v.size(0))
        self.v.data.uniform_(-stdv, stdv)

    def forward(self, hidden, encoder_outputs):
        timestep = encoder_outputs.size(0)
        h = hidden.repeat(timestep, 1, 1).transpose(0, 1)
        encoder_outputs = encoder_outputs.transpose(0, 1)  # [B*T*H]
        attn_energies = self.score(h, encoder_outputs)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

    def score(self, hidden, encoder_outputs):
        # [B*T*2H]->[B*T*H]
#         pdb.set_trace()
        
#         energy = torch.cat([torch.unsqueeze(self.attn(out_batch),0) for out_batch in torch.cat([hidden, encoder_outputs], 2)],0)
        
        energy = self.attn(torch.cat([hidden, encoder_outputs], 2))
        
        energy = energy.transpose(1, 2)  # [B*H*T]
        v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1)  # [B*1*H]
        energy = torch.bmm(v, energy)  # [B*1*T]
        return energy.squeeze(1)  # [B*T]


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, n_layers=1,  bidirectional = True, dropout_p=0.2):
        super(DecoderRNN, self).__init__()
        
        output_size = embed_size
        
        self.bidirectional = bidirectional
        
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        
        
        self.hidden_in = hidden_size
        if bidirectional:
            self.hidden_in = hidden_size * 2

        self.embed = nn.Embedding(output_size, embed_size)
        self.dropout = nn.Dropout(dropout_p, inplace=True)
        self.attention = Attention(self.hidden_in, hidden_size)
        self.gru = nn.GRU(self.hidden_in + embed_size, hidden_size, n_layers, bidirectional = bidirectional)
#         self.out = nn.Linear(self.hidden_in * 2, output_size)
        
        self.out = nn.Sequential(
                                    nn.Linear(self.hidden_in * 2, output_size)
                                )

    def forward(self, embedded, last_hidden, encoder_outputs):
        # Get the embedding of the current input word (last output word)
#         embedded = self.embed(input).unsqueeze(0)  # (1,B,N)
#         embedded = self.dropout(input)
        # Calculate attention weights and apply to encoder outputs

        
        h_in = last_hidden[-1]
        if self.bidirectional:
            h_in = torch.cat([last_hidden[-2], last_hidden[-1]], 1)
        
        attn_weights = self.attention(h_in, encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))  # (B,1,N)
        context = context.transpose(0, 1)  # (1,B,N)
        # Combine embedded input word and attended context, run through RNN
        rnn_input = torch.cat([embedded, context], 2)
        output, hidden = self.gru(rnn_input, last_hidden)
        
#         output = output.squeeze(0)  # (1,B,N) -> (B,N)
#         context = context.squeeze(0)
        
        output = self.out(torch.cat([output, context], 2))
        output = F.log_softmax(output, dim=2)
        return output, hidden, attn_weights

In [None]:
import os
from decimal import Decimal

import seq2loc.utils as utils

from seq2loc.data.datasets import PaddedSequenceDataset, SequenceDataset



GPU_id = 0
LR = 0.001
N_EPOCHS = 500
hidden_size = 128
batch_size = 64
num_gru_layers = 3
teacher_forcing_ratio = 0.5
log_number = 50
loss_thresh = 1.5E-4
N_LETTERS = utils.n_letters()

grad_clip = 10.0

bidirectional=True

ds = PaddedSequenceDataset(SequenceDataset('./data/uniprot.tsv', max_seq_len = 15), GPU_id = GPU_id)


criterion = torch.nn.NLLLoss()

enc = EncoderRNN(N_LETTERS, hidden_size, num_gru_layers, bidirectional=bidirectional).cuda(GPU_id)
dec = DecoderRNN(N_LETTERS, hidden_size, num_gru_layers, bidirectional=bidirectional).cuda(GPU_id)


opt = optim.Adam([{'params':enc.parameters()}, {'params':dec.parameters()}], lr = LR)

# dec = enc

losses = np.array([])

for _ in range(N_EPOCHS):

    
    loss_thresh = 8.12E-04 / np.log(ds.sequenceDataset.max_seq_len)
    epoch_inds = utils.get_epoch_inds(len(ds), batch_size)
    pbar = tqdm(epoch_inds)

    epoch_losses = list()
    
    for batch in pbar:
        opt.zero_grad()

        x, x_inds = ds[batch]        
#         x = (x - 1) * log_number #set to log(one-hot)

        hidden = enc.initHidden(batch_size).cuda(GPU_id)

        enc_out, hidden = enc(x, hidden)
        
        #input the stop character to the stream    
        out = Variable(utils.stopChar(batch_size)).cuda(GPU_id)

        loss = Variable(torch.zeros(1,1,1).cuda(GPU_id))

        out_list = list()
        for i in range(x.shape[0]):

            out, hidden, attn_weights = dec(out.detach(), hidden, enc_out)    
            
            out_list += [out]

            loss += criterion(torch.squeeze(out,0), x_inds[i])
    
            if teacher_forcing_ratio < torch.rand(1):
                out = torch.unsqueeze(x[i],0)
            else:
#                 pdb.set_trace()
                out = utils.indicesToTensor(torch.max(out,2)[1].cpu(), ndims = N_LETTERS)
                out = Variable(out).cuda(GPU_id)
                


        loss = loss/(x.shape[0]*x.shape[1]*x.shape[2])
        loss.backward()
        
        clip_grad_norm(enc.parameters(), grad_clip)
        clip_grad_norm(dec.parameters(), grad_clip)

        opt.step()

        losses_np = np.squeeze(loss.detach().cpu().numpy())
        
        epoch_losses += [losses_np]
    #     t.set_description('GEN %i' % i)
        pbar.set_description('%.4E' % Decimal(str(losses_np)))
        
    losses = np.hstack([losses, np.stack(epoch_losses)])
    
    if np.mean(epoch_losses) < loss_thresh:
        ds.sequenceDataset.max_seq_len += 5
        
        if batch_size > 2:
            batch_size -= 1
    
    pbar.set_description('%.4E' % Decimal(str(np.mean(epoch_losses))))
    
    print(''.join(utils.tensorToChar(torch.cat(out_list, 0))[:,0]))
    print(''.join(utils.tensorToChar(x)[:,0]))
    
    if '.' in ''.join(utils.tensorToChar(x)[:,0]):
        pdb.set_trace()

HBox(children=(IntProgress(value=0, max=315), HTML(value='')))



AAALLLLLLLLLLLL
MATVLLALLVYLGAL


HBox(children=(IntProgress(value=0, max=315), HTML(value='')))


YVLCTVLLAVLLLVA
YVLCTVLLALAVLLA


HBox(children=(IntProgress(value=0, max=315), HTML(value='')))


DMYSGIIRRLLKLAV
MDYSRIIERLLKLAV


HBox(children=(IntProgress(value=0, max=315), HTML(value='')))


WRALHPLLLLLLLFP
WRALHPLLLLLLLFP


HBox(children=(IntProgress(value=0, max=320), HTML(value='')))


KSKNVFLKNNLKKIGDGGVS
KSKNVFLKNNLLKIGDFGVS


HBox(children=(IntProgress(value=0, max=320), HTML(value='')))


QALIDALLEEDGKKLLCVSS
QALIDACLEEDGKLYLCVSS


HBox(children=(IntProgress(value=0, max=325), HTML(value='')))


QYLRLSHNELADSGIPGNSSNVSSL
QYLRLSHNELADSGIPGNSFNVSSL


HBox(children=(IntProgress(value=0, max=325), HTML(value='')))


VSSCERGLVKVWHIAMAQLVKTLSG
VSSCERGLVKVWHIAMAQLVKTLSG


HBox(children=(IntProgress(value=0, max=325), HTML(value='')))


SFDSSPTSSTDGGSSYGLDSGFCTI
SFDSSPTSSTDGHSSYGLDSGFCTI


HBox(children=(IntProgress(value=0, max=331), HTML(value='')))


SVGRPSPASSGRRESGPPGRRHEHSQHPQS
SVGRPSPLASGRRESGAPHRRHEHSPHPQS


HBox(children=(IntProgress(value=0, max=331), HTML(value='')))


IRGKIRLRQASWIIRGGTEADYMLHNVQVI
IRGKIRLRQASWIIRGGTEADYQLHNVQVI


HBox(children=(IntProgress(value=0, max=331), HTML(value='')))


AEPCGKGHRDCNSPGSFRCECKTGYYGDGI
AEPCGKGHRCVNSPGSFRCECKTGYYFDGI


HBox(children=(IntProgress(value=0, max=336), HTML(value='')))


VLLSEEEIQQTCEMLQQCEEEFIIIISGGKPLVVE
VLLSEEEIQQTCEMLQQCKEEFINDISGGKPLEVE


HBox(children=(IntProgress(value=0, max=336), HTML(value='')))


VRGFLEKESAAVSRPLNPFTAKALSGTSPDVDQPG
VRGFLEKESAIVSRPLNPFTAKALSGTSPDDVQPG


HBox(children=(IntProgress(value=0, max=336), HTML(value='')))


VVLSFSRIAIILPANNSFGPLIISLGRTVKDIFKF
VVLSFSRIAYILPANESFGPLQISLGRTVKDIFKF


HBox(children=(IntProgress(value=0, max=336), HTML(value='')))


YSFDLMLSITSIFHLCSVAIDRFYAICIYYLLSTK
YSFDLMLSITSIFHLCSVAIDRFYAICYPLLYSTK


HBox(children=(IntProgress(value=0, max=342), HTML(value='')))


HSVEEELTSVIGINKKIPPFISKGEIMNEWCFFTCLVSFS
HSVMEELTSVIGINMKIPPFISKGEIMNEWFHFTCLVSFS


HBox(children=(IntProgress(value=0, max=342), HTML(value='')))


RYLVTSLILVVTMAILCCSMQDCVRSKPPLWLLGLVTISL
RYLVTSLILVVTMAILCCSMQDCVRSKPWLGLLGLVTISL


HBox(children=(IntProgress(value=0, max=342), HTML(value='')))


RRGRPGLKGQEGPPGAPGIRTGIQGLKGDQGEPGPSGNPG
RRGRPGLKGEQGEPGAPGIRTGIQGLKGDQGEPGPSGNPG


HBox(children=(IntProgress(value=0, max=342), HTML(value='')))


PTLAPAASVAAAFQFQLLVMQPCGAADEAAAPGSGVGAGK
PTLAPASVAAAASQFTLLVMQPCAGQDEAAAPGGSVGAGK


HBox(children=(IntProgress(value=0, max=342), HTML(value='')))


PFIVLCHPDTIRSITAASAAIAPKDNLFFRFLKPWLGEGI
PFIVLCHPDTIRSITNASAAIAPKDNLFIRFLKPWLGEGI


HBox(children=(IntProgress(value=0, max=348), HTML(value='')))


EKRLEEEQLLAEEEDDDLKETTDLRKIAAQLLQQEQKNRILNHST
EKRLEKEQLLAEEEDDDLKEVTDLRKIAAQLLQQEQKNRILNHST


HBox(children=(IntProgress(value=0, max=348), HTML(value='')))


GVLENWIWQMVAALQSKAPQPVNVVLVDNITLAHHHYTIAVRNTR
GVLENWIWQMVAALKSQPAQPVNVGLVDWITLAHDHYTIAVRNTR


HBox(children=(IntProgress(value=0, max=348), HTML(value='')))


ISRENGGSSSILYYRPFEKLRMSDDGGIRNLYLDFGGPEGEDTMD
ISRENGGSSSILYRYPFERLKMSADDGIRNLYLDFGGPEGELTMD


HBox(children=(IntProgress(value=0, max=348), HTML(value='')))


ERNEILTEEQNFSQDVTLNSLVSEAFVRFFVELGGHYSLMMTVTE
ERNEILTQEQNFSQDVTLNSLVSEAFVRFFVELVGHYSLNMTVTE


HBox(children=(IntProgress(value=0, max=348), HTML(value='')))


ASLSGEREFKTPTISLKETIGKYSDDHEMRENVYHRKIISWFGDS
ASLSGEREFKTPTISLKETIGKYSDDHEMRNEVYHRKIISWFGDS


HBox(children=(IntProgress(value=0, max=348), HTML(value='')))


PHAAAAAAAAAAAAVEASSPWSGSAVGMAGSPQQPPQPPPPPPQG
PHAAAAAAAAAAAAVEASSPWSGSAVGMAGSPQQPPQPPPPPPQG


HBox(children=(IntProgress(value=0, max=354), HTML(value='')))


AEAVRPKTPPVVIKSQQTKKEDEEEISTSPGVSSFVDSAFCADDLDQEDL
AEAVRPKTPPVVIKSQLKTQEDEEEISTSPGVSEFVSDAFDACNLNQEDL


HBox(children=(IntProgress(value=0, max=354), HTML(value='')))


SCGSNCGQSSSCAPVYCRRTCYYPTTVCLPGCLNQSGGSNCCQCPCCRPC
SCGSSCGQSSSCAPVYCRRTCYYPTTVCLPGCLNQSCGSNCCQPCCRPAC


HBox(children=(IntProgress(value=0, max=354), HTML(value='')))


ENRHGGGLTGLNKAETAAKHGEAQVKIWRRSYDVPPPPMEPDHPFYSNIS
NERHYGGLTGLNKAETAAKHGEAQVKIWRRSYDVPPPPMEPDHPFYSNIS


HBox(children=(IntProgress(value=0, max=354), HTML(value='')))


QPSNLLPQRGLGAPLPAETAHTPQSPNDRSLYLSPKSSSASSSLHARQSP
QPSNLLPQRGLGAPLPAETAHTQPSPNDRSLYLSPKSSSASSSLHARQSP


HBox(children=(IntProgress(value=0, max=354), HTML(value='')))


AELEVRVAAVVDTHLEEAGGGPEPTRNGVDPPPRRAAASVPPGSTRLLLP
AELVERVAAIDVTHLEEADGGPEPTRNGVDPPPRARAASVIPGSTSRLLP


HBox(children=(IntProgress(value=0, max=354), HTML(value='')))


GCGRLLRGLLAAPAATSWSRLPARGFREVVETEQGKTTIIEGRITATPKE
GCGRLLRGLLAGPAATSWSRLPARGFREVVETQEGKTTIIEGRITATPKE


HBox(children=(IntProgress(value=0, max=354), HTML(value='')))


ETANLDGETNIKIRQGLSHTAMDQTRDVLMKLSGTIECEGPNRHLYDFTG
ETANLDGETNLKIRQGLSHTADMQTREVLMKLSGTIECEGPNRHLYDFTG


HBox(children=(IntProgress(value=0, max=354), HTML(value='')))


VCVCQHNTAGPNCERCAPFYNNRPWRPAAGQDAHECQRCDCGGHSETCHF
VCVCQHNTAGPNCERCAPFYNNRPWRPAEGQDAHECQRCDCNGHSETCHF


HBox(children=(IntProgress(value=0, max=360), HTML(value='')))

In [None]:
import matplotlib.pyplot as plt


plt.plot(losses)
plt.show()

In [12]:
dec.train(False)
enc.train(False)


x, x_inds = ds[epoch_inds[0]]
        
x = (x - 1) * log_number

hidden = enc.initHidden(batch_size).cuda(GPU_id)

#     pbar.set_description(str(x.shape[0]))

enc_out, hidden = enc(x, hidden)

#input the stop character to the stream    
out = (Variable(utils.stopChar(batch_size)).cuda(GPU_id)-1) * log_number

loss = Variable(torch.zeros(1,1,1).cuda(GPU_id))

out_list = list()
#     pdb.set_trace()
for i in range(x.shape[0]):

    out, hidden = dec(out, hidden, enc_out)    

    out_list += [out]

    loss += criterion(torch.squeeze(out,0), x_inds[i])

x_hat = utils.tensorToChar(torch.cat(out_list, 0))
x = utils.tensorToChar(x)

print(''.join(x_hat[:,2]))
print(''.join(x[:,2]))

enc.train(True)
dec.train(True)

UnboundLocalError: local variable 'embedded' referenced before assignment

In [25]:
np.log(np.exp(3.219125824868201)/28)

-0.11307868530700314

In [71]:
enc.train(False)
dec.train(False)

x_tmp, _ = ds[[np.random.randint(len(ds))]]

# x = torch.unsqueeze(x[:,0,:],1)
batch_size_tmp = x_tmp.shape[1]

hidden = enc.initHidden(batch_size_tmp).cuda(GPU_id)
out, hidden = enc(x_tmp, hidden)

#input the stop character to the stream    
out = Variable(stopChar(batch_size_tmp)).cuda(GPU_id)


#     pdb.set_trace()
out_chars = list()

for i in range(x_tmp.shape[0]):

    out, hidden = dec(out, hidden) 
    
    out_chars += [tensorToChar(out)[0,0]]
    
enc.train(True)
dec.train(True)

print(''.join(np.hstack(tensorToChar(x_tmp))))
print(''.join(out_chars))

MSLMVVSMACVGFFLLEGPW
MMHCTPCLLLMMMMMMMMMM


In [51]:
3E-4*np.log(15)

0.0008124150603306629

In [41]:
6.77e-05/np.log(15)

2.499953655676149e-05