<a href="https://colab.research.google.com/github/Saputoa21/Machine-Translation/blob/main/seq2seq_NMT_MTMA2025s_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Seq2seq NMT with RNN


[Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473)

**NOTE:**

-  use clean bpe data
-  use a piece of training data during coding or low in credits

You have to implement:

- Encoder
- Attention (Bahdanau)
- training loop
- extra: BLEU model selection

Goal:

- Loss in training, validation and test




In [2]:
!pip install torch==2.3.0
!pip install torchtext==0.18

Collecting torch==2.3.0
  Downloading torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch==2.3.0)
  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.3.0)
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.3.0)
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylin

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import random
import time

In [None]:
#if you dont have bpe data use sacremoese tokenizer
#!pip install sacremoses

In [4]:
#which libraries are we using!!?
!pip freeze > requirements.txt

In [5]:
!cat requirements.txt

absl-py==1.4.0
accelerate==1.6.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.15
aiosignal==1.3.2
alabaster==1.0.0
albucore==0.0.24
albumentations==2.0.6
ale-py==0.11.0
altair==5.5.0
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.9.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.7.2
arviz==0.21.0
astropy==7.0.2
astropy-iers-data==0.2025.5.12.0.38.29
astunparse==1.6.3
atpublic==5.1
attrs==25.3.0
audioread==3.0.1
autograd==1.8.0
babel==2.17.0
backcall==0.2.0
backports.tarfile==1.2.0
beautifulsoup4==4.13.4
betterproto==2.0.0b6
bigframes==2.4.0
bigquery-magics==0.9.0
bleach==6.2.0
blinker==1.9.0
blis==1.3.0
blobfile==3.0.0
blosc2==3.3.2
bokeh==3.7.3
Bottleneck==1.4.2
bqplot==0.12.44
branca==0.8.1
build==1.2.2.post1
CacheControl==0.14.3
cachetools==5.5.2
catalogue==2.0.10
certifi==2025.4.26
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.2
chex==0.1.89
clarabel==0.10.0
click==8.2.0
cloudpathlib==0.21.0
cloudpickle==3.1.1
cmake==3.31.6
cmdstanpy==1.2.5
colorcet

In [6]:
SEED = 42 #to reproduce the traning

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [7]:
import torchtext
dir(torchtext)

['_CACHE_DIR',
 '_TEXT_BUCKET',
 '_TORCHTEXT_DEPRECATION_MSG',
 '_WARN',
 '__all__',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '__version__',
 '_extension',
 '_get_torch_home',
 '_internal',
 '_torchtext',
 'git_version',
 'os',
 'version']

In [8]:
import torchtext
import torch
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab
#from torchtext.utils import download_from_url, extract_archive
import io

#0=en 1=de
# soure and target data
#NOTE: USE clean bpe data!

#for this task I used the clean bpe files (train 500k, test 500, dev 500)
#sorting the path to the files en-de
train_filepaths = ['train.en-de.clean.bpe.en', 'train.en-de.clean.bpe.de']
val_filepaths = ['dev.en-de.clean.bpe.en', 'dev.en-de.clean.bpe.de']
test_filepaths = ['test.en-de.clean.bpe.en', 'test.en-de.clean.bpe.de']

#TODO use clean bpe tokenized data!!!

#de_tokenizer = get_tokenizer('moses', language='de') for cases without bpe data
#en_tokenizer = get_tokenizer('moses', language='en')

de_tokenizer = None #None as there are tokenized data
en_tokenizer = None

def build_vocab(filepath, tokenizer=None):
  counter = Counter()
  with io.open(filepath, encoding="utf8") as f: #go to a file
    for string_ in f: #for each line in a file
      #counter.update(tokenizer(string_))
      counter.update(string_.split()) #split each line by space, i.e. creating tokens. Counter counts the tokens to create a database of the vocab
  return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>']) #vocav from torchtext containing the tokens and their counts, as well as specioal ones

#Vocab
en_vocab = build_vocab(train_filepaths[0], en_tokenizer) #create a vocab
de_vocab = build_vocab(train_filepaths[1], de_tokenizer)

print(dir(en_vocab))
en_vocab.set_default_index(en_vocab['<unk>']) #set <unk> as a default token
de_vocab.set_default_index(de_vocab['<unk>'])

def data_process(filepaths):
  raw_en_iter = iter(io.open(filepaths[0], encoding="utf8"))
  raw_de_iter = iter(io.open(filepaths[1], encoding="utf8"))
  data = []
  for (raw_en, raw_de) in zip(raw_en_iter, raw_de_iter): #for each sentence in the source and target
    en_tensor_ = torch.tensor([en_vocab[token] for token in raw_en.split()], #en_tokenizer(raw_en) #transforms tokens into tensors
                            dtype=torch.long)
    de_tensor_ = torch.tensor([de_vocab[token] for token in raw_de.split()], #de_tokenizer(raw_de)
                            dtype=torch.long)
    data.append((en_tensor_, de_tensor_)) #append to a list
  return data

#pre-process
train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)



['T_destination', '__annotations__', '__call__', '__class__', '__contains__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__jit_unused_properties__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__prepare_scriptable__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '

In [9]:
#NOTE: if you are low on credits or testing only use a piece of the data e.g. 20K segments
train_data = train_data[:20000]

In [10]:
len(train_data)
len(val_data)
len(test_data)

467

In [11]:
len(en_vocab)
len(de_vocab)

42161

Define the device.

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


Create the iterators.

In [13]:
BATCH_SIZE = 8
PAD_IDX = de_vocab['<pad>'] #find the indecies of the special tokens
BOS_IDX = de_vocab['<bos>']
EOS_IDX = de_vocab['<eos>']

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def generate_batch(data_batch):
    en_batch, de_batch = [], []
    for (en_item, de_item) in data_batch: #each sentences in a batch
        de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0)) #add a source sentence from the beginning (BOS) till the end (EOS)
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX, batch_first=True) #add padding to sentences if they are of different langth
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX, batch_first=True)
    return en_batch, de_batch

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch) #shuffle samples after each epoch
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=False, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=False, collate_fn=generate_batch)

## Building the Seq2Seq Model

### Encoder




In [18]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, emb_dim) #size of vocab (50k from clean bpe data), embedding dimensionality (256 from below)

        #[YOUR CODE] GRU(embeding size, encoder hidden size) NOTE: bidirectional batch_first
        # parameters for GRU in Pytorch
        # torch.nn.GRU(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, device=None, dtype=None)
        self.rnn = nn.GRU(input_size=emb_dim, hidden_size=enc_hid_dim, batch_first=True, bidirectional=True)

        #[YOUR CODE] linear(encoder hidden size * 2, decoder hidden size)
        # times 2 as we have a bidirectional model
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, src):

        #[B=8, S= longest sentence]
        embedded = self.embedding(src) #embed the source

        #[B, S, E] print(embedded.size())
        # the matrix grow by the additional dimension of the embedding size (256)

        embedded = self.dropout(embedded) #regularize the embeddings

        #[B, S, E]
        #[YOUR CODE] rrn(embeddings)
        outputs, hidden = self.rnn(embedded)

        #[B, S, H*2]

        #h[n layers * num directions, batch size, hid dim]

        #[forward_1, backward_1, forward_2, backward_2, ...]

        #[-2, :, : ] last state forward RNN
        #[-1, :, : ] last state backward RNN
        print('hid', hidden.size())

        #[YOUR CODE] last state forward RNN
        h1 = hidden[-2, :, : ]

        #[YOUR CODE] last state backward RNN
        h2 = hidden[-1, :, : ]

        #https://pytorch.org/docs/main/generated/torch.cat.html

        #[YOUR CODE] concatenate h1 amd h2 on seq dim
        h_cat = torch.cat([h1, h2], dim=1)

        #[YOUR CODE] tanh(linear(hidden_cat))
        hidden = torch.tanh(self.fc(h_cat))

        #[B, S, H*2]

        return outputs, hidden # outputs = vector for each sentence, hidden = concatenated vector for forward and backward

# Attention

## Luong Attention




Formula: score(target hidden state, each source hidden stare) = target hidden state transposed * W * eachsource hidden state)

In [22]:
class LuongAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()

        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)

    def forward(self, hidden, encoder_outputs): #keys, query  hidden = target

        #[batch size, dec hid dim]
        #[src len, batch size, enc hid dim * 2]

        batch_size = encoder_outputs.shape[0] #to see the size
        src_len = encoder_outputs.shape[1]

        #x times decoder hidden state for the size of the src_len
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) #repatn the hidden state over source

        #[batch size, src len, dec hid dim]

        scores = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) #concateneted in the embedding dimension
        # attn = linear layer
        # tanh = non-linear

        #[batch size, src len, dec hid dim]

        attention = self.v(scores).squeeze(2) # v is W in the formula

        #[batch size, src len]

        return F.softmax(attention, dim=1)

Formula: score(target hidden state, each source hidden stare) = Va transposed * tanh(W1 * target hidden state + W2 * each source hidden state)

In [21]:
class BahdanauAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(dec_hid_dim, dec_hid_dim, bias=False)
        self.Ua = nn.Linear(enc_hid_dim * 2, dec_hid_dim, bias=False)
        self.Va = nn.Linear(dec_hid_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs): #keys, query
        #[YOUR CODE]
        batch_size = encoder_outputs.shape(0)
        src_len = encoder_outputs.shape(1)

        #x times decoder hidden state for src_len
        #[YOUR CODE]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # pytorch - Returns a new tensor with a dimension of size one inserted at the specified position.

        """
        >>> x = torch.tensor([1, 2, 3, 4])
        >>> torch.unsqueeze(x, 0)
        tensor([[ 1,  2,  3,  4]])
        >>> torch.unsqueeze(x, 1)
        tensor([[ 1],
                [ 2],
                [ 3],
                [ 4]])
        """

        #[YOUR CODE] Va(tanh(Wa(hidden) + Ua(encoder outputs)))
        scores = self.Va(torch.tanh(self.Wa(hidden) + self.Ua(encoder_outputs)))

        scores = scores.squeeze(2) #Returns a tensor with all specified dimensions of input of size 1 removed.

        """
        For example, if input is of shape: (A×1×B×C×1×D) then the input.squeeze() will be of shape: (A×B×C×D).
        When dim is given, a squeeze operation is done only in the given dimension(s). If input is of shape:
        (A×1×B), squeeze(input, 0) leaves the tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape (A×B).
        """

        #[YOUR CODE] softmax(scores, dim seq)
        weights =  torch.softmax(scores, dim=1)

        return weights

### Decoder



In [23]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim

        self.attention = attention

        self.embedding = nn.Embedding(output_dim, emb_dim)

        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim, batch_first=True)

        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)

        self.dropout = nn.Dropout(dropout)

        #attention
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)

    def forward(self, input, hidden, encoder_outputs):

        #[batch size]
        #[batch size, dec hid dim]
        #[src len, batch size, enc hid dim * 2]


        input = input.unsqueeze(0)

        #[1, batch size]

        embedded = self.dropout(self.embedding(input))

        #[1, batch size, emb dim]

        a = self.attention(hidden, encoder_outputs)

        #[batch size, src len]

        a = a.unsqueeze(1)

        #[batch size, 1, src len]

        #[batch size, src len, enc hid dim * 2]

        weighted = torch.bmm(a, encoder_outputs)

        #[batch size, 1, enc hid dim * 2]

        weighted = weighted.permute(1, 0, 2)

        #[1, batch size, enc hid dim * 2]

        rnn_input = torch.cat((embedded, weighted), dim = 2)

        #[1, batch size, (enc hid dim * 2) + emb dim]
        #[B, 1, (enc hid dim * 2) + emb dim]

        rnn_input = rnn_input.permute(1, 0, 2)
        hidden = hidden.unsqueeze(0)

        output, hidden = self.rnn(rnn_input, hidden)

        #[seq len, batch size, dec hid dim * n directions]
        #[n layers * n directions, batch size, dec hid dim]

        #[1, batch size, dec hid dim]
        #[1, batch size, dec hid dim]

        embedded = embedded.squeeze(0)
        output = output.squeeze(1)
        weighted = weighted.squeeze(0)

        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))

        #[batch size, output dim]

        return prediction, hidden.squeeze(0)

### Seq2Seq




In [24]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device


    def forward(self, src, trg, teacher_forcing_ratio = 0.5):

        #[src len, batch size]
        #[trg len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        # 0.75 teacher forcing 75% of the time

        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim


        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        #encoder_outputs is all hidden states of the input sequence
        #hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src)

        #first input to the decoder is the <sos> tokens
        input = trg[:,0]
        # unroll RNN
        for t in range(1, trg_len):

            #insert input token embedding, previous hidden state and all encoder hidden states

            output, hidden = self.decoder(input, hidden, encoder_outputs)

            #predictions
            outputs[:, t] = output

            #teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio

            #greedy search
            top1 = output.argmax(1)

            #if teacher forcing, use gold token as next input
            #if not, use predicted token
            input = trg[:, t] if teacher_force else top1

        return outputs

## Training the Seq2Seq Model



In [25]:
INPUT_DIM = len(en_vocab)
OUTPUT_DIM = len(de_vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.2

attn = LuongAttention(ENC_HID_DIM, DEC_HID_DIM)
#[YOUR CODE] Bahdanau att
#attn = BahdanauAttention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, device).to(device)

In [26]:
print(len(en_vocab)) #BPE size 16k approx
print(len(de_vocab))

36926
42161


In [27]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(36926, 256)
    (rnn): GRU(256, 512, batch_first=True, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): LuongAttention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(42161, 256)
    (rnn): GRU(1280, 512, batch_first=True)
    (fc_out): Linear(in_features=1792, out_features=42161, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
    (attn): Linear(in_features=1536, out_features=512, bias=True)
    (v): Linear(in_features=512, out_features=1, bias=False)
  )
)

In [28]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 103,061,681 trainable parameters


We create an optimizer.

In [29]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#[YOUR CODE] Adam lr 1e-3

We initialize the loss function.

In [30]:
TRG_PAD_IDX = de_vocab['<pad>'] #TODO
#https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

#[YOUR CODE] loss add ignore_index idx
#parameters
#torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)

In [31]:
print(TRG_PAD_IDX)

1


In [32]:
def train(model, iterator, optimizer, criterion, clip):

    #[YOUR CODE] set model train model.train()
    model.train()

    epoch_loss = 0

    for (src, trg) in tqdm(iterator):

        #[YOUR CODE] to gpu
        src = src.to(device)
        trg = trg.to(device)

        #[YOUR CODE] optimizer zero grad
        optimizer.zero_grad()

        #[YOUR CODE] model()
        output = model(src, trg)

        #[trg len, batch size]
        #[trg len, batch size, output dim]

        output = output.permute(1, 0, 2)

        output_dim = output.shape[-1]
        trg = trg.permute(1, 0)

        output = output[1:].reshape(-1, output_dim)
        trg = trg[1:].reshape(-1)

        #[(trg len - 1) * batch size]
        #[(trg len - 1) * batch size, output dim]

        #[YOUR CODE] criterion()
        loss = criterion(output, trg)

        #[YOUR CODE] loss backward
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        #[YOUR CODE] optimizer step
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [33]:
def evaluate(model, iterator, criterion):

    #[YOUR CODE] set model test/eval  model.eval()
    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        for (src, trg) in iterator:
            #[YOUR CODE] to device
            src = src.to(device)
            trg = trg.to(device)

            #[YOUR CODE] model() turn off teacher forcing 0
            output = model(src, trg, teacher_force = 0)

            #trg = [trg len, batch size]
            #output = [trg len, batch size, output dim]

            output = output.permute(1, 0, 2)
            output_dim = output.shape[-1]
            trg = trg.permute(1, 0)
            output = output[1:].reshape(-1, output_dim)
            trg = trg[1:].reshape(-1)

            #trg = [(trg len - 1) * batch size]
            #output = [(trg len - 1) * batch size, output dim]

            #[YOUR CODE] criterion()
            loss = criterion(output, trg)

            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [None]:
N_EPOCHS = 5
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)

    if valid_loss < best_valid_loss:
      #[YOUR CODE] extra add BLEU model selection
      best_valid_loss = valid_loss
      #torch.save(model.state_dict(), 'model.pt')

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}\tTrain PPL: {np.exp(train_loss):7.3f}')
    print(f'\t Validation Loss: {valid_loss:.3f}\tValidation PPL: {np.exp(valid_loss):7.3f}')

  0%|          | 0/2500 [00:00<?, ?it/s]

hid torch.Size([2, 8, 512])


  0%|          | 1/2500 [00:01<56:54,  1.37s/it]

hid torch.Size([2, 8, 512])


  0%|          | 2/2500 [00:01<36:13,  1.15it/s]

hid torch.Size([2, 8, 512])


  0%|          | 3/2500 [00:02<29:44,  1.40it/s]

hid torch.Size([2, 8, 512])


  0%|          | 4/2500 [00:02<24:58,  1.67it/s]

hid torch.Size([2, 8, 512])


  0%|          | 5/2500 [00:03<25:10,  1.65it/s]

hid torch.Size([2, 8, 512])


  0%|          | 6/2500 [00:03<23:58,  1.73it/s]

hid torch.Size([2, 8, 512])


  0%|          | 7/2500 [00:04<23:48,  1.75it/s]

hid torch.Size([2, 8, 512])


  0%|          | 8/2500 [00:05<24:22,  1.70it/s]

hid torch.Size([2, 8, 512])


  0%|          | 9/2500 [00:05<23:55,  1.74it/s]

hid torch.Size([2, 8, 512])


  0%|          | 10/2500 [00:06<22:31,  1.84it/s]

hid torch.Size([2, 8, 512])


  0%|          | 11/2500 [00:07<28:54,  1.43it/s]

hid torch.Size([2, 8, 512])


  0%|          | 12/2500 [00:08<30:25,  1.36it/s]

hid torch.Size([2, 8, 512])


  1%|          | 13/2500 [00:08<28:31,  1.45it/s]

hid torch.Size([2, 8, 512])


  1%|          | 14/2500 [00:09<26:21,  1.57it/s]

hid torch.Size([2, 8, 512])


  1%|          | 15/2500 [00:09<27:21,  1.51it/s]

hid torch.Size([2, 8, 512])


  1%|          | 16/2500 [00:10<32:44,  1.26it/s]

hid torch.Size([2, 8, 512])


  1%|          | 17/2500 [00:12<36:59,  1.12it/s]

hid torch.Size([2, 8, 512])


  1%|          | 18/2500 [00:13<37:27,  1.10it/s]

hid torch.Size([2, 8, 512])


  1%|          | 19/2500 [00:13<35:21,  1.17it/s]

hid torch.Size([2, 8, 512])


  1%|          | 20/2500 [00:14<38:42,  1.07it/s]

hid torch.Size([2, 8, 512])


  1%|          | 21/2500 [00:15<35:22,  1.17it/s]

hid torch.Size([2, 8, 512])


  1%|          | 22/2500 [00:16<34:51,  1.18it/s]

hid torch.Size([2, 8, 512])


  1%|          | 23/2500 [00:16<30:37,  1.35it/s]

hid torch.Size([2, 8, 512])


  1%|          | 24/2500 [00:17<26:30,  1.56it/s]

hid torch.Size([2, 8, 512])


  1%|          | 25/2500 [00:17<25:56,  1.59it/s]

hid torch.Size([2, 8, 512])


  1%|          | 26/2500 [00:18<27:51,  1.48it/s]

hid torch.Size([2, 8, 512])


  1%|          | 27/2500 [00:19<27:03,  1.52it/s]

hid torch.Size([2, 8, 512])


  1%|          | 28/2500 [00:19<26:25,  1.56it/s]

hid torch.Size([2, 8, 512])


  1%|          | 29/2500 [00:20<25:55,  1.59it/s]

hid torch.Size([2, 8, 512])


  1%|          | 30/2500 [00:20<23:24,  1.76it/s]

hid torch.Size([2, 8, 512])


  1%|          | 31/2500 [00:21<22:40,  1.82it/s]

hid torch.Size([2, 8, 512])


  1%|▏         | 32/2500 [00:22<24:42,  1.67it/s]

hid torch.Size([2, 8, 512])


  1%|▏         | 33/2500 [00:23<28:26,  1.45it/s]

hid torch.Size([2, 8, 512])


  1%|▏         | 34/2500 [00:23<27:23,  1.50it/s]

hid torch.Size([2, 8, 512])


  1%|▏         | 35/2500 [00:24<26:16,  1.56it/s]

hid torch.Size([2, 8, 512])


  1%|▏         | 36/2500 [00:24<26:23,  1.56it/s]

hid torch.Size([2, 8, 512])


  1%|▏         | 37/2500 [00:25<24:14,  1.69it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 38/2500 [00:25<23:19,  1.76it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 39/2500 [00:26<26:45,  1.53it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 40/2500 [00:27<29:16,  1.40it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 41/2500 [00:28<25:57,  1.58it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 42/2500 [00:28<27:23,  1.50it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 43/2500 [00:29<27:50,  1.47it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 44/2500 [00:29<25:02,  1.63it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 45/2500 [00:30<22:35,  1.81it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 46/2500 [00:31<26:38,  1.53it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 47/2500 [00:31<24:10,  1.69it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 48/2500 [00:32<27:05,  1.51it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 49/2500 [00:33<26:23,  1.55it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 50/2500 [00:33<23:51,  1.71it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 51/2500 [00:34<23:08,  1.76it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 52/2500 [00:34<22:55,  1.78it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 53/2500 [00:35<23:50,  1.71it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 54/2500 [00:35<24:04,  1.69it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 55/2500 [00:36<23:42,  1.72it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 56/2500 [00:37<27:50,  1.46it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 57/2500 [00:37<24:23,  1.67it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 58/2500 [00:38<25:37,  1.59it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 59/2500 [00:39<25:29,  1.60it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 60/2500 [00:39<27:53,  1.46it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 61/2500 [00:40<29:13,  1.39it/s]

hid torch.Size([2, 8, 512])


  2%|▏         | 62/2500 [00:41<28:39,  1.42it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 63/2500 [00:42<35:33,  1.14it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 64/2500 [00:43<32:23,  1.25it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 65/2500 [00:43<30:19,  1.34it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 66/2500 [00:44<26:13,  1.55it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 67/2500 [00:44<23:20,  1.74it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 68/2500 [00:45<27:10,  1.49it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 69/2500 [00:46<27:28,  1.47it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 70/2500 [00:46<24:55,  1.63it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 71/2500 [00:47<28:36,  1.42it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 72/2500 [00:48<27:01,  1.50it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 73/2500 [00:48<22:48,  1.77it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 74/2500 [00:48<20:36,  1.96it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 75/2500 [00:49<25:43,  1.57it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 76/2500 [00:50<24:40,  1.64it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 77/2500 [00:50<23:13,  1.74it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 78/2500 [00:51<22:46,  1.77it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 79/2500 [00:52<22:20,  1.81it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 80/2500 [00:52<26:41,  1.51it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 81/2500 [00:53<26:23,  1.53it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 82/2500 [00:54<25:47,  1.56it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 83/2500 [00:54<23:46,  1.69it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 84/2500 [00:55<24:28,  1.64it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 85/2500 [00:55<25:36,  1.57it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 86/2500 [00:56<24:44,  1.63it/s]

hid torch.Size([2, 8, 512])


  3%|▎         | 87/2500 [00:57<25:52,  1.55it/s]

hid torch.Size([2, 8, 512])


  4%|▎         | 88/2500 [00:57<23:58,  1.68it/s]

hid torch.Size([2, 8, 512])


  4%|▎         | 89/2500 [00:58<23:33,  1.71it/s]

hid torch.Size([2, 8, 512])


  4%|▎         | 90/2500 [00:58<23:03,  1.74it/s]

hid torch.Size([2, 8, 512])


  4%|▎         | 91/2500 [00:59<22:54,  1.75it/s]

hid torch.Size([2, 8, 512])


  4%|▎         | 92/2500 [00:59<20:48,  1.93it/s]

hid torch.Size([2, 8, 512])


  4%|▎         | 93/2500 [01:00<23:36,  1.70it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 94/2500 [01:01<27:55,  1.44it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 95/2500 [01:02<36:19,  1.10it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 96/2500 [01:03<29:07,  1.38it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 97/2500 [01:03<25:52,  1.55it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 98/2500 [01:04<24:48,  1.61it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 99/2500 [01:04<24:23,  1.64it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 100/2500 [01:05<25:50,  1.55it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 101/2500 [01:05<22:22,  1.79it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 102/2500 [01:06<21:10,  1.89it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 103/2500 [01:07<24:44,  1.61it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 104/2500 [01:07<25:10,  1.59it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 105/2500 [01:08<24:43,  1.61it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 106/2500 [01:09<25:44,  1.55it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 107/2500 [01:09<22:29,  1.77it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 108/2500 [01:10<22:41,  1.76it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 109/2500 [01:10<22:11,  1.80it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 110/2500 [01:11<23:11,  1.72it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 111/2500 [01:11<22:09,  1.80it/s]

hid torch.Size([2, 8, 512])


  4%|▍         | 112/2500 [01:12<22:47,  1.75it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 113/2500 [01:12<21:03,  1.89it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 114/2500 [01:13<20:51,  1.91it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 115/2500 [01:13<18:53,  2.10it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 116/2500 [01:14<18:43,  2.12it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 117/2500 [01:14<21:35,  1.84it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 118/2500 [01:15<21:37,  1.84it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 119/2500 [01:15<20:46,  1.91it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 120/2500 [01:16<21:05,  1.88it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 121/2500 [01:16<19:02,  2.08it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 122/2500 [01:17<18:02,  2.20it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 123/2500 [01:17<20:57,  1.89it/s]

hid torch.Size([2, 8, 512])


  5%|▍         | 124/2500 [01:18<22:19,  1.77it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 125/2500 [01:19<22:13,  1.78it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 126/2500 [01:19<21:54,  1.81it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 127/2500 [01:19<19:08,  2.07it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 128/2500 [01:20<19:30,  2.03it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 129/2500 [01:20<19:36,  2.02it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 130/2500 [01:21<23:08,  1.71it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 131/2500 [01:22<23:11,  1.70it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 132/2500 [01:23<24:28,  1.61it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 133/2500 [01:23<25:32,  1.54it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 134/2500 [01:24<28:57,  1.36it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 135/2500 [01:25<30:46,  1.28it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 136/2500 [01:26<33:35,  1.17it/s]

hid torch.Size([2, 8, 512])


  5%|▌         | 137/2500 [01:27<29:53,  1.32it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 138/2500 [01:27<26:39,  1.48it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 139/2500 [01:28<25:26,  1.55it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 140/2500 [01:28<24:48,  1.59it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 141/2500 [01:29<24:32,  1.60it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 142/2500 [01:30<28:30,  1.38it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 143/2500 [01:30<26:16,  1.50it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 144/2500 [01:31<27:05,  1.45it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 145/2500 [01:32<26:57,  1.46it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 146/2500 [01:33<28:27,  1.38it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 147/2500 [01:33<28:21,  1.38it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 148/2500 [01:34<23:09,  1.69it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 149/2500 [01:34<22:39,  1.73it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 150/2500 [01:35<23:13,  1.69it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 151/2500 [01:36<30:41,  1.28it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 152/2500 [01:37<30:32,  1.28it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 153/2500 [01:37<28:17,  1.38it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 154/2500 [01:38<26:30,  1.48it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 155/2500 [01:38<24:22,  1.60it/s]

hid torch.Size([2, 8, 512])


  6%|▌         | 156/2500 [01:39<23:16,  1.68it/s]

hid torch.Size([2, 8, 512])


  6%|▋         | 157/2500 [01:40<22:13,  1.76it/s]

hid torch.Size([2, 8, 512])


  6%|▋         | 158/2500 [01:40<20:41,  1.89it/s]

hid torch.Size([2, 8, 512])


  6%|▋         | 159/2500 [01:41<21:26,  1.82it/s]

hid torch.Size([2, 8, 512])


  6%|▋         | 160/2500 [01:41<24:27,  1.59it/s]

hid torch.Size([2, 8, 512])


  6%|▋         | 161/2500 [01:42<22:50,  1.71it/s]

hid torch.Size([2, 8, 512])


  6%|▋         | 162/2500 [01:43<25:26,  1.53it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 163/2500 [01:43<23:08,  1.68it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 164/2500 [01:44<21:29,  1.81it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 165/2500 [01:44<21:00,  1.85it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 166/2500 [01:45<23:49,  1.63it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 167/2500 [01:46<24:54,  1.56it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 168/2500 [01:47<30:50,  1.26it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 169/2500 [01:47<27:55,  1.39it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 170/2500 [01:48<25:01,  1.55it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 171/2500 [01:49<30:39,  1.27it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 172/2500 [01:49<27:11,  1.43it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 173/2500 [01:50<26:09,  1.48it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 174/2500 [01:51<27:04,  1.43it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 175/2500 [01:51<24:41,  1.57it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 176/2500 [01:52<28:00,  1.38it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 177/2500 [01:53<26:55,  1.44it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 178/2500 [01:53<25:13,  1.53it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 179/2500 [01:54<26:28,  1.46it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 180/2500 [01:55<24:27,  1.58it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 181/2500 [01:56<30:29,  1.27it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 182/2500 [01:56<27:51,  1.39it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 183/2500 [01:57<30:11,  1.28it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 184/2500 [01:58<26:33,  1.45it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 185/2500 [01:58<24:06,  1.60it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 186/2500 [01:59<29:10,  1.32it/s]

hid torch.Size([2, 8, 512])


  7%|▋         | 187/2500 [02:00<26:59,  1.43it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 188/2500 [02:01<29:17,  1.32it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 189/2500 [02:01<26:40,  1.44it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 190/2500 [02:02<26:16,  1.47it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 191/2500 [02:02<23:19,  1.65it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 192/2500 [02:03<23:43,  1.62it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 193/2500 [02:04<24:06,  1.60it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 194/2500 [02:04<24:18,  1.58it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 195/2500 [02:05<23:23,  1.64it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 196/2500 [02:06<24:57,  1.54it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 197/2500 [02:06<22:56,  1.67it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 198/2500 [02:07<22:15,  1.72it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 199/2500 [02:07<22:13,  1.73it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 200/2500 [02:08<20:41,  1.85it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 201/2500 [02:08<19:56,  1.92it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 202/2500 [02:09<23:37,  1.62it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 203/2500 [02:09<21:35,  1.77it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 204/2500 [02:10<23:57,  1.60it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 205/2500 [02:11<20:51,  1.83it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 206/2500 [02:11<20:18,  1.88it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 207/2500 [02:12<22:47,  1.68it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 208/2500 [02:12<22:58,  1.66it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 209/2500 [02:13<26:14,  1.46it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 210/2500 [02:14<26:07,  1.46it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 211/2500 [02:15<30:23,  1.26it/s]

hid torch.Size([2, 8, 512])


  8%|▊         | 212/2500 [02:16<28:12,  1.35it/s]

hid torch.Size([2, 8, 512])


  9%|▊         | 213/2500 [02:16<23:35,  1.62it/s]

hid torch.Size([2, 8, 512])


  9%|▊         | 214/2500 [02:16<21:46,  1.75it/s]

hid torch.Size([2, 8, 512])


  9%|▊         | 215/2500 [02:17<21:08,  1.80it/s]

hid torch.Size([2, 8, 512])


  9%|▊         | 216/2500 [02:18<22:53,  1.66it/s]

hid torch.Size([2, 8, 512])


  9%|▊         | 217/2500 [02:18<22:20,  1.70it/s]

hid torch.Size([2, 8, 512])


  9%|▊         | 218/2500 [02:19<22:38,  1.68it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 219/2500 [02:19<22:42,  1.67it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 220/2500 [02:20<26:16,  1.45it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 221/2500 [02:21<24:20,  1.56it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 222/2500 [02:21<24:32,  1.55it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 223/2500 [02:22<22:19,  1.70it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 224/2500 [02:22<21:17,  1.78it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 225/2500 [02:23<21:07,  1.80it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 226/2500 [02:24<21:42,  1.75it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 227/2500 [02:24<23:17,  1.63it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 228/2500 [02:25<22:36,  1.67it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 229/2500 [02:26<23:09,  1.63it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 230/2500 [02:26<21:32,  1.76it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 231/2500 [02:27<21:50,  1.73it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 232/2500 [02:27<20:48,  1.82it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 233/2500 [02:28<19:46,  1.91it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 234/2500 [02:28<19:38,  1.92it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 235/2500 [02:29<23:44,  1.59it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 236/2500 [02:30<23:29,  1.61it/s]

hid torch.Size([2, 8, 512])


  9%|▉         | 237/2500 [02:30<21:08,  1.78it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 238/2500 [02:31<24:24,  1.54it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 239/2500 [02:32<26:34,  1.42it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 240/2500 [02:32<24:46,  1.52it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 241/2500 [02:33<24:20,  1.55it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 242/2500 [02:33<24:06,  1.56it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 243/2500 [02:34<24:41,  1.52it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 244/2500 [02:35<25:39,  1.46it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 245/2500 [02:36<25:11,  1.49it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 246/2500 [02:36<25:07,  1.50it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 247/2500 [02:37<24:48,  1.51it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 248/2500 [02:37<21:54,  1.71it/s]

hid torch.Size([2, 8, 512])


 10%|▉         | 249/2500 [02:38<22:20,  1.68it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 250/2500 [02:39<28:56,  1.30it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 251/2500 [02:40<26:48,  1.40it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 252/2500 [02:40<22:57,  1.63it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 253/2500 [02:40<21:27,  1.74it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 254/2500 [02:41<20:52,  1.79it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 255/2500 [02:42<21:40,  1.73it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 256/2500 [02:42<21:02,  1.78it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 257/2500 [02:43<20:56,  1.79it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 258/2500 [02:43<22:40,  1.65it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 259/2500 [02:44<23:00,  1.62it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 260/2500 [02:45<25:08,  1.48it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 261/2500 [02:45<24:37,  1.51it/s]

hid torch.Size([2, 8, 512])


 10%|█         | 262/2500 [02:47<32:09,  1.16it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 263/2500 [02:47<28:22,  1.31it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 264/2500 [02:48<25:27,  1.46it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 265/2500 [02:48<22:58,  1.62it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 266/2500 [02:49<25:02,  1.49it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 267/2500 [02:50<23:53,  1.56it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 268/2500 [02:50<25:18,  1.47it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 269/2500 [02:51<22:40,  1.64it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 270/2500 [02:52<23:05,  1.61it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 271/2500 [02:52<26:52,  1.38it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 272/2500 [02:53<27:12,  1.36it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 273/2500 [02:54<24:24,  1.52it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 274/2500 [02:54<22:08,  1.68it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 275/2500 [02:55<20:50,  1.78it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 276/2500 [02:55<19:01,  1.95it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 277/2500 [02:56<18:45,  1.97it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 278/2500 [02:56<18:35,  1.99it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 279/2500 [02:56<17:38,  2.10it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 280/2500 [02:57<18:45,  1.97it/s]

hid torch.Size([2, 8, 512])


 11%|█         | 281/2500 [02:58<24:15,  1.52it/s]

hid torch.Size([2, 8, 512])


 11%|█▏        | 282/2500 [02:59<24:01,  1.54it/s]

hid torch.Size([2, 8, 512])


 11%|█▏        | 283/2500 [02:59<25:15,  1.46it/s]

hid torch.Size([2, 8, 512])


 11%|█▏        | 284/2500 [03:00<22:57,  1.61it/s]

hid torch.Size([2, 8, 512])


 11%|█▏        | 285/2500 [03:00<21:29,  1.72it/s]

hid torch.Size([2, 8, 512])


 11%|█▏        | 286/2500 [03:01<20:47,  1.78it/s]

hid torch.Size([2, 8, 512])


 11%|█▏        | 287/2500 [03:02<21:59,  1.68it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 288/2500 [03:02<23:09,  1.59it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 289/2500 [03:03<21:14,  1.73it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 290/2500 [03:03<22:55,  1.61it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 291/2500 [03:04<21:55,  1.68it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 292/2500 [03:05<22:09,  1.66it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 293/2500 [03:05<22:34,  1.63it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 294/2500 [03:06<25:54,  1.42it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 295/2500 [03:07<26:29,  1.39it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 296/2500 [03:08<25:11,  1.46it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 297/2500 [03:08<22:28,  1.63it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 298/2500 [03:09<22:13,  1.65it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 299/2500 [03:10<29:16,  1.25it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 300/2500 [03:10<26:19,  1.39it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 301/2500 [03:11<27:30,  1.33it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 302/2500 [03:12<26:42,  1.37it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 303/2500 [03:12<24:31,  1.49it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 304/2500 [03:14<30:19,  1.21it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 305/2500 [03:14<26:02,  1.40it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 306/2500 [03:15<23:49,  1.53it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 307/2500 [03:16<29:01,  1.26it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 308/2500 [03:16<28:54,  1.26it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 309/2500 [03:17<24:58,  1.46it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 310/2500 [03:18<24:41,  1.48it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 311/2500 [03:18<24:24,  1.49it/s]

hid torch.Size([2, 8, 512])


 12%|█▏        | 312/2500 [03:19<21:11,  1.72it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 313/2500 [03:19<20:10,  1.81it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 314/2500 [03:20<19:21,  1.88it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 315/2500 [03:20<19:38,  1.85it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 316/2500 [03:21<21:19,  1.71it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 317/2500 [03:22<25:44,  1.41it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 318/2500 [03:22<23:49,  1.53it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 319/2500 [03:23<25:57,  1.40it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 320/2500 [03:24<26:11,  1.39it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 321/2500 [03:24<23:49,  1.52it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 322/2500 [03:25<23:34,  1.54it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 323/2500 [03:26<22:12,  1.63it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 324/2500 [03:26<20:40,  1.75it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 325/2500 [03:27<21:31,  1.68it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 326/2500 [03:27<21:43,  1.67it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 327/2500 [03:28<25:35,  1.42it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 328/2500 [03:29<22:52,  1.58it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 329/2500 [03:29<21:21,  1.69it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 330/2500 [03:30<21:25,  1.69it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 331/2500 [03:30<21:19,  1.69it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 332/2500 [03:31<20:46,  1.74it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 333/2500 [03:32<20:57,  1.72it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 334/2500 [03:32<19:44,  1.83it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 335/2500 [03:33<21:04,  1.71it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 336/2500 [03:33<23:24,  1.54it/s]

hid torch.Size([2, 8, 512])


 13%|█▎        | 337/2500 [03:34<22:18,  1.62it/s]

hid torch.Size([2, 8, 512])


 14%|█▎        | 338/2500 [03:35<22:37,  1.59it/s]

hid torch.Size([2, 8, 512])


 14%|█▎        | 339/2500 [03:35<21:14,  1.70it/s]

hid torch.Size([2, 8, 512])


 14%|█▎        | 340/2500 [03:36<21:28,  1.68it/s]

hid torch.Size([2, 8, 512])


 14%|█▎        | 341/2500 [03:36<21:56,  1.64it/s]

hid torch.Size([2, 8, 512])


 14%|█▎        | 342/2500 [03:37<22:30,  1.60it/s]

hid torch.Size([2, 8, 512])


 14%|█▎        | 343/2500 [03:38<25:00,  1.44it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 344/2500 [03:38<22:24,  1.60it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 345/2500 [03:39<20:59,  1.71it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 346/2500 [03:40<21:43,  1.65it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 347/2500 [03:41<26:33,  1.35it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 348/2500 [03:41<25:46,  1.39it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 349/2500 [03:42<23:41,  1.51it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 350/2500 [03:43<26:06,  1.37it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 351/2500 [03:43<24:37,  1.45it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 352/2500 [03:44<26:43,  1.34it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 353/2500 [03:45<27:59,  1.28it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 354/2500 [03:45<24:05,  1.48it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 355/2500 [03:46<25:21,  1.41it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 356/2500 [03:47<28:24,  1.26it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 357/2500 [03:48<28:01,  1.27it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 358/2500 [03:49<28:29,  1.25it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 359/2500 [03:49<26:01,  1.37it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 360/2500 [03:50<23:25,  1.52it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 361/2500 [03:50<21:30,  1.66it/s]

hid torch.Size([2, 8, 512])


 14%|█▍        | 362/2500 [03:51<21:08,  1.69it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 363/2500 [03:52<23:07,  1.54it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 364/2500 [03:52<21:55,  1.62it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 365/2500 [03:53<27:49,  1.28it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 366/2500 [03:54<25:31,  1.39it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 367/2500 [03:55<25:04,  1.42it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 368/2500 [03:55<24:17,  1.46it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 369/2500 [03:56<23:27,  1.51it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 370/2500 [03:57<25:09,  1.41it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 371/2500 [03:58<25:45,  1.38it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 372/2500 [03:58<26:31,  1.34it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 373/2500 [04:00<32:36,  1.09it/s]

hid torch.Size([2, 8, 512])


 15%|█▍        | 374/2500 [04:00<27:05,  1.31it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 375/2500 [04:01<24:31,  1.44it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 376/2500 [04:01<21:00,  1.68it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 377/2500 [04:02<21:41,  1.63it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 378/2500 [04:02<20:14,  1.75it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 379/2500 [04:03<20:00,  1.77it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 380/2500 [04:03<19:14,  1.84it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 381/2500 [04:04<23:37,  1.49it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 382/2500 [04:05<27:04,  1.30it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 383/2500 [04:06<27:31,  1.28it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 384/2500 [04:06<25:11,  1.40it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 385/2500 [04:07<25:28,  1.38it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 386/2500 [04:08<23:50,  1.48it/s]

hid torch.Size([2, 8, 512])


 15%|█▌        | 387/2500 [04:08<21:49,  1.61it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 388/2500 [04:09<24:15,  1.45it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 389/2500 [04:09<20:50,  1.69it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 390/2500 [04:10<20:33,  1.71it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 391/2500 [04:11<20:10,  1.74it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 392/2500 [04:11<20:14,  1.74it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 393/2500 [04:12<19:53,  1.77it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 394/2500 [04:12<21:54,  1.60it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 395/2500 [04:14<26:34,  1.32it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 396/2500 [04:14<23:51,  1.47it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 397/2500 [04:15<24:04,  1.46it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 398/2500 [04:15<24:05,  1.45it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 399/2500 [04:16<22:58,  1.52it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 400/2500 [04:17<23:03,  1.52it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 401/2500 [04:17<22:33,  1.55it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 402/2500 [04:18<26:14,  1.33it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 403/2500 [04:19<26:02,  1.34it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 404/2500 [04:20<24:23,  1.43it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 405/2500 [04:21<28:06,  1.24it/s]

hid torch.Size([2, 8, 512])


 16%|█▌        | 406/2500 [04:21<25:37,  1.36it/s]

hid torch.Size([2, 8, 512])


 16%|█▋        | 407/2500 [04:22<21:28,  1.62it/s]

hid torch.Size([2, 8, 512])


 16%|█▋        | 408/2500 [04:22<23:41,  1.47it/s]

hid torch.Size([2, 8, 512])


 16%|█▋        | 409/2500 [04:23<24:58,  1.39it/s]

hid torch.Size([2, 8, 512])


 16%|█▋        | 410/2500 [04:24<23:40,  1.47it/s]

hid torch.Size([2, 8, 512])


 16%|█▋        | 411/2500 [04:24<22:08,  1.57it/s]

hid torch.Size([2, 8, 512])


 16%|█▋        | 412/2500 [04:25<20:25,  1.70it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 413/2500 [04:25<20:30,  1.70it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 414/2500 [04:26<20:47,  1.67it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 415/2500 [04:27<24:58,  1.39it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 416/2500 [04:27<20:45,  1.67it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 417/2500 [04:28<23:43,  1.46it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 418/2500 [04:29<23:14,  1.49it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 419/2500 [04:30<24:23,  1.42it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 420/2500 [04:31<30:24,  1.14it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 421/2500 [04:32<31:39,  1.09it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 422/2500 [04:32<26:50,  1.29it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 423/2500 [04:33<24:41,  1.40it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 424/2500 [04:33<22:11,  1.56it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 425/2500 [04:34<19:01,  1.82it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 426/2500 [04:34<17:27,  1.98it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 427/2500 [04:35<19:49,  1.74it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 428/2500 [04:36<21:56,  1.57it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 429/2500 [04:36<20:15,  1.70it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 430/2500 [04:37<19:16,  1.79it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 431/2500 [04:38<24:03,  1.43it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 432/2500 [04:38<21:59,  1.57it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 433/2500 [04:39<23:26,  1.47it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 434/2500 [04:39<21:18,  1.62it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 435/2500 [04:40<24:31,  1.40it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 436/2500 [04:41<20:54,  1.65it/s]

hid torch.Size([2, 8, 512])


 17%|█▋        | 437/2500 [04:41<21:02,  1.63it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 438/2500 [04:42<20:13,  1.70it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 439/2500 [04:42<19:03,  1.80it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 440/2500 [04:43<17:24,  1.97it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 441/2500 [04:43<18:04,  1.90it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 442/2500 [04:44<17:16,  1.99it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 443/2500 [04:44<19:22,  1.77it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 444/2500 [04:45<23:25,  1.46it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 445/2500 [04:46<20:14,  1.69it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 446/2500 [04:46<20:07,  1.70it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 447/2500 [04:47<18:05,  1.89it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 448/2500 [04:47<17:10,  1.99it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 449/2500 [04:48<18:09,  1.88it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 450/2500 [04:48<17:41,  1.93it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 451/2500 [04:50<25:26,  1.34it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 452/2500 [04:50<25:13,  1.35it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 453/2500 [04:51<23:48,  1.43it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 454/2500 [04:51<22:33,  1.51it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 455/2500 [04:52<24:43,  1.38it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 456/2500 [04:53<23:49,  1.43it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 457/2500 [04:54<22:21,  1.52it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 458/2500 [04:54<20:46,  1.64it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 459/2500 [04:55<23:29,  1.45it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 460/2500 [04:55<21:33,  1.58it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 461/2500 [04:56<19:49,  1.71it/s]

hid torch.Size([2, 8, 512])


 18%|█▊        | 462/2500 [04:56<19:49,  1.71it/s]

hid torch.Size([2, 8, 512])


 19%|█▊        | 463/2500 [04:57<21:02,  1.61it/s]

hid torch.Size([2, 8, 512])


 19%|█▊        | 464/2500 [04:58<19:48,  1.71it/s]

hid torch.Size([2, 8, 512])


 19%|█▊        | 465/2500 [04:58<20:36,  1.65it/s]

hid torch.Size([2, 8, 512])


 19%|█▊        | 466/2500 [04:59<21:46,  1.56it/s]

hid torch.Size([2, 8, 512])


 19%|█▊        | 467/2500 [05:00<20:13,  1.68it/s]

hid torch.Size([2, 8, 512])


 19%|█▊        | 468/2500 [05:00<21:58,  1.54it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 469/2500 [05:01<20:50,  1.62it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 470/2500 [05:01<19:17,  1.75it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 471/2500 [05:02<21:24,  1.58it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 472/2500 [05:03<23:24,  1.44it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 473/2500 [05:03<20:48,  1.62it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 474/2500 [05:04<17:58,  1.88it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 475/2500 [05:05<27:34,  1.22it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 476/2500 [05:06<27:49,  1.21it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 477/2500 [05:07<28:30,  1.18it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 478/2500 [05:08<27:23,  1.23it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 479/2500 [05:08<23:30,  1.43it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 480/2500 [05:09<25:04,  1.34it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 481/2500 [05:09<21:04,  1.60it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 482/2500 [05:10<20:48,  1.62it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 483/2500 [05:10<20:52,  1.61it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 484/2500 [05:11<20:42,  1.62it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 485/2500 [05:12<18:40,  1.80it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 486/2500 [05:12<18:13,  1.84it/s]

hid torch.Size([2, 8, 512])


 19%|█▉        | 487/2500 [05:13<20:16,  1.65it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 488/2500 [05:13<20:26,  1.64it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 489/2500 [05:14<24:35,  1.36it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 490/2500 [05:15<23:33,  1.42it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 491/2500 [05:15<20:17,  1.65it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 492/2500 [05:16<19:55,  1.68it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 493/2500 [05:16<17:57,  1.86it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 494/2500 [05:17<19:41,  1.70it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 495/2500 [05:18<18:29,  1.81it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 496/2500 [05:18<18:32,  1.80it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 497/2500 [05:19<17:39,  1.89it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 498/2500 [05:19<19:22,  1.72it/s]

hid torch.Size([2, 8, 512])


 20%|█▉        | 499/2500 [05:20<20:26,  1.63it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 500/2500 [05:21<19:47,  1.68it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 501/2500 [05:21<22:50,  1.46it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 502/2500 [05:22<24:50,  1.34it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 503/2500 [05:23<22:42,  1.47it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 504/2500 [05:23<21:08,  1.57it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 505/2500 [05:24<23:23,  1.42it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 506/2500 [05:25<22:17,  1.49it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 507/2500 [05:25<21:47,  1.52it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 508/2500 [05:26<22:52,  1.45it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 509/2500 [05:27<23:22,  1.42it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 510/2500 [05:28<22:28,  1.48it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 511/2500 [05:28<24:16,  1.37it/s]

hid torch.Size([2, 8, 512])


 20%|██        | 512/2500 [05:29<25:45,  1.29it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 513/2500 [05:30<25:18,  1.31it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 514/2500 [05:31<23:25,  1.41it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 515/2500 [05:31<22:28,  1.47it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 516/2500 [05:32<18:54,  1.75it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 517/2500 [05:32<19:40,  1.68it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 518/2500 [05:33<19:49,  1.67it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 519/2500 [05:33<18:26,  1.79it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 520/2500 [05:34<18:19,  1.80it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 521/2500 [05:34<18:14,  1.81it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 522/2500 [05:35<18:19,  1.80it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 523/2500 [05:36<19:22,  1.70it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 524/2500 [05:36<18:16,  1.80it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 525/2500 [05:37<21:38,  1.52it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 526/2500 [05:38<21:03,  1.56it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 527/2500 [05:38<20:19,  1.62it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 528/2500 [05:39<20:23,  1.61it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 529/2500 [05:40<24:17,  1.35it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 530/2500 [05:40<22:01,  1.49it/s]

hid torch.Size([2, 8, 512])


 21%|██        | 531/2500 [05:41<20:35,  1.59it/s]

hid torch.Size([2, 8, 512])


 21%|██▏       | 532/2500 [05:41<19:54,  1.65it/s]

hid torch.Size([2, 8, 512])


 21%|██▏       | 533/2500 [05:42<20:31,  1.60it/s]

hid torch.Size([2, 8, 512])


 21%|██▏       | 534/2500 [05:43<23:06,  1.42it/s]

hid torch.Size([2, 8, 512])


 21%|██▏       | 535/2500 [05:43<21:11,  1.55it/s]

hid torch.Size([2, 8, 512])


 21%|██▏       | 536/2500 [05:44<20:50,  1.57it/s]

hid torch.Size([2, 8, 512])


 21%|██▏       | 537/2500 [05:45<24:24,  1.34it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 538/2500 [05:46<23:13,  1.41it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 539/2500 [05:47<24:27,  1.34it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 540/2500 [05:47<23:48,  1.37it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 541/2500 [05:48<21:28,  1.52it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 542/2500 [05:48<21:10,  1.54it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 543/2500 [05:49<20:55,  1.56it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 544/2500 [05:49<18:55,  1.72it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 545/2500 [05:50<18:23,  1.77it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 546/2500 [05:50<16:26,  1.98it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 547/2500 [05:51<21:32,  1.51it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 548/2500 [05:52<21:55,  1.48it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 549/2500 [05:53<20:15,  1.61it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 550/2500 [05:53<19:27,  1.67it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 551/2500 [05:54<20:58,  1.55it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 552/2500 [05:54<20:40,  1.57it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 553/2500 [05:55<22:50,  1.42it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 554/2500 [05:56<22:53,  1.42it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 555/2500 [05:57<25:37,  1.26it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 556/2500 [05:58<24:53,  1.30it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 557/2500 [05:58<22:34,  1.43it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 558/2500 [05:59<19:27,  1.66it/s]

hid torch.Size([2, 8, 512])


 22%|██▏       | 559/2500 [06:00<31:05,  1.04it/s]

In [None]:
#NOTE: load model from file
#model.load_state_dict(torch.load('model.pt'))

test_loss = evaluate(model, test_iter, criterion)

print(f'\tTest Loss: {test_loss:.3f}\tTest PPL: {np.exp(test_loss):7.3f}')

In [None]:
#clean mem
del model
del train_iter
del valid_iter
del test_iter
torch.cuda.empty_cache()

BLEU added

In [None]:
from torchtext.data.metrics import bleu_score

def calculate_bleu(model, iterator, src_vocab, trg_vocab, device):
    model.eval()
    trgs = []
    pred_trgs = []

    with torch.no_grad():
        for src, trg in iterator:
            src = src.to(device)
            trg = trg.to(device)

            output = model(src, trg, 0)  # no teacher forcing
            output_dim = output.shape[-1]

            output = output.argmax(-1)  # [trg_len, batch_size]

            for i in range(trg.shape[1]):
                trg_sent = [trg_vocab.lookup_tokens(trg[1:, i])]
                pred_sent = trg_vocab.lookup_tokens(output[1:, i])

                # remove padding and EOS tokens from predictions
                pred_sent = [token for token in pred_sent if token not in ['<pad>', '<eos>']]

                trgs.append(trg_sent)
                pred_trgs.append(pred_sent)

    return bleu_score(pred_trgs, trgs)

In [None]:
N_EPOCHS = 5
CLIP = 1

best_valid_loss = float('inf')
best_valid_bleu = 0.0

for epoch in range(N_EPOCHS):

    train_loss = train(model, train_iter, optimizer, criterion, CLIP, device)
    valid_loss = evaluate(model, valid_iter, criterion, device)
    valid_bleu = calculate_bleu(model, valid_iter, SRC.vocab, TRG.vocab, device)

    # Check both BLEU and validation loss for model saving (you can pick either)
    if valid_bleu > best_valid_bleu:
        best_valid_bleu = valid_bleu
        torch.save(model.state_dict(), 'model_best_bleu.pt')

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model_best_loss.pt')

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}\tTrain PPL: {np.exp(train_loss):7.3f}')
    print(f'\tValidation Loss: {valid_loss:.3f}\tValidation PPL: {np.exp(valid_loss):7.3f}')
    print(f'\tValidation BLEU Score: {valid_bleu*100:.2f}')

In [None]:
#NOTE: load model from file
#model.load_state_dict(torch.load('model.pt'))

test_loss = evaluate(model, test_iter, criterion)

print(f'\tTest Loss: {test_loss:.3f}\tTest PPL: {np.exp(test_loss):7.3f}')

In [None]:
#clean mem
del model
del train_iter
del valid_iter
del test_iter
torch.cuda.empty_cache()