<a href="https://colab.research.google.com/github/Saputoa21/Machine-Translation/blob/main/seq2seq_NMT_MTMA2025s_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Seq2seq NMT with RNN


[Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473)

**NOTE:**

-  use clean bpe data
-  use a piece of training data during coding or low in credits

You have to implement:

- Encoder
- Attention (Bahdanau)
- training loop
- extra: BLEU model selection

Goal:

- Loss in training, validation and test




In [2]:
!pip install torch==2.3.0
!pip install torchtext==0.18

Collecting torch==2.3.0
  Downloading torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch==2.3.0)
  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.3.0)
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.3.0)
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylin

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import random
import time

In [4]:
#if you dont have bpe data use sacremoese tokenizer
#!pip install sacremoses

In [5]:
#which libraries are we using!!?
!pip freeze > requirements.txt

In [6]:
!cat requirements.txt

absl-py==1.4.0
accelerate==1.6.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.15
aiosignal==1.3.2
alabaster==1.0.0
albucore==0.0.24
albumentations==2.0.6
ale-py==0.11.0
altair==5.5.0
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.9.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.7.2
arviz==0.21.0
astropy==7.0.2
astropy-iers-data==0.2025.5.12.0.38.29
astunparse==1.6.3
atpublic==5.1
attrs==25.3.0
audioread==3.0.1
autograd==1.8.0
babel==2.17.0
backcall==0.2.0
backports.tarfile==1.2.0
beautifulsoup4==4.13.4
betterproto==2.0.0b6
bigframes==2.4.0
bigquery-magics==0.9.0
bleach==6.2.0
blinker==1.9.0
blis==1.3.0
blobfile==3.0.0
blosc2==3.3.2
bokeh==3.7.3
Bottleneck==1.4.2
bqplot==0.12.44
branca==0.8.1
build==1.2.2.post1
CacheControl==0.14.3
cachetools==5.5.2
catalogue==2.0.10
certifi==2025.4.26
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.2
chex==0.1.89
clarabel==0.10.0
click==8.2.0
cloudpathlib==0.21.0
cloudpickle==3.1.1
cmake==3.31.6
cmdstanpy==1.2.5
colorcet

In [7]:
SEED = 42 #to reproduce the traning

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [8]:
import torchtext
dir(torchtext)

['_CACHE_DIR',
 '_TEXT_BUCKET',
 '_TORCHTEXT_DEPRECATION_MSG',
 '_WARN',
 '__all__',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '__version__',
 '_extension',
 '_get_torch_home',
 '_internal',
 '_torchtext',
 'git_version',
 'os',
 'version']

In [11]:
import torchtext
import torch
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab
#from torchtext.utils import download_from_url, extract_archive
import io

#0=en 1=de
# soure and target data
#NOTE: USE clean bpe data!

#for this task I used the clean bpe files (train 500k, test 500, dev 500)
#sorting the path to the files en-de
train_filepaths = ['train.en-de.clean.bpe.en', 'train.en-de.clean.bpe.de']
val_filepaths = ['dev.en-de.clean.bpe.en', 'dev.en-de.clean.bpe.de']
test_filepaths = ['test.en-de.clean.bpe.en', 'test.en-de.clean.bpe.de']

#TODO use clean bpe tokenized data!!!

#de_tokenizer = get_tokenizer('moses', language='de') for cases without bpe data
#en_tokenizer = get_tokenizer('moses', language='en')

de_tokenizer = None #None as there are tokenized data
en_tokenizer = None

def build_vocab(filepath, tokenizer=None):
  counter = Counter()
  with io.open(filepath, encoding="utf8") as f: #go to a file
    for string_ in f: #for each line in a file
      #counter.update(tokenizer(string_))
      counter.update(string_.split()) #split each line by space, i.e. creating tokens. Counter counts the tokens to create a database of the vocab
  return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>']) #vocav from torchtext containing the tokens and their counts, as well as specioal ones

#Vocab
en_vocab = build_vocab(train_filepaths[0], en_tokenizer) #create a vocab
de_vocab = build_vocab(train_filepaths[1], de_tokenizer)

print(dir(en_vocab))
en_vocab.set_default_index(en_vocab['<unk>']) #set <unk> as a default token
de_vocab.set_default_index(de_vocab['<unk>'])

def data_process(filepaths):
  raw_en_iter = iter(io.open(filepaths[0], encoding="utf8"))
  raw_de_iter = iter(io.open(filepaths[1], encoding="utf8"))
  data = []
  for (raw_en, raw_de) in zip(raw_en_iter, raw_de_iter): #for each sentence in the source and target
    en_tensor_ = torch.tensor([en_vocab[token] for token in raw_en.split()], #en_tokenizer(raw_en) #transforms tokens into tensors
                            dtype=torch.long)
    de_tensor_ = torch.tensor([de_vocab[token] for token in raw_de.split()], #de_tokenizer(raw_de)
                            dtype=torch.long)
    data.append((en_tensor_, de_tensor_)) #append to a list
  return data

#pre-process
train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

['T_destination', '__annotations__', '__call__', '__class__', '__contains__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__jit_unused_properties__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__prepare_scriptable__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '

In [12]:
#NOTE: if you are low on credits or testing only use a piece of the data e.g. 20K segments
train_data = train_data[:1000]

In [13]:
len(train_data)

1000

In [14]:
len(val_data)

466

In [15]:
len(test_data)

467

In [16]:
len(en_vocab)

36926

In [17]:
len(de_vocab)

42161

In [18]:
#a small look at the data
train_data[0]

(tensor([ 4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
         22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
         40, 41, 23, 42, 26, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 45, 40,
         54, 55, 56, 57, 40, 58, 59, 60, 36, 61, 38, 62, 63]),
 tensor([ 4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
         22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
         40, 41, 42, 43, 15, 44, 45, 46, 47, 48, 49, 50, 29, 51, 52, 32, 34, 53,
         54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 47, 65, 66, 67, 68, 47, 69,
         70, 71, 15, 72, 45, 73, 74]))

In [19]:
en_vocab

Vocab()

Define the device.

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


Create the iterators.

In [21]:
BATCH_SIZE = 10
PAD_IDX = de_vocab['<pad>'] #find the indecies of the special tokens
BOS_IDX = de_vocab['<bos>']
EOS_IDX = de_vocab['<eos>']

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def generate_batch(data_batch):
    en_batch, de_batch = [], []
    for (en_item, de_item) in data_batch: #each sentences in a batch
        de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0)) #add a source sentence from the beginning (BOS) till the end (EOS)
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX, batch_first=True) #add padding to sentences if they are of different langth
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX, batch_first=True)
    return en_batch, de_batch

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch) #shuffle samples after each epoch
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=False, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=False, collate_fn=generate_batch)

## Building the Seq2Seq Model

### Encoder




In [22]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, emb_dim) #size of vocab (50k from clean bpe data), embedding dimensionality (256 from below)

        #[YOUR CODE] GRU(embeding size, encoder hidden size) NOTE: bidirectional batch_first
        # parameters for GRU in Pytorch
        # torch.nn.GRU(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, device=None, dtype=None)
        self.rnn = nn.GRU(input_size=emb_dim, hidden_size=enc_hid_dim, batch_first=True, bidirectional=True)

        #[YOUR CODE] linear(encoder hidden size * 2, decoder hidden size)
        # times 2 as we have a bidirectional model
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, src):

        #[B=8, S= longest sentence]
        embedded = self.embedding(src) #embed the source

        #[B, S, E] print(embedded.size())
        # the matrix grow by the additional dimension of the embedding size (256)

        embedded = self.dropout(embedded) #regularize the embeddings

        #[B, S, E]
        #[YOUR CODE] rrn(embeddings)
        outputs, hidden = self.rnn(embedded)

        #[B, S, H*2]

        #h[n layers * num directions, batch size, hid dim]

        #[forward_1, backward_1, forward_2, backward_2, ...]

        #[-2, :, : ] last state forward RNN
        #[-1, :, : ] last state backward RNN
        print('hid', hidden.size())

        #[YOUR CODE] last state forward RNN
        h1 = hidden[-2, :, : ]

        #[YOUR CODE] last state backward RNN
        h2 = hidden[-1, :, : ]

        #https://pytorch.org/docs/main/generated/torch.cat.html

        #[YOUR CODE] concatenate h1 amd h2 on seq dim
        h_cat = torch.cat([h1, h2], dim=1)

        #[YOUR CODE] tanh(linear(hidden_cat))
        hidden = torch.tanh(self.fc(h_cat))

        #[B, S, H*2]

        return outputs, hidden # outputs = vector for each sentence, hidden = concatenated vector for forward and backward

## Attention

### Luong Attention




Formula: score(target hidden state, each source hidden stare) = target hidden state transposed * W * eachsource hidden state)

In [23]:
class LuongAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()

        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)

    def forward(self, hidden, encoder_outputs): #keys, query  hidden = target

        #[batch size, dec hid dim]
        #[src len, batch size, enc hid dim * 2]

        batch_size = encoder_outputs.shape[0] #to see the size
        src_len = encoder_outputs.shape[1]

        #x times decoder hidden state for the size of the src_len
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) #repatn the hidden state over source

        #[batch size, src len, dec hid dim]

        scores = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) #concateneted in the embedding dimension
        # attn = linear layer
        # tanh = non-linear

        #[batch size, src len, dec hid dim]

        attention = self.v(scores).squeeze(2) # v is W in the formula

        #[batch size, src len]

        return F.softmax(attention, dim=1)

### Bahdanau Attention

Formula: score(target hidden state, each source hidden stare) = Va transposed * tanh(W1 * target hidden state + W2 * each source hidden state)

In [40]:
class BahdanauAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(dec_hid_dim, dec_hid_dim, bias=False)
        self.Ua = nn.Linear(enc_hid_dim * 2, dec_hid_dim, bias=False)
        self.Va = nn.Linear(dec_hid_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs): #keys, query
        #[YOUR CODE]
        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]

        #x times decoder hidden state for src_len
        #[YOUR CODE]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # pytorch - Returns a new tensor with a dimension of size one inserted at the specified position.

        """
        >>> x = torch.tensor([1, 2, 3, 4])
        >>> torch.unsqueeze(x, 0)
        tensor([[ 1,  2,  3,  4]])
        >>> torch.unsqueeze(x, 1)
        tensor([[ 1],
                [ 2],
                [ 3],
                [ 4]])
        """

        #[YOUR CODE] Va(tanh(Wa(hidden) + Ua(encoder outputs)))
        scores = self.Va(torch.tanh(self.Wa(hidden) + self.Ua(encoder_outputs)))

        scores = scores.squeeze(2) #Returns a tensor with all specified dimensions of input of size 1 removed.

        """
        For example, if input is of shape: (A×1×B×C×1×D) then the input.squeeze() will be of shape: (A×B×C×D).
        When dim is given, a squeeze operation is done only in the given dimension(s). If input is of shape:
        (A×1×B), squeeze(input, 0) leaves the tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape (A×B).
        """

        #[YOUR CODE] softmax(scores, dim seq)
        weights =  torch.softmax(scores, dim=1)

        return weights

### Decoder



In [25]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim

        self.attention = attention

        self.embedding = nn.Embedding(output_dim, emb_dim)

        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim, batch_first=True)

        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)

        self.dropout = nn.Dropout(dropout)

        #attention
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)

    def forward(self, input, hidden, encoder_outputs):

        #[batch size]
        #[batch size, dec hid dim]
        #[src len, batch size, enc hid dim * 2]

        input = input.unsqueeze(0)

        #[1, batch size]

        embedded = self.dropout(self.embedding(input))

        #[1, batch size, emb dim]

        a = self.attention(hidden, encoder_outputs)

        #[batch size, src len]

        a = a.unsqueeze(1)

        #[batch size, 1, src len]

        #[batch size, src len, enc hid dim * 2]

        weighted = torch.bmm(a, encoder_outputs)

        #[batch size, 1, enc hid dim * 2]

        weighted = weighted.permute(1, 0, 2)

        #[1, batch size, enc hid dim * 2]

        rnn_input = torch.cat((embedded, weighted), dim = 2)

        #[1, batch size, (enc hid dim * 2) + emb dim]
        #[B, 1, (enc hid dim * 2) + emb dim]

        rnn_input = rnn_input.permute(1, 0, 2)
        hidden = hidden.unsqueeze(0)

        output, hidden = self.rnn(rnn_input, hidden)

        #[seq len, batch size, dec hid dim * n directions]
        #[n layers * n directions, batch size, dec hid dim]

        #[1, batch size, dec hid dim]
        #[1, batch size, dec hid dim]

        embedded = embedded.squeeze(0)
        output = output.squeeze(1)
        weighted = weighted.squeeze(0)

        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))

        #[batch size, output dim]

        return prediction, hidden.squeeze(0)

### Seq2Seq




In [26]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device


    def forward(self, src, trg, teacher_forcing_ratio = 0.5):

        #[src len, batch size]
        #[trg len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        # 0.75 teacher forcing 75% of the time

        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim


        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        #encoder_outputs is all hidden states of the input sequence
        #hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src)

        #first input to the decoder is the <sos> tokens
        input = trg[:,0]
        # unroll RNN
        for t in range(1, trg_len):

            #insert input token embedding, previous hidden state and all encoder hidden states

            output, hidden = self.decoder(input, hidden, encoder_outputs)

            #predictions
            outputs[:, t] = output

            #teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio

            #greedy search
            top1 = output.argmax(1)

            #if teacher forcing, use gold token as next input
            #if not, use predicted token
            input = trg[:, t] if teacher_force else top1

        return outputs

## Training the Seq2Seq Model

Firstly, I have been training first with tranining data size = 1000, batch size =10, num of epoch = 3 and Luong attention


### Luong attention

In [40]:
INPUT_DIM = len(en_vocab)
OUTPUT_DIM = len(de_vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.2

attn = LuongAttention(ENC_HID_DIM, DEC_HID_DIM)
#[YOUR CODE] Bahdanau attn
attn_bahdanau = BahdanauAttention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, device).to(device)

In [30]:
print(len(en_vocab)) #BPE size 16k approx
print(len(de_vocab))

36926
42161


In [31]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(36926, 256)
    (rnn): GRU(256, 512, batch_first=True, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): LuongAttention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(42161, 256)
    (rnn): GRU(1280, 512, batch_first=True)
    (fc_out): Linear(in_features=1792, out_features=42161, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
    (attn): Linear(in_features=1536, out_features=512, bias=True)
    (v): Linear(in_features=512, out_features=1, bias=False)
  )
)

In [32]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 103,061,681 trainable parameters


We create an optimizer.

In [33]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#[YOUR CODE] Adam lr 1e-3

We initialize the loss function.

In [35]:
TRG_PAD_IDX = de_vocab['<pad>'] #TODO
#https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

#[YOUR CODE] loss add ignore_index idx
#parameters
#torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)

In [36]:
print(TRG_PAD_IDX)

1


In [36]:
def train(model, iterator, optimizer, criterion, clip):

    #[YOUR CODE] set model train model.train()
    model.train()

    epoch_loss = 0

    for (src, trg) in tqdm(iterator):

        #[YOUR CODE] to gpu
        src = src.to(device)
        trg = trg.to(device)

        #[YOUR CODE] optimizer zero grad
        optimizer.zero_grad()

        #[YOUR CODE] model()
        output = model(src, trg)

        #[trg len, batch size]
        #[trg len, batch size, output dim]

        output = output.permute(1, 0, 2)

        output_dim = output.shape[-1]
        trg = trg.permute(1, 0)

        output = output[1:].reshape(-1, output_dim)
        trg = trg[1:].reshape(-1)

        #[(trg len - 1) * batch size]
        #[(trg len - 1) * batch size, output dim]

        #[YOUR CODE] criterion()
        loss = criterion(output, trg)

        #[YOUR CODE] loss backward
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        #[YOUR CODE] optimizer step
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [37]:
def evaluate(model, iterator, criterion):

    #[YOUR CODE] set model test/eval  model.eval()
    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        for (src, trg) in iterator:
            #[YOUR CODE] to device
            src = src.to(device)
            trg = trg.to(device)

            #[YOUR CODE] model() turn off teacher forcing 0
            output = model(src, trg, 0)

            #trg = [trg len, batch size]
            #output = [trg len, batch size, output dim]

            output = output.permute(1, 0, 2)
            output_dim = output.shape[-1]
            trg = trg.permute(1, 0)
            output = output[1:].reshape(-1, output_dim)
            trg = trg[1:].reshape(-1)

            #trg = [(trg len - 1) * batch size]
            #output = [(trg len - 1) * batch size, output dim]

            #[YOUR CODE] criterion()
            loss = criterion(output, trg)

            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [42]:
N_EPOCHS = 3
CLIP = 1

epoch_info_list = []

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)

    if valid_loss < best_valid_loss:
      #[YOUR CODE] extra add BLEU model selection
      best_valid_loss = valid_loss
      torch.save(model.state_dict(), 'model.pt')

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}\tTrain PPL: {np.exp(train_loss):7.3f}')
    print(f'\t Validation Loss: {valid_loss:.3f}\tValidation PPL: {np.exp(valid_loss):7.3f}')

    epoch_info_list.append(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}, Train PPL: {np.exp(train_loss):7.3f}, Validation Loss:{valid_loss:.3f}, Validation PPL: {np.exp(valid_loss):7.3f}')


  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:00<01:15,  1.32it/s]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [00:01<00:53,  1.83it/s]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [00:02<01:16,  1.27it/s]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [00:03<01:20,  1.19it/s]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [00:03<01:13,  1.29it/s]

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [00:04<01:04,  1.45it/s]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [00:04<00:59,  1.58it/s]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [00:05<00:55,  1.65it/s]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [00:05<00:52,  1.73it/s]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [00:06<00:49,  1.82it/s]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [00:06<00:46,  1.91it/s]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [00:08<01:02,  1.40it/s]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [00:08<01:03,  1.37it/s]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [00:09<01:02,  1.37it/s]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [00:10<00:55,  1.52it/s]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [00:10<00:52,  1.60it/s]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [00:11<00:51,  1.61it/s]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [00:11<00:50,  1.63it/s]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [00:12<00:47,  1.69it/s]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [00:13<00:57,  1.39it/s]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [00:13<00:49,  1.60it/s]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [00:14<00:59,  1.31it/s]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [00:15<00:59,  1.29it/s]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [00:16<01:00,  1.25it/s]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [00:17<00:59,  1.27it/s]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [00:17<00:53,  1.38it/s]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [00:18<00:55,  1.32it/s]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [00:19<00:47,  1.51it/s]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [00:19<00:48,  1.48it/s]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [00:20<00:45,  1.54it/s]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [00:21<00:45,  1.52it/s]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [00:22<01:00,  1.13it/s]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [00:23<00:59,  1.12it/s]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [00:24<00:57,  1.15it/s]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [00:24<00:50,  1.30it/s]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [00:26<00:59,  1.07it/s]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [00:26<00:55,  1.14it/s]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [00:27<00:51,  1.21it/s]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [00:28<00:45,  1.34it/s]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [00:28<00:38,  1.55it/s]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [00:28<00:34,  1.69it/s]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [00:29<00:33,  1.75it/s]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [00:30<00:33,  1.71it/s]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [00:30<00:33,  1.70it/s]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [00:31<00:30,  1.79it/s]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [00:31<00:28,  1.86it/s]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [00:32<00:29,  1.82it/s]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [00:32<00:31,  1.64it/s]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [00:33<00:29,  1.76it/s]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [00:33<00:27,  1.81it/s]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [00:34<00:27,  1.78it/s]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [00:35<00:29,  1.64it/s]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [00:36<00:33,  1.40it/s]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [00:36<00:28,  1.59it/s]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [00:37<00:29,  1.53it/s]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [00:38<00:28,  1.54it/s]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [00:38<00:25,  1.72it/s]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [00:39<00:30,  1.37it/s]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [00:39<00:26,  1.55it/s]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [00:40<00:27,  1.47it/s]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [00:41<00:29,  1.32it/s]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [00:43<00:35,  1.06it/s]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [00:43<00:31,  1.17it/s]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [00:44<00:27,  1.29it/s]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [00:44<00:24,  1.41it/s]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [00:45<00:23,  1.45it/s]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [00:46<00:21,  1.51it/s]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [00:47<00:25,  1.27it/s]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [00:48<00:25,  1.21it/s]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [00:48<00:23,  1.26it/s]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [00:49<00:21,  1.33it/s]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [00:50<00:25,  1.11it/s]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [00:51<00:24,  1.09it/s]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [00:52<00:19,  1.32it/s]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [00:52<00:16,  1.48it/s]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [00:53<00:15,  1.53it/s]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [00:53<00:13,  1.69it/s]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [00:54<00:12,  1.74it/s]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [00:54<00:14,  1.50it/s]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [00:56<00:16,  1.24it/s]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [00:56<00:12,  1.50it/s]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [00:57<00:11,  1.57it/s]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [00:57<00:10,  1.59it/s]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [00:58<00:10,  1.48it/s]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [00:59<00:10,  1.39it/s]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [00:59<00:09,  1.54it/s]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [01:00<00:08,  1.49it/s]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [01:00<00:07,  1.59it/s]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [01:01<00:06,  1.70it/s]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [01:02<00:07,  1.39it/s]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [01:03<00:06,  1.46it/s]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [01:03<00:05,  1.37it/s]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [01:04<00:04,  1.46it/s]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [01:05<00:03,  1.53it/s]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [01:05<00:03,  1.61it/s]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [01:06<00:02,  1.66it/s]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [01:07<00:02,  1.48it/s]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [01:07<00:01,  1.65it/s]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [01:08<00:00,  1.28it/s]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [01:09<00:00,  1.45it/s]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:00<01:08,  1.44it/s]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [00:01<01:18,  1.25it/s]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [00:02<01:07,  1.45it/s]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [00:03<01:16,  1.25it/s]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [00:03<01:10,  1.34it/s]

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [00:04<01:14,  1.26it/s]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [00:05<01:11,  1.30it/s]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [00:06<01:14,  1.24it/s]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [00:06<01:06,  1.38it/s]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [00:07<01:01,  1.46it/s]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [00:07<00:58,  1.51it/s]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [00:08<00:56,  1.55it/s]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [00:09<00:53,  1.63it/s]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [00:09<00:48,  1.76it/s]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [00:10<01:03,  1.34it/s]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [00:11<01:10,  1.19it/s]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [00:12<01:01,  1.35it/s]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [00:12<00:56,  1.44it/s]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [00:13<00:50,  1.62it/s]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [00:13<00:45,  1.77it/s]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [00:14<00:43,  1.81it/s]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [00:14<00:37,  2.06it/s]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [00:15<00:35,  2.14it/s]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [00:16<00:54,  1.39it/s]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [00:16<00:46,  1.60it/s]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [00:17<00:44,  1.68it/s]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [00:17<00:42,  1.72it/s]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [00:18<00:41,  1.73it/s]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [00:18<00:40,  1.77it/s]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [00:19<00:39,  1.79it/s]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [00:20<00:37,  1.84it/s]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [00:21<00:47,  1.43it/s]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [00:22<00:55,  1.20it/s]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [00:22<00:46,  1.43it/s]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [00:23<00:56,  1.14it/s]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [00:24<00:45,  1.40it/s]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [00:25<00:45,  1.37it/s]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [00:25<00:41,  1.50it/s]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [00:26<00:39,  1.54it/s]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [00:26<00:36,  1.65it/s]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [00:27<00:35,  1.67it/s]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [00:27<00:35,  1.62it/s]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [00:28<00:34,  1.65it/s]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [00:28<00:29,  1.90it/s]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [00:29<00:28,  1.90it/s]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [00:30<00:32,  1.68it/s]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [00:30<00:31,  1.68it/s]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [00:31<00:29,  1.75it/s]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [00:31<00:29,  1.75it/s]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [00:32<00:27,  1.84it/s]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [00:32<00:24,  1.99it/s]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [00:33<00:23,  2.00it/s]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [00:33<00:24,  1.91it/s]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [00:34<00:23,  1.93it/s]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [00:35<00:27,  1.65it/s]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [00:36<00:36,  1.21it/s]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [00:37<00:37,  1.15it/s]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [00:37<00:32,  1.30it/s]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [00:38<00:28,  1.44it/s]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [00:39<00:26,  1.49it/s]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [00:39<00:27,  1.42it/s]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [00:40<00:25,  1.51it/s]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [00:40<00:22,  1.67it/s]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [00:41<00:26,  1.37it/s]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [00:42<00:26,  1.30it/s]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [00:43<00:29,  1.17it/s]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [00:44<00:27,  1.21it/s]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [00:45<00:26,  1.22it/s]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [00:46<00:25,  1.20it/s]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [00:46<00:21,  1.37it/s]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [00:47<00:21,  1.35it/s]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [00:47<00:18,  1.51it/s]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [00:48<00:19,  1.37it/s]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [00:49<00:18,  1.37it/s]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [00:50<00:16,  1.48it/s]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [00:51<00:18,  1.27it/s]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [00:51<00:18,  1.24it/s]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [00:52<00:14,  1.48it/s]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [00:52<00:12,  1.63it/s]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [00:53<00:10,  1.91it/s]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [00:53<00:11,  1.69it/s]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [00:54<00:10,  1.78it/s]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [00:54<00:09,  1.81it/s]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [00:55<00:09,  1.65it/s]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [00:56<00:08,  1.67it/s]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [00:56<00:09,  1.56it/s]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [00:57<00:07,  1.64it/s]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [00:57<00:06,  1.84it/s]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [00:58<00:06,  1.69it/s]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [00:59<00:07,  1.32it/s]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [01:00<00:06,  1.33it/s]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [01:01<00:07,  1.05it/s]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [01:02<00:05,  1.30it/s]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [01:02<00:04,  1.38it/s]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [01:03<00:03,  1.59it/s]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [01:03<00:02,  1.52it/s]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [01:04<00:02,  1.50it/s]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [01:05<00:01,  1.30it/s]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [01:06<00:00,  1.40it/s]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [01:06<00:00,  1.50it/s]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:01<01:57,  1.18s/it]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [00:02<01:37,  1.00it/s]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [00:02<01:25,  1.14it/s]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [00:03<01:16,  1.26it/s]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [00:04<01:23,  1.13it/s]

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [00:05<01:25,  1.10it/s]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [00:05<01:12,  1.28it/s]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [00:06<01:04,  1.43it/s]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [00:06<00:57,  1.58it/s]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [00:07<00:53,  1.68it/s]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [00:08<01:12,  1.23it/s]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [00:09<01:01,  1.44it/s]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [00:09<01:00,  1.43it/s]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [00:10<01:02,  1.37it/s]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [00:11<00:59,  1.42it/s]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [00:11<00:55,  1.52it/s]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [00:12<00:56,  1.48it/s]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [00:13<00:53,  1.54it/s]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [00:13<00:48,  1.65it/s]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [00:14<00:55,  1.44it/s]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [00:15<00:51,  1.53it/s]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [00:15<00:49,  1.57it/s]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [00:16<00:44,  1.75it/s]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [00:17<00:54,  1.39it/s]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [00:17<00:47,  1.60it/s]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [00:18<00:41,  1.77it/s]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [00:18<00:44,  1.63it/s]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [00:19<00:50,  1.44it/s]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [00:21<01:02,  1.13it/s]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [00:21<01:02,  1.13it/s]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [00:22<00:56,  1.22it/s]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [00:22<00:47,  1.44it/s]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [00:23<00:42,  1.59it/s]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [00:24<00:45,  1.44it/s]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [00:24<00:42,  1.54it/s]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [00:25<00:40,  1.59it/s]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [00:25<00:37,  1.67it/s]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [00:26<00:35,  1.75it/s]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [00:26<00:32,  1.89it/s]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [00:28<00:42,  1.41it/s]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [00:28<00:40,  1.45it/s]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [00:29<00:37,  1.53it/s]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [00:30<00:39,  1.43it/s]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [00:31<00:46,  1.19it/s]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [00:32<00:49,  1.11it/s]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [00:32<00:45,  1.20it/s]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [00:33<00:42,  1.25it/s]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [00:34<00:40,  1.27it/s]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [00:35<00:38,  1.31it/s]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [00:35<00:33,  1.51it/s]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [00:37<00:44,  1.11it/s]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [00:37<00:38,  1.23it/s]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [00:38<00:33,  1.42it/s]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [00:38<00:31,  1.46it/s]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [00:39<00:28,  1.58it/s]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [00:39<00:29,  1.48it/s]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [00:40<00:29,  1.48it/s]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [00:41<00:30,  1.39it/s]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [00:42<00:34,  1.17it/s]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [00:43<00:30,  1.31it/s]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [00:44<00:31,  1.25it/s]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [00:44<00:27,  1.41it/s]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [00:45<00:25,  1.47it/s]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [00:45<00:23,  1.50it/s]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [00:46<00:21,  1.64it/s]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [00:47<00:21,  1.56it/s]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [00:47<00:20,  1.62it/s]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [00:48<00:21,  1.50it/s]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [00:49<00:21,  1.44it/s]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [00:49<00:18,  1.58it/s]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [00:50<00:21,  1.37it/s]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [00:51<00:21,  1.31it/s]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [00:51<00:18,  1.49it/s]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [00:52<00:18,  1.44it/s]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [00:53<00:15,  1.58it/s]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [00:53<00:13,  1.72it/s]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [00:54<00:16,  1.37it/s]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [00:55<00:14,  1.47it/s]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [00:55<00:12,  1.65it/s]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [00:56<00:11,  1.73it/s]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [00:56<00:11,  1.60it/s]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [00:57<00:13,  1.34it/s]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [00:58<00:12,  1.40it/s]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [00:59<00:10,  1.46it/s]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [00:59<00:09,  1.60it/s]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [01:00<00:08,  1.69it/s]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [01:00<00:07,  1.65it/s]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [01:01<00:07,  1.69it/s]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [01:01<00:05,  1.86it/s]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [01:02<00:05,  1.78it/s]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [01:03<00:05,  1.63it/s]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [01:03<00:04,  1.76it/s]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [01:04<00:04,  1.55it/s]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [01:04<00:03,  1.75it/s]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [01:05<00:03,  1.52it/s]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [01:06<00:02,  1.56it/s]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [01:07<00:02,  1.46it/s]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [01:07<00:01,  1.42it/s]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [01:08<00:00,  1.36it/s]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [01:09<00:00,  1.44it/s]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

In [44]:
for epoch_info in epoch_info_list:
  print(epoch_info)

Epoch: 01, Train Loss: 10.688, Train PPL: 43816.560, Validation Loss:10.685, Validation PPL: 43682.609
Epoch: 02, Train Loss: 10.688, Train PPL: 43815.816, Validation Loss:10.685, Validation PPL: 43682.609
Epoch: 03, Train Loss: 10.688, Train PPL: 43819.184, Validation Loss:10.685, Validation PPL: 43682.609


In [45]:
#NOTE: load model from file
model.load_state_dict(torch.load('model.pt'))

test_loss = evaluate(model, test_iter, criterion)

print(f'\tTest Loss: {test_loss:.3f}\tTest PPL: {np.exp(test_loss):7.3f}')

hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

In [46]:
#clean mem
del model
del train_iter
del valid_iter
del test_iter
torch.cuda.empty_cache()

### Bahdanau attention

In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [28]:
BATCH_SIZE = 10
PAD_IDX = de_vocab['<pad>'] #find the indecies of the special tokens
BOS_IDX = de_vocab['<bos>']
EOS_IDX = de_vocab['<eos>']

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def generate_batch(data_batch):
    en_batch, de_batch = [], []
    for (en_item, de_item) in data_batch: #each sentences in a batch
        de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0)) #add a source sentence from the beginning (BOS) till the end (EOS)
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX, batch_first=True) #add padding to sentences if they are of different langth
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX, batch_first=True)
    return en_batch, de_batch

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch) #shuffle samples after each epoch
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=False, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=False, collate_fn=generate_batch)

In [42]:
INPUT_DIM = len(en_vocab)
OUTPUT_DIM = len(de_vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.2

#attn = LuongAttention(ENC_HID_DIM, DEC_HID_DIM)
#[YOUR CODE] Bahdanau attn
attn_bahdanau = BahdanauAttention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn_bahdanau)

model = Seq2Seq(enc, dec, device).to(device)

In [43]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(36926, 256)
    (rnn): GRU(256, 512, batch_first=True, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): BahdanauAttention(
      (Wa): Linear(in_features=512, out_features=512, bias=False)
      (Ua): Linear(in_features=1024, out_features=512, bias=False)
      (Va): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(42161, 256)
    (rnn): GRU(1280, 512, batch_first=True)
    (fc_out): Linear(in_features=1792, out_features=42161, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
    (attn): Linear(in_features=1536, out_features=512, bias=True)
    (v): Linear(in_features=512, out_features=1, bias=False)
  )
)

In [44]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 103,061,169 trainable parameters


In [33]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#[YOUR CODE] Adam lr 1e-3

In [34]:
TRG_PAD_IDX = de_vocab['<pad>'] #TODO
#https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

#[YOUR CODE] loss add ignore_index idx
#parameters
#torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)

In [45]:
N_EPOCHS = 3
CLIP = 1

epoch_info_list = []

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)

    if valid_loss < best_valid_loss:
      #[YOUR CODE] extra add BLEU model selection
      best_valid_loss = valid_loss
      torch.save(model.state_dict(), 'model.pt')

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}\tTrain PPL: {np.exp(train_loss):7.3f}')
    print(f'\t Validation Loss: {valid_loss:.3f}\tValidation PPL: {np.exp(valid_loss):7.3f}')

    epoch_info_list.append(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}, Train PPL: {np.exp(train_loss):7.3f}, Validation Loss:{valid_loss:.3f}, Validation PPL: {np.exp(valid_loss):7.3f}')


  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:21<35:42, 21.64s/it]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [00:45<37:19, 22.85s/it]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [01:09<38:05, 23.56s/it]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [01:43<44:26, 27.78s/it]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [02:18<47:46, 30.17s/it]

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [02:44<45:03, 28.76s/it]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [03:11<43:35, 28.12s/it]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [03:46<46:28, 30.31s/it]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [04:00<38:19, 25.27s/it]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [04:17<34:08, 22.76s/it]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [04:48<37:24, 25.22s/it]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [05:19<39:44, 27.10s/it]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [05:50<41:06, 28.36s/it]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [06:15<39:10, 27.34s/it]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [06:35<35:15, 24.89s/it]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [07:02<35:59, 25.71s/it]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [07:21<32:35, 23.56s/it]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [08:19<46:16, 33.86s/it]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [08:51<45:02, 33.37s/it]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [09:31<47:06, 35.33s/it]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [09:54<41:50, 31.78s/it]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [10:41<47:15, 36.35s/it]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [11:24<48:54, 38.11s/it]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [11:54<45:17, 35.75s/it]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [12:36<47:03, 37.65s/it]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [13:27<51:34, 41.82s/it]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [13:47<42:45, 35.14s/it]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [14:15<39:26, 32.86s/it]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [15:01<43:50, 37.05s/it]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [16:02<51:28, 44.13s/it]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [16:55<53:41, 46.69s/it]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [17:12<42:58, 37.92s/it]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [17:49<41:54, 37.52s/it]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [18:22<39:48, 36.18s/it]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [18:44<34:48, 32.13s/it]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [19:30<38:40, 36.25s/it]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [20:13<39:57, 38.05s/it]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [21:06<44:09, 42.74s/it]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [21:29<37:13, 36.62s/it]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [22:36<45:50, 45.85s/it]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [23:28<46:45, 47.56s/it]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [24:05<42:55, 44.40s/it]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [24:35<38:10, 40.19s/it]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [25:11<36:23, 38.98s/it]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [25:40<32:55, 35.92s/it]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [26:08<30:11, 33.54s/it]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [26:46<30:49, 34.89s/it]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [27:06<26:16, 30.33s/it]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [27:36<25:52, 30.44s/it]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [28:12<26:47, 32.15s/it]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [28:55<28:42, 35.16s/it]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [29:22<26:21, 32.94s/it]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [29:48<24:01, 30.67s/it]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [30:09<21:16, 27.75s/it]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [30:50<23:50, 31.78s/it]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [31:18<22:33, 30.77s/it]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [31:53<23:00, 32.11s/it]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [32:16<20:25, 29.18s/it]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [32:55<21:53, 32.04s/it]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [33:20<19:57, 29.95s/it]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [33:44<18:23, 28.29s/it]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [34:14<18:13, 28.77s/it]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [34:34<16:03, 26.05s/it]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [34:50<13:55, 23.20s/it]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [35:31<16:33, 28.38s/it]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [35:52<14:49, 26.17s/it]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [36:25<15:34, 28.32s/it]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [37:01<16:22, 30.70s/it]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [37:21<14:08, 27.37s/it]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [37:43<12:50, 25.68s/it]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [38:09<12:27, 25.77s/it]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [38:32<11:38, 24.93s/it]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [39:09<12:57, 28.81s/it]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [39:28<11:06, 25.63s/it]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [39:56<10:58, 26.32s/it]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [40:47<13:33, 33.91s/it]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [41:47<15:57, 41.63s/it]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [42:08<12:59, 35.44s/it]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [42:32<11:11, 31.99s/it]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [42:57<09:57, 29.88s/it]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [44:10<13:36, 42.96s/it]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [44:30<10:47, 35.98s/it]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [44:56<09:21, 33.02s/it]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [46:07<11:49, 44.33s/it]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [46:37<10:00, 40.03s/it]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [47:02<08:19, 35.70s/it]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [47:23<06:45, 31.19s/it]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [47:49<05:54, 29.52s/it]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [48:27<05:55, 32.28s/it]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [49:08<05:46, 34.65s/it]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [49:49<05:30, 36.69s/it]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [50:09<04:12, 31.62s/it]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [50:33<03:24, 29.26s/it]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [51:09<03:08, 31.44s/it]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [52:01<03:07, 37.51s/it]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [52:56<02:51, 42.90s/it]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [53:34<02:04, 41.48s/it]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [53:56<01:10, 35.48s/it]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [54:20<00:32, 32.22s/it]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [54:53<00:00, 32.93s/it]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:15<25:52, 15.68s/it]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [00:56<49:44, 30.46s/it]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [01:10<36:59, 22.88s/it]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [01:28<33:20, 20.83s/it]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [01:40<28:17, 17.87s/it]

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [01:55<26:25, 16.86s/it]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [02:44<42:26, 27.38s/it]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [03:24<48:06, 31.37s/it]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [03:39<39:45, 26.21s/it]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [04:03<38:32, 25.69s/it]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [04:18<32:56, 22.21s/it]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [04:41<32:55, 22.45s/it]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [04:58<30:22, 20.95s/it]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [05:19<30:04, 20.98s/it]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [05:43<30:46, 21.72s/it]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [05:58<27:33, 19.68s/it]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [06:12<24:59, 18.07s/it]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [07:07<39:44, 29.07s/it]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [07:38<39:59, 29.62s/it]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [08:16<42:57, 32.22s/it]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [08:30<35:10, 26.71s/it]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [09:01<36:34, 28.13s/it]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [09:19<32:18, 25.17s/it]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [09:39<29:48, 23.54s/it]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [10:07<30:57, 24.77s/it]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [10:21<26:48, 21.74s/it]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [11:09<35:48, 29.43s/it]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [11:27<31:09, 25.97s/it]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [12:05<35:14, 29.78s/it]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [12:32<33:28, 28.70s/it]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [12:54<30:41, 26.69s/it]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [13:11<27:02, 23.86s/it]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [13:43<29:32, 26.45s/it]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [14:20<32:18, 29.38s/it]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [14:49<31:59, 29.54s/it]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [15:18<31:17, 29.34s/it]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [15:41<28:39, 27.30s/it]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [16:03<26:45, 25.89s/it]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [16:44<30:56, 30.44s/it]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [17:03<26:57, 26.96s/it]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [17:29<26:10, 26.61s/it]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [17:48<23:32, 24.36s/it]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [18:13<23:08, 24.35s/it]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [18:38<23:09, 24.80s/it]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [19:07<23:38, 25.79s/it]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [19:45<26:38, 29.60s/it]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [20:07<24:07, 27.31s/it]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [20:27<21:50, 25.19s/it]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [21:32<31:37, 37.20s/it]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [22:30<36:06, 43.33s/it]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [22:50<29:39, 36.32s/it]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [23:41<32:27, 40.57s/it]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [24:22<31:55, 40.76s/it]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [24:47<27:39, 36.08s/it]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [25:11<24:23, 32.53s/it]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [25:39<22:46, 31.06s/it]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [26:08<21:52, 30.53s/it]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [26:37<20:56, 29.91s/it]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [26:54<17:58, 26.30s/it]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [28:00<25:22, 38.07s/it]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [28:33<23:44, 36.52s/it]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [29:08<22:51, 36.09s/it]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [29:45<22:31, 36.52s/it]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [30:56<28:01, 46.71s/it]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [31:19<23:09, 39.71s/it]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [31:46<20:14, 35.71s/it]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [32:11<18:00, 32.73s/it]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [33:00<20:00, 37.50s/it]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [33:20<16:38, 32.21s/it]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [33:42<14:33, 29.13s/it]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [34:07<13:29, 27.92s/it]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [34:56<16:01, 34.32s/it]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [35:33<15:43, 34.93s/it]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [35:58<13:51, 31.97s/it]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [36:27<12:57, 31.08s/it]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [37:04<13:12, 33.01s/it]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [37:37<12:34, 32.82s/it]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [38:11<12:13, 33.35s/it]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [39:08<14:09, 40.43s/it]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [39:49<13:29, 40.48s/it]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [40:29<12:45, 40.29s/it]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [41:00<11:19, 37.74s/it]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [41:21<09:11, 32.47s/it]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [41:47<08:11, 30.73s/it]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [42:08<06:58, 27.90s/it]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [42:45<07:07, 30.51s/it]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [43:18<06:47, 31.37s/it]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [43:44<05:53, 29.50s/it]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [44:24<05:59, 32.73s/it]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [44:47<04:57, 29.80s/it]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [45:18<04:30, 30.09s/it]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [46:18<05:12, 39.11s/it]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [46:39<03:56, 33.75s/it]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [47:03<03:04, 30.80s/it]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [47:35<02:35, 31.10s/it]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [47:57<01:53, 28.45s/it]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [48:29<01:28, 29.56s/it]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [49:14<01:08, 34.16s/it]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [49:50<00:34, 34.72s/it]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [50:08<00:00, 30.08s/it]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:50<1:24:02, 50.93s/it]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [01:40<1:21:33, 49.93s/it]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [02:04<1:01:42, 38.17s/it]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [02:41<1:00:10, 37.61s/it]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [03:10<55:00, 34.75s/it]  

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [03:36<49:48, 31.79s/it]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [04:05<47:38, 30.73s/it]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [04:26<42:37, 27.80s/it]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [04:51<40:46, 26.88s/it]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [05:40<50:19, 33.55s/it]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [06:03<44:53, 30.26s/it]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [06:29<42:52, 29.23s/it]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [07:04<44:36, 30.76s/it]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [07:30<41:57, 29.27s/it]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [07:46<36:00, 25.42s/it]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [08:08<34:08, 24.39s/it]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [08:36<35:26, 25.62s/it]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [09:10<38:07, 27.90s/it]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [10:06<49:04, 36.35s/it]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [10:36<45:50, 34.38s/it]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [11:00<41:33, 31.56s/it]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [11:27<39:02, 30.03s/it]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [11:52<36:38, 28.55s/it]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [12:59<50:37, 39.97s/it]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [13:17<41:57, 33.57s/it]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [13:51<41:18, 33.50s/it]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [14:26<41:17, 33.94s/it]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [15:02<41:27, 34.55s/it]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [15:26<37:09, 31.40s/it]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [16:22<45:14, 38.78s/it]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [16:47<39:50, 34.64s/it]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [17:09<35:02, 30.92s/it]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [17:43<35:36, 31.89s/it]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [18:10<33:33, 30.51s/it]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [18:30<29:40, 27.39s/it]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [19:04<31:21, 29.40s/it]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [19:29<29:19, 27.94s/it]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [19:46<25:31, 24.70s/it]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [20:10<24:50, 24.44s/it]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [20:40<26:06, 26.10s/it]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [21:05<25:14, 25.66s/it]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [21:37<26:51, 27.79s/it]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [22:33<34:12, 36.01s/it]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [23:09<33:50, 36.26s/it]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [23:32<29:28, 32.15s/it]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [23:58<27:10, 30.20s/it]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [24:19<24:27, 27.69s/it]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [25:02<27:56, 32.23s/it]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [25:26<25:13, 29.69s/it]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [25:58<25:21, 30.43s/it]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [26:29<24:54, 30.50s/it]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [27:26<30:53, 38.62s/it]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [27:57<28:27, 36.34s/it]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [28:50<31:29, 41.07s/it]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [29:24<29:23, 39.20s/it]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [29:50<25:39, 35.00s/it]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [30:12<22:20, 31.17s/it]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [30:36<20:16, 28.95s/it]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [31:16<22:12, 32.50s/it]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [31:36<19:06, 28.67s/it]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [32:16<20:46, 31.95s/it]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [33:03<23:13, 36.68s/it]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [33:39<22:24, 36.34s/it]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [34:36<25:33, 42.58s/it]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [34:57<21:02, 36.06s/it]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [35:36<20:55, 36.93s/it]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [36:25<22:23, 40.71s/it]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [37:01<20:55, 39.25s/it]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [37:42<20:33, 39.80s/it]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [38:23<20:02, 40.09s/it]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [38:50<17:26, 36.08s/it]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [39:06<14:02, 30.07s/it]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [39:30<12:46, 28.38s/it]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [39:52<11:23, 26.30s/it]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [40:19<11:05, 26.63s/it]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [41:04<12:48, 32.03s/it]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [41:25<11:00, 28.72s/it]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [42:02<11:27, 31.25s/it]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [42:34<11:04, 31.63s/it]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [43:28<12:46, 38.32s/it]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [44:41<15:23, 48.59s/it]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [45:10<12:49, 42.75s/it]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [45:48<11:41, 41.24s/it]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [46:40<11:53, 44.58s/it]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [47:20<10:48, 43.25s/it]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [47:44<08:44, 37.43s/it]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [48:08<07:13, 33.33s/it]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [48:22<05:29, 27.50s/it]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [48:45<04:47, 26.14s/it]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [49:07<04:08, 24.83s/it]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [49:25<03:27, 23.03s/it]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [50:06<03:46, 28.30s/it]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [50:36<03:22, 28.88s/it]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [50:58<02:39, 26.62s/it]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [51:23<02:11, 26.24s/it]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [51:52<01:47, 27.00s/it]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [52:34<01:34, 31.59s/it]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [53:39<01:23, 41.55s/it]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [54:29<00:44, 44.18s/it]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [54:52<00:00, 32.92s/it]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

In [46]:
for epoch_info in epoch_info_list:
  print(epoch_info)

Epoch: 01, Train Loss: 10.649, Train PPL: 42160.120, Validation Loss:10.649, Validation PPL: 42159.501
Epoch: 02, Train Loss: 10.649, Train PPL: 42160.473, Validation Loss:10.649, Validation PPL: 42159.501
Epoch: 03, Train Loss: 10.649, Train PPL: 42160.641, Validation Loss:10.649, Validation PPL: 42159.501


In [47]:
#NOTE: load model from file
model.load_state_dict(torch.load('model.pt'))

test_loss = evaluate(model, test_iter, criterion)

print(f'\tTest Loss: {test_loss:.3f}\tTest PPL: {np.exp(test_loss):7.3f}')

hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

In [None]:
#clean mem
del model
del train_iter
del valid_iter
del test_iter
torch.cuda.empty_cache()

### BLEU added

Seconadly, I have been training the model with BLEU score addded, with tranining data size = 1000, batch size =10, num of epoch = 3 and Luong attention

In [70]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [71]:
BATCH_SIZE = 10
PAD_IDX = de_vocab['<pad>'] #find the indecies of the special tokens
BOS_IDX = de_vocab['<bos>']
EOS_IDX = de_vocab['<eos>']

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def generate_batch(data_batch):
    en_batch, de_batch = [], []
    for (en_item, de_item) in data_batch: #each sentences in a batch
        de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0)) #add a source sentence from the beginning (BOS) till the end (EOS)
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX, batch_first=True) #add padding to sentences if they are of different langth
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX, batch_first=True)
    return en_batch, de_batch

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch) #shuffle samples after each epoch
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=False, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=False, collate_fn=generate_batch)

In [58]:
INPUT_DIM = len(en_vocab)
OUTPUT_DIM = len(de_vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.2

attn = LuongAttention(ENC_HID_DIM, DEC_HID_DIM)
#[YOUR CODE] Bahdanau attn
attn_bahdanau = BahdanauAttention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, device).to(device)

In [59]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(36926, 256)
    (rnn): GRU(256, 512, batch_first=True, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): LuongAttention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(42161, 256)
    (rnn): GRU(1280, 512, batch_first=True)
    (fc_out): Linear(in_features=1792, out_features=42161, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
    (attn): Linear(in_features=1536, out_features=512, bias=True)
    (v): Linear(in_features=512, out_features=1, bias=False)
  )
)

In [60]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 103,061,681 trainable parameters


In [61]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#[YOUR CODE] Adam lr 1e-3

In [62]:
TRG_PAD_IDX = de_vocab['<pad>'] #TODO
#https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

#[YOUR CODE] loss add ignore_index idx
#parameters
#torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)

I used the BLEU fro here: https://docs.pytorch.org/text/stable/data_metrics.html

* candidate_corpus – an iterable of candidate translations. Each translation is an iterable of tokens

* references_corpus – an iterable of iterables of reference translations. Each translation is an iterable of tokens

* max_n – the maximum n-gram we want to use. E.g. if max_n=3, we will use unigrams, bigrams and trigrams

* weights – a list of weights used for each n-gram category (uniform by default)

In [103]:
from torchtext.data.metrics import bleu_score
import torch

def calculate_bleu(model, data_iter, de_vocab):
    """
    Args:
        model: your Seq2Seq model
        data_iter: DataLoader yielding (src, trg) batches
        de_vocab: target vocabulary (with .get_itos())
        device: cuda or cpu
    """
    model.eval()
    trgs = []
    preds = []

    itos = de_vocab.get_itos()

    with torch.no_grad():
        for en_batch, de_batch in data_iter:
            en_batch, de_batch = en_batch.to(device), de_batch.to(device)

            output = model(en_batch, de_batch, 0)  # no teacher forcing
            output_tokens = output.argmax(2)  # [batch_size, trg_len]

            for i in range(de_batch.shape[0]):
                pred_ids = output_tokens[i].tolist()
                trg_ids = de_batch[i].tolist()

                # Convert to tokens
                pred_tokens = [itos[idx] for idx in pred_ids]
                trg_tokens = [itos[idx] for idx in trg_ids]

                # # Clean up tokens
                # if '<eos>' in pred_tokens:
                #     pred_tokens = pred_tokens[:pred_tokens.index('<eos>')]
                # if '<bos>' in pred_tokens:
                #     pred_tokens.remove('<bos>')

                # if '<eos>' in trg_tokens:
                #     trg_tokens = trg_tokens[1:trg_tokens.index('<eos>')]  # skip <bos>
                # else:
                #     trg_tokens = trg_tokens[1:]  # skip <bos>

                preds.append(pred_tokens)
                trgs.append([trg_tokens])  # Note: torchtext expects list of refs

    return bleu_score(preds, trgs)

In [104]:
N_EPOCHS = 3
CLIP = 1

epoch_info_list = []
best_valid_loss = float('inf')
best_bleu = 0.0

for epoch in range(N_EPOCHS):

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)
    bleu = calculate_bleu(model, valid_iter, de_vocab)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model_loss.pt')

    if bleu > best_bleu:
        best_bleu = bleu
        torch.save(model.state_dict(), 'model_bleu.pt')

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}\tTrain PPL: {np.exp(train_loss):7.3f}')
    print(f'\tValidation Loss: {valid_loss:.3f}\tValidation PPL: {np.exp(valid_loss):7.3f}')
    print(f'\tValidation BLEU: {bleu*100:.2f}')

    epoch_info_list.append(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}, Train PPL: {np.exp(train_loss):7.3f}, Validation Loss: {valid_loss:.3f}, Validation PPL: {np.exp(valid_loss):7.3f}, Validation BLEU: {bleu*100:.2f}')

  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:00<01:35,  1.03it/s]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [00:01<01:21,  1.20it/s]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [00:02<01:02,  1.56it/s]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [00:02<00:55,  1.72it/s]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [00:03<00:58,  1.64it/s]

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [00:04<01:03,  1.49it/s]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [00:04<01:05,  1.41it/s]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [00:05<01:08,  1.35it/s]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [00:06<01:04,  1.40it/s]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [00:07<01:11,  1.26it/s]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [00:08<01:14,  1.19it/s]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [00:08<01:06,  1.33it/s]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [00:09<01:17,  1.13it/s]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [00:10<01:09,  1.25it/s]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [00:11<01:00,  1.41it/s]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [00:11<00:49,  1.69it/s]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [00:11<00:43,  1.90it/s]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [00:12<00:46,  1.78it/s]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [00:13<00:50,  1.60it/s]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [00:13<00:47,  1.68it/s]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [00:14<00:48,  1.63it/s]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [00:14<00:47,  1.64it/s]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [00:15<00:45,  1.69it/s]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [00:16<00:55,  1.37it/s]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [00:17<00:51,  1.45it/s]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [00:18<00:56,  1.31it/s]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [00:19<00:59,  1.24it/s]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [00:19<00:53,  1.34it/s]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [00:21<01:08,  1.04it/s]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [00:22<01:06,  1.06it/s]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [00:22<01:03,  1.08it/s]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [00:23<00:56,  1.21it/s]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [00:23<00:48,  1.38it/s]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [00:24<00:46,  1.42it/s]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [00:25<00:45,  1.42it/s]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [00:26<00:44,  1.42it/s]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [00:26<00:47,  1.34it/s]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [00:27<00:45,  1.35it/s]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [00:28<00:47,  1.30it/s]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [00:29<00:46,  1.30it/s]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [00:29<00:39,  1.49it/s]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [00:30<00:36,  1.57it/s]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [00:30<00:35,  1.62it/s]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [00:31<00:33,  1.69it/s]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [00:31<00:33,  1.65it/s]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [00:32<00:32,  1.68it/s]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [00:33<00:35,  1.48it/s]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [00:34<00:43,  1.20it/s]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [00:35<00:41,  1.22it/s]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [00:35<00:37,  1.35it/s]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [00:36<00:33,  1.44it/s]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [00:36<00:30,  1.59it/s]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [00:38<00:39,  1.18it/s]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [00:39<00:38,  1.20it/s]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [00:39<00:33,  1.34it/s]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [00:40<00:32,  1.35it/s]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [00:41<00:33,  1.30it/s]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [00:41<00:29,  1.42it/s]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [00:42<00:28,  1.44it/s]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [00:43<00:27,  1.46it/s]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [00:44<00:32,  1.20it/s]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [00:44<00:28,  1.32it/s]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [00:46<00:34,  1.07it/s]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [00:46<00:31,  1.14it/s]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [00:47<00:25,  1.35it/s]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [00:48<00:25,  1.35it/s]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [00:48<00:23,  1.43it/s]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [00:49<00:23,  1.35it/s]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [00:50<00:25,  1.20it/s]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [00:51<00:22,  1.31it/s]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [00:51<00:18,  1.54it/s]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [00:52<00:19,  1.46it/s]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [00:52<00:16,  1.61it/s]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [00:53<00:14,  1.78it/s]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [00:53<00:13,  1.82it/s]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [00:54<00:12,  1.92it/s]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [00:55<00:13,  1.66it/s]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [00:56<00:16,  1.34it/s]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [00:56<00:12,  1.63it/s]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [00:57<00:12,  1.63it/s]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [00:57<00:12,  1.58it/s]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [00:58<00:10,  1.75it/s]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [00:58<00:09,  1.88it/s]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [00:59<00:08,  1.79it/s]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [00:59<00:08,  1.74it/s]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [01:00<00:08,  1.74it/s]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [01:00<00:07,  1.77it/s]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [01:01<00:07,  1.64it/s]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [01:02<00:06,  1.72it/s]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [01:03<00:07,  1.37it/s]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [01:03<00:06,  1.48it/s]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [01:04<00:05,  1.59it/s]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [01:04<00:03,  1.79it/s]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [01:05<00:03,  1.85it/s]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [01:06<00:03,  1.36it/s]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [01:07<00:02,  1.44it/s]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [01:07<00:02,  1.47it/s]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [01:08<00:01,  1.56it/s]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [01:08<00:00,  1.64it/s]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [01:09<00:00,  1.44it/s]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:00<01:27,  1.14it/s]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [00:01<01:27,  1.13it/s]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [00:02<01:39,  1.02s/it]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [00:04<01:39,  1.03s/it]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [00:04<01:17,  1.22it/s]

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [00:05<01:20,  1.17it/s]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [00:05<01:07,  1.38it/s]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [00:07<01:24,  1.09it/s]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [00:07<01:09,  1.31it/s]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [00:08<01:04,  1.40it/s]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [00:09<01:08,  1.31it/s]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [00:09<01:09,  1.27it/s]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [00:10<01:04,  1.35it/s]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [00:10<00:55,  1.54it/s]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [00:11<00:58,  1.45it/s]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [00:12<01:00,  1.39it/s]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [00:13<00:55,  1.48it/s]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [00:13<00:56,  1.45it/s]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [00:14<00:48,  1.67it/s]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [00:14<00:42,  1.88it/s]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [00:15<00:42,  1.85it/s]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [00:15<00:45,  1.71it/s]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [00:16<00:40,  1.91it/s]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [00:16<00:39,  1.95it/s]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [00:17<00:41,  1.82it/s]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [00:18<00:45,  1.64it/s]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [00:18<00:48,  1.51it/s]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [00:19<00:49,  1.45it/s]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [00:20<00:48,  1.48it/s]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [00:21<00:59,  1.19it/s]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [00:22<00:54,  1.26it/s]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [00:22<00:48,  1.40it/s]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [00:23<00:57,  1.17it/s]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [00:24<00:47,  1.38it/s]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [00:24<00:42,  1.53it/s]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [00:25<00:45,  1.41it/s]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [00:26<00:42,  1.49it/s]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [00:26<00:40,  1.54it/s]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [00:27<00:34,  1.75it/s]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [00:28<00:43,  1.38it/s]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [00:29<00:43,  1.37it/s]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [00:29<00:37,  1.54it/s]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [00:30<00:41,  1.37it/s]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [00:30<00:37,  1.51it/s]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [00:31<00:38,  1.43it/s]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [00:32<00:35,  1.54it/s]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [00:32<00:32,  1.62it/s]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [00:34<00:41,  1.26it/s]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [00:34<00:41,  1.23it/s]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [00:36<00:50,  1.01s/it]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [00:36<00:41,  1.17it/s]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [00:37<00:39,  1.22it/s]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [00:38<00:39,  1.19it/s]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [00:39<00:42,  1.09it/s]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [00:39<00:34,  1.30it/s]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [00:40<00:32,  1.36it/s]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [00:41<00:28,  1.52it/s]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [00:41<00:26,  1.60it/s]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [00:42<00:28,  1.42it/s]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [00:43<00:28,  1.42it/s]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [00:43<00:26,  1.46it/s]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [00:44<00:26,  1.46it/s]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [00:45<00:29,  1.27it/s]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [00:46<00:28,  1.27it/s]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [00:46<00:23,  1.47it/s]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [00:47<00:27,  1.24it/s]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [00:48<00:23,  1.38it/s]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [00:49<00:21,  1.47it/s]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [00:49<00:22,  1.41it/s]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [00:50<00:21,  1.38it/s]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [00:51<00:20,  1.44it/s]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [00:51<00:18,  1.54it/s]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [00:52<00:17,  1.57it/s]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [00:53<00:19,  1.31it/s]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [00:54<00:17,  1.39it/s]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [00:54<00:15,  1.52it/s]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [00:55<00:16,  1.36it/s]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [00:56<00:17,  1.23it/s]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [00:57<00:20,  1.02it/s]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [00:58<00:18,  1.08it/s]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [00:59<00:14,  1.29it/s]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [00:59<00:12,  1.40it/s]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [01:00<00:12,  1.39it/s]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [01:01<00:13,  1.20it/s]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [01:01<00:11,  1.34it/s]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [01:02<00:09,  1.45it/s]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [01:03<00:08,  1.51it/s]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [01:04<00:08,  1.37it/s]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [01:04<00:08,  1.34it/s]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [01:05<00:07,  1.33it/s]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [01:05<00:05,  1.53it/s]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [01:06<00:04,  1.66it/s]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [01:06<00:03,  1.82it/s]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [01:07<00:03,  1.64it/s]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [01:08<00:03,  1.59it/s]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [01:09<00:02,  1.47it/s]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [01:09<00:01,  1.52it/s]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [01:10<00:01,  1.68it/s]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [01:10<00:00,  1.70it/s]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [01:11<00:00,  1.40it/s]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:00<00:54,  1.83it/s]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [00:01<00:54,  1.78it/s]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [00:01<01:08,  1.42it/s]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [00:02<00:57,  1.66it/s]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [00:03<01:13,  1.30it/s]

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [00:04<01:16,  1.23it/s]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [00:04<01:08,  1.37it/s]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [00:05<01:03,  1.45it/s]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [00:06<01:00,  1.50it/s]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [00:07<01:08,  1.31it/s]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [00:07<01:05,  1.36it/s]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [00:08<01:02,  1.40it/s]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [00:09<00:56,  1.54it/s]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [00:09<00:54,  1.57it/s]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [00:10<00:51,  1.65it/s]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [00:10<00:55,  1.51it/s]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [00:11<00:56,  1.47it/s]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [00:12<00:53,  1.53it/s]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [00:13<01:05,  1.23it/s]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [00:13<00:57,  1.39it/s]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [00:14<01:01,  1.29it/s]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [00:16<01:10,  1.10it/s]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [00:17<01:14,  1.03it/s]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [00:17<01:09,  1.09it/s]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [00:18<01:01,  1.22it/s]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [00:19<00:59,  1.25it/s]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [00:20<00:59,  1.23it/s]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [00:20<00:53,  1.35it/s]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [00:21<00:59,  1.20it/s]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [00:22<00:56,  1.24it/s]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [00:23<00:49,  1.38it/s]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [00:23<00:47,  1.43it/s]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [00:24<00:44,  1.49it/s]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [00:24<00:42,  1.56it/s]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [00:26<00:55,  1.17it/s]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [00:26<00:47,  1.34it/s]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [00:27<00:53,  1.18it/s]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [00:28<00:49,  1.26it/s]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [00:29<00:48,  1.25it/s]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [00:29<00:44,  1.35it/s]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [00:30<00:47,  1.23it/s]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [00:31<00:45,  1.27it/s]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [00:32<00:39,  1.43it/s]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [00:32<00:42,  1.31it/s]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [00:33<00:39,  1.39it/s]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [00:34<00:34,  1.57it/s]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [00:34<00:35,  1.50it/s]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [00:35<00:31,  1.67it/s]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [00:35<00:31,  1.62it/s]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [00:36<00:29,  1.67it/s]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [00:36<00:27,  1.77it/s]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [00:37<00:28,  1.70it/s]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [00:38<00:30,  1.55it/s]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [00:38<00:28,  1.61it/s]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [00:39<00:24,  1.80it/s]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [00:40<00:27,  1.62it/s]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [00:40<00:25,  1.70it/s]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [00:41<00:31,  1.33it/s]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [00:42<00:27,  1.49it/s]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [00:42<00:24,  1.60it/s]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [00:43<00:29,  1.31it/s]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [00:44<00:25,  1.52it/s]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [00:44<00:23,  1.54it/s]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [00:45<00:24,  1.46it/s]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [00:46<00:23,  1.47it/s]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [00:47<00:25,  1.32it/s]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [00:47<00:22,  1.46it/s]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [00:48<00:22,  1.39it/s]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [00:49<00:20,  1.52it/s]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [00:49<00:17,  1.70it/s]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [00:50<00:16,  1.73it/s]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [00:50<00:16,  1.69it/s]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [00:51<00:18,  1.47it/s]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [00:52<00:18,  1.39it/s]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [00:53<00:18,  1.39it/s]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [00:53<00:16,  1.42it/s]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [00:54<00:15,  1.49it/s]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [00:55<00:15,  1.43it/s]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [00:55<00:14,  1.44it/s]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [00:57<00:17,  1.17it/s]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [00:57<00:14,  1.34it/s]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [00:58<00:12,  1.39it/s]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [00:59<00:14,  1.16it/s]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [00:59<00:11,  1.37it/s]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [01:00<00:11,  1.31it/s]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [01:01<00:09,  1.40it/s]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [01:01<00:08,  1.48it/s]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [01:02<00:08,  1.38it/s]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [01:03<00:07,  1.50it/s]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [01:04<00:07,  1.35it/s]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [01:04<00:06,  1.36it/s]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [01:05<00:06,  1.29it/s]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [01:07<00:06,  1.04it/s]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [01:07<00:04,  1.28it/s]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [01:07<00:03,  1.49it/s]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [01:08<00:02,  1.38it/s]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [01:09<00:01,  1.51it/s]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [01:09<00:01,  1.57it/s]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [01:11<00:00,  1.14it/s]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [01:12<00:00,  1.39it/s]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

In [105]:
for i in epoch_info_list:
  print(i)

Epoch: 01, Train Loss: 4.760, Train PPL: 116.754, Validation Loss: 11.217, Validation PPL: 74372.007, Validation BLEU: 0.00
Epoch: 02, Train Loss: 4.111, Train PPL:  61.033, Validation Loss: 11.739, Validation PPL: 125399.615, Validation BLEU: 0.00
Epoch: 03, Train Loss: 3.542, Train PPL:  34.525, Validation Loss: 12.014, Validation PPL: 165126.857, Validation BLEU: 0.00


In [106]:
#NOTE: load model from file
model.load_state_dict(torch.load('model_loss.pt'))
# model.load_state_dict(torch.load('model_bleu.pt'))

test_loss = evaluate(model, test_iter, criterion)

print(f'\tTest Loss: {test_loss:.3f}\tTest PPL: {np.exp(test_loss):7.3f}')

hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

In [108]:
#clean mem
del model
del train_iter
del valid_iter
del test_iter
torch.cuda.empty_cache()