# CS 287, Homework 3: Neural Machine Translation

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_

from common import *

import re
from torchtext import data, datasets
from namedtensor import ntorch, NamedTensor
from namedtensor.text import NamedField
import numpy as np

%reload_ext autoreload
%autoreload 2

In [2]:
# split raw data into tokens
import spacy
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

# add beginning-of-sentence and end-of-sentence tokens to target
BOS_WORD = '<s>'
EOS_WORD = '</s>'
DE = NamedField(names=('srcSeqlen',), tokenize=tokenize_de)
EN = NamedField(names=('trgSeqlen',), tokenize=tokenize_en,
                init_token = BOS_WORD, eos_token = EOS_WORD) # only target needs BOS/EOS

# download dataset of 200K pairs of sentences
# start with MAXLEN = 20
MAX_LEN = 20
train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(DE, EN), 
                                         filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and 
                                         len(vars(x)['trg']) <= MAX_LEN)

# build vocab, convert words to indices
MIN_FREQ = 5
DE.build_vocab(train.src, min_freq=MIN_FREQ)
EN.build_vocab(train.trg, min_freq=MIN_FREQ)
print("Size of German vocab", len(DE.vocab))
print("Size of English vocab", len(EN.vocab))
print(EN.vocab.stoi["<s>"], EN.vocab.stoi["</s>"])
print(EN.vocab.stoi["<pad>"], EN.vocab.stoi["<unk>"])

Size of German vocab 13353
Size of English vocab 11560
2 3
1 0


In [5]:
# split data into batches
BATCH_SIZE = 32
device = torch.device('cuda:0')
train_iter, val_iter = data.BucketIterator.splits((train, val), batch_size=BATCH_SIZE, device=device,
                                                  repeat=False, sort_key=lambda x: len(x.src))

In [3]:
context_size = 1000
num_layers = 4

BEAM_WIDTH = 50
max_len = 3

attn_seq2context = SequenceModel(len(DE.vocab),context_size,num_layers=num_layers)
state_dict = torch.load('best_seq2seq_withattn_seq2context_big_network.pt')
attn_seq2context.load_state_dict(state_dict)
attn_seq2context = attn_seq2context.cuda()

attn_context2trg = attn_RNNet_batched(input_size=len(EN.vocab),hidden_size=context_size,num_layers=num_layers)
state_dict = torch.load('best_seq2seq_withattn_context2trg_big_network.pt')
attn_context2trg.load_state_dict(state_dict)
attn_context2trg = attn_context2trg.cuda()

In [4]:
# load test set
sentences = []
for i, l in enumerate(open("source_test.txt")):
    sentences.append(re.split(' ', l))

max_sent_len = 0
for i in range(len(sentences)):
    if len(sentences[i]) > max_sent_len:
        max_sent_len = len(sentences[i])

batch = torch.tensor([], device='cuda')
for b in range(len(sentences)):
    m = nn.ConstantPad1d((0, max_sent_len - len(sentences[b])), EN.vocab.stoi['<pad>'])
    src = m(torch.tensor([DE.vocab.stoi[i] for i in sentences[b]], device='cuda').unsqueeze(0)).float()
    batch = torch.cat((batch,src), dim=0)
batch_rev = reverse_sequence(batch)

In [5]:
batch_rev_data = torch.utils.data.TensorDataset(batch_rev)

In [6]:
BATCH_SIZE = 50
batch_rev_data_loader = torch.utils.data.DataLoader(batch_rev_data, batch_size=BATCH_SIZE, shuffle=False)
#BATCH_SIZE = src.shape[0]

In [50]:
def beam_search_first3(src, attn_seq2context, attn_context2trg, BEAM_WIDTH = 2, BATCH_SIZE=32, max_len=3,context_size=500,EN=None):
    top_p = {}
    top_s = {}
    stopped = torch.zeros((BATCH_SIZE,BEAM_WIDTH), device='cuda') #==1
    items = []
    for i in range(BATCH_SIZE):
        top_p[i] = []
        top_s[i] = []
        items.append(i)



    encoder_outputs, encoder_hidden = attn_seq2context(src)
    decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
    decoder_hidden = encoder_hidden
    outputs = []
    word_input = (torch.zeros(BATCH_SIZE, device='cuda') + EN.vocab.stoi['<s>']).long()
    decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, decoder_context, decoder_hidden, encoder_outputs)
    outputs.append(decoder_output)
    next_words = torch.argsort(lsm2(decoder_output),dim=1, descending=True)[:,0:BEAM_WIDTH].detach()
    p_words_init = torch.stack([torch.index_select(decoder_output[i,:],-1,next_words[i,:]) for i in range(BATCH_SIZE)]).detach()
    p_words_running = torch.stack([p_words_init[:,b].repeat(1,BEAM_WIDTH) for b in range(BEAM_WIDTH)]).view(BEAM_WIDTH**2,BATCH_SIZE).transpose(0,1)

    update = []
    for ix,p in enumerate(next_words):
        update.append([torch.stack(([torch.tensor(EN.vocab.stoi['<s>'], device='cuda')])+([next_words[ix,b]])) for b in range(BEAM_WIDTH)])

    next_words = torch.argsort(lsm2(decoder_output),dim=1, descending=True)[:,0:BEAM_WIDTH].detach()
    p_words_init = torch.stack([torch.index_select(decoder_output[i,:],-1,next_words[i,:]) for i in range(BATCH_SIZE)]).detach()
    p_words_running = torch.stack([p_words_init[:,b].repeat(1,BEAM_WIDTH) for b in range(BEAM_WIDTH)]).view(BEAM_WIDTH**2,BATCH_SIZE).transpose(0,1)
    update = []
    for ix,p in enumerate(next_words):
        update.append([torch.stack(([torch.tensor(EN.vocab.stoi['<s>'], device='cuda')])+([next_words[ix,b]])) for b in range(BEAM_WIDTH)])
    top_s.update(dict(zip(items, update)))
    top_p.update(dict(zip(items, p_words_init)))

    next_words = next_words.transpose(0,1).flatten().long()
    encoder_outputs = encoder_outputs.repeat(BEAM_WIDTH,1,1)
    decoder_hidden = tuple([h.repeat(1,BEAM_WIDTH,1) for h in decoder_hidden])
    decoder_context = decoder_context.repeat(BEAM_WIDTH,1)
    
    for j in range(2):
            decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(next_words, decoder_context, decoder_hidden, encoder_outputs)
            args = torch.argsort(lsm2(decoder_output),dim=1, descending=True)[:,0:BEAM_WIDTH].detach()
            next_words = torch.cat([args[BATCH_SIZE*(b):BATCH_SIZE*(b+1),:] for b in range(BEAM_WIDTH)],dim=1).detach()
            p_words = torch.stack([torch.index_select(decoder_output[i,:],-1,next_words[i,:]) for i in range(BATCH_SIZE)])
            p_words_running += p_words.detach()
            p_words_norm = p_words_running/(j+1)

            word_selector = torch.argsort(p_words_norm,dim=1,descending=True)[:,:BEAM_WIDTH]        
            beam_indicator = (word_selector/BEAM_WIDTH).float().long() #word_selector>=2

            prev_words = list(top_s.values())
            words = [torch.stack([prev_words[s][i] for i in beam_indicator[s,:]]) for s in range(BATCH_SIZE)]
            #update = []
            #for ix,p in enumerate(words):
            #    update.append([torch.cat((p[b],next_words[ix,b].unsqueeze(0))) for b in range(BEAM_WIDTH)])
            #top_s.update(dict(zip(items, update)))

            #update_p = torch.stack([torch.index_select(p_words_running[b], 0, word_selector[b]) for b in range(BATCH_SIZE)])
            #top_p.update(dict(zip(items, update_p)))

            indexs = torch.zeros(BATCH_SIZE,BEAM_WIDTH,device='cuda')
            for i in range(BATCH_SIZE):
                indexs[i,:] += i+(BATCH_SIZE*beam_indicator[i,:].float())
            indexs = indexs.long()
            indexs = indexs.transpose(0,1).flatten()
            decoder_hidden = tuple([torch.index_select(h,1,indexs) for h in decoder_hidden])
            decoder_context = torch.index_select(decoder_context,0,indexs)
            decoder_output = torch.index_select(decoder_output,0,indexs)
            outputs.append(decoder_output)
            next_words = torch.stack([torch.index_select(next_words[s,:],0,word_selector[s,:]) for s in range(BATCH_SIZE)]).transpose(0,1).flatten().long()
            
    return outputs

## Submission

In [51]:
outputs = beam_search_first3(src, attn_seq2context, attn_context2trg, BEAM_WIDTH = 8, BATCH_SIZE=BATCH_SIZE, max_len=3,EN=EN)
preds = torch.stack([i[:BATCH_SIZE,:] for i in outputs]).transpose(0,1).transpose(1,2)

torch.Size([100, 11560, 3])

In [7]:
def ints_to_sentences(list_of_phrases):
    sentences = []
    for phrase in list_of_phrases:
        sentences.append(" ".join(["|".join([EN.vocab.itos[w] for w in phrase])]))
    return sentences

In [45]:
' '.join(ints_to_sentences(list_of_phrases))

"I|was|an I|was|a I|was|going I|was|having I|was|talking <unk>|'ve|watching I|was|in I|was|looking I|was|running I|was|writing I|do|responsible I|miss|what I|was|wearing I|see|on A|'s|there I|was|supposed I|did|showing I|'m|doing I|was|willing I|is|taking I|was|holding And|want|putting I|think|entering I|just|teaching I|was|saying I|'m|the I|'ll|making I|miss|<unk> I|'m|able I|guess|wondering I|was|flying I|always|giving I|<unk>|losing I|was|just I|first|so I|have|trying I|was|creating I|was|reading I|would|like I|'d|standing I|get|seeing I|wrote|curious I|was|using I|call|playing I|live|building And|feel|leaving I|was|experiencing I|'m|thinking And|want|several I|was|here I|do|one I|was|hanging <unk>|'ve|then I|am|eating I|always|launching I|'m|, I|read|asking I|was|leading One|felt|to I|was|somebody And|feel|listening Well|go|this I|<unk>|sending I|negotiate|actually I|was|following <unk>|'ve|calling <unk>|like|being Voice|kind|offering I|do|glad I|wrote|that I|was|my I|look|some I|'

In [24]:
preds = " ".join("|".join([EN.vocab.itos[j] for j in top_s[0][0][1:]]) for )
preds

'<unk>|I|<unk>'

In [8]:
# for kaggle
def escape(l):
    return l.replace("\"", "<quote>").replace(",", "<comma>")

# for kaggle
line_counter = 0
with open("pred_kaggle_best_model_no-split_large.txt", "w+") as f:
    f.write("Id,Predicted\n")
    for b in batch_rev_data_loader:
        
        src = b[0].long()
        top_s = beam_search(src, attn_seq2context, attn_context2trg, BEAM_WIDTH = 100, BATCH_SIZE=BATCH_SIZE, max_len=3,context_size=context_size,EN=EN)
        lists_per_sample = [torch.stack(top_s[i])[:,1:] for i in range(BATCH_SIZE)]
        for i in range(BATCH_SIZE):
            list_of_phrases = lists_per_sample[i]
            f.write("%d,%s"%(line_counter,escape(' '.join(ints_to_sentences(list_of_phrases)))))
            f.write('\n')
            line_counter+=1

In [29]:
[DE.vocab.itos[i] for i in reverse_sequence(src.long())[0]]

['<pad>',
 '<pad>',
 '<pad>',
 '<pad>',
 '<unk>',
 '"',
 '<unk>',
 '"',
 'für',
 'steht',
 ',',
 'hinausgeht',
 'es',
 'wo',
 ',',
 'sagt',
 'der',
 ',',
 'Mund',
 'Ein']

In [None]:
# for bleu
with open("pred_bleu.txt", "w") as f:
    for i, l in enumerate(open("source_test.txt")):
        stop_idx = torch.tensor([max_len], device='cuda') if (top_s[i][0] == 3).nonzero().size()[0]==0 else (top_s[i][0] == 3).nonzero()
        f.write("%s\n"%(" ".join([EN.vocab.itos[j] for j in top_s[i][0][1:stop_idx]])))