# CS 287, Homework 3: Neural Machine Translation

In [8]:
import torch
from torch.nn.utils import clip_grad_norm_
torch.__version__
from common import *
## Setup
import torch.nn.functional as F

#!pip install --upgrade pip
#!pip install -q numpy

#!pip install -q torch torchtext spacy opt_einsum
#!pip install -qU git+https://github.com/harvardnlp/namedtensor
#!python -m spacy download en
#!python -m spacy download de

# Torch
import torch.nn as nn
import torch
# Text text processing library and methods for pretrained word embeddings
from torchtext import data, datasets
# Named Tensor wrappers
from namedtensor import ntorch, NamedTensor
from namedtensor.text import NamedField
import numpy as np
%reload_ext autoreload
%autoreload 2

In [2]:
# split raw data into tokens
import spacy
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

# add beginning-of-sentence and end-of-sentence tokens to target
BOS_WORD = '<s>'
EOS_WORD = '</s>'
DE = NamedField(names=('srcSeqlen',), tokenize=tokenize_de)
EN = NamedField(names=('trgSeqlen',), tokenize=tokenize_en,
                init_token = BOS_WORD, eos_token = EOS_WORD) # only target needs BOS/EOS

# download dataset of 200K pairs of sentences
# start with MAXLEN = 20
MAX_LEN = 20
train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(DE, EN), 
                                         filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN and 
                                         len(vars(x)['trg']) <= MAX_LEN)
print(train.fields)
print(len(train))
print(vars(train[0]))

# WHAT DOES THIS DO?
'''src = open("valid.src", "w")
trg = open("valid.trg", "w")
for example in val:
    print(" ".join(example.src), file=src)
    print(" ".join(example.trg), file=trg)
src.close()
trg.close()'''

# build vocab, convert words to indices
MIN_FREQ = 5
DE.build_vocab(train.src, min_freq=MIN_FREQ)
EN.build_vocab(train.trg, min_freq=MIN_FREQ)
print(DE.vocab.freqs.most_common(10))
print("Size of German vocab", len(DE.vocab))
print(EN.vocab.freqs.most_common(10))
print("Size of English vocab", len(EN.vocab))
print(EN.vocab.stoi["<s>"], EN.vocab.stoi["</s>"])

print(EN.vocab.stoi["<pad>"], EN.vocab.stoi["<unk>"])

{'src': <namedtensor.text.torch_text.NamedField object at 0x7f0210038b38>, 'trg': <namedtensor.text.torch_text.NamedField object at 0x7f02097ef5c0>}
119076
{'src': ['David', 'Gallo', ':', 'Das', 'ist', 'Bill', 'Lange', '.', 'Ich', 'bin', 'Dave', 'Gallo', '.'], 'trg': ['David', 'Gallo', ':', 'This', 'is', 'Bill', 'Lange', '.', 'I', "'m", 'Dave', 'Gallo', '.']}
[('.', 113253), (',', 67237), ('ist', 24189), ('die', 23778), ('das', 17102), ('der', 15727), ('und', 15622), ('Sie', 15085), ('es', 13197), ('ich', 12946)]
Size of German vocab 13353
[('.', 113433), (',', 59512), ('the', 46029), ('to', 29177), ('a', 27548), ('of', 26794), ('I', 24887), ('is', 21775), ("'s", 20630), ('that', 19814)]
Size of English vocab 11560
2 3
1 0


In [3]:
# split data into batches
BATCH_SIZE = 32
device = torch.device('cuda:0')
train_iter, val_iter = data.BucketIterator.splits((train, val), batch_size=BATCH_SIZE, device=device,
                                                  repeat=False, sort_key=lambda x: len(x.src))

## Sequence to Sequence Learning with Neural Networks

- English to French translation, $p \left( y_1, \dots, y_{T'} \ | \ x_1, \dots, x_T \right) = \prod_{t = 1}^{T'} p \left( y_t \ | \ v, y_1, \dots, y_{t-1} \right)$
- Each sentence ends in '<EOS\>', out-of-vocab words denoted '<UNK\>'
- Model specs: 
    * Input vocabulary of 160,000 and output vocabulary of 80,000
    * Deep LSTM to map (encode) input sequence to fixed-len vector
    * Another deep LSTM to translate (decode) fixed-len vector to output sequence
    * 4 layers per LSTM, 1000 cells per layer, 1000-dimensional word embeddings, softmax over 80,000 words
    * Reversing order of words in source (but not target) improved performance
        * Each word in the source is far from its corresponding word in the target (large minimal time lag); reversing the source reduces the minimal time lag, thereby allowing backprop to establish communication between source and target more easily
- Training specs:
    * Initialize all LSTM params $\sim Unif[-0.08,0.08]$
    * SGD w/o momentum, lr = 0.7
        * After 5 epochs, halve the lr every half-epoch
        * Train for 7.5 epochs
    * Batch size = 128; divide gradient by batch size (denoted $g$)
    * Hard constraint gradient norm; if $s = ||g||_2 > 5$, set $s = 5$
    * Make sure all sentences within a minibatch are roughly the same length
- Objective: $max \frac{1}{|S|} \sum_{(T,S) \in \mathcal{S}} log \ p(T \ | \ S)$, where $\mathcal{S}$ is the training set
- Prediction: $\hat{T} = argmax \ p(T \ | \ S)$ via beam search, where beam size $B \in {1,2}$

In [4]:
context_size = 500
num_layers = 2
attn_context2trg = attn_RNNet_batched(input_size=len(EN.vocab),hidden_size=context_size,num_layers=num_layers)
attn_context2trg = attn_context2trg.cuda()
attn_context2trg_optimizer = torch.optim.Adam(attn_context2trg.parameters(), lr=1e-3)

seq2context = SequenceModel(len(DE.vocab),context_size,num_layers=num_layers)
seq2context_optimizer = torch.optim.Adam(seq2context.parameters(), lr=1e-3)
seq2context = seq2context.cuda()



scheduler_c2t = torch.optim.lr_scheduler.ReduceLROnPlateau(attn_context2trg_optimizer, mode="min", patience=4)
scheduler_s2c = torch.optim.lr_scheduler.ReduceLROnPlateau(seq2context_optimizer, mode="min", patience=4)



In [None]:
best_ppl = 1e8
for e in range(0,300):
    attn_training_loop(e,train_iter,seq2context,attn_context2trg,seq2context_optimizer,attn_context2trg_optimizer)
    ppl = attn_validation_loop(e,val_iter,seq2context,attn_context2trg,scheduler_c2t,scheduler_s2c,BATCH_SIZE=32,context_size=500)
    if ppl < best_ppl:
        torch.save(seq2context.state_dict(),'best_seq2seq_withattn_seq2context.pt')
        torch.save(attn_context2trg.state_dict(),'best_seq2seq_withattn_context2trg.pt')
        best_ppl = ppl
        print('Wrote model!')

Epoch: 0, Batch: 0, Loss: 196.80746459960938
Epoch: 0, Batch: 500, Loss: 57.9072151184082
Epoch: 0, Batch: 1000, Loss: 59.319637298583984
Epoch: 0, Batch: 1500, Loss: 46.97551345825195
Epoch: 0, Batch: 2000, Loss: 52.368019104003906
Epoch: 0, Batch: 2500, Loss: 42.604759216308594
Epoch: 0, Batch: 3000, Loss: 46.63115692138672
Epoch: 0, Batch: 3500, Loss: 36.59870147705078
Epoch: 0, Validation PPL: 7.792147636413574
Wrote model!
Epoch: 1, Batch: 0, Loss: 37.57067108154297
Epoch: 1, Batch: 500, Loss: 41.15217971801758
Epoch: 1, Batch: 1000, Loss: 38.78003692626953
Epoch: 1, Batch: 1500, Loss: 37.293521881103516
Epoch: 1, Batch: 2000, Loss: 37.09842300415039
Epoch: 1, Batch: 2500, Loss: 40.06744384765625
Epoch: 1, Batch: 3000, Loss: 40.234317779541016
Epoch: 1, Batch: 3500, Loss: 41.856075286865234
Epoch: 1, Validation PPL: 5.6893157958984375
Wrote model!
Epoch: 2, Batch: 0, Loss: 37.55695343017578
Epoch: 2, Batch: 500, Loss: 37.20197677612305
Epoch: 2, Batch: 1000, Loss: 38.4419937133789

In [18]:
ppl = attn_validation_loop(e,val_iter,seq2context,attn_context2trg,scheduler_c2t,scheduler_s2c,BATCH_SIZE=32,context_size=500)

Epoch: 65, Validation PPL: 4.2662811279296875


In [5]:
best_ppl = 1e8
for e in range(0,300):
    attn_training_split_loop(0,train_iter,seq2context,attn_context2trg,seq2context_optimizer,attn_context2trg_optimizer,EN=EN)
    ppl = attn_validation_loop(e,val_iter,seq2context,attn_context2trg,scheduler_c2t,scheduler_s2c,BATCH_SIZE=32,context_size=500)
    if ppl < best_ppl:
        torch.save(seq2context.state_dict(),'best_seq2seq_withattn_seq2context_splittrain.pt')
        torch.save(attn_context2trg.state_dict(),'best_seq2seq_withattn_context2trg_splittrain.pt')
        best_ppl = ppl
        print('Wrote model!')

Epoch: 0, Batch: 0, Loss: 196.8906707763672
Epoch: 0, Batch: 500, Loss: 70.20809936523438
Epoch: 0, Batch: 1000, Loss: 48.192955017089844
Epoch: 0, Batch: 1500, Loss: 52.265480041503906
Epoch: 0, Batch: 2000, Loss: 57.046905517578125
Epoch: 0, Batch: 2500, Loss: 69.36911010742188
Epoch: 0, Batch: 3000, Loss: 68.65656280517578
Epoch: 0, Batch: 3500, Loss: 50.92197799682617
Epoch: 0, Validation PPL: 11.415091514587402
Wrote model!
Epoch: 0, Batch: 0, Loss: 64.52406311035156
Epoch: 0, Batch: 500, Loss: 57.46894836425781
Epoch: 0, Batch: 1000, Loss: 55.048831939697266
Epoch: 0, Batch: 1500, Loss: 43.473289489746094
Epoch: 0, Batch: 2000, Loss: 58.344329833984375
Epoch: 0, Batch: 2500, Loss: 43.607215881347656
Epoch: 0, Batch: 3000, Loss: 65.991943359375
Epoch: 0, Batch: 3500, Loss: 44.39083480834961
Epoch: 1, Validation PPL: 8.097634315490723
Wrote model!
Epoch: 0, Batch: 0, Loss: 60.247406005859375
Epoch: 0, Batch: 500, Loss: 53.51862335205078
Epoch: 0, Batch: 1000, Loss: 45.7724151611328

Epoch: 0, Batch: 1500, Loss: 22.260120391845703
Epoch: 0, Batch: 2000, Loss: 24.752552032470703
Epoch: 0, Batch: 2500, Loss: 37.87808609008789
Epoch: 0, Batch: 3000, Loss: 43.870948791503906
Epoch: 0, Batch: 3500, Loss: 29.658802032470703
Epoch: 19, Validation PPL: 5.444913387298584
Epoch: 0, Batch: 0, Loss: 38.40983581542969
Epoch: 0, Batch: 500, Loss: 28.45709991455078
Epoch: 0, Batch: 1000, Loss: 41.88131332397461
Epoch: 0, Batch: 1500, Loss: 36.70486068725586
Epoch: 0, Batch: 2000, Loss: 40.18235397338867
Epoch: 0, Batch: 2500, Loss: 41.98257064819336
Epoch: 0, Batch: 3000, Loss: 18.77935028076172
Epoch: 0, Batch: 3500, Loss: 49.40693664550781
Epoch: 20, Validation PPL: 5.447747230529785
Epoch: 0, Batch: 0, Loss: 37.98321533203125
Epoch: 0, Batch: 500, Loss: 39.4747200012207
Epoch: 0, Batch: 1000, Loss: 29.31151008605957
Epoch: 0, Batch: 1500, Loss: 39.950992584228516
Epoch: 0, Batch: 2000, Loss: 23.07223129272461
Epoch: 0, Batch: 2500, Loss: 26.125091552734375
Epoch: 0, Batch: 300

Epoch: 38, Validation PPL: 5.463284492492676
Epoch: 0, Batch: 0, Loss: 38.857749938964844
Epoch: 0, Batch: 500, Loss: 41.44780731201172
Epoch: 0, Batch: 1000, Loss: 23.618331909179688
Epoch: 0, Batch: 1500, Loss: 28.234294891357422
Epoch: 0, Batch: 2000, Loss: 42.85017013549805
Epoch: 0, Batch: 2500, Loss: 26.964160919189453
Epoch: 0, Batch: 3000, Loss: 43.33356857299805
Epoch: 0, Batch: 3500, Loss: 39.65460968017578
Epoch: 39, Validation PPL: 5.463293552398682
Epoch: 0, Batch: 0, Loss: 36.394622802734375
Epoch: 0, Batch: 500, Loss: 39.83744812011719
Epoch: 0, Batch: 1000, Loss: 20.397340774536133
Epoch: 0, Batch: 1500, Loss: 28.511592864990234
Epoch: 0, Batch: 2000, Loss: 24.351829528808594
Epoch: 0, Batch: 2500, Loss: 27.15172004699707
Epoch: 0, Batch: 3000, Loss: 26.600584030151367
Epoch: 0, Batch: 3500, Loss: 22.412168502807617
Epoch: 40, Validation PPL: 5.463318347930908
Epoch: 0, Batch: 0, Loss: 31.601530075073242
Epoch: 0, Batch: 500, Loss: 23.55372428894043
Epoch: 0, Batch: 100

Epoch: 0, Batch: 2000, Loss: 42.66079330444336
Epoch: 0, Batch: 2500, Loss: 36.82075881958008
Epoch: 0, Batch: 3000, Loss: 26.00669288635254
Epoch: 0, Batch: 3500, Loss: 45.179630279541016
Epoch: 58, Validation PPL: 5.463516712188721
Epoch: 0, Batch: 0, Loss: 41.242591857910156
Epoch: 0, Batch: 500, Loss: 24.52005386352539
Epoch: 0, Batch: 1000, Loss: 38.66619873046875
Epoch: 0, Batch: 1500, Loss: 39.9073600769043
Epoch: 0, Batch: 2000, Loss: 38.79574203491211
Epoch: 0, Batch: 2500, Loss: 26.69204330444336
Epoch: 0, Batch: 3000, Loss: 40.750770568847656
Epoch: 0, Batch: 3500, Loss: 22.576255798339844
Epoch: 59, Validation PPL: 5.463543891906738
Epoch: 0, Batch: 0, Loss: 34.08090591430664
Epoch: 0, Batch: 500, Loss: 41.23442459106445
Epoch: 0, Batch: 1000, Loss: 45.269752502441406
Epoch: 0, Batch: 1500, Loss: 35.664302825927734
Epoch: 0, Batch: 2000, Loss: 42.96604919433594
Epoch: 0, Batch: 2500, Loss: 34.56273651123047
Epoch: 0, Batch: 3000, Loss: 37.08980178833008
Epoch: 0, Batch: 350

Epoch: 0, Batch: 0, Loss: 37.66118621826172
Epoch: 0, Batch: 500, Loss: 45.74074935913086
Epoch: 0, Batch: 1000, Loss: 35.84641647338867
Epoch: 0, Batch: 1500, Loss: 22.201576232910156
Epoch: 0, Batch: 2000, Loss: 36.429080963134766
Epoch: 0, Batch: 2500, Loss: 23.765439987182617
Epoch: 0, Batch: 3000, Loss: 33.25324630737305
Epoch: 0, Batch: 3500, Loss: 24.549882888793945
Epoch: 78, Validation PPL: 5.4638166427612305
Epoch: 0, Batch: 0, Loss: 25.900331497192383
Epoch: 0, Batch: 500, Loss: 26.18463897705078
Epoch: 0, Batch: 1000, Loss: 44.317527770996094
Epoch: 0, Batch: 1500, Loss: 25.33626937866211
Epoch: 0, Batch: 2000, Loss: 46.25902557373047
Epoch: 0, Batch: 2500, Loss: 25.8304386138916
Epoch: 0, Batch: 3000, Loss: 28.175928115844727
Epoch: 0, Batch: 3500, Loss: 34.304325103759766
Epoch: 79, Validation PPL: 5.463858127593994
Epoch: 0, Batch: 0, Loss: 23.23537254333496
Epoch: 0, Batch: 500, Loss: 25.295164108276367
Epoch: 0, Batch: 1000, Loss: 27.01877212524414
Epoch: 0, Batch: 150

KeyboardInterrupt: 

In [13]:
ppl = attn_validation_loop(0,val_iter,seq2context,attn_context2trg,scheduler_c2t,scheduler_s2c,BATCH_SIZE=32,context_size=500)

Epoch: 0, Validation PPL: 1.0001404285430908


## Submission

In [None]:
# load test set
sentences = []
for i, l in enumerate(open("source_test.txt"), 1):
    sentences.append(re.split(' ', l))