# CS 287, Homework 3: Neural Machine Translation

In [1]:
from common import *

import torch
from torch.nn.utils import clip_grad_norm_
import torch.nn as nn
import torch.nn.functional as F
from torchtext import data, datasets
from namedtensor import ntorch, NamedTensor
from namedtensor.text import NamedField
import numpy as np
import re

%reload_ext autoreload
%autoreload 2

In [2]:
# split raw data into tokens
import spacy
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

# add beginning-of-sentence and end-of-sentence tokens to target
BOS_WORD = '<s>'
EOS_WORD = '</s>'
DE = NamedField(names=('srcSeqlen',), tokenize=tokenize_de)
EN = NamedField(names=('trgSeqlen',), tokenize=tokenize_en,
                init_token = BOS_WORD, eos_token = EOS_WORD) # only target needs BOS/EOS

# download dataset of 200K pairs of sentences
# start with MAXLEN = 20
MAX_LEN = 20
train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(DE, EN), 
                                         filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and 
                                         len(vars(x)['trg']) <= MAX_LEN)
#print(train.fields)
#print(len(train))
#print(vars(train[0]))

# build vocab, convert words to indices
MIN_FREQ = 5
DE.build_vocab(train.src, min_freq=MIN_FREQ)
EN.build_vocab(train.trg, min_freq=MIN_FREQ)
#print(DE.vocab.freqs.most_common(10))
#print("Size of German vocab", len(DE.vocab))
#print(EN.vocab.freqs.most_common(10))
#print("Size of English vocab", len(EN.vocab))
print(EN.vocab.stoi["<s>"], EN.vocab.stoi["</s>"])
print(EN.vocab.stoi["<pad>"], EN.vocab.stoi["<unk>"])
print(DE.vocab.stoi["<pad>"], DE.vocab.stoi["<unk>"])

2 3
1 0
1 0


In [3]:
# split data into batches
BATCH_SIZE = 32
device = torch.device('cuda:0')
train_iter, val_iter = data.BucketIterator.splits((train, val), batch_size=BATCH_SIZE, device=device,
                                                  repeat=False, sort_key=lambda x: len(x.src))

## Sequence to Sequence Learning with Neural Networks

In [4]:
context_size = 500
num_layers = 2

BEAM_WIDTH = 5
max_len = 20

attn_seq2context = SequenceModel(len(DE.vocab), context_size, num_layers=num_layers)
state_dict = torch.load('best_seq2seq_seq2context.pt')
attn_seq2context.load_state_dict(state_dict)
attn_seq2context = attn_seq2context.cuda()

attn_context2trg = RNNet(input_size=len(EN.vocab), hidden_size=context_size, num_layers=num_layers)
state_dict = torch.load('best_seq2seq_context2trg.pt')
attn_context2trg.load_state_dict(state_dict)
attn_context2trg = attn_context2trg.cuda()

In [5]:
# load test set
sentences = []
for i, l in enumerate(open("source_test.txt")):
    sentences.append(re.split(' ', l))

max_sent_len = 0
for i in range(len(sentences)):
    sentences[i][-1] = '.'
    if len(sentences[i]) > max_sent_len:
        max_sent_len = len(sentences[i])

batch = torch.tensor([], device='cuda')
for b in range(len(sentences)):
    m = nn.ConstantPad1d((0, max_sent_len - len(sentences[b])), EN.vocab.stoi['<pad>'])
    src = m(torch.tensor([DE.vocab.stoi[i] for i in sentences[b]], device='cuda').unsqueeze(0)).float()
    batch = torch.cat((batch,src), dim=0)
batch_rev = reverse_sequence(batch).long()

In [6]:
BATCH_SIZE=32
batch_rev_data = torch.utils.data.TensorDataset(batch_rev)
batch_rev_data_loader = torch.utils.data.DataLoader(batch_rev_data, batch_size=BATCH_SIZE, shuffle=False)

In [31]:
def beam_search_noattn(src, attn_seq2context, attn_context2trg, BEAM_WIDTH = 2, BATCH_SIZE=32, max_len=3,context_size=500,EN=None):
    top_p = {}
    top_s = {}
    items = []
    for i in range(BATCH_SIZE):
        top_p[i] = []
        top_s[i] = []
        items.append(i)

    encoder_outputs, encoder_hidden = attn_seq2context(src)
    #decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
    decoder_hidden = encoder_hidden
    #word_input = (torch.zeros(BATCH_SIZE, device='cuda') + EN.vocab.stoi['<s>']).long()
    #decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, decoder_context, decoder_hidden, encoder_outputs)
    word_input = (torch.zeros((BATCH_SIZE,1), device='cuda') + EN.vocab.stoi['<s>']).long()
    decoder_output, decoder_hidden = attn_context2trg(word_input, decoder_hidden)
    decoder_output = decoder_output[:,-1,:]
    
    next_words = torch.argsort(lsm2(decoder_output), dim=1, descending=True)[:,0:BEAM_WIDTH].detach()
    p_words_init = torch.stack([torch.index_select(decoder_output[i,:],-1,next_words[i,:]) for i in range(BATCH_SIZE)]).detach()
    p_words_running = torch.stack([p_words_init[:,b].repeat(1,BEAM_WIDTH) for b in range(BEAM_WIDTH)]).view(BEAM_WIDTH**2,BATCH_SIZE).transpose(0,1)

    update = []
    for ix,p in enumerate(next_words):
        update.append([torch.stack(([torch.tensor(EN.vocab.stoi['<s>'], device='cuda')])+([next_words[ix,b]])) for b in range(BEAM_WIDTH)])

    next_words = torch.argsort(lsm2(decoder_output),dim=1, descending=True)[:,0:BEAM_WIDTH].detach()
    p_words_init = torch.stack([torch.index_select(decoder_output[i,:],-1,next_words[i,:]) for i in range(BATCH_SIZE)]).detach()
    p_words_running = torch.stack([p_words_init[:,b].repeat(1,BEAM_WIDTH) for b in range(BEAM_WIDTH)]).view(BEAM_WIDTH**2,BATCH_SIZE).transpose(0,1)
    update = []
    for ix,p in enumerate(next_words):
        update.append([torch.stack(([torch.tensor(EN.vocab.stoi['<s>'], device='cuda')])+([next_words[ix,b]])) for b in range(BEAM_WIDTH)])
    top_s.update(dict(zip(items, update)))
    top_p.update(dict(zip(items, p_words_init)))

    next_words = next_words.transpose(0,1).flatten().long()
    next_words = next_words.view(-1,1)
    #encoder_outputs = encoder_outputs.repeat(BEAM_WIDTH,1,1)
    #decoder_context = decoder_context.repeat(BEAM_WIDTH,1)
    decoder_hidden = tuple([h.repeat(1,BEAM_WIDTH,1) for h in decoder_hidden])
    mask = torch.zeros(BATCH_SIZE,BEAM_WIDTH,dtype=torch.uint8).cuda()
    
    #print(next_words.shape)
    for j in range(max_len-1):
            #decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(next_words, decoder_context, decoder_hidden, encoder_outputs)
            decoder_output, decoder_hidden = attn_context2trg(next_words, decoder_hidden)
            decoder_output = decoder_output[:,-1,:]
            
            args = torch.argsort(lsm2(decoder_output),dim=1, descending=True)[:,0:BEAM_WIDTH].detach()
            next_words = torch.cat([args[BATCH_SIZE*(b):BATCH_SIZE*(b+1),:] for b in range(BEAM_WIDTH)],dim=1).detach()
            p_words = torch.stack([torch.index_select(decoder_output[i,:],-1,next_words[i,:]) for i in range(BATCH_SIZE)])
            p_words_running += p_words.detach()
            p_words_norm = p_words_running/(j+1)

            word_selector = torch.argsort(p_words_norm,dim=1,descending=True)[:,:BEAM_WIDTH]        
            beam_indicator = (word_selector/BEAM_WIDTH).float().long() #word_selector>=2
            
            prev_words = list(top_s.values())
            words = [torch.stack([prev_words[s][i] for i in beam_indicator[s,:]]) for s in range(BATCH_SIZE)]
            update = []
            for ix,p in enumerate(words):
                update.append([torch.cat((p[b],next_words[ix,b].unsqueeze(0))) for b in range(BEAM_WIDTH)])
            top_s.update(dict(zip(items, update)))
            mask += torch.stack([next_words[s,:].index_select(0,word_selector[s,:]) for s in range(BATCH_SIZE)]) == 3
            update_p = torch.stack([torch.index_select(p_words_running[b], 0, word_selector[b]) for b in range(BATCH_SIZE)])
            top_p.update(dict(zip(items, update_p)))
            p_words_running = torch.stack([update_p[:,b].repeat(1,BEAM_WIDTH) for b in range(BEAM_WIDTH)]).view(BEAM_WIDTH**2,BATCH_SIZE).transpose(0,1)

            indexs = torch.zeros(BATCH_SIZE,BEAM_WIDTH,device='cuda')
            for i in range(BATCH_SIZE):
                indexs[i,:] += i+(BATCH_SIZE*beam_indicator[i,:].float())
            indexs = indexs.long()
            indexs = indexs.transpose(0,1).flatten()
            #print(indexs.shape)
            decoder_hidden = tuple([torch.index_select(h,1,indexs) for h in decoder_hidden])
            #decoder_context = torch.index_select(decoder_context,0,indexs)
            next_words = torch.stack([torch.index_select(next_words[s,:],0,word_selector[s,:]) for s in range(BATCH_SIZE)]).transpose(0,1).flatten().long()
            next_words = next_words.view(-1,1)
    return top_s

In [40]:
with open("pred_bleu_update.txt", "w+") as f:
    for b in batch_rev_data_loader:  
        src = b[0].long()
        top_s = beam_search_noattn(src, attn_seq2context, attn_context2trg, BEAM_WIDTH=BEAM_WIDTH, BATCH_SIZE=BATCH_SIZE, max_len=max_len, context_size=context_size, EN=EN)
        for i in range(BATCH_SIZE):
            stop_idx = torch.tensor([max_len], device='cuda') if (top_s[i][0].long() == 3).nonzero().size()[0]==0 else min((top_s[i][0].long() == 3).nonzero())
            f.write("%s\n"%(" ".join([EN.vocab.itos[j] for j in top_s[i][0][1:stop_idx].long()])))

In [34]:
top_s[0][0]

tensor([   2,  247,   10,   23, 3277,   17,   50,    0,    5,   10,   23,   50,
         132,    0,    4,    3,    4,    3,   12,   46,  319], device='cuda:0')

- English to French translation, $p \left( y_1, \dots, y_{T'} \ | \ x_1, \dots, x_T \right) = \prod_{t = 1}^{T'} p \left( y_t \ | \ v, y_1, \dots, y_{t-1} \right)$
- Each sentence ends in '<EOS\>', out-of-vocab words denoted '<UNK\>'
- Model specs: 
    * Input vocabulary of 160,000 and output vocabulary of 80,000
    * Deep LSTM to map (encode) input sequence to fixed-len vector
    * Another deep LSTM to translate (decode) fixed-len vector to output sequence
    * 4 layers per LSTM, 1000 cells per layer, 1000-dimensional word embeddings, softmax over 80,000 words
    * Reversing order of words in source (but not target) improved performance
        * Each word in the source is far from its corresponding word in the target (large minimal time lag); reversing the source reduces the minimal time lag, thereby allowing backprop to establish communication between source and target more easily
- Training specs:
    * Initialize all LSTM params $\sim Unif[-0.08,0.08]$
    * SGD w/o momentum, lr = 0.7
        * After 5 epochs, halve the lr every half-epoch
        * Train for 7.5 epochs
    * Batch size = 128; divide gradient by batch size (denoted $g$)
    * Hard constraint gradient norm; if $s = ||g||_2 > 5$, set $s = 5$
    * Make sure all sentences within a minibatch are roughly the same length
- Objective: $max \frac{1}{|S|} \sum_{(T,S) \in \mathcal{S}} log \ p(T \ | \ S)$, where $\mathcal{S}$ is the training set
- Prediction: $\hat{T} = argmax \ p(T \ | \ S)$ via beam search, where beam size $B \in {1,2}$

In [30]:
'''class SequenceModel(nn.Module):
    def __init__(self, src_vocab_size, context_size, num_layers, weight_init = 0.08):
        super(SequenceModel, self).__init__()
        # embedding
        self.embedding = nn.Embedding(src_vocab_size, context_size)
        # language summarization
        self.lstm = nn.LSTM(input_size=context_size, hidden_size=context_size, num_layers=num_layers, batch_first=True)
        for p in self.lstm.parameters():
            torch.nn.init.uniform_(p, a=weight_init, b=weight_init)

    def forward(self, inputs, h0=None):
        # embed the words 
        embedded = self.embedding(inputs)
        # summarize context
        context, hidden = self.lstm(embedded,h0)
        return context, hidden
    
class LanguageModel(nn.Module):
    def __init__(self, target_vocab_size, hidden_size, context_size, num_layers, weight_init = 0.08):
        super(LanguageModel, self).__init__()
        # context is batch_size x seq_len x context_size
        # context to hidden
        self.embedding = nn.Embedding(target_vocab_size, hidden_size)
        # hidden to hidden 
        self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        # decode hidden state for y_t
        for p in self.lstm.parameters():
            torch.nn.init.uniform_(p, a=weight_init, b=weight_init)
            
        self.translate = nn.Linear(hidden_size, target_vocab_size)

    def forward(self, inputs, h0=None):
        # embed the trg words
        embedded = self.embedding(inputs)
        # setting hidden state to context at t=0
        # otherwise context = prev hidden state
        output, hidden = self.lstm(embedded, h0)
        output = self.translate(output)
        return output,hidden'''

'class SequenceModel(nn.Module):\n    def __init__(self, src_vocab_size, context_size, num_layers, weight_init = 0.08):\n        super(SequenceModel, self).__init__()\n        # embedding\n        self.embedding = nn.Embedding(src_vocab_size, context_size)\n        # language summarization\n        self.lstm = nn.LSTM(input_size=context_size, hidden_size=context_size, num_layers=num_layers, batch_first=True)\n        for p in self.lstm.parameters():\n            torch.nn.init.uniform_(p, a=weight_init, b=weight_init)\n\n    def forward(self, inputs, h0=None):\n        # embed the words \n        embedded = self.embedding(inputs)\n        # summarize context\n        context, hidden = self.lstm(embedded,h0)\n        return context, hidden\n    \nclass LanguageModel(nn.Module):\n    def __init__(self, target_vocab_size, hidden_size, context_size, num_layers, weight_init = 0.08):\n        super(LanguageModel, self).__init__()\n        # context is batch_size x seq_len x context_size\n    

In [31]:
'''def repackage_hidden(h):
    return tuple(v.detach() for v in h)
def repackage_layer(hidden_s2c,hidden=100):
    return tuple([hidden_s2c[0][-1].detach().view(1,BATCH_SIZE,hidden),hidden_s2c[1][-1].detach().view(1,BATCH_SIZE,hidden)])
def reverse_sequence(src):
    length = list(src.shape)[1]
    idx = torch.linspace(length-1, 0, steps=length).long()
    rev_src = src[:,idx]
    return rev_src
'''

'def repackage_hidden(h):\n    return tuple(v.detach() for v in h)\ndef repackage_layer(hidden_s2c,hidden=100):\n    return tuple([hidden_s2c[0][-1].detach().view(1,BATCH_SIZE,hidden),hidden_s2c[1][-1].detach().view(1,BATCH_SIZE,hidden)])\ndef reverse_sequence(src):\n    length = list(src.shape)[1]\n    idx = torch.linspace(length-1, 0, steps=length).long()\n    rev_src = src[:,idx]\n    return rev_src'

In [32]:
'''context_size = 500
num_layers = 1
seq2context = SequenceModel(len(DE.vocab),context_size,num_layers)
context2trg = LanguageModel(len(EN.vocab),hidden_size=context_size,context_size=context_size,num_layers=num_layers)
seq2context,context2trg = seq2context.cuda(),context2trg.cuda()
seq2context_optimizer = torch.optim.Adam(seq2context.parameters(), lr=1e-5)
context2trg_optimizer = torch.optim.Adam(context2trg.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss(reduction='none')'''

"context_size = 500\nnum_layers = 1\nseq2context = SequenceModel(len(DE.vocab),context_size,num_layers)\ncontext2trg = LanguageModel(len(EN.vocab),hidden_size=context_size,context_size=context_size,num_layers=num_layers)\nseq2context,context2trg = seq2context.cuda(),context2trg.cuda()\nseq2context_optimizer = torch.optim.Adam(seq2context.parameters(), lr=1e-5)\ncontext2trg_optimizer = torch.optim.Adam(context2trg.parameters(), lr=1e-5)\ncriterion = nn.CrossEntropyLoss(reduction='none')"

In [33]:
'''def training_loop(e=0):
    seq2context.train()
    context2trg.train()
    h0 = None
    for ix,batch in enumerate(train_iter):
        seq2context_optimizer.zero_grad()
        context2trg_optimizer.zero_grad()
        
        src = batch.src.values.transpose(0,1)
        src = reverse_sequence(src)
        trg = batch.trg.values.transpose(0,1)
        if src.shape[0]!=BATCH_SIZE:
            break
        else:
            # generate hidden state for decoder
            context, hidden_s2c = seq2context(src,h0)
            hidden = repackage_layer(hidden_s2c,context_size)
            output, hidden_lm = context2trg(trg[:,:-1],hidden)
            loss = criterion(output.transpose(2,1),trg[:,1:])
            mask = trg[:,1:]!=1
            loss = loss[mask].sum()
            #clip_grad_norm_(seq2context.parameters(), max_norm=5)
            #clip_grad_norm_(context2trg.parameters(), max_norm=5)
            loss.backward()
            seq2context_optimizer.step()
            context2trg_optimizer.step()
        if np.mod(ix,100) == 0:
            var = torch.var(torch.argmax(lsm(output).cpu().detach(),2).float())
            print('Epoch: {}, Batch: {}, loss: {}, var: {},'.format(e, ix, loss.cpu().detach()/BATCH_SIZE, var))
    loss = 0
    for b in iter(val_iter):
        src = b.src.values.transpose(0,1)
        src = reverse_sequence(src)
        trg = b.trg.values.transpose(0,1)
        if src.shape[0]!=BATCH_SIZE:
            break
        else:
            # generate hidden state for decoder
            context, hidden_s2c = seq2context(src,h0)
            hidden = repackage_layer(hidden_s2c,context_size)
            output, hidden_lm = context2trg(trg[:,:-1],hidden)
            bloss = criterion(output.transpose(2,1),trg[:,1:])
            mask = trg[:,1:]!=1
            loss += bloss[mask].sum()
    print('Epoch: {}, loss: {}, var: {},'.format(e, loss.cpu().detach()/(BATCH_SIZE*len(val_iter))))'''

"def training_loop(e=0):\n    seq2context.train()\n    context2trg.train()\n    h0 = None\n    for ix,batch in enumerate(train_iter):\n        seq2context_optimizer.zero_grad()\n        context2trg_optimizer.zero_grad()\n        \n        src = batch.src.values.transpose(0,1)\n        src = reverse_sequence(src)\n        trg = batch.trg.values.transpose(0,1)\n        if src.shape[0]!=BATCH_SIZE:\n            break\n        else:\n            # generate hidden state for decoder\n            context, hidden_s2c = seq2context(src,h0)\n            hidden = repackage_layer(hidden_s2c,context_size)\n            output, hidden_lm = context2trg(trg[:,:-1],hidden)\n            loss = criterion(output.transpose(2,1),trg[:,1:])\n            mask = trg[:,1:]!=1\n            loss = loss[mask].sum()\n            #clip_grad_norm_(seq2context.parameters(), max_norm=5)\n            #clip_grad_norm_(context2trg.parameters(), max_norm=5)\n            loss.backward()\n            seq2context_optimizer.ste

In [34]:
'''for e in range(2):
    training_loop(e)
    #training_loop(e,train_iter,seq2context,context2trg,seq2context_optimizer,context2trg_optimizer,BATCH_SIZE)
    #validation_loop(e,val_iter,seq2context,context2trg,BATCH_SIZE)'''

'for e in range(2):\n    training_loop(e)\n    #training_loop(e,train_iter,seq2context,context2trg,seq2context_optimizer,context2trg_optimizer,BATCH_SIZE)\n    #validation_loop(e,val_iter,seq2context,context2trg,BATCH_SIZE)'

In [35]:
'''for ix,batch in enumerate(train_iter):
    src = batch.src.values.transpose(0,1)
    trg = batch.trg.values.transpose(0,1)
    break

h0 = None
context, hidden_s2c = seq2context(reverse_sequence(src),h0)
hidden = repackage_layer(hidden_s2c,context_size)
output, hidden_lm = context2trg(trg[:,:-1],hidden)

[EN.vocab.itos[i] for i in torch.argmax(lsm(output),2)[30,:]]'''

'for ix,batch in enumerate(train_iter):\n    src = batch.src.values.transpose(0,1)\n    trg = batch.trg.values.transpose(0,1)\n    break\n\nh0 = None\ncontext, hidden_s2c = seq2context(reverse_sequence(src),h0)\nhidden = repackage_layer(hidden_s2c,context_size)\noutput, hidden_lm = context2trg(trg[:,:-1],hidden)\n\n[EN.vocab.itos[i] for i in torch.argmax(lsm(output),2)[30,:]]'

In [36]:
'''# define attention-based encoder-decoder model
class attn_RNNet_batched(torch.nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, dropout=0.5, weight_init=0.05):
        super(attn_RNNet_batched, self).__init__()
        self.emb = torch.nn.Sequential(torch.nn.Embedding(input_size, hidden_size), torch.nn.Dropout(dropout))
        self.rnn = torch.nn.LSTM(input_size=2*hidden_size, hidden_size=hidden_size, num_layers=num_layers, bias=True, batch_first=True, dropout=dropout)
        self.lnr = torch.nn.Sequential(torch.nn.Dropout(dropout), torch.nn.Linear(2*hidden_size, input_size))
    
        for f in self.parameters():
            torch.nn.init.uniform_(f, a=-weight_init, b=weight_init)

    def forward(self, word_input, last_context, last_hidden, encoder_outputs):
        word_embedded = self.emb(word_input)
        rnn_input = torch.cat([word_embedded, last_context], 1).unsqueeze(1) # batch x 1 x hiddenx2
        rnn_output, hidden = self.rnn(rnn_input, last_hidden)
        attn_weights = rnn_output.bmm(encoder_outputs.transpose(1,2))# batch x src_seqlen x 1
        context = attn_weights.bmm(encoder_outputs)
        rnn_output = rnn_output.squeeze(1)
        context = context.squeeze(1)
        output = self.lnr(torch.cat((rnn_output, context), 1))
        # prediction, last_context, last_hidden, weights for vis
        return output, context, hidden, attn_weights '''

'# define attention-based encoder-decoder model\nclass attn_RNNet_batched(torch.nn.Module):\n\n    def __init__(self, input_size, hidden_size, num_layers, dropout=0.5, weight_init=0.05):\n        super(attn_RNNet_batched, self).__init__()\n        self.emb = torch.nn.Sequential(torch.nn.Embedding(input_size, hidden_size), torch.nn.Dropout(dropout))\n        self.rnn = torch.nn.LSTM(input_size=2*hidden_size, hidden_size=hidden_size, num_layers=num_layers, bias=True, batch_first=True, dropout=dropout)\n        self.lnr = torch.nn.Sequential(torch.nn.Dropout(dropout), torch.nn.Linear(2*hidden_size, input_size))\n    \n        for f in self.parameters():\n            torch.nn.init.uniform_(f, a=-weight_init, b=weight_init)\n\n    def forward(self, word_input, last_context, last_hidden, encoder_outputs):\n        word_embedded = self.emb(word_input)\n        rnn_input = torch.cat([word_embedded, last_context], 1).unsqueeze(1) # batch x 1 x hiddenx2\n        rnn_output, hidden = self.rnn(rnn

In [37]:
'''# initialize model
context_size = 500
num_layers = 2
attn_context2trg = attn_RNNet_batched(input_size=len(EN.vocab),hidden_size=context_size,num_layers=num_layers)
attn_context2trg = attn_context2trg.cuda()
seq2context = SequenceModel(len(DE.vocab),context_size,num_layers=num_layers)
seq2context = seq2context.cuda()'''

'# initialize model\ncontext_size = 500\nnum_layers = 2\nattn_context2trg = attn_RNNet_batched(input_size=len(EN.vocab),hidden_size=context_size,num_layers=num_layers)\nattn_context2trg = attn_context2trg.cuda()\nseq2context = SequenceModel(len(DE.vocab),context_size,num_layers=num_layers)\nseq2context = seq2context.cuda()'

In [38]:
'''# prep for training
attn_context2trg_optimizer = torch.optim.Adam(attn_context2trg.parameters(), lr=1e-3)
seq2context_optimizer = torch.optim.Adam(seq2context.parameters(), lr=1e-3)
criterion_train = nn.CrossEntropyLoss(reduction='sum')'''

"# prep for training\nattn_context2trg_optimizer = torch.optim.Adam(attn_context2trg.parameters(), lr=1e-3)\nseq2context_optimizer = torch.optim.Adam(seq2context.parameters(), lr=1e-3)\ncriterion_train = nn.CrossEntropyLoss(reduction='sum')"

In [39]:
'''def attn_training_loop(e=0):
    for ix,batch in enumerate(train_iter):
        src = batch.src.values.transpose(0,1)
        src = reverse_sequence(src)
        trg = batch.trg.values.transpose(0,1)
        if trg.shape[0] == BATCH_SIZE:
        
            seq2context_optimizer.zero_grad()
            attn_context2trg_optimizer.zero_grad()
        
            encoder_outputs, encoder_hidden = seq2context(src)
            loss = 0
            decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
            decoder_hidden = encoder_hidden
            sentence = []
            for j in range(trg.shape[1] - 1):
                word_input = trg[:,j]
                decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, decoder_context, decoder_hidden, encoder_outputs)
                #print(decoder_output.shape, trg[i,j+1].view(-1).shape)
                loss += criterion_train(decoder_output, trg[:,j+1])
                
                if np.mod(ix,100) == 0:
                    sentence.extend([torch.argmax(decoder_output[0,:],dim=0)])
                
            loss.backward()
            seq2context_optimizer.step()
            attn_context2trg_optimizer.step()
        
            if np.mod(ix,100) == 0:
                print('Epoch: {}, Batch: {}, Loss: {}'.format(e, ix, loss.cpu().detach()/BATCH_SIZE))
                print([EN.vocab.itos[i] for i in sentence])
                print([EN.vocab.itos[i] for i in trg[0,:]])'''

"def attn_training_loop(e=0):\n    for ix,batch in enumerate(train_iter):\n        src = batch.src.values.transpose(0,1)\n        src = reverse_sequence(src)\n        trg = batch.trg.values.transpose(0,1)\n        if trg.shape[0] == BATCH_SIZE:\n        \n            seq2context_optimizer.zero_grad()\n            attn_context2trg_optimizer.zero_grad()\n        \n            encoder_outputs, encoder_hidden = seq2context(src)\n            loss = 0\n            decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500\n            decoder_hidden = encoder_hidden\n            sentence = []\n            for j in range(trg.shape[1] - 1):\n                word_input = trg[:,j]\n                decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, decoder_context, decoder_hidden, encoder_outputs)\n                #print(decoder_output.shape, trg[i,j+1].view(-1).shape)\n                loss += criterion_train(decoder_output, tr

In [40]:
'''for e in range(10):
    attn_training_loop(e)'''

'for e in range(10):\n    attn_training_loop(e)'

## Beam Search for common.py

In [50]:
'''def beamsearch(seq2context, context2trg, context_size, src, beam_width, max_len, output_width=1, alpha=1, padding=False):
    '''
    run beam search and return top predictions
        - seq2context: encoder model
        - context2trg: decoder model
        - context_size: hidden size
        - src: tensor of source sentences
        - beam_width: beam search width
        - max_len: maximum length for predictions
        - output_width: number of predictions to return per sentence <= beam_width
        - alpha: string length discount rate; e.g., normalizing factor = 1/(T^alpha)
        - padding: pad predictions to max_len
    '''
    # set up
    START_TKN = EN.vocab.stoi["<s>"]
    END_TKN = EN.vocab.stoi["</s>"]
    BEAM_WIDTH = beam_width
    lsm = nn.LogSoftmax(dim=1)
    
    # run forward pass of encoder once
    encoder_outputs, encoder_hidden = seq2context(src)
    decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
    decoder_hidden = encoder_hidden
    
    # prepare for beam search
    b_string = torch.zeros((BATCH_SIZE, max_len, BEAM_WIDTH), device='cuda') # stores the top BEAM_WIDTH strings
    b_string[:,0,:] = START_TKN
    b_probs = {} # stores the top BEAM_WIDTH probs
    '''
    b_probs key = tuple(batch idx, beam idx)
    b_probs val = [cum log prob, length]
    '''
    done = {} # stores the finished strings
    '''
    done key = batch idx
    done val = [str, cum log prob, length]
    '''
    predictions = {} # stores the top output_width predictions
    for b in range(BATCH_SIZE):
        done[b] = []
        predictions[b] = []
        for c in range(BEAM_WIDTH):
            b_probs[(b, c)] = [0, 1]

    # loop through target sequence max len
    for i in range(1,max_len):
        if i == 1: # if predicting the word following <s>, take top BEAM_WIDTH preds
            word_input = b_string[:,i-1,0].long()
            decoder_output, decoder_context, decoder_hidden, decoder_attention = context2trg(word_input, 
                                                                                             decoder_context, 
                                                                                             decoder_hidden, 
                                                                                             encoder_outputs)
            logprobs = lsm(decoder_output.detach()) # BATCH_SIZE x VOCAB_SIZE
            toppreds = torch.argsort(logprobs, dim=1, descending=True)[:,0:BEAM_WIDTH] # BATCH_SIZE x BEAM_WIDTH
            b_string[:,i,:] = toppreds
            for b in range(BATCH_SIZE):
                for c in range(BEAM_WIDTH):
                    b_probs[tuple((b,c))][0] += logprobs[b, toppreds[b,c]]
                    b_probs[tuple((b,c))][1] += 1
        else: # if predicting the word for positions 2+, compare top BEAM_WIDTH preds for each of BEAM_WIDTH strings
            curr_probs = {} # temporary storage
            curr_string = torch.zeros(BATCH_SIZE, i+1, BEAM_WIDTH) # temporary storage

            for j in range(BEAM_WIDTH):
                word_input = b_string[:,i-1,j].long()
                decoder_output, decoder_context, decoder_hidden, decoder_attention = context2trg(word_input, 
                                                                                                 decoder_context, 
                                                                                                 decoder_hidden, 
                                                                                                 encoder_outputs)
                logprobs = lsm(decoder_output.detach()) # unsorted log probs
                sortedpreds = torch.argsort(logprobs, dim=1, descending=True) # sorted words
                toppreds = sortedpreds[:,0:BEAM_WIDTH] # top words

                # check if any top preds are </s>
                for b in range(BATCH_SIZE):
                    if END_TKN in toppreds[b,:]: # if </s> in top preds
                        # track finished strings
                        done_string = torch.cat((b_string[b,0:i,j],torch.tensor([END_TKN], device='cuda').float()))
                        done_prob = b_probs[tuple((b,j))][0] + logprobs[b,END_TKN]
                        done[b].append([done_string, done_prob, done_string.shape[0]])
                        # replace </s> with 4th best pred
                        done_idx = (toppreds[b,:] == END_TKN).nonzero()
                        toppreds[b,done_idx] = sortedpreds[b,BEAM_WIDTH]

                if j == 0: # if preds are from first beam, take top BEAM_WIDTH preds (temporarily)
                    for b in range(BATCH_SIZE):
                        for c in range(BEAM_WIDTH):
                            new_b_prob = b_probs[tuple((b,j))][0] + logprobs[b,toppreds[b,c]]
                            curr_probs[tuple((b,c))] = new_b_prob # set top prob
                            curr_string[b,0:i,c] = b_string[b,0:i,j] # set sentence
                            curr_string[b,i,c] = toppreds[b,c] # set top word
                else: # if preds are from subsequent beams, compare to existing
                    for b in range(BATCH_SIZE):
                        for c in range(BEAM_WIDTH): # proposed strings
                            replaced = False
                            for d in range(BEAM_WIDTH): # existing strings
                                new_b_prob = b_probs[tuple((b,j))][0] + logprobs[b,toppreds[b,c]]
                                if new_b_prob > curr_probs[tuple((b,d))] and not replaced:
                                    curr_probs[tuple((b,d))] = new_b_prob # update top prob
                                    curr_string[b,0:i,d] = b_string[b,0:i,j] # update sentence
                                    curr_string[b,i,d] = toppreds[b,c] # update top word
                                    replaced = True                        
            b_string[:,0:i+1,:] = curr_string
            # update top strings, probs
            for b in range(BATCH_SIZE):
                for c in range(BEAM_WIDTH):
                    b_probs[tuple((b,c))][0] = curr_probs[tuple((b,c))]
                    b_probs[tuple((b,c))][1] += 1

    K = output_width
    for b in range(BATCH_SIZE):
        if len(done[b]) < K:
            gap = K - len(done[b])
            probs = torch.tensor([b_probs[tuple((b,j))][0] for j in range(BEAM_WIDTH)], device='cuda')
            for c in torch.argsort(probs, descending=True)[0:gap]:  
                d = c.item()
                done_string = b_string[b,:,d].long()
                #print(b_probs[tuple((b,d))])
                done_prob = b_probs[tuple((b,d))][0]
                done_len = b_probs[tuple((b,d))][1]
                done[b].append([done_string, done_prob, done_len])
                        
    for b in range(BATCH_SIZE):
        normalized_probs = torch.tensor([], device='cuda')
        for sentence in range(len(done[b])):
            normalized = torch.tensor([done[b][sentence][1]/done[b][sentence][2]**alpha], device='cuda')
            normalized_probs = torch.cat((normalized_probs,normalized),0)
        top = torch.argsort(normalized_probs, descending=True)[0:K]
        for k in range(K):
            best = done[b][top[k]]
            if padding:
                m = nn.ConstantPad1d((0, max_len - best[2]), EN.vocab.stoi['<pad>'])
                predictions[b].append(m(best[0].long()))
            else:
                predictions[b].append(best[0].long())
            #print([EN.vocab.itos[i] for i in best[0].long()])
    
    return predictions'''

In [51]:
'''context_size = 500
num_layers = 2

attn_seq2context = SequenceModel(len(DE.vocab),context_size,num_layers=num_layers)
#state_dict = torch.load('best_seq2seq_withattn_seq2context.pt')
#attn_seq2context.load_state_dict(state_dict)
attn_seq2context = attn_seq2context.cuda()

attn_context2trg = attn_RNNet_batched(input_size=len(EN.vocab),hidden_size=context_size,num_layers=num_layers)
#state_dict = torch.load('best_seq2seq_withattn_context2trg.pt')
#attn_context2trg.load_state_dict(state_dict)
attn_context2trg = attn_context2trg.cuda()'''

In [None]:
'''# run beam search on one batch
it = iter(train_iter)
batch = next(it)
src = batch.src.values.transpose(0,1)
src = reverse_sequence(src)
beam_width = 3
max_len = src.shape[1] # restrict target sentence length to source sentence length
beamsearch(attn_seq2context, attn_context2trg, context_size, src, beam_width, max_len, padding=True)'''

## Beam Search

In [193]:
'''# run forward pass of encoder once
encoder_outputs, encoder_hidden = attn_seq2context(src)
decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
decoder_hidden = encoder_hidden'''

In [194]:
'''# prepare for beam search
START_TKN = EN.vocab.stoi["<s>"]
END_TKN = EN.vocab.stoi["</s>"]
BEAM_WIDTH = 3
lsm = nn.LogSoftmax(dim=1)

b_string = torch.zeros((BATCH_SIZE, max_len, BEAM_WIDTH), device='cuda')
b_string[:,0,:] = START_TKN

b_probs = {}
# b_probs key = tuple(batch idx, beam idx)
# b_probs val = [cum log prob, length]
done = {} # stores the finished strings
# done key = batch idx
# done val = [str, cum log prob, length]
predictions = {} # stores the top BEAM_WIDTH predictions
for b in range(BATCH_SIZE):
    done[b] = []
    predictions[b] = []
    for c in range(BEAM_WIDTH):
        b_probs[(b, c)] = [0, 1]'''

In [195]:
'''# loop through target sequence max len
for i in range(1,max_len):
    if i == 1: # if predicting the word following <s>, take top BEAM_WIDTH preds
        word_input = b_string[:,i-1,0].long()
        #print(word_input)
        decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, 
                                                                                                  decoder_context, 
                                                                                                  decoder_hidden, 
                                                                                                  encoder_outputs)
        logprobs = lsm(decoder_output.detach()) # BATCH_SIZE x VOCAB_SIZE
        #print(logprobs[0,:])
        toppreds = torch.argsort(logprobs, dim=1, descending=True)[:,0:BEAM_WIDTH] # BATCH_SIZE x BEAM_WIDTH
        #print(toppreds[0,:])
        #print(logprobs[0,:][toppreds[0,:]])
        b_string[:,i,:] = toppreds
        for b in range(BATCH_SIZE):
            for c in range(BEAM_WIDTH):
                b_probs[tuple((b,c))][0] += logprobs[b, toppreds[b,c]]
                b_probs[tuple((b,c))][1] += 1
    else: # if predicting the word for positions 2+, compare top BEAM_WIDTH preds for each of BEAM_WIDTH strings
        # temporary storage
        curr_probs = {}
        curr_string = torch.zeros(BATCH_SIZE, i+1, BEAM_WIDTH)

        for j in range(BEAM_WIDTH):
            word_input = b_string[:,i-1,j].long()
            decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, 
                                                                                                      decoder_context, 
                                                                                                      decoder_hidden, 
                                                                                                      encoder_outputs)
            logprobs = lsm(decoder_output.detach()) # unsorted log probs
            sortedpreds = torch.argsort(logprobs, dim=1, descending=True) # sorted words
            toppreds = sortedpreds[:,0:BEAM_WIDTH] # top words
            
            # check if any top preds are </s>
            for b in range(BATCH_SIZE):
                if END_TKN in toppreds[b,:]: # if </s> in top preds
                    # track finished strings
                    done_string = torch.cat((b_string[b,0:i,j],torch.tensor([END_TKN], device='cuda').float()))
                    done_prob = b_probs[tuple((b,j))][0] + logprobs[b,END_TKN]
                    done[b].append([done_string, done_prob, done_string.shape[0]])
                    # replace </s> with 4th best pred
                    done_idx = (toppreds[b,:] == END_TKN).nonzero()
                    toppreds[b,done_idx] = sortedpreds[b,BEAM_WIDTH]
               
            if j == 0: # if preds are from first beam, take top BEAM_WIDTH preds (temporarily)
                for b in range(BATCH_SIZE):
                    for c in range(BEAM_WIDTH):
                        new_b_prob = b_probs[tuple((b,j))][0] + logprobs[b,toppreds[b,c]]
                        curr_probs[tuple((b,c))] = new_b_prob # set top prob
                        curr_string[b,0:i,c] = b_string[b,0:i,j] # set sentence
                        curr_string[b,i,c] = toppreds[b,c] # set top word
            else:
                for b in range(BATCH_SIZE):
                    for c in range(BEAM_WIDTH): # proposed strings
                        replaced = False
                        for d in range(BEAM_WIDTH): # existing strings
                            new_b_prob = b_probs[tuple((b,j))][0] + logprobs[b,toppreds[b,c]]
                            if new_b_prob > curr_probs[tuple((b,d))] and not replaced:
                                curr_probs[tuple((b,d))] = new_b_prob # update top prob
                                curr_string[b,0:i,d] = b_string[b,0:i,j] # update sentence
                                curr_string[b,i,d] = toppreds[b,c] # update top word
                                replaced = True
        #print(b_string[:,0:i+2,:].shape, curr_string.shape)                        
        b_string[:,0:i+1,:] = curr_string
        for b in range(BATCH_SIZE):
            for c in range(BEAM_WIDTH):
                b_probs[tuple((b,c))][0] = curr_probs[tuple((b,c))]
                b_probs[tuple((b,c))][1] += 1'''

In [147]:
'''for b in range(BATCH_SIZE):
    for c in range(BEAM_WIDTH):
        if b_string[b,-1,c] == END_TKN:
            done_string = b_string[b,:,c]
            done_prob = b_probs[tuple((b,c))][0]
            done_len = b_probs[tuple((b,c))][1]
            done[b].append([done_string, done_prob, done_len])'''

In [196]:
'''alpha = 0.7
K = 3
for b in range(BATCH_SIZE):
    normalized_probs = torch.tensor([], device='cuda')
    for sentence in range(len(done[b])):
        normalized = torch.tensor([done[b][sentence][1]/done[b][sentence][2]**alpha], device='cuda')
        normalized_probs = torch.cat((normalized_probs,normalized),0)
    top = torch.argsort(normalized_probs, descending=True)[0:K]
    for k in range(K):
        best = done[b][top[k]]
        if padding:
            m = nn.ConstantPad1d((0, max_len - best[2]), EN.vocab.stoi['<pad>'])
            predictions[b].append(m(best[0].long()))
        else:
            predictions[b].append(best[0].long())
        print([EN.vocab.itos[i] for i in best[0].long()])'''

['<s>', 'The', 'problem', 'was', 'very', 'expensive', 'and', 'it', '.', '</s>']
['<s>', 'The', 'problem', 'is', 'it', ',', 'and', 'use', 'it', 'goes', 'very', 'difficult', '.', '</s>']
['<s>', 'The', 'problem', 'is', 'it', ',', 'and', 'use', 'it', '.', '</s>']
['<s>', 'And', 'the', 'result', 'is', 'the', '<unk>', '.', '</s>']
['<s>', 'And', 'the', 'outcome', 'is', '<unk>', '.', '</s>']
['<s>', 'And', 'the', 'result', 'is', 'the', '<unk>', '</s>']
['<s>', 'So', 'ca', "n't", 'do', 'that', "'s", '<unk>', '.', '</s>']
['<s>', 'So', 'ca', "n't", 'do', 'that', "'s", 'way', '.', '</s>']
['<s>', 'So', 'ca', "n't", 'do', 'that', "'s", 'difference', '.', '</s>']
['<s>', 'Of', 'course', 'we', '<unk>', '.', '</s>']
['<s>', 'Of', 'course', 'we', 'do', 'dawn', '.', '</s>']
['<s>', 'Of', 'course', 'we', 'do', 'dawn', 'are', "n't", '.', '</s>']
['<s>', 'And', 'it', 'agreed', ',', 'and', 'everyone', '.', '</s>']
['<s>', 'And', 'it', 'agreed', ',', 'and', 'agreed', '.', '</s>']
['<s>', 'And', 'it', 'agr

In [178]:
encoder_outputs, encoder_hidden = attn_seq2context(src)
decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
decoder_hidden = encoder_hidden

sentence = []
trg = batch.trg.values.transpose(0,1)

b = 0
for j in range(trg.shape[1] - 1):
    word_input = trg[:,j]
    decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, decoder_context, decoder_hidden, encoder_outputs)
    sentence.extend([torch.argmax(decoder_output.detach()[b,:],dim=0)])
print([EN.vocab.itos[i] for i in sentence])
print([EN.vocab.itos[i] for i in trg[b,:]])

['So', 'if', 'none', 'of', 'that', 'things', 'is', '?', 'problem', '?', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
['<s>', 'What', 'if', 'none', 'of', 'these', 'things', 'is', 'the', 'problem', '?', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']


In [179]:
for j in range(len(done[b])):
    print([EN.vocab.itos[i] for i in done[b][j][0].long()])
for j in range(BEAM_WIDTH):
    print([EN.vocab.itos[i] for i in b_string[b,:,j].long()])
print([EN.vocab.itos[i] for i in trg[b,:].long()])

['<s>', 'What', 'if', 'neither', 'is', '?', '</s>']
['<s>', 'What', 'if', 'neither', 'is', 'the', 'problem', '</s>']
['<s>', 'What', 'if', 'neither', 'are', 'the', 'issue', '</s>']
['<s>', 'What', 'if', 'neither', 'are', 'the', 'problem', '</s>']
['<s>', 'What', 'if', 'neither', 'is', 'the', 'problem', '?', '</s>']
['<s>', 'What', 'if', 'neither', 'are', 'the', 'problem', '?', '</s>']
['<s>', 'What', 'if', 'neither', 'are', 'the', 'issue', 'is', 'that', '?', '</s>']
['<s>', 'What', 'if', 'neither', 'are', 'the', 'issue', 'is', "n't", 'the', 'problem', '</s>']
['<s>', 'What', 'if', 'neither', 'are', 'the', 'issue', 'is', 'that', 'the', 'problem', '?', '</s>']
['<s>', 'What', 'if', 'neither', 'are', 'the', 'issue', 'is', "n't", 'the', 'problem', 'is', 'that', '</s>']
['<s>', 'What', 'if', 'neither', 'are', 'the', 'issue', 'is', 'that', 'the', 'problem', 'is', 'that', '</s>']
['<s>', 'What', 'if', 'neither', 'are', 'the', 'issue', 'is', "n't", 'the', 'problem', 'is', 'that', '?', '</s>']


In [107]:
'''# check out prediction
for b in range(BATCH_SIZE):
    print('predictions:')
    for j in range(BEAM_WIDTH):
        print([EN.vocab.itos[i] for i in b_string[b,:,j].long()])
    print('actual:')
    print([EN.vocab.itos[i] for i in trg[b,:].long()])'''

predictions:
['<s>', 'Once', 'the', 'parasites', ',', 'nothing', 'to', 'raise', 'the', 'parasites', '.', 'Okay', ',', 'right', ',', 'the', '<unk>', 'did']
['<s>', 'Once', 'the', 'parasites', ',', 'nothing', 'to', 'raise', 'the', 'parasites', '.', 'Okay', '?', 'the', 'parasites', ',', 'the', '<unk>']
['<s>', 'Once', 'the', 'parasites', ',', 'nothing', 'to', 'raise', 'the', 'parasites', '.', 'Okay', '?', '<unk>', '<unk>', ',', 'the', 'report']
['<s>', 'Once', 'the', 'parasites', ',', 'nothing', 'to', 'raise', 'the', 'parasites', '.', 'Okay', '?', '<unk>', '<unk>', ',', 'the', '<unk>']
['<s>', 'Once', 'the', 'parasites', ',', 'nothing', 'to', 'raise', 'the', 'parasites', '.', 'Okay', '?', 'the', 'parasites', ',', 'you', 'have']
actual:
['<s>', 'Once', 'the', 'parasites', 'get', 'in', ',', 'the', 'hosts', 'do', "n't", 'get', 'a', 'say', '.', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
predictions:
['<s>', 'And', 'nobody', 'can', 'can', 'do', 'aging', '<pad>', '<pad>', '<p

In [None]:
# attention:
# https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb

# normalization:
# https://medium.com/machine-learning-bites/deeplearning-series-sequence-to-sequence-architectures-4c4ca89e5654

## Submission

In [65]:
# load test set
sentences = []
for i, l in enumerate(open("source_test.txt")):
  sentences.append(re.split(' ', l))

In [66]:
max_sent_len = 0
for i in range(len(sentences)):
    if len(sentences[i]) > max_sent_len:
        max_sent_len = len(sentences[i])

In [67]:
batch = torch.tensor([], device='cuda')
for b in range(len(sentences)):
    m = nn.ConstantPad1d((0, max_sent_len - len(sentences[b])), EN.vocab.stoi['<pad>'])
    src = m(torch.tensor([DE.vocab.stoi[i] for i in sentences[b]], device='cuda').unsqueeze(0)).float()
    batch = torch.cat((batch,src), dim=0)
batch_rev = reverse_sequence(batch)

In [104]:
context_size = 500
num_layers = 2

attn_seq2context = SequenceModel(len(DE.vocab),context_size,num_layers=num_layers)
state_dict = torch.load('best_seq2seq_withattn_seq2context.pt')
attn_seq2context.load_state_dict(state_dict)
attn_seq2context = attn_seq2context.cuda()

attn_context2trg = attn_RNNet_batched(input_size=len(EN.vocab),hidden_size=context_size,num_layers=num_layers)
state_dict = torch.load('best_seq2seq_withattn_context2trg.pt')
attn_context2trg.load_state_dict(state_dict)
attn_context2trg = attn_context2trg.cuda()

In [102]:
context_size = 500
num_layers = 2

seq2context = SequenceModel(len(DE.vocab),context_size,num_layers=num_layers)
state_dict = torch.load('best_seq2seq_seq2context.pt')
seq2context.load_state_dict(state_dict)
seq2context = seq2context.cuda()

context2trg = RNNet(input_size=len(EN.vocab),hidden_size=context_size,num_layers=num_layers,weight_tie=True)
state_dict = torch.load('best_seq2seq_context2trg.pt')
context2trg.load_state_dict(state_dict)
context2trg = context2trg.cuda()

In [107]:
model = 's2s_attn'
src = batch_rev.long()
beam_width = 3
max_len = 20
output_width = 1
alpha = 1
batch_size = len(sentences)
predictions = beamsearch(model, attn_seq2context, attn_context2trg, context_size, src, beam_width, max_len, output_width=output_width, alpha=alpha, BATCH_SIZE=batch_size, padding=False, EN=EN)

position: 1
position: 2
beam: 0
beam: 1
beam: 2
position: 3
beam: 0
beam: 1
beam: 2
position: 4
beam: 0
beam: 1
beam: 2
position: 5
beam: 0
beam: 1
beam: 2
position: 6
beam: 0
beam: 1
beam: 2
position: 7
beam: 0
beam: 1
beam: 2
position: 8
beam: 0
beam: 1
beam: 2
position: 9
beam: 0
beam: 1
beam: 2
position: 10
beam: 0
beam: 1
beam: 2
position: 11
beam: 0
beam: 1
beam: 2
position: 12
beam: 0
beam: 1
beam: 2
position: 13
beam: 0
beam: 1
beam: 2
position: 14
beam: 0
beam: 1
beam: 2
position: 15
beam: 0
beam: 1
beam: 2
position: 16
beam: 0
beam: 1
beam: 2
position: 17
beam: 0
beam: 1
beam: 2
position: 18
beam: 0
beam: 1
beam: 2
position: 19
batch: 0
['<s>', 'When', 'I', 'was', 'my', '<unk>', ',', 'I', 'had', 'my', 'first', '<unk>', '.', '</s>']
batch: 1
['<s>', 'I', 'was', '<unk>', ',', '<unk>', '.', '</s>']
batch: 2
['<s>', 'She', 'was', 'a', '<unk>', 'named', 'Alex', '.', '</s>']
batch: 3
['<s>', 'And', 'when', 'I', 'hear', 'I', 'happy', '.', '</s>']
batch: 4
['<s>', 'My', '<unk>', 'as'

['<s>', '"', 'hired', ',', '"', 'and', 'much', 'better', '.', '</s>']
batch: 108
['<s>', 'It', 'with', 'a', 'new', 'culture', ',', '<unk>', '.', '</s>']
batch: 109
['<s>', 'This', 'is', '<unk>', '<unk>', '.', '</s>']
batch: 110
['<s>', 'It', 'never', 'make', 'sense', 'of', 'the', 'world', '.', '</s>']
batch: 111
['<s>', 'So', 'a', 'map', 'says', 'about', 'personal', '<unk>', '.', '</s>']
batch: 112
['<s>', 'The', '<unk>', 'is', 'not', 'new', '.', '</s>']
batch: 113
['<s>', 'In', 'the', 'world', 'is', 'a', '<unk>', '.', '</s>']
batch: 114
['<s>', 'There', 'are', '<unk>', '<unk>', 'and', '<unk>', '.', '</s>']
batch: 115
['<s>', 'And', '<unk>', 'is', '<unk>', 'understanding', '.', '</s>']
batch: 116
['<s>', 'And', 'I', 'started', 'to', 'read', 'both', 'simultaneously', '.', '</s>']
batch: 117
['<s>', 'In', 'addition', 'was', 'political', 'and', '<unk>', '.', '</s>']
batch: 118
['<s>', '<unk>', 'was', ',', 'liberals', 'and', 'interesting', '<unk>', '.', '</s>']
batch: 119
['<s>', '"', '<un

['<s>', 'I', 'was', 'not', 'alone', ',', 'I', 'went', '.', '</s>']
batch: 225
['<s>', 'And', 'I', '<unk>', '<unk>', 'and', 'the', 'snow', '.', '</s>']
batch: 226
['<s>', 'It', 'was', 'a', 'boring', 'privilege', ',', '<unk>', '.', '</s>']
batch: 227
['<s>', 'Everybody', 'was', '<unk>', '.', '</s>']
batch: 228
['<s>', 'It', "'s", 'hard', 'to', '<unk>', '.', '</s>']
batch: 229
['<s>', 'Then', 'there', 'is', '<unk>', '.', '</s>']
batch: 230
['<s>', 'You', "'ve", '<unk>', 'all', 'day', '.', '</s>']
batch: 231
['<s>', 'Everybody', 'wants', 'to', 'stop', 'enough', '.', '</s>']
batch: 232
['<s>', 'This', 'is', 'really', 'bad', '.', '</s>']
batch: 233
['<s>', 'I', 'learned', 'by', 'the', '<unk>', '.', '</s>']
batch: 234
['<s>', 'He', 'does', "n't", '<unk>', '.', '</s>']
batch: 235
['<s>', 'It', "'s", 'kind', 'of', '<unk>', '.', '</s>']
batch: 236
['<s>', 'He', 'has', 'to', 'be', 'always', '.', '</s>']
batch: 237
['<s>', 'And', 'then', 'it', 'is', '.', '</s>']
batch: 238
['<s>', 'They', "'re", '

['<s>', 'Sometimes', ',', 'not', 'know', 'the', 'issues', '.', '</s>']
batch: 357
['<s>', 'We', 'fly', 'to', 'the', 'Mars', ',', '</s>']
batch: 358
['<s>', 'NASA', 'does', 'a', 'plan', 'plan', '.', '</s>']
batch: 359
['<s>', 'But', ',', 'is', ',', 'and', 'therefore', 'going', 'to', 'happen', '.', '</s>']
batch: 360
['<s>', 'Another', 'time', 'we', 'do', "n't", 'solve', 'our', 'political', 'systems', '.', '</s>']
batch: 361
['<s>', 'gas', 'machines', 'as', 'solar', 'oil', 'as', 'a', 'barrel', '.', '</s>']
batch: 362
['<s>', 'We', 'need', '<unk>', 'the', '<unk>', '<unk>', '.', '</s>']
batch: 363
['<s>', 'Sometimes', 'there', 'are', '<unk>', ',', 'but', 'not', 'so', '<unk>', ':', '</s>']
batch: 364
['<s>', 'It', 'was', 'a', 'long', '-', 'failure', '.', '</s>']
batch: 365
['<s>', 'Ultimately', ',', 'we', 'get', '<unk>', 'because', 'we', 'do', "n't", 'understand', '.', '</s>']
batch: 366
['<s>', '<unk>', 'problems', 'are', '<unk>', '.', '</s>']
batch: 367
['<s>', 'It', "'s", 'not', 'the', '

['<s>', 'That', "'s", '<unk>', '.', '</s>']
batch: 471
['<s>', '<unk>', 'because', 'they', 'simple', 'ideas', ',', 'ideas', '.', '</s>']
batch: 472
['<s>', '<unk>', '<unk>', '.', '</s>']
batch: 473
['<s>', 'We', 'are', 'better', 'things', '.', '</s>']
batch: 474
['<s>', 'And', 'we', 'are', '<unk>', '.', '</s>']
batch: 475
['<s>', '<unk>', ',', 'they', 'all', '<unk>', '.', '</s>']
batch: 476
['<s>', 'Actually', ',', 'it', "'s", 'not', 'necessarily', 'a', 'deep', '.', '</s>']
batch: 477
['<s>', 'They', 'work', 'because', 'they', '.', '</s>']
batch: 478
['<s>', 'But', 'it', "'s", 'not', '<unk>', '.', '</s>']
batch: 479
['<s>', 'This', 'is', 'what', 'we', '<unk>', 'in', 'the', 'jungle', '.', '</s>']
batch: 480
['<s>', 'I', 'learned', 'that', 'in', 'of', 'us', 'has', '<unk>', '.', '</s>']
batch: 481
['<s>', '<unk>', '<unk>', 'in', 'New', 'York', '<unk>', '.', '</s>']
batch: 482
['<s>', 'As', 'every', 'student', 'did', 'a', '<unk>', '.', '</s>']
batch: 483
['<s>', 'I', "'d", '<unk>', 'on', '

['<s>', 'Many', 'of', 'systems', 'have', 'the', 'effect', '.', '</s>']
batch: 578
['<s>', 'They', 'do', "n't", 'work', '.', '</s>']
batch: 579
['<s>', '<unk>', 'can', 'probably', ',', 'like', '.', '</s>']
batch: 580
['<s>', 'So', ',', 'to', 'the', '<unk>', '.', '</s>']
batch: 581
['<s>', 'And', 'now', 'to', 'the', 'task', '.', '</s>']
batch: 582
['<s>', 'These', '<unk>', 'of', 'describing', 'the', '<unk>', '<unk>', '.', '</s>']
batch: 583
['<s>', 'Because', 'it', "'s", '<unk>', 'and', '<unk>', '.', '</s>']
batch: 584
['<s>', 'We', 'can', 'do', 'this', 'is', '.', '</s>']
batch: 585
['<s>', 'We', 'can', 'give', 'a', '<unk>', '.', '</s>']
batch: 586
['<s>', 'So', 'two', 'people', 'who', '<unk>', '.', '</s>']
batch: 587
['<s>', 'But', 'ultimately', ',', 'to', 'others', '.', '</s>']
batch: 588
['<s>', 'You', 'ca', "n't", 'build', 'what', 'you', 'you', 'do', '.', '</s>']
batch: 589
['<s>', 'You', "'ve", 'to', 'trust', 'that', 'you', 'do', '.', '</s>']
batch: 590
['<s>', 'So', 'you', 'have', 

['<s>', 'There', 'was', 'a', 'young', ',', '<unk>', '.', '</s>']
batch: 690
['<s>', 'And', 'he', 'asked', 'in', 'his', 'high', 'school', '.', '</s>']
batch: 691
['<s>', 'There', 'was', 'no', 'work', '.', '</s>']
batch: 692
['<s>', 'They', 'just', 'people', 'like', 'him', '.', '</s>']
batch: 693
['<s>', 'But', 'this', 'story', 'a', 'different', 'story', '.', '</s>']
batch: 694
['<s>', 'In', '<unk>', ',', '"', 'A', 'are', 'the', 'streets', '.', '</s>']
batch: 695
['<s>', 'And', 'then', 'recognized', '.', '</s>']
batch: 696
['<s>', 'It', '<unk>', '.', '</s>']
batch: 697
['<s>', 'It', 'started', 'to', 'rock', ',', 'which', 'could', "n't", 'afford', '.', '</s>']
batch: 698
['<s>', 'How', 'how', 'this', 'story', 'from', '.', '</s>']
batch: 699
['<s>', 'What', "'s", 'difference', '.', '</s>']
batch: 700
['<s>', 'I', 'think', 'it', "'s", 'new', '.', '</s>']
batch: 701
['<s>', 'It', "'s", '<unk>', 'and', 'I', 'can', 'be', '<unk>', '.', '</s>']
batch: 702
['<s>', 'It', "'s", 'hard', 'to', 'creat

In [None]:
'''stored_logprobs = torch.zeros((32,3,100))
stored_logprobs[0,1,1] = 5
stored_logprobs[0,2,1] = 10
probs = torch.argsort(stored_logprobs[0].view(-1), descending=True)
temp = torch.zeros((2,300))
sorted_logprobs = torch.tensor(divmod(probs.numpy(), 100), device='cuda')
gap = 2
for c in sorted_logprobs.transpose(0,1)[0:gap]:
    print(c)
sorted_logprobs.transpose(0,1).shape'''

In [96]:
# for kaggle
def escape(l):
    return l.replace("\"", "<quote>").replace(",", "<comma>")

In [99]:
# for kaggle
with open("pred_kaggle.txt", "w") as f:
    f.write("Id,Predicted\n")
    for i, l in enumerate(open("source_test.txt")):
        preds = ['|'.join([EN.vocab.itos[k] for k in predictions[i][j][1:]]) for j in range(output_width)]
        f.write("%d,%s"%(i, escape(" ".join(preds))))
        f.write('\n')

In [94]:
# for bleu
with open("pred_bleu.txt", "w") as f:
    for i, l in enumerate(open("source_test.txt")):
        f.write("%s\n"%(" ".join([EN.vocab.itos[k] for k in predictions[i][0][1:-1]])))