# CS 287, Homework 3: Neural Machine Translation

In [5]:
import torch
from torch.nn.utils import clip_grad_norm_
torch.__version__
from common import *
## Setup
import torch.nn.functional as F

#!pip install --upgrade pip
#!pip install -q numpy

#!pip install -q torch torchtext spacy opt_einsum
#!pip install -qU git+https://github.com/harvardnlp/namedtensor
#!python -m spacy download en
#!python -m spacy download de

# Torch
import torch.nn as nn
import torch
# Text text processing library and methods for pretrained word embeddings
from torchtext import data, datasets
# Named Tensor wrappers
from namedtensor import ntorch, NamedTensor
from namedtensor.text import NamedField
import numpy as np
%reload_ext autoreload
%autoreload 2

In [6]:
# split raw data into tokens
import spacy
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

# add beginning-of-sentence and end-of-sentence tokens to target
BOS_WORD = '<s>'
EOS_WORD = '</s>'
DE = NamedField(names=('srcSeqlen',), tokenize=tokenize_de)
EN = NamedField(names=('trgSeqlen',), tokenize=tokenize_en,
                init_token = BOS_WORD, eos_token = EOS_WORD) # only target needs BOS/EOS

# download dataset of 200K pairs of sentences
# start with MAXLEN = 20
MAX_LEN = 20
train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(DE, EN), 
                                         filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and 
                                         len(vars(x)['trg']) <= MAX_LEN)
print(train.fields)
print(len(train))
print(vars(train[0]))

# WHAT DOES THIS DO?
'''src = open("valid.src", "w")
trg = open("valid.trg", "w")
for example in val:
    print(" ".join(example.src), file=src)
    print(" ".join(example.trg), file=trg)
src.close()
trg.close()'''

# build vocab, convert words to indices
MIN_FREQ = 5
DE.build_vocab(train.src, min_freq=MIN_FREQ)
EN.build_vocab(train.trg, min_freq=MIN_FREQ)
print(DE.vocab.freqs.most_common(10))
print("Size of German vocab", len(DE.vocab))
print(EN.vocab.freqs.most_common(10))
print("Size of English vocab", len(EN.vocab))
print(EN.vocab.stoi["<s>"], EN.vocab.stoi["</s>"])

print(EN.vocab.stoi["<pad>"], EN.vocab.stoi["<unk>"])

{'src': <namedtensor.text.torch_text.NamedField object at 0x7fecc05cb438>, 'trg': <namedtensor.text.torch_text.NamedField object at 0x7fecb98336a0>}
119076
{'src': ['David', 'Gallo', ':', 'Das', 'ist', 'Bill', 'Lange', '.', 'Ich', 'bin', 'Dave', 'Gallo', '.'], 'trg': ['David', 'Gallo', ':', 'This', 'is', 'Bill', 'Lange', '.', 'I', "'m", 'Dave', 'Gallo', '.']}
[('.', 113253), (',', 67237), ('ist', 24189), ('die', 23778), ('das', 17102), ('der', 15727), ('und', 15622), ('Sie', 15085), ('es', 13197), ('ich', 12946)]
Size of German vocab 13353
[('.', 113433), (',', 59512), ('the', 46029), ('to', 29177), ('a', 27548), ('of', 26794), ('I', 24887), ('is', 21775), ("'s", 20630), ('that', 19814)]
Size of English vocab 11560
2 3
1 0


In [7]:
# split data into batches
BATCH_SIZE = 32
device = torch.device('cuda:0')
train_iter, val_iter = data.BucketIterator.splits((train, val), batch_size=BATCH_SIZE, device=device,
                                                  repeat=False, sort_key=lambda x: len(x.src))

## Sequence to Sequence Learning with Neural Networks

- English to French translation, $p \left( y_1, \dots, y_{T'} \ | \ x_1, \dots, x_T \right) = \prod_{t = 1}^{T'} p \left( y_t \ | \ v, y_1, \dots, y_{t-1} \right)$
- Each sentence ends in '<EOS\>', out-of-vocab words denoted '<UNK\>'
- Model specs: 
    * Input vocabulary of 160,000 and output vocabulary of 80,000
    * Deep LSTM to map (encode) input sequence to fixed-len vector
    * Another deep LSTM to translate (decode) fixed-len vector to output sequence
    * 4 layers per LSTM, 1000 cells per layer, 1000-dimensional word embeddings, softmax over 80,000 words
    * Reversing order of words in source (but not target) improved performance
        * Each word in the source is far from its corresponding word in the target (large minimal time lag); reversing the source reduces the minimal time lag, thereby allowing backprop to establish communication between source and target more easily
- Training specs:
    * Initialize all LSTM params $\sim Unif[-0.08,0.08]$
    * SGD w/o momentum, lr = 0.7
        * After 5 epochs, halve the lr every half-epoch
        * Train for 7.5 epochs
    * Batch size = 128; divide gradient by batch size (denoted $g$)
    * Hard constraint gradient norm; if $s = ||g||_2 > 5$, set $s = 5$
    * Make sure all sentences within a minibatch are roughly the same length
- Objective: $max \frac{1}{|S|} \sum_{(T,S) \in \mathcal{S}} log \ p(T \ | \ S)$, where $\mathcal{S}$ is the training set
- Prediction: $\hat{T} = argmax \ p(T \ | \ S)$ via beam search, where beam size $B \in {1,2}$

In [8]:
context_size = 500
num_layers = 2
attn_context2trg = attn_RNNet_batched(input_size=len(EN.vocab),hidden_size=context_size,num_layers=num_layers)
attn_context2trg = attn_context2trg.cuda()
attn_context2trg_optimizer = torch.optim.Adam(attn_context2trg.parameters(), lr=1e-3)

seq2context = SequenceModel(len(DE.vocab),context_size,num_layers=num_layers)
seq2context_optimizer = torch.optim.Adam(seq2context.parameters(), lr=1e-3)
seq2context = seq2context.cuda()



scheduler_c2t = torch.optim.lr_scheduler.ReduceLROnPlateau(attn_context2trg_optimizer, mode="min", patience=4)
scheduler_s2c = torch.optim.lr_scheduler.ReduceLROnPlateau(seq2context_optimizer, mode="min", patience=4)



In [None]:
best_ppl = 1e8
for e in range(0,300):
    attn_training_loop(e,train_iter,seq2context,attn_context2trg,seq2context_optimizer,attn_context2trg_optimizer)
    ppl = attn_validation_loop(e,val_iter,seq2context,attn_context2trg,scheduler_c2t,scheduler_s2c,BATCH_SIZE=32,context_size=500)
    if ppl < best_ppl:
        torch.save(seq2context.state_dict(),'best_seq2seq_withattn_seq2context.pt')
        torch.save(attn_context2trg.state_dict(),'best_seq2seq_withattn_context2trg.pt')
        best_ppl = ppl
        print('Wrote model!')

Epoch: 0, Batch: 0, Loss: 196.80746459960938
Epoch: 0, Batch: 500, Loss: 57.9072151184082
Epoch: 0, Batch: 1000, Loss: 59.319637298583984
Epoch: 0, Batch: 1500, Loss: 46.97551345825195
Epoch: 0, Batch: 2000, Loss: 52.368019104003906
Epoch: 0, Batch: 2500, Loss: 42.604759216308594
Epoch: 0, Batch: 3000, Loss: 46.63115692138672
Epoch: 0, Batch: 3500, Loss: 36.59870147705078
Epoch: 0, Validation PPL: 7.792147636413574
Wrote model!
Epoch: 1, Batch: 0, Loss: 37.57067108154297
Epoch: 1, Batch: 500, Loss: 41.15217971801758
Epoch: 1, Batch: 1000, Loss: 38.78003692626953
Epoch: 1, Batch: 1500, Loss: 37.293521881103516
Epoch: 1, Batch: 2000, Loss: 37.09842300415039
Epoch: 1, Batch: 2500, Loss: 40.06744384765625
Epoch: 1, Batch: 3000, Loss: 40.234317779541016
Epoch: 1, Batch: 3500, Loss: 41.856075286865234
Epoch: 1, Validation PPL: 5.6893157958984375
Wrote model!
Epoch: 2, Batch: 0, Loss: 37.55695343017578
Epoch: 2, Batch: 500, Loss: 37.20197677612305
Epoch: 2, Batch: 1000, Loss: 38.4419937133789

In [18]:
ppl = attn_validation_loop(e,val_iter,seq2context,attn_context2trg,scheduler_c2t,scheduler_s2c,BATCH_SIZE=32,context_size=500)

Epoch: 65, Validation PPL: 4.2662811279296875


In [12]:
#best_ppl = 1e8
for e in range(0,300):
    attn_training_split_loop(0,train_iter,seq2context,attn_context2trg,seq2context_optimizer,attn_context2trg_optimizer,EN=EN)
    ppl = attn_validation_loop(e,val_iter,seq2context,attn_context2trg,scheduler_c2t,scheduler_s2c,BATCH_SIZE=32,context_size=500)
    if ppl < best_ppl:
        torch.save(seq2context.state_dict(),'best_seq2seq_withattn_seq2context_splittrain2.pt')
        torch.save(attn_context2trg.state_dict(),'best_seq2seq_withattn_context2trg_splittrain2.pt')
        best_ppl = ppl
        print('Wrote model!')

Epoch: 0, Batch: 0, Loss: 40.098819732666016
Epoch: 0, Batch: 500, Loss: 44.26935577392578
Epoch: 0, Batch: 1000, Loss: 36.21977233886719
Epoch: 0, Batch: 1500, Loss: 54.8524169921875
Epoch: 0, Batch: 2000, Loss: 44.83747482299805
Epoch: 0, Batch: 2500, Loss: 47.17257308959961
Epoch: 0, Batch: 3000, Loss: 30.97538948059082
Epoch: 0, Batch: 3500, Loss: 54.668338775634766
Epoch: 0, Validation PPL: 1.0001386404037476
Wrote model!
Epoch: 0, Batch: 0, Loss: 34.754852294921875
Epoch: 0, Batch: 500, Loss: 44.461830139160156
Epoch: 0, Batch: 1000, Loss: 52.99005126953125


KeyboardInterrupt: 

In [13]:
ppl = attn_validation_loop(0,val_iter,seq2context,attn_context2trg,scheduler_c2t,scheduler_s2c,BATCH_SIZE=32,context_size=500)

Epoch: 0, Validation PPL: 15.696435928344727


In [9]:
seq2context.train()
attn_context2trg.train()
for ix,batch in enumerate(train_iter):
        src = batch.src.values.transpose(0,1)
        src = reverse_sequence(src)
        trg = batch.trg.values.transpose(0,1)
        break
        if trg.shape[0] == BATCH_SIZE:
        
            seq2context_optimizer.zero_grad()
            attn_context2trg_optimizer.zero_grad()
        
            encoder_outputs, encoder_hidden = seq2context(src)
            loss = 0
            decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
            decoder_hidden = encoder_hidden
            sentence = []
            for j in range(trg.shape[1] - 1):
                word_input = trg[:,j]
                decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, decoder_context, decoder_hidden, encoder_outputs)
                #print(decoder_output.shape, trg[i,j+1].view(-1).shape)
                loss += criterion_train(decoder_output, trg[:,j+1])
                
                if np.mod(ix,100) == 0:
                    sentence.extend([torch.argmax(decoder_output[0,:],dim=0)])
                
            loss.backward()
            seq2context_optimizer.step()
            attn_context2trg_optimizer.step()
        
            if np.mod(ix,500) == 0:
                print('Epoch: {}, Batch: {}, Loss: {}'.format(e, ix, loss.cpu().detach()/BATCH_SIZE))
                #print([EN.vocab.itos[i] for i in sentence])
                #print([EN.vocab.itos[i] for i in trg[0,:]])
    

In [229]:

context_size = 500
num_layers = 2
max_len = 3
attn_seq2context = SequenceModel(len(DE.vocab),context_size,num_layers=num_layers)
state_dict = torch.load('best_seq2seq_withattn_seq2context.pt')
attn_seq2context.load_state_dict(state_dict)
attn_seq2context = attn_seq2context.cuda()

attn_context2trg = attn_RNNet_batched(input_size=len(EN.vocab),hidden_size=context_size,num_layers=num_layers)
state_dict = torch.load('best_seq2seq_withattn_context2trg.pt')
attn_context2trg.load_state_dict(state_dict)
attn_context2trg = attn_context2trg.cuda()

def beam_search(seq2context, attn_context2trg, BEAM_WIDTH = 2, BATCH_SIZE=32, max_len=3):
    top_p = {}
    top_s = {}
    stopped = torch.zeros(BATCH_SIZE,BEAM_WIDTH)==1
    items = []
    for i in range(BATCH_SIZE):
        top_p[i] = []
        top_s[i] = []
        items.append(i)
    
    encoder_outputs, encoder_hidden = seq2context(src)
    decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
    decoder_hidden = encoder_hidden
    word_input = trg[:,0]
    decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, decoder_context, decoder_hidden, encoder_outputs)

    next_words = torch.argsort(lsm2(decoder_output),dim=1, descending=True)[:,0:BEAM_WIDTH]
    p_words_init = torch.stack([torch.index_select(decoder_output[i,:],-1,next_words[i,:]) for i in range(BATCH_SIZE)])
    p_words_running = torch.stack([p_words_init[:,b].repeat(1,2) for b in range(BEAM_WIDTH)]).view(BEAM_WIDTH**2,BATCH_SIZE).transpose(0,1)
    next_words = next_words.transpose(0,1).flatten()
    
    top_p.update(dict(zip(items, p_words_init)))
    top_s.update(dict(zip(items, next_words)))
    
    encoder_outputs = encoder_outputs.repeat(BEAM_WIDTH,1,1)
    decoder_hidden = tuple([h.repeat(1,BEAM_WIDTH,1) for h in decoder_hidden])
    decoder_context = decoder_context.repeat(2,1)
    for j in range(max_len):
            decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(next_words, decoder_context, decoder_hidden, encoder_outputs)
            args = torch.argsort(lsm2(decoder_output),dim=1, descending=True)[:,0:BEAM_WIDTH]
            next_words = torch.cat([args[BATCH_SIZE*(b):BATCH_SIZE*(b+1),:] for b in range(BEAM_WIDTH)],dim=1)
            p_words = torch.stack([torch.index_select(decoder_output[i,:],-1,next_words[i,:]) for i in range(BATCH_SIZE)])
            p_words_running += p_words
            p_words_norm = p_words_running/(j+1) 
            word_selector = torch.argsort(p_words_norm,dim=1,descending=True)[:,:BEAM_WIDTH]
            next_words = torch.stack([torch.index_select(next_words[s,:],0,word_selector[s,:]) for s in range(BATCH_SIZE)]).transpose(0,1).flatten()
            break
            torch.index_select(top_s.values, beam_indicator)
            top_s.update(dict(zip(items, )))
            
            beam_indicator = word_selector>=2
            indexs = torch.zeros(BATCH_SIZE,2,device='cuda')
            for i in range(BATCH_SIZE):
                indexs[i,:] += i+(BATCH_SIZE*beam_indicator[i,:].float())
            indexs = indexs.long()
            indexs = indexs.transpose(0,1).flatten()
            decoder_hidden = tuple([torch.index_select(h,1,indexs) for h in decoder_hidden])
            decoder_context = torch.index_select(decoder_context,0,indexs)
            
            
            top_p.update(dict(zip(items, )))

In [231]:
next_words.shape
        

torch.Size([64])

In [306]:

top_p = {}
top_s = {}
stopped = torch.zeros(BATCH_SIZE,BEAM_WIDTH)==1
items = []
for i in range(BATCH_SIZE):
    top_p[i] = []
    top_s[i] = []
    items.append(i)

encoder_outputs, encoder_hidden = seq2context(src)
decoder_context = torch.zeros(BATCH_SIZE, context_size, device='cuda') # 32 x 500
decoder_hidden = encoder_hidden
word_input = trg[:,0]
decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(word_input, decoder_context, decoder_hidden, encoder_outputs)

next_words = torch.argsort(lsm2(decoder_output),dim=1, descending=True)[:,0:BEAM_WIDTH]
p_words_init = torch.stack([torch.index_select(decoder_output[i,:],-1,next_words[i,:]) for i in range(BATCH_SIZE)])
p_words_running = torch.stack([p_words_init[:,b].repeat(1,2) for b in range(BEAM_WIDTH)]).view(BEAM_WIDTH**2,BATCH_SIZE).transpose(0,1)


update = []
for ix,p in enumerate(next_words):
    update.append([torch.stack([trg[0,0]]+([next_words[ix,b]])) for b in range(BEAM_WIDTH)])

top_p.update(dict(zip(items, p_words_init)))
top_s.update(dict(zip(items, update)))
next_words = next_words.transpose(0,1).flatten()

encoder_outputs = encoder_outputs.repeat(BEAM_WIDTH,1,1)
decoder_hidden = tuple([h.repeat(1,BEAM_WIDTH,1) for h in decoder_hidden])
decoder_context = decoder_context.repeat(2,1)
for j in range(max_len):
        break
        decoder_output, decoder_context, decoder_hidden, decoder_attention = attn_context2trg(next_words, decoder_context, decoder_hidden, encoder_outputs)
        args = torch.argsort(lsm2(decoder_output),dim=1, descending=True)[:,0:BEAM_WIDTH]
        next_words = torch.cat([args[BATCH_SIZE*(b):BATCH_SIZE*(b+1),:] for b in range(BEAM_WIDTH)],dim=1)
        p_words = torch.stack([torch.index_select(decoder_output[i,:],-1,next_words[i,:]) for i in range(BATCH_SIZE)])
        p_words_running += p_words
        p_words_norm = p_words_running/(j+1) 
        word_selector = torch.argsort(p_words_norm,dim=1,descending=True)[:,:BEAM_WIDTH]
        prev_words = torch.stack(list(top_s.values()))
        words=([torch.index_select(prev_words[s,:],0, beam_indicator[s,:].long()) for s in range(BATCH_SIZE)])

        update = []
        for ix,p in enumerate(words):
            update.append([torch.stack([p[b]]+([next_words[ix,b]])) for b in range(BEAM_WIDTH)])
        top_s.update(dict(zip(items, update)))
        next_words = torch.stack([torch.index_select(next_words[s,:],0,word_selector[s,:]) for s in range(BATCH_SIZE)]).transpose(0,1).flatten()
        

In [304]:
next_words.shape

torch.Size([64])

In [291]:
update = []
for ix,p in enumerate(words):
    update.append([torch.stack([p[b]]+([next_words[ix,b]])) for b in range(BEAM_WIDTH)])

In [307]:
update

[[tensor([2, 0], device='cuda:0'), tensor([ 2, 14], device='cuda:0')],
 [tensor([2, 0], device='cuda:0'), tensor([ 2, 24], device='cuda:0')],
 [tensor([2, 0], device='cuda:0'), tensor([ 2, 27], device='cuda:0')],
 [tensor([2, 0], device='cuda:0'), tensor([ 2, 14], device='cuda:0')],
 [tensor([2, 0], device='cuda:0'), tensor([ 2, 14], device='cuda:0')],
 [tensor([2, 0], device='cuda:0'), tensor([ 2, 27], device='cuda:0')],
 [tensor([2, 0], device='cuda:0'), tensor([ 2, 42], device='cuda:0')],
 [tensor([ 2, 27], device='cuda:0'), tensor([ 2, 24], device='cuda:0')],
 [tensor([2, 0], device='cuda:0'), tensor([ 2, 24], device='cuda:0')],
 [tensor([2, 0], device='cuda:0'), tensor([ 2, 24], device='cuda:0')],
 [tensor([2, 0], device='cuda:0'), tensor([ 2, 52], device='cuda:0')],
 [tensor([ 2, 14], device='cuda:0'), tensor([ 2, 27], device='cuda:0')],
 [tensor([2, 0], device='cuda:0'), tensor([ 2, 27], device='cuda:0')],
 [tensor([2, 0], device='cuda:0'), tensor([ 2, 24], device='cuda:0')],
 [

tensor([0, 5], device='cuda:0')

In [184]:
beam_indicator = torch.argsort(p_words_running,dim=1,descending=True)[:,:BEAM_WIDTH]>=2
indexs = torch.zeros(BATCH_SIZE,2,device='cuda')
for i in range(BATCH_SIZE):
    indexs[i,:] += i+(BATCH_SIZE*beam_indicator[i,:].float())

In [190]:
indexs = indexs.long()
indexs = indexs.transpose(0,1).flatten()
decoder_hidden = tuple([torch.index_select(h,1,indexs) for h in decoder_hidden])
decoder_context = torch.index_select(decoder_context,0,indexs)

torch.Size([64, 500])

In [116]:
next_words = torch.cat([args[BATCH_SIZE*(b):BATCH_SIZE*(b+1),:] for b in range(BEAM_WIDTH)],dim=1)

In [117]:
p_words = torch.stack([torch.index_select(decoder_output[i,:],-1,next_words[i,:]) for i in range(BATCH_SIZE)])

In [225]:
dc = dict()

In [228]:
items = []
for i in range(10):
    items.append(i)
    dc[i] = []
x = np.random.rand(10)
dc.update(dict(zip(items, x)))
dc

{0: 0.12301008936547375,
 1: 0.1743470876801082,
 2: 0.024813375393648696,
 3: 0.4313473508338772,
 4: 0.7169799087697334,
 5: 0.2156059142994965,
 6: 0.07298992879668831,
 7: 0.8894616281344004,
 8: 0.3090598195280607,
 9: 0.6069338103608437}

## Submission

In [None]:
# load test set
sentences = []
for i, l in enumerate(open("source_test.txt"), 1):
    sentences.append(re.split(' ', l))