<a href="https://colab.research.google.com/github/Saputoa21/Machine-Translation/blob/main/seq2seq_NMT_MTMA2025s_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Seq2seq NMT with RNN


[Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473)

**NOTE:**

-  use clean bpe data
-  use a piece of training data during coding or low in credits

You have to implement:

- Encoder
- Attention (Bahdanau)
- training loop
- extra: BLEU model selection

Goal:

- Loss in training, validation and test




In [1]:
!pip install torch==2.3.0
!pip install torchtext==0.18

Collecting torch==2.3.0
  Downloading torch-2.3.0-cp311-cp311-manylinux1_x86_64.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch==2.3.0)
  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.3.0)
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.3.0)
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylin

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import random
import time

In [3]:
#if you dont have bpe data use sacremoese tokenizer
#!pip install sacremoses

In [4]:
#which libraries are we using!!?
!pip freeze > requirements.txt

In [5]:
!cat requirements.txt

absl-py==1.4.0
accelerate==1.6.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.15
aiosignal==1.3.2
alabaster==1.0.0
albucore==0.0.24
albumentations==2.0.6
ale-py==0.11.0
altair==5.5.0
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.9.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.7.2
arviz==0.21.0
astropy==7.0.2
astropy-iers-data==0.2025.5.12.0.38.29
astunparse==1.6.3
atpublic==5.1
attrs==25.3.0
audioread==3.0.1
autograd==1.8.0
babel==2.17.0
backcall==0.2.0
backports.tarfile==1.2.0
beautifulsoup4==4.13.4
betterproto==2.0.0b6
bigframes==2.4.0
bigquery-magics==0.9.0
bleach==6.2.0
blinker==1.9.0
blis==1.3.0
blobfile==3.0.0
blosc2==3.3.2
bokeh==3.7.3
Bottleneck==1.4.2
bqplot==0.12.44
branca==0.8.1
build==1.2.2.post1
CacheControl==0.14.3
cachetools==5.5.2
catalogue==2.0.10
certifi==2025.4.26
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.2
chex==0.1.89
clarabel==0.10.0
click==8.2.0
cloudpathlib==0.21.0
cloudpickle==3.1.1
cmake==3.31.6
cmdstanpy==1.2.5
colorcet

In [6]:
SEED = 42 #to reproduce the traning

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [7]:
import torchtext
dir(torchtext)

['_CACHE_DIR',
 '_TEXT_BUCKET',
 '_TORCHTEXT_DEPRECATION_MSG',
 '_WARN',
 '__all__',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '__version__',
 '_extension',
 '_get_torch_home',
 '_internal',
 '_torchtext',
 'git_version',
 'os',
 'version']

In [8]:
import torchtext
import torch
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab
#from torchtext.utils import download_from_url, extract_archive
import io

#0=en 1=de
# soure and target data
#NOTE: USE clean bpe data!

#for this task I used the clean bpe files (train 500k, test 500, dev 500)
#sorting the path to the files en-de
train_filepaths = ['train.en-de.clean.bpe.en', 'train.en-de.clean.bpe.de']
val_filepaths = ['dev.en-de.clean.bpe.en', 'dev.en-de.clean.bpe.de']
test_filepaths = ['test.en-de.clean.bpe.en', 'test.en-de.clean.bpe.de']

#TODO use clean bpe tokenized data!!!

#de_tokenizer = get_tokenizer('moses', language='de') for cases without bpe data
#en_tokenizer = get_tokenizer('moses', language='en')

de_tokenizer = None #None as there are tokenized data
en_tokenizer = None

def build_vocab(filepath, tokenizer=None):
  counter = Counter()
  with io.open(filepath, encoding="utf8") as f: #go to a file
    for string_ in f: #for each line in a file
      #counter.update(tokenizer(string_))
      counter.update(string_.split()) #split each line by space, i.e. creating tokens. Counter counts the tokens to create a database of the vocab
  return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>']) #vocav from torchtext containing the tokens and their counts, as well as specioal ones

#Vocab
en_vocab = build_vocab(train_filepaths[0], en_tokenizer) #create a vocab
de_vocab = build_vocab(train_filepaths[1], de_tokenizer)

print(dir(en_vocab))
en_vocab.set_default_index(en_vocab['<unk>']) #set <unk> as a default token
de_vocab.set_default_index(de_vocab['<unk>'])

def data_process(filepaths):
  raw_en_iter = iter(io.open(filepaths[0], encoding="utf8"))
  raw_de_iter = iter(io.open(filepaths[1], encoding="utf8"))
  data = []
  for (raw_en, raw_de) in zip(raw_en_iter, raw_de_iter): #for each sentence in the source and target
    en_tensor_ = torch.tensor([en_vocab[token] for token in raw_en.split()], #en_tokenizer(raw_en) #transforms tokens into tensors
                            dtype=torch.long)
    de_tensor_ = torch.tensor([de_vocab[token] for token in raw_de.split()], #de_tokenizer(raw_de)
                            dtype=torch.long)
    data.append((en_tensor_, de_tensor_)) #append to a list
  return data

#pre-process
train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)



['T_destination', '__annotations__', '__call__', '__class__', '__contains__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__jit_unused_properties__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__prepare_scriptable__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '

In [14]:
#NOTE: if you are low on credits or testing only use a piece of the data e.g. 20K segments
train_data = train_data[:1000]

In [15]:
len(train_data)

1000

In [16]:
len(val_data)

466

In [17]:
len(test_data)

467

In [18]:
len(en_vocab)

36926

In [19]:
len(de_vocab)

42161

Define the device.

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


Create the iterators.

In [21]:
BATCH_SIZE = 10
PAD_IDX = de_vocab['<pad>'] #find the indecies of the special tokens
BOS_IDX = de_vocab['<bos>']
EOS_IDX = de_vocab['<eos>']

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def generate_batch(data_batch):
    en_batch, de_batch = [], []
    for (en_item, de_item) in data_batch: #each sentences in a batch
        de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0)) #add a source sentence from the beginning (BOS) till the end (EOS)
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX, batch_first=True) #add padding to sentences if they are of different langth
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX, batch_first=True)
    return en_batch, de_batch

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch) #shuffle samples after each epoch
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=False, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=False, collate_fn=generate_batch)

## Building the Seq2Seq Model

### Encoder




In [22]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, emb_dim) #size of vocab (50k from clean bpe data), embedding dimensionality (256 from below)

        #[YOUR CODE] GRU(embeding size, encoder hidden size) NOTE: bidirectional batch_first
        # parameters for GRU in Pytorch
        # torch.nn.GRU(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, device=None, dtype=None)
        self.rnn = nn.GRU(input_size=emb_dim, hidden_size=enc_hid_dim, batch_first=True, bidirectional=True)

        #[YOUR CODE] linear(encoder hidden size * 2, decoder hidden size)
        # times 2 as we have a bidirectional model
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, src):

        #[B=8, S= longest sentence]
        embedded = self.embedding(src) #embed the source

        #[B, S, E] print(embedded.size())
        # the matrix grow by the additional dimension of the embedding size (256)

        embedded = self.dropout(embedded) #regularize the embeddings

        #[B, S, E]
        #[YOUR CODE] rrn(embeddings)
        outputs, hidden = self.rnn(embedded)

        #[B, S, H*2]

        #h[n layers * num directions, batch size, hid dim]

        #[forward_1, backward_1, forward_2, backward_2, ...]

        #[-2, :, : ] last state forward RNN
        #[-1, :, : ] last state backward RNN
        print('hid', hidden.size())

        #[YOUR CODE] last state forward RNN
        h1 = hidden[-2, :, : ]

        #[YOUR CODE] last state backward RNN
        h2 = hidden[-1, :, : ]

        #https://pytorch.org/docs/main/generated/torch.cat.html

        #[YOUR CODE] concatenate h1 amd h2 on seq dim
        h_cat = torch.cat([h1, h2], dim=1)

        #[YOUR CODE] tanh(linear(hidden_cat))
        hidden = torch.tanh(self.fc(h_cat))

        #[B, S, H*2]

        return outputs, hidden # outputs = vector for each sentence, hidden = concatenated vector for forward and backward

# Attention

## Luong Attention




Formula: score(target hidden state, each source hidden stare) = target hidden state transposed * W * eachsource hidden state)

In [23]:
class LuongAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()

        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)

    def forward(self, hidden, encoder_outputs): #keys, query  hidden = target

        #[batch size, dec hid dim]
        #[src len, batch size, enc hid dim * 2]

        batch_size = encoder_outputs.shape[0] #to see the size
        src_len = encoder_outputs.shape[1]

        #x times decoder hidden state for the size of the src_len
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) #repatn the hidden state over source

        #[batch size, src len, dec hid dim]

        scores = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) #concateneted in the embedding dimension
        # attn = linear layer
        # tanh = non-linear

        #[batch size, src len, dec hid dim]

        attention = self.v(scores).squeeze(2) # v is W in the formula

        #[batch size, src len]

        return F.softmax(attention, dim=1)

Formula: score(target hidden state, each source hidden stare) = Va transposed * tanh(W1 * target hidden state + W2 * each source hidden state)

In [24]:
class BahdanauAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(dec_hid_dim, dec_hid_dim, bias=False)
        self.Ua = nn.Linear(enc_hid_dim * 2, dec_hid_dim, bias=False)
        self.Va = nn.Linear(dec_hid_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs): #keys, query
        #[YOUR CODE]
        batch_size = encoder_outputs.shape(0)
        src_len = encoder_outputs.shape(1)

        #x times decoder hidden state for src_len
        #[YOUR CODE]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # pytorch - Returns a new tensor with a dimension of size one inserted at the specified position.

        """
        >>> x = torch.tensor([1, 2, 3, 4])
        >>> torch.unsqueeze(x, 0)
        tensor([[ 1,  2,  3,  4]])
        >>> torch.unsqueeze(x, 1)
        tensor([[ 1],
                [ 2],
                [ 3],
                [ 4]])
        """

        #[YOUR CODE] Va(tanh(Wa(hidden) + Ua(encoder outputs)))
        scores = self.Va(torch.tanh(self.Wa(hidden) + self.Ua(encoder_outputs)))

        scores = scores.squeeze(2) #Returns a tensor with all specified dimensions of input of size 1 removed.

        """
        For example, if input is of shape: (A×1×B×C×1×D) then the input.squeeze() will be of shape: (A×B×C×D).
        When dim is given, a squeeze operation is done only in the given dimension(s). If input is of shape:
        (A×1×B), squeeze(input, 0) leaves the tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape (A×B).
        """

        #[YOUR CODE] softmax(scores, dim seq)
        weights =  torch.softmax(scores, dim=1)

        return weights

### Decoder



In [25]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim

        self.attention = attention

        self.embedding = nn.Embedding(output_dim, emb_dim)

        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim, batch_first=True)

        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)

        self.dropout = nn.Dropout(dropout)

        #attention
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)

    def forward(self, input, hidden, encoder_outputs):

        #[batch size]
        #[batch size, dec hid dim]
        #[src len, batch size, enc hid dim * 2]

        input = input.unsqueeze(0)

        #[1, batch size]

        embedded = self.dropout(self.embedding(input))

        #[1, batch size, emb dim]

        a = self.attention(hidden, encoder_outputs)

        #[batch size, src len]

        a = a.unsqueeze(1)

        #[batch size, 1, src len]

        #[batch size, src len, enc hid dim * 2]

        weighted = torch.bmm(a, encoder_outputs)

        #[batch size, 1, enc hid dim * 2]

        weighted = weighted.permute(1, 0, 2)

        #[1, batch size, enc hid dim * 2]

        rnn_input = torch.cat((embedded, weighted), dim = 2)

        #[1, batch size, (enc hid dim * 2) + emb dim]
        #[B, 1, (enc hid dim * 2) + emb dim]

        rnn_input = rnn_input.permute(1, 0, 2)
        hidden = hidden.unsqueeze(0)

        output, hidden = self.rnn(rnn_input, hidden)

        #[seq len, batch size, dec hid dim * n directions]
        #[n layers * n directions, batch size, dec hid dim]

        #[1, batch size, dec hid dim]
        #[1, batch size, dec hid dim]

        embedded = embedded.squeeze(0)
        output = output.squeeze(1)
        weighted = weighted.squeeze(0)

        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))

        #[batch size, output dim]

        return prediction, hidden.squeeze(0)

### Seq2Seq




In [26]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device


    def forward(self, src, trg, teacher_forcing_ratio = 0.5):

        #[src len, batch size]
        #[trg len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        # 0.75 teacher forcing 75% of the time

        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim


        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        #encoder_outputs is all hidden states of the input sequence
        #hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src)

        #first input to the decoder is the <sos> tokens
        input = trg[:,0]
        # unroll RNN
        for t in range(1, trg_len):

            #insert input token embedding, previous hidden state and all encoder hidden states

            output, hidden = self.decoder(input, hidden, encoder_outputs)

            #predictions
            outputs[:, t] = output

            #teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio

            #greedy search
            top1 = output.argmax(1)

            #if teacher forcing, use gold token as next input
            #if not, use predicted token
            input = trg[:, t] if teacher_force else top1

        return outputs

## Training the Seq2Seq Model

Firstly, I have been training first with tranining data size = 1000, batch size =10, num of epoch = 3 and Luong attention


In [40]:
INPUT_DIM = len(en_vocab)
OUTPUT_DIM = len(de_vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.2

attn = LuongAttention(ENC_HID_DIM, DEC_HID_DIM)
#[YOUR CODE] Bahdanau attn
attn_bahdanau = BahdanauAttention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, device).to(device)

In [30]:
print(len(en_vocab)) #BPE size 16k approx
print(len(de_vocab))

36926
42161


In [31]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(36926, 256)
    (rnn): GRU(256, 512, batch_first=True, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): LuongAttention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(42161, 256)
    (rnn): GRU(1280, 512, batch_first=True)
    (fc_out): Linear(in_features=1792, out_features=42161, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
    (attn): Linear(in_features=1536, out_features=512, bias=True)
    (v): Linear(in_features=512, out_features=1, bias=False)
  )
)

In [32]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 103,061,681 trainable parameters


We create an optimizer.

In [33]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#[YOUR CODE] Adam lr 1e-3

We initialize the loss function.

In [35]:
TRG_PAD_IDX = de_vocab['<pad>'] #TODO
#https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

#[YOUR CODE] loss add ignore_index idx
#parameters
#torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)

In [36]:
print(TRG_PAD_IDX)

1


In [37]:
def train(model, iterator, optimizer, criterion, clip):

    #[YOUR CODE] set model train model.train()
    model.train()

    epoch_loss = 0

    for (src, trg) in tqdm(iterator):

        #[YOUR CODE] to gpu
        src = src.to(device)
        trg = trg.to(device)

        #[YOUR CODE] optimizer zero grad
        optimizer.zero_grad()

        #[YOUR CODE] model()
        output = model(src, trg)

        #[trg len, batch size]
        #[trg len, batch size, output dim]

        output = output.permute(1, 0, 2)

        output_dim = output.shape[-1]
        trg = trg.permute(1, 0)

        output = output[1:].reshape(-1, output_dim)
        trg = trg[1:].reshape(-1)

        #[(trg len - 1) * batch size]
        #[(trg len - 1) * batch size, output dim]

        #[YOUR CODE] criterion()
        loss = criterion(output, trg)

        #[YOUR CODE] loss backward
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        #[YOUR CODE] optimizer step
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [38]:
def evaluate(model, iterator, criterion):

    #[YOUR CODE] set model test/eval  model.eval()
    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        for (src, trg) in iterator:
            #[YOUR CODE] to device
            src = src.to(device)
            trg = trg.to(device)

            #[YOUR CODE] model() turn off teacher forcing 0
            output = model(src, trg, 0)

            #trg = [trg len, batch size]
            #output = [trg len, batch size, output dim]

            output = output.permute(1, 0, 2)
            output_dim = output.shape[-1]
            trg = trg.permute(1, 0)
            output = output[1:].reshape(-1, output_dim)
            trg = trg[1:].reshape(-1)

            #trg = [(trg len - 1) * batch size]
            #output = [(trg len - 1) * batch size, output dim]

            #[YOUR CODE] criterion()
            loss = criterion(output, trg)

            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [42]:
N_EPOCHS = 3
CLIP = 1

epoch_info_list = []

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)

    if valid_loss < best_valid_loss:
      #[YOUR CODE] extra add BLEU model selection
      best_valid_loss = valid_loss
      torch.save(model.state_dict(), 'model.pt')

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}\tTrain PPL: {np.exp(train_loss):7.3f}')
    print(f'\t Validation Loss: {valid_loss:.3f}\tValidation PPL: {np.exp(valid_loss):7.3f}')

    epoch_info_list.append(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}, Train PPL: {np.exp(train_loss):7.3f}, Validation Loss:{valid_loss:.3f}, Validation PPL: {np.exp(valid_loss):7.3f}')


  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:00<01:15,  1.32it/s]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [00:01<00:53,  1.83it/s]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [00:02<01:16,  1.27it/s]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [00:03<01:20,  1.19it/s]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [00:03<01:13,  1.29it/s]

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [00:04<01:04,  1.45it/s]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [00:04<00:59,  1.58it/s]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [00:05<00:55,  1.65it/s]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [00:05<00:52,  1.73it/s]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [00:06<00:49,  1.82it/s]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [00:06<00:46,  1.91it/s]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [00:08<01:02,  1.40it/s]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [00:08<01:03,  1.37it/s]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [00:09<01:02,  1.37it/s]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [00:10<00:55,  1.52it/s]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [00:10<00:52,  1.60it/s]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [00:11<00:51,  1.61it/s]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [00:11<00:50,  1.63it/s]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [00:12<00:47,  1.69it/s]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [00:13<00:57,  1.39it/s]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [00:13<00:49,  1.60it/s]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [00:14<00:59,  1.31it/s]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [00:15<00:59,  1.29it/s]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [00:16<01:00,  1.25it/s]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [00:17<00:59,  1.27it/s]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [00:17<00:53,  1.38it/s]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [00:18<00:55,  1.32it/s]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [00:19<00:47,  1.51it/s]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [00:19<00:48,  1.48it/s]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [00:20<00:45,  1.54it/s]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [00:21<00:45,  1.52it/s]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [00:22<01:00,  1.13it/s]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [00:23<00:59,  1.12it/s]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [00:24<00:57,  1.15it/s]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [00:24<00:50,  1.30it/s]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [00:26<00:59,  1.07it/s]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [00:26<00:55,  1.14it/s]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [00:27<00:51,  1.21it/s]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [00:28<00:45,  1.34it/s]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [00:28<00:38,  1.55it/s]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [00:28<00:34,  1.69it/s]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [00:29<00:33,  1.75it/s]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [00:30<00:33,  1.71it/s]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [00:30<00:33,  1.70it/s]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [00:31<00:30,  1.79it/s]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [00:31<00:28,  1.86it/s]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [00:32<00:29,  1.82it/s]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [00:32<00:31,  1.64it/s]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [00:33<00:29,  1.76it/s]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [00:33<00:27,  1.81it/s]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [00:34<00:27,  1.78it/s]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [00:35<00:29,  1.64it/s]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [00:36<00:33,  1.40it/s]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [00:36<00:28,  1.59it/s]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [00:37<00:29,  1.53it/s]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [00:38<00:28,  1.54it/s]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [00:38<00:25,  1.72it/s]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [00:39<00:30,  1.37it/s]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [00:39<00:26,  1.55it/s]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [00:40<00:27,  1.47it/s]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [00:41<00:29,  1.32it/s]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [00:43<00:35,  1.06it/s]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [00:43<00:31,  1.17it/s]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [00:44<00:27,  1.29it/s]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [00:44<00:24,  1.41it/s]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [00:45<00:23,  1.45it/s]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [00:46<00:21,  1.51it/s]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [00:47<00:25,  1.27it/s]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [00:48<00:25,  1.21it/s]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [00:48<00:23,  1.26it/s]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [00:49<00:21,  1.33it/s]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [00:50<00:25,  1.11it/s]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [00:51<00:24,  1.09it/s]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [00:52<00:19,  1.32it/s]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [00:52<00:16,  1.48it/s]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [00:53<00:15,  1.53it/s]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [00:53<00:13,  1.69it/s]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [00:54<00:12,  1.74it/s]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [00:54<00:14,  1.50it/s]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [00:56<00:16,  1.24it/s]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [00:56<00:12,  1.50it/s]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [00:57<00:11,  1.57it/s]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [00:57<00:10,  1.59it/s]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [00:58<00:10,  1.48it/s]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [00:59<00:10,  1.39it/s]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [00:59<00:09,  1.54it/s]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [01:00<00:08,  1.49it/s]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [01:00<00:07,  1.59it/s]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [01:01<00:06,  1.70it/s]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [01:02<00:07,  1.39it/s]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [01:03<00:06,  1.46it/s]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [01:03<00:05,  1.37it/s]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [01:04<00:04,  1.46it/s]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [01:05<00:03,  1.53it/s]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [01:05<00:03,  1.61it/s]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [01:06<00:02,  1.66it/s]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [01:07<00:02,  1.48it/s]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [01:07<00:01,  1.65it/s]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [01:08<00:00,  1.28it/s]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [01:09<00:00,  1.45it/s]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:00<01:08,  1.44it/s]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [00:01<01:18,  1.25it/s]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [00:02<01:07,  1.45it/s]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [00:03<01:16,  1.25it/s]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [00:03<01:10,  1.34it/s]

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [00:04<01:14,  1.26it/s]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [00:05<01:11,  1.30it/s]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [00:06<01:14,  1.24it/s]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [00:06<01:06,  1.38it/s]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [00:07<01:01,  1.46it/s]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [00:07<00:58,  1.51it/s]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [00:08<00:56,  1.55it/s]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [00:09<00:53,  1.63it/s]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [00:09<00:48,  1.76it/s]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [00:10<01:03,  1.34it/s]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [00:11<01:10,  1.19it/s]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [00:12<01:01,  1.35it/s]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [00:12<00:56,  1.44it/s]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [00:13<00:50,  1.62it/s]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [00:13<00:45,  1.77it/s]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [00:14<00:43,  1.81it/s]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [00:14<00:37,  2.06it/s]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [00:15<00:35,  2.14it/s]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [00:16<00:54,  1.39it/s]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [00:16<00:46,  1.60it/s]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [00:17<00:44,  1.68it/s]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [00:17<00:42,  1.72it/s]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [00:18<00:41,  1.73it/s]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [00:18<00:40,  1.77it/s]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [00:19<00:39,  1.79it/s]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [00:20<00:37,  1.84it/s]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [00:21<00:47,  1.43it/s]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [00:22<00:55,  1.20it/s]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [00:22<00:46,  1.43it/s]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [00:23<00:56,  1.14it/s]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [00:24<00:45,  1.40it/s]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [00:25<00:45,  1.37it/s]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [00:25<00:41,  1.50it/s]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [00:26<00:39,  1.54it/s]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [00:26<00:36,  1.65it/s]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [00:27<00:35,  1.67it/s]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [00:27<00:35,  1.62it/s]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [00:28<00:34,  1.65it/s]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [00:28<00:29,  1.90it/s]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [00:29<00:28,  1.90it/s]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [00:30<00:32,  1.68it/s]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [00:30<00:31,  1.68it/s]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [00:31<00:29,  1.75it/s]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [00:31<00:29,  1.75it/s]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [00:32<00:27,  1.84it/s]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [00:32<00:24,  1.99it/s]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [00:33<00:23,  2.00it/s]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [00:33<00:24,  1.91it/s]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [00:34<00:23,  1.93it/s]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [00:35<00:27,  1.65it/s]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [00:36<00:36,  1.21it/s]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [00:37<00:37,  1.15it/s]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [00:37<00:32,  1.30it/s]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [00:38<00:28,  1.44it/s]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [00:39<00:26,  1.49it/s]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [00:39<00:27,  1.42it/s]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [00:40<00:25,  1.51it/s]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [00:40<00:22,  1.67it/s]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [00:41<00:26,  1.37it/s]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [00:42<00:26,  1.30it/s]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [00:43<00:29,  1.17it/s]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [00:44<00:27,  1.21it/s]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [00:45<00:26,  1.22it/s]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [00:46<00:25,  1.20it/s]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [00:46<00:21,  1.37it/s]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [00:47<00:21,  1.35it/s]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [00:47<00:18,  1.51it/s]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [00:48<00:19,  1.37it/s]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [00:49<00:18,  1.37it/s]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [00:50<00:16,  1.48it/s]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [00:51<00:18,  1.27it/s]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [00:51<00:18,  1.24it/s]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [00:52<00:14,  1.48it/s]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [00:52<00:12,  1.63it/s]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [00:53<00:10,  1.91it/s]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [00:53<00:11,  1.69it/s]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [00:54<00:10,  1.78it/s]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [00:54<00:09,  1.81it/s]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [00:55<00:09,  1.65it/s]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [00:56<00:08,  1.67it/s]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [00:56<00:09,  1.56it/s]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [00:57<00:07,  1.64it/s]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [00:57<00:06,  1.84it/s]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [00:58<00:06,  1.69it/s]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [00:59<00:07,  1.32it/s]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [01:00<00:06,  1.33it/s]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [01:01<00:07,  1.05it/s]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [01:02<00:05,  1.30it/s]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [01:02<00:04,  1.38it/s]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [01:03<00:03,  1.59it/s]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [01:03<00:02,  1.52it/s]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [01:04<00:02,  1.50it/s]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [01:05<00:01,  1.30it/s]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [01:06<00:00,  1.40it/s]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [01:06<00:00,  1.50it/s]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

  0%|          | 0/100 [00:00<?, ?it/s]

hid torch.Size([2, 10, 512])


  1%|          | 1/100 [00:01<01:57,  1.18s/it]

hid torch.Size([2, 10, 512])


  2%|▏         | 2/100 [00:02<01:37,  1.00it/s]

hid torch.Size([2, 10, 512])


  3%|▎         | 3/100 [00:02<01:25,  1.14it/s]

hid torch.Size([2, 10, 512])


  4%|▍         | 4/100 [00:03<01:16,  1.26it/s]

hid torch.Size([2, 10, 512])


  5%|▌         | 5/100 [00:04<01:23,  1.13it/s]

hid torch.Size([2, 10, 512])


  6%|▌         | 6/100 [00:05<01:25,  1.10it/s]

hid torch.Size([2, 10, 512])


  7%|▋         | 7/100 [00:05<01:12,  1.28it/s]

hid torch.Size([2, 10, 512])


  8%|▊         | 8/100 [00:06<01:04,  1.43it/s]

hid torch.Size([2, 10, 512])


  9%|▉         | 9/100 [00:06<00:57,  1.58it/s]

hid torch.Size([2, 10, 512])


 10%|█         | 10/100 [00:07<00:53,  1.68it/s]

hid torch.Size([2, 10, 512])


 11%|█         | 11/100 [00:08<01:12,  1.23it/s]

hid torch.Size([2, 10, 512])


 12%|█▏        | 12/100 [00:09<01:01,  1.44it/s]

hid torch.Size([2, 10, 512])


 13%|█▎        | 13/100 [00:09<01:00,  1.43it/s]

hid torch.Size([2, 10, 512])


 14%|█▍        | 14/100 [00:10<01:02,  1.37it/s]

hid torch.Size([2, 10, 512])


 15%|█▌        | 15/100 [00:11<00:59,  1.42it/s]

hid torch.Size([2, 10, 512])


 16%|█▌        | 16/100 [00:11<00:55,  1.52it/s]

hid torch.Size([2, 10, 512])


 17%|█▋        | 17/100 [00:12<00:56,  1.48it/s]

hid torch.Size([2, 10, 512])


 18%|█▊        | 18/100 [00:13<00:53,  1.54it/s]

hid torch.Size([2, 10, 512])


 19%|█▉        | 19/100 [00:13<00:48,  1.65it/s]

hid torch.Size([2, 10, 512])


 20%|██        | 20/100 [00:14<00:55,  1.44it/s]

hid torch.Size([2, 10, 512])


 21%|██        | 21/100 [00:15<00:51,  1.53it/s]

hid torch.Size([2, 10, 512])


 22%|██▏       | 22/100 [00:15<00:49,  1.57it/s]

hid torch.Size([2, 10, 512])


 23%|██▎       | 23/100 [00:16<00:44,  1.75it/s]

hid torch.Size([2, 10, 512])


 24%|██▍       | 24/100 [00:17<00:54,  1.39it/s]

hid torch.Size([2, 10, 512])


 25%|██▌       | 25/100 [00:17<00:47,  1.60it/s]

hid torch.Size([2, 10, 512])


 26%|██▌       | 26/100 [00:18<00:41,  1.77it/s]

hid torch.Size([2, 10, 512])


 27%|██▋       | 27/100 [00:18<00:44,  1.63it/s]

hid torch.Size([2, 10, 512])


 28%|██▊       | 28/100 [00:19<00:50,  1.44it/s]

hid torch.Size([2, 10, 512])


 29%|██▉       | 29/100 [00:21<01:02,  1.13it/s]

hid torch.Size([2, 10, 512])


 30%|███       | 30/100 [00:21<01:02,  1.13it/s]

hid torch.Size([2, 10, 512])


 31%|███       | 31/100 [00:22<00:56,  1.22it/s]

hid torch.Size([2, 10, 512])


 32%|███▏      | 32/100 [00:22<00:47,  1.44it/s]

hid torch.Size([2, 10, 512])


 33%|███▎      | 33/100 [00:23<00:42,  1.59it/s]

hid torch.Size([2, 10, 512])


 34%|███▍      | 34/100 [00:24<00:45,  1.44it/s]

hid torch.Size([2, 10, 512])


 35%|███▌      | 35/100 [00:24<00:42,  1.54it/s]

hid torch.Size([2, 10, 512])


 36%|███▌      | 36/100 [00:25<00:40,  1.59it/s]

hid torch.Size([2, 10, 512])


 37%|███▋      | 37/100 [00:25<00:37,  1.67it/s]

hid torch.Size([2, 10, 512])


 38%|███▊      | 38/100 [00:26<00:35,  1.75it/s]

hid torch.Size([2, 10, 512])


 39%|███▉      | 39/100 [00:26<00:32,  1.89it/s]

hid torch.Size([2, 10, 512])


 40%|████      | 40/100 [00:28<00:42,  1.41it/s]

hid torch.Size([2, 10, 512])


 41%|████      | 41/100 [00:28<00:40,  1.45it/s]

hid torch.Size([2, 10, 512])


 42%|████▏     | 42/100 [00:29<00:37,  1.53it/s]

hid torch.Size([2, 10, 512])


 43%|████▎     | 43/100 [00:30<00:39,  1.43it/s]

hid torch.Size([2, 10, 512])


 44%|████▍     | 44/100 [00:31<00:46,  1.19it/s]

hid torch.Size([2, 10, 512])


 45%|████▌     | 45/100 [00:32<00:49,  1.11it/s]

hid torch.Size([2, 10, 512])


 46%|████▌     | 46/100 [00:32<00:45,  1.20it/s]

hid torch.Size([2, 10, 512])


 47%|████▋     | 47/100 [00:33<00:42,  1.25it/s]

hid torch.Size([2, 10, 512])


 48%|████▊     | 48/100 [00:34<00:40,  1.27it/s]

hid torch.Size([2, 10, 512])


 49%|████▉     | 49/100 [00:35<00:38,  1.31it/s]

hid torch.Size([2, 10, 512])


 50%|█████     | 50/100 [00:35<00:33,  1.51it/s]

hid torch.Size([2, 10, 512])


 51%|█████     | 51/100 [00:37<00:44,  1.11it/s]

hid torch.Size([2, 10, 512])


 52%|█████▏    | 52/100 [00:37<00:38,  1.23it/s]

hid torch.Size([2, 10, 512])


 53%|█████▎    | 53/100 [00:38<00:33,  1.42it/s]

hid torch.Size([2, 10, 512])


 54%|█████▍    | 54/100 [00:38<00:31,  1.46it/s]

hid torch.Size([2, 10, 512])


 55%|█████▌    | 55/100 [00:39<00:28,  1.58it/s]

hid torch.Size([2, 10, 512])


 56%|█████▌    | 56/100 [00:39<00:29,  1.48it/s]

hid torch.Size([2, 10, 512])


 57%|█████▋    | 57/100 [00:40<00:29,  1.48it/s]

hid torch.Size([2, 10, 512])


 58%|█████▊    | 58/100 [00:41<00:30,  1.39it/s]

hid torch.Size([2, 10, 512])


 59%|█████▉    | 59/100 [00:42<00:34,  1.17it/s]

hid torch.Size([2, 10, 512])


 60%|██████    | 60/100 [00:43<00:30,  1.31it/s]

hid torch.Size([2, 10, 512])


 61%|██████    | 61/100 [00:44<00:31,  1.25it/s]

hid torch.Size([2, 10, 512])


 62%|██████▏   | 62/100 [00:44<00:27,  1.41it/s]

hid torch.Size([2, 10, 512])


 63%|██████▎   | 63/100 [00:45<00:25,  1.47it/s]

hid torch.Size([2, 10, 512])


 64%|██████▍   | 64/100 [00:45<00:23,  1.50it/s]

hid torch.Size([2, 10, 512])


 65%|██████▌   | 65/100 [00:46<00:21,  1.64it/s]

hid torch.Size([2, 10, 512])


 66%|██████▌   | 66/100 [00:47<00:21,  1.56it/s]

hid torch.Size([2, 10, 512])


 67%|██████▋   | 67/100 [00:47<00:20,  1.62it/s]

hid torch.Size([2, 10, 512])


 68%|██████▊   | 68/100 [00:48<00:21,  1.50it/s]

hid torch.Size([2, 10, 512])


 69%|██████▉   | 69/100 [00:49<00:21,  1.44it/s]

hid torch.Size([2, 10, 512])


 70%|███████   | 70/100 [00:49<00:18,  1.58it/s]

hid torch.Size([2, 10, 512])


 71%|███████   | 71/100 [00:50<00:21,  1.37it/s]

hid torch.Size([2, 10, 512])


 72%|███████▏  | 72/100 [00:51<00:21,  1.31it/s]

hid torch.Size([2, 10, 512])


 73%|███████▎  | 73/100 [00:51<00:18,  1.49it/s]

hid torch.Size([2, 10, 512])


 74%|███████▍  | 74/100 [00:52<00:18,  1.44it/s]

hid torch.Size([2, 10, 512])


 75%|███████▌  | 75/100 [00:53<00:15,  1.58it/s]

hid torch.Size([2, 10, 512])


 76%|███████▌  | 76/100 [00:53<00:13,  1.72it/s]

hid torch.Size([2, 10, 512])


 77%|███████▋  | 77/100 [00:54<00:16,  1.37it/s]

hid torch.Size([2, 10, 512])


 78%|███████▊  | 78/100 [00:55<00:14,  1.47it/s]

hid torch.Size([2, 10, 512])


 79%|███████▉  | 79/100 [00:55<00:12,  1.65it/s]

hid torch.Size([2, 10, 512])


 80%|████████  | 80/100 [00:56<00:11,  1.73it/s]

hid torch.Size([2, 10, 512])


 81%|████████  | 81/100 [00:56<00:11,  1.60it/s]

hid torch.Size([2, 10, 512])


 82%|████████▏ | 82/100 [00:57<00:13,  1.34it/s]

hid torch.Size([2, 10, 512])


 83%|████████▎ | 83/100 [00:58<00:12,  1.40it/s]

hid torch.Size([2, 10, 512])


 84%|████████▍ | 84/100 [00:59<00:10,  1.46it/s]

hid torch.Size([2, 10, 512])


 85%|████████▌ | 85/100 [00:59<00:09,  1.60it/s]

hid torch.Size([2, 10, 512])


 86%|████████▌ | 86/100 [01:00<00:08,  1.69it/s]

hid torch.Size([2, 10, 512])


 87%|████████▋ | 87/100 [01:00<00:07,  1.65it/s]

hid torch.Size([2, 10, 512])


 88%|████████▊ | 88/100 [01:01<00:07,  1.69it/s]

hid torch.Size([2, 10, 512])


 89%|████████▉ | 89/100 [01:01<00:05,  1.86it/s]

hid torch.Size([2, 10, 512])


 90%|█████████ | 90/100 [01:02<00:05,  1.78it/s]

hid torch.Size([2, 10, 512])


 91%|█████████ | 91/100 [01:03<00:05,  1.63it/s]

hid torch.Size([2, 10, 512])


 92%|█████████▏| 92/100 [01:03<00:04,  1.76it/s]

hid torch.Size([2, 10, 512])


 93%|█████████▎| 93/100 [01:04<00:04,  1.55it/s]

hid torch.Size([2, 10, 512])


 94%|█████████▍| 94/100 [01:04<00:03,  1.75it/s]

hid torch.Size([2, 10, 512])


 95%|█████████▌| 95/100 [01:05<00:03,  1.52it/s]

hid torch.Size([2, 10, 512])


 96%|█████████▌| 96/100 [01:06<00:02,  1.56it/s]

hid torch.Size([2, 10, 512])


 97%|█████████▋| 97/100 [01:07<00:02,  1.46it/s]

hid torch.Size([2, 10, 512])


 98%|█████████▊| 98/100 [01:07<00:01,  1.42it/s]

hid torch.Size([2, 10, 512])


 99%|█████████▉| 99/100 [01:08<00:00,  1.36it/s]

hid torch.Size([2, 10, 512])


100%|██████████| 100/100 [01:09<00:00,  1.44it/s]


hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

In [44]:
for epoch_info in epoch_info_list:
  print(epoch_info)

Epoch: 01, Train Loss: 10.688, Train PPL: 43816.560, Validation Loss:10.685, Validation PPL: 43682.609
Epoch: 02, Train Loss: 10.688, Train PPL: 43815.816, Validation Loss:10.685, Validation PPL: 43682.609
Epoch: 03, Train Loss: 10.688, Train PPL: 43819.184, Validation Loss:10.685, Validation PPL: 43682.609


In [45]:
#NOTE: load model from file
model.load_state_dict(torch.load('model.pt'))

test_loss = evaluate(model, test_iter, criterion)

print(f'\tTest Loss: {test_loss:.3f}\tTest PPL: {np.exp(test_loss):7.3f}')

hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size([2, 10, 512])
hid torch.Size

In [46]:
#clean mem
del model
del train_iter
del valid_iter
del test_iter
torch.cuda.empty_cache()

BLEU added

Seconadly, I have been training the model with BLEU score addded, with tranining data size = 1000, batch size =10, num of epoch = 3 and Luong attention


In [47]:
from torchtext.data.metrics import bleu_score

def calculate_bleu(model, iterator, src_vocab, trg_vocab, device):
    model.eval()
    trgs = []
    pred_trgs = []

    with torch.no_grad():
        for src, trg in iterator:
            src = src.to(device)
            trg = trg.to(device)

            output = model(src, trg, 0)  # no teacher forcing
            output_dim = output.shape[-1]

            output = output.argmax(-1)  # [trg_len, batch_size] finds the index of the most likely token across the last dimension (output_dim)

            for i in range(trg.shape[1]):
                trg_sent = [trg_vocab.lookup_tokens(trg[1:, i])]
                pred_sent = trg_vocab.lookup_tokens(output[1:, i])

                """
                def lookup_tokens(self, indices: List[int]) -> List[str]:r
                Args:indices: The `indices` used to lookup their corresponding`tokens`.
                Returns: The `tokens` associated with `indices`.
                Raises: RuntimeError: If an index within `indices` is not int range [0, itos.size()).
                     return self.vocab.lookup_tokens(indices)
                """

                # remove padding and EOS tokens from predictions
                pred_sent = [token for token in pred_sent if token not in ['<pad>', '<eos>']]

                trgs.append(trg_sent)
                pred_trgs.append(pred_sent)

    return bleu_score(pred_trgs, trgs)

In [48]:
N_EPOCHS = 3
CLIP = 1

best_valid_loss = float('inf')
best_valid_bleu = 0.0

epoch_info_list_bleu = []

for epoch in range(N_EPOCHS):

    train_loss = train(model, train_iter, optimizer, criterion, CLIP, device)
    valid_loss = evaluate(model, valid_iter, criterion, device)
    valid_bleu = calculate_bleu(model, valid_iter, SRC.vocab, TRG.vocab, device)

    # Check both BLEU and validation loss for model saving (you can pick either)
    if valid_bleu > best_valid_bleu:
        best_valid_bleu = valid_bleu
        torch.save(model.state_dict(), 'model_best_bleu.pt')

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model_best_loss.pt')

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}\tTrain PPL: {np.exp(train_loss):7.3f}')
    print(f'\tValidation Loss: {valid_loss:.3f}\tValidation PPL: {np.exp(valid_loss):7.3f}')
    print(f'\tValidation BLEU Score: {valid_bleu*100:.2f}')


    epoch_info_list_bleu.append(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}, Train PPL: {np.exp(train_loss):7.3f}, Validation Loss:{valid_loss:.3f}, Validation PPL: {np.exp(valid_loss):7.3f}, Validation BLEU Score: {valid_bleu*100:.2f}')

NameError: name 'model' is not defined

In [None]:
#NOTE: load model from file
#model.load_state_dict(torch.load('model.pt'))

test_loss = evaluate(model, test_iter, criterion)

print(f'\tTest Loss: {test_loss:.3f}\tTest PPL: {np.exp(test_loss):7.3f}')

In [None]:
#clean mem
del model
del train_iter
del valid_iter
del test_iter
torch.cuda.empty_cache()