# CS 287, Homework 3: Neural Machine Translation

In [3]:
from common import *

import torch
from torch.nn.utils import clip_grad_norm_
import torch.nn as nn
import torch.nn.functional as F
from torchtext import data, datasets
from namedtensor import ntorch, NamedTensor
from namedtensor.text import NamedField
import numpy as np

%reload_ext autoreload
%autoreload 2

In [4]:
# split raw data into tokens
import spacy
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

# add beginning-of-sentence and end-of-sentence tokens to target
BOS_WORD = '<s>'
EOS_WORD = '</s>'
DE = NamedField(names=('srcSeqlen',), tokenize=tokenize_de)
EN = NamedField(names=('trgSeqlen',), tokenize=tokenize_en,
                init_token = BOS_WORD, eos_token = EOS_WORD) # only target needs BOS/EOS

# download dataset of 200K pairs of sentences
# start with MAXLEN = 20
MAX_LEN = 20
train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(DE, EN), 
                                         filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and 
                                         len(vars(x)['trg']) <= MAX_LEN)
#print(train.fields)
#print(len(train))
#print(vars(train[0]))

# build vocab, convert words to indices
MIN_FREQ = 5
DE.build_vocab(train.src, min_freq=MIN_FREQ)
EN.build_vocab(train.trg, min_freq=MIN_FREQ)
#print(DE.vocab.freqs.most_common(10))
#print("Size of German vocab", len(DE.vocab))
#print(EN.vocab.freqs.most_common(10))
#print("Size of English vocab", len(EN.vocab))
print(EN.vocab.stoi["<s>"], EN.vocab.stoi["</s>"])
print(EN.vocab.stoi["<pad>"], EN.vocab.stoi["<unk>"])
print(DE.vocab.stoi["<pad>"], DE.vocab.stoi["<unk>"])

2 3
1 0
1 0


In [5]:
# split data into batches
BATCH_SIZE = 32
device = torch.device('cuda:0')
train_iter, val_iter = data.BucketIterator.splits((train, val), batch_size=BATCH_SIZE, device=device,
                                                  repeat=False, sort_key=lambda x: len(x.src))

## Sequence to Sequence Learning with Neural Networks

- English to French translation, $p \left( y_1, \dots, y_{T'} \ | \ x_1, \dots, x_T \right) = \prod_{t = 1}^{T'} p \left( y_t \ | \ v, y_1, \dots, y_{t-1} \right)$
- Each sentence ends in '<EOS\>', out-of-vocab words denoted '<UNK\>'
- Model specs: 
    * Input vocabulary of 160,000 and output vocabulary of 80,000
    * Deep LSTM to map (encode) input sequence to fixed-len vector
    * Another deep LSTM to translate (decode) fixed-len vector to output sequence
    * 4 layers per LSTM, 1000 cells per layer, 1000-dimensional word embeddings, softmax over 80,000 words
    * Reversing order of words in source (but not target) improved performance
        * Each word in the source is far from its corresponding word in the target (large minimal time lag); reversing the source reduces the minimal time lag, thereby allowing backprop to establish communication between source and target more easily
- Training specs:
    * Initialize all LSTM params $\sim Unif[-0.08,0.08]$
    * SGD w/o momentum, lr = 0.7
        * After 5 epochs, halve the lr every half-epoch
        * Train for 7.5 epochs
    * Batch size = 128; divide gradient by batch size (denoted $g$)
    * Hard constraint gradient norm; if $s = ||g||_2 > 5$, set $s = 5$
    * Make sure all sentences within a minibatch are roughly the same length
- Objective: $max \frac{1}{|S|} \sum_{(T,S) \in \mathcal{S}} log \ p(T \ | \ S)$, where $\mathcal{S}$ is the training set
- Prediction: $\hat{T} = argmax \ p(T \ | \ S)$ via beam search, where beam size $B \in {1,2}$

In [None]:
'''class SequenceModel(nn.Module):
    def __init__(self, src_vocab_size, context_size, num_layers, weight_init = 0.08):
        super(SequenceModel, self).__init__()
        # embedding
        self.embedding = nn.Embedding(src_vocab_size, context_size)
        # language summarization
        self.lstm = nn.LSTM(input_size=context_size, hidden_size=context_size, num_layers=num_layers, batch_first=True)
        for p in self.lstm.parameters():
            torch.nn.init.uniform_(p, a=weight_init, b=weight_init)

    def forward(self, inputs, h0=None):
        # embed the words 
        embedded = self.embedding(inputs)
        # summarize context
        context, hidden = self.lstm(embedded,h0)
        return context, hidden
    
class LanguageModel(nn.Module):
    def __init__(self, target_vocab_size, hidden_size, context_size, num_layers, weight_init = 0.08):
        super(LanguageModel, self).__init__()
        # context is batch_size x seq_len x context_size
        # context to hidden
        self.embedding = nn.Embedding(target_vocab_size, hidden_size)
        # hidden to hidden 
        self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        # decode hidden state for y_t
        for p in self.lstm.parameters():
            torch.nn.init.uniform_(p, a=weight_init, b=weight_init)
            
        self.translate = nn.Linear(hidden_size, target_vocab_size)

    def forward(self, inputs, h0=None):
        # embed the trg words
        embedded = self.embedding(inputs)
        # setting hidden state to context at t=0
        # otherwise context = prev hidden state
        output, hidden = self.lstm(embedded, h0)
        output = self.translate(output)
        return output,hidden'''

In [None]:
'''def repackage_hidden(h):
    return tuple(v.detach() for v in h)
def repackage_layer(hidden_s2c,hidden=100):
    return tuple([hidden_s2c[0][-1].detach().view(1,BATCH_SIZE,hidden),hidden_s2c[1][-1].detach().view(1,BATCH_SIZE,hidden)])
def reverse_sequence(src):
    length = list(src.shape)[1]
    idx = torch.linspace(length-1, 0, steps=length).long()
    rev_src = src[:,idx]
    return rev_src'''

In [None]:
'''context_size = 500
num_layers = 1
seq2context = SequenceModel(len(DE.vocab),context_size,num_layers)
context2trg = LanguageModel(len(EN.vocab),hidden_size=context_size,context_size=context_size,num_layers=num_layers)
seq2context,context2trg = seq2context.cuda(),context2trg.cuda()
seq2context_optimizer = torch.optim.Adam(seq2context.parameters(), lr=1e-5)
context2trg_optimizer = torch.optim.Adam(context2trg.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss(reduction='none')'''

In [None]:
'''def training_loop(e=0):
    seq2context.train()
    context2trg.train()
    h0 = None
    for ix,batch in enumerate(train_iter):
        seq2context_optimizer.zero_grad()
        context2trg_optimizer.zero_grad()
        
        src = batch.src.values.transpose(0,1)
        src = reverse_sequence(src)
        trg = batch.trg.values.transpose(0,1)
        if src.shape[0]!=BATCH_SIZE:
            break
        else:
            # generate hidden state for decoder
            context, hidden_s2c = seq2context(src,h0)
            hidden = repackage_layer(hidden_s2c,context_size)
            output, hidden_lm = context2trg(trg[:,:-1],hidden)
            loss = criterion(output.transpose(2,1),trg[:,1:])
            mask = trg[:,1:]!=1
            loss = loss[mask].sum()
            #clip_grad_norm_(seq2context.parameters(), max_norm=5)
            #clip_grad_norm_(context2trg.parameters(), max_norm=5)
            loss.backward()
            seq2context_optimizer.step()
            context2trg_optimizer.step()
        if np.mod(ix,100) == 0:
            var = torch.var(torch.argmax(lsm(output).cpu().detach(),2).float())
            print('Epoch: {}, Batch: {}, loss: {}, var: {},'.format(e, ix, loss.cpu().detach()/BATCH_SIZE, var))
    loss = 0
    for b in iter(val_iter):
        src = b.src.values.transpose(0,1)
        src = reverse_sequence(src)
        trg = b.trg.values.transpose(0,1)
        if src.shape[0]!=BATCH_SIZE:
            break
        else:
            # generate hidden state for decoder
            context, hidden_s2c = seq2context(src,h0)
            hidden = repackage_layer(hidden_s2c,context_size)
            output, hidden_lm = context2trg(trg[:,:-1],hidden)
            bloss = criterion(output.transpose(2,1),trg[:,1:])
            mask = trg[:,1:]!=1
            loss += bloss[mask].sum()
    print('Epoch: {}, loss: {}, var: {},'.format(e, loss.cpu().detach()/(BATCH_SIZE*len(val_iter))))'''

In [None]:
'''for e in range(2):
    training_loop(e)
    #training_loop(e,train_iter,seq2context,context2trg,seq2context_optimizer,context2trg_optimizer,BATCH_SIZE)
    #validation_loop(e,val_iter,seq2context,context2trg,BATCH_SIZE)'''

In [None]:
'''for ix,batch in enumerate(train_iter):
    src = batch.src.values.transpose(0,1)
    trg = batch.trg.values.transpose(0,1)
    break

h0 = None
context, hidden_s2c = seq2context(reverse_sequence(src),h0)
hidden = repackage_layer(hidden_s2c,context_size)
output, hidden_lm = context2trg(trg[:,:-1],hidden)

[EN.vocab.itos[i] for i in torch.argmax(lsm(output),2)[30,:]]'''

## Beam Search

In [125]:
# define attention-based encoder-decoder model
class attn_RNNet_batched(torch.nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, dropout=0.5, weight_init=0.05):
        super(attn_RNNet_batched, self).__init__()
        self.emb = torch.nn.Sequential(torch.nn.Embedding(input_size, hidden_size), torch.nn.Dropout(dropout))
        self.rnn = torch.nn.LSTM(input_size=2*hidden_size, hidden_size=hidden_size, num_layers=num_layers, bias=True, batch_first=True, dropout=dropout)
        self.lnr = torch.nn.Sequential(torch.nn.Dropout(dropout), torch.nn.Linear(2*hidden_size, input_size))
    
        for f in self.parameters():
            torch.nn.init.uniform_(f, a=-weight_init, b=weight_init)

    def forward(self, word_input, last_context, last_hidden, encoder_outputs):
        word_embedded = self.emb(word_input)
        rnn_input = torch.cat([word_embedded, last_context], 1).unsqueeze(1) # batch x 1 x hiddenx2
        rnn_output, hidden = self.rnn(rnn_input, last_hidden)
        attn_weights = rnn_output.bmm(encoder_outputs.transpose(1,2))# batch x src_seqlen x 1
        context = attn_weights.bmm(encoder_outputs)
        rnn_output = rnn_output.squeeze(1)
        context = context.squeeze(1)
        output = self.lnr(torch.cat((rnn_output, context), 1))
        # prediction, last_context, last_hidden, weights for vis
        return output, context, hidden, attn_weights 

In [126]:
# initialize model
context_size = 500
num_layers = 2
attn_context2trg = attn_RNNet_batched(input_size=len(EN.vocab),hidden_size=context_size,num_layers=num_layers)
attn_context2trg = attn_context2trg.cuda()
seq2context = SequenceModel(len(DE.vocab),context_size,num_layers=num_layers)
seq2context = seq2context.cuda()

In [127]:
# prep for training
attn_context2trg_optimizer = torch.optim.Adam(attn_context2trg.parameters(), lr=1e-3)
seq2context_optimizer = torch.optim.Adam(seq2context.parameters(), lr=1e-3)
criterion_train = nn.CrossEntropyLoss(reduction='sum')

In [128]:
def attn_training_loop(e=0):
    for ix,batch in enumerate(train_iter):
        src = batch.src.values.transpose(0,1)
        src = reverse_sequence(src)
        trg = batch.trg.values.transpose(0,1)
        if trg.shape[0] == BATCH_SIZE:
        
            seq2context_optimizer.zero_grad()
            attn_context2trg_optimizer.zero_grad()
        
            encoder_outputs, encoder_hidden = seq2context(src)
            loss = 0
            decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
            decoder_hidden = encoder_hidden
            sentence = []
            for j in range(trg.shape[1] - 1):
                word_input = trg[:,j]
                decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, decoder_context, decoder_hidden, encoder_outputs)
                #print(decoder_output.shape, trg[i,j+1].view(-1).shape)
                loss += criterion_train(decoder_output, trg[:,j+1])
                
                if np.mod(ix,100) == 0:
                    sentence.extend([torch.argmax(decoder_output[0,:],dim=0)])
                
            loss.backward()
            seq2context_optimizer.step()
            attn_context2trg_optimizer.step()
        
            if np.mod(ix,100) == 0:
                print('Epoch: {}, Batch: {}, Loss: {}'.format(e, ix, loss.cpu().detach()/BATCH_SIZE))
                print([EN.vocab.itos[i] for i in sentence])
                print([EN.vocab.itos[i] for i in trg[0,:]])

In [None]:
for e in range(10,11):
    attn_training_loop(e)

In [137]:
# run beam search on one batch
# run forward pass of encoder once
it = iter(train_iter)
batch = next(it)
src = batch.src.values.transpose(0,1)
src = reverse_sequence(src)
max_len = src.shape[1] # restrict target sentence length to source sentence length
encoder_outputs, encoder_hidden = seq2context(src)
decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
decoder_hidden = encoder_hidden

In [138]:
# prepare for beam search
START_TKN = EN.vocab.stoi["<s>"]
END_TKN = EN.vocab.stoi["</s>"]
BEAM_WIDTH = 3
lsm = nn.LogSoftmax(dim=1)

predictions = []
b_string = torch.zeros((BATCH_SIZE, max_len, BEAM_WIDTH), device='cuda')
b_string[:,0,:] = START_TKN

b_probs = {}
# b_probs key = tuple(batch idx, beam idx)
# b_probs val = [cum log prob, length]
for b in range(BATCH_SIZE):
    for c in range(BEAM_WIDTH):
        b_probs[(b, c)] = [0, 1]

In [139]:
# loop through target sequence max len
for i in range(max_len - 1):
    if i == 0: # if predicting the word following <s>, take top BEAM_WIDTH preds
        word_input = b_string[:,i,0].long()
        decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, 
                                                                                                  decoder_context, 
                                                                                                  decoder_hidden, 
                                                                                                  encoder_outputs)
        logprobs = lsm(decoder_output.detach())
        toppreds = torch.argsort(logprobs, dim=1, descending=True)[:,0:BEAM_WIDTH]
        b_string[:,i+1,:] = toppreds
        for b in range(BATCH_SIZE):
            for c in range(BEAM_WIDTH):
                b_probs[tuple((b,c))][0] += logprobs[b, c]
                b_probs[tuple((b,c))][1] += 1
    else: # if predicting the word for positions 2+, compare top BEAM_WIDTH preds for each of BEAM_WIDTH strings
        # temporary storage
        curr_probs = {}
        curr_string = torch.zeros(BATCH_SIZE, i+2, BEAM_WIDTH)

        for j in range(BEAM_WIDTH):
            word_input = b_string[:,i,j].long()
            decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, 
                                                                                                      decoder_context, 
                                                                                                      decoder_hidden, 
                                                                                                      encoder_outputs)
            logprobs = lsm(decoder_output.detach())
            toppreds = torch.argsort(logprobs, dim=1, descending=True)[:,0:BEAM_WIDTH]

            if j == 0: # if preds are from first beam, take top BEAM_WIDTH preds (temporarily)
                for b in range(BATCH_SIZE):
                    for c in range(BEAM_WIDTH):
                        new_b_prob = b_probs[tuple((b,j))][0] + logprobs[b,c]
                        curr_probs[tuple((b,c))] = new_b_prob # set top prob
                        curr_string[b,0:i+1,c] = b_string[b,0:i+1,j] # set sentence
                        curr_string[b,i+1,c] = toppreds[b,c] # set top word
            else:
                for b in range(BATCH_SIZE):
                    for c in range(BEAM_WIDTH): # proposed strings
                        replaced = False
                        for d in range(BEAM_WIDTH): # existing strings
                            new_b_prob = b_probs[tuple((b,j))][0] + logprobs[b,c]
                            if new_b_prob > curr_probs[tuple((b,d))] and not replaced:
                                curr_probs[tuple((b,d))] = new_b_prob # update top prob
                                curr_string[b,0:i+1,d] = b_string[b,0:i+1,j] # update sentence
                                curr_string[b,i+1,d] = toppreds[b,c] # update top word
                                replaced = True
        #print(b_string[:,0:i+2,:].shape, curr_string.shape)                        
        b_string[:,0:i+2,:] = curr_string
        for b in range(BATCH_SIZE):
            for c in range(BEAM_WIDTH):
                b_probs[tuple((b,c))][0] = curr_probs[tuple((b,c))]
                b_probs[tuple((b,c))][1] += 1

In [140]:
trg = batch.trg.values.transpose(0,1)
# check out prediction
for b in range(BATCH_SIZE):
    print('predictions:')
    for j in range(BEAM_WIDTH):
        print([EN.vocab.itos[i] for i in b_string[b,:,j].long()])
    print('actual:')
    print([EN.vocab.itos[i] for i in trg[b,:].long()])

predictions:
['<s>', 'And', 'we', 'run', 'in', 'areas', 'where', 'we', 'are', 'far', 'do', 'like', 'well', '.', '</s>', '<pad>', 'will', 'be', 'a', 'matter']
['<s>', 'And', 'we', 'run', 'in', 'areas', 'where', 'we', 'are', 'far', 'do', 'like', 'well', '.', '</s>', 'do', 'in', 'all', 'of', 'health']
['<s>', 'And', 'we', 'run', 'in', 'areas', 'where', 'we', 'are', 'far', 'do', 'like', 'well', '.', '</s>', '<pad>', 'will', 'do', "n't", '<unk>']
actual:
['<s>', 'And', 'we', 'run', 'clinics', 'in', 'these', 'very', 'remote', 'regions', 'where', 'there', "'s", 'no', 'medical', 'care', 'whatsoever', '.', '</s>', '<pad>', '<pad>', '<pad>']
predictions:
['<s>', 'When', 'I', 'been', 'mentioned', 'immediately', 'his', 'eye', '.', '</s>', '</s>', 'mentioned', ',', '</s>', ')', '</s>', '</s>', 'lettuces', 'fleet', 'of']
['<s>', 'When', 'I', 'been', 'mentioned', 'immediately', 'his', 'eye', '.', '</s>', '</s>', 'mentioned', ',', '</s>', ')', '</s>', '</s>', 'lettuces', 'fleet', 'was']
['<s>', 'When'

In [None]:
# attention:
# https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb

# normalization:
# https://medium.com/machine-learning-bites/deeplearning-series-sequence-to-sequence-architectures-4c4ca89e5654

## Submission

In [None]:
'''# load test set
sentences = []
for i, l in enumerate(open("source_test.txt"), 1):
  sentences.append(re.split(' ', l))'''