In [1]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline

NLP From Scratch: Translation with a Sequence to Sequence Network and Attention
===============================================================================

**Author**: [Sean Robertson](https://github.com/spro)

This is the third and final tutorial on doing \"NLP From Scratch\",
where we write our own classes and functions to preprocess the data to
do our NLP modeling tasks. We hope after you complete this tutorial that
you\'ll proceed to learn how [torchtext]{.title-ref} can handle much of
this preprocessing for you in the three tutorials immediately following
this one.

In this project we will be teaching a neural network to translate from
French to English.

``` {.sourceCode .sh}
[KEY: > input, = target, < output]

> il est en train de peindre un tableau .
= he is painting a picture .
< he is painting a picture .

> pourquoi ne pas essayer ce vin delicieux ?
= why not try that delicious wine ?
< why not try that delicious wine ?

> elle n est pas poete mais romanciere .
= she is not a poet but a novelist .
< she not not a poet but a novelist .

> vous etes trop maigre .
= you re too skinny .
< you re all alone .
```

\... to varying degrees of success.

This is made possible by the simple but powerful idea of the [sequence
to sequence network](https://arxiv.org/abs/1409.3215), in which two
recurrent neural networks work together to transform one sequence to
another. An encoder network condenses an input sequence into a vector,
and a decoder network unfolds that vector into a new sequence.

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/seq2seq.png)

To improve upon this model we\'ll use an [attention
mechanism](https://arxiv.org/abs/1409.0473), which lets the decoder
learn to focus over a specific range of the input sequence.

**Recommended Reading:**

I assume you have at least installed PyTorch, know Python, and
understand Tensors:

-   <https://pytorch.org/> For installation instructions
-   `/beginner/deep_learning_60min_blitz`{.interpreted-text role="doc"}
    to get started with PyTorch in general
-   `/beginner/pytorch_with_examples`{.interpreted-text role="doc"} for
    a wide and deep overview
-   `/beginner/former_torchies_tutorial`{.interpreted-text role="doc"}
    if you are former Lua Torch user

It would also be useful to know about Sequence to Sequence networks and
how they work:

-   [Learning Phrase Representations using RNN Encoder-Decoder for
    Statistical Machine Translation](https://arxiv.org/abs/1406.1078)
-   [Sequence to Sequence Learning with Neural
    Networks](https://arxiv.org/abs/1409.3215)
-   [Neural Machine Translation by Jointly Learning to Align and
    Translate](https://arxiv.org/abs/1409.0473)
-   [A Neural Conversational Model](https://arxiv.org/abs/1506.05869)

You will also find the previous tutorials on
`/intermediate/char_rnn_classification_tutorial`{.interpreted-text
role="doc"} and
`/intermediate/char_rnn_generation_tutorial`{.interpreted-text
role="doc"} helpful as those concepts are very similar to the Encoder
and Decoder models, respectively.

**Requirements**


In [33]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Loading data files
==================

The data for this project is a set of many thousands of English to
French translation pairs.

[This question on Open Data Stack
Exchange](https://opendata.stackexchange.com/questions/3888/dataset-of-sentences-translated-into-many-languages)
pointed me to the open translation site <https://tatoeba.org/> which has
downloads available at <https://tatoeba.org/eng/downloads> - and better
yet, someone did the extra work of splitting language pairs into
individual text files here: <https://www.manythings.org/anki/>

The English to French pairs are too big to include in the repository, so
download to `data/eng-fra.txt` before continuing. The file is a tab
separated list of translation pairs:

``` {.sourceCode .sh}
I am cold.    J'ai froid.
```

<div style="background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px"><strong>NOTE:</strong></div>
<div style="background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px">
<p>Download the data from<a href="https://download.pytorch.org/tutorial/data.zip">here</a>and extract it to the current directory.</p>
</div>


Similar to the character encoding used in the character-level RNN
tutorials, we will be representing each word in a language as a one-hot
vector, or giant vector of zeros except for a single one (at the index
of the word). Compared to the dozens of characters that might exist in a
language, there are many many more words, so the encoding vector is much
larger. We will however cheat a bit and trim the data to only use a few
thousand words per language.

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/word-encoding.png)


We\'ll need a unique index per word to use as the inputs and targets of
the networks later. To keep track of all this we will use a helper class
called `Lang` which has word → index (`word2index`) and index → word
(`index2word`) dictionaries, as well as a count of each word
`word2count` which will be used to replace rare words later.


In [34]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

The files are all in Unicode, to simplify we will turn Unicode
characters to ASCII, make everything lowercase, and trim most
punctuation.


In [35]:
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z!?]+", r" ", s)
    return s.strip()

To read the data file we will split the file into lines, and then split
lines into pairs. The files are all English → Other Language, so if we
want to translate from Other Language → English I added the `reverse`
flag to reverse the pairs.


In [36]:
def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
        read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

Since there are a *lot* of example sentences and we want to train
something quickly, we\'ll trim the data set to only relatively short and
simple sentences. Here the maximum length is 10 words (that includes
ending punctuation) and we\'re filtering to sentences that translate to
the form \"I am\" or \"He is\" etc. (accounting for apostrophes replaced
earlier).


In [37]:
MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH # and \
        # p[1].startswith(eng_prefixes)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

The full process for preparing the data is:

-   Read text file and split into lines, split lines into pairs
-   Normalize text, filter by length and content
-   Make word lists from sentences in pairs


In [38]:
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

input_lang, output_lang, pairs = prepareData('eng', 'spa', True)
print(random.choice(pairs))

Reading lines...
Read 118121 sentence pairs
Trimmed to 99195 sentence pairs
Counting words...
Counted words:
spa 21337
eng 11307
['hacemos esto cada lunes', 'we do this every monday']


The Seq2Seq Model
=================

A Recurrent Neural Network, or RNN, is a network that operates on a
sequence and uses its own output as input for subsequent steps.

A [Sequence to Sequence network](https://arxiv.org/abs/1409.3215), or
seq2seq network, or [Encoder Decoder
network](https://arxiv.org/pdf/1406.1078v3.pdf), is a model consisting
of two RNNs called the encoder and decoder. The encoder reads an input
sequence and outputs a single vector, and the decoder reads that vector
to produce an output sequence.

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/seq2seq.png)

Unlike sequence prediction with a single RNN, where every input
corresponds to an output, the seq2seq model frees us from sequence
length and order, which makes it ideal for translation between two
languages.

Consider the sentence `Je ne suis pas le chat noir` →
`I am not the black cat`. Most of the words in the input sentence have a
direct translation in the output sentence, but are in slightly different
orders, e.g. `chat noir` and `black cat`. Because of the `ne/pas`
construction there is also one more word in the input sentence. It would
be difficult to produce a correct translation directly from the sequence
of input words.

With a seq2seq model the encoder creates a single vector which, in the
ideal case, encodes the \"meaning\" of the input sequence into a single
vector --- a single point in some N dimensional space of sentences.


The Encoder
===========

The encoder of a seq2seq network is a RNN that outputs some value for
every word from the input sentence. For every input word the encoder
outputs a vector and a hidden state, and uses the hidden state for the
next input word.

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/encoder-network.png)


In [39]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden

The Decoder
===========

The decoder is another RNN that takes the encoder output vector(s) and
outputs a sequence of words to create the translation.


Simple Decoder
==============

In the simplest seq2seq decoder we use only last output of the encoder.
This last output is sometimes called the *context vector* as it encodes
context from the entire sequence. This context vector is used as the
initial hidden state of the decoder.

At every step of decoding, the decoder is given an input token and
hidden state. The initial input token is the start-of-string `<SOS>`
token, and the first hidden state is the context vector (the encoder\'s
last hidden state).

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/decoder-network.png)


In [40]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden  = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)

            if target_tensor is not None:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop

    def forward_step(self, input, hidden):
        output = self.embedding(input)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.out(output)
        return output, hidden

I encourage you to train and observe the results of this model, but to
save space we\'ll be going straight for the gold and introducing the
Attention Mechanism.


Attention Decoder
=================

If only the context vector is passed between the encoder and decoder,
that single vector carries the burden of encoding the entire sentence.

Attention allows the decoder network to \"focus\" on a different part of
the encoder\'s outputs for every step of the decoder\'s own outputs.
First we calculate a set of *attention weights*. These will be
multiplied by the encoder output vectors to create a weighted
combination. The result (called `attn_applied` in the code) should
contain information about that specific part of the input sequence, and
thus help the decoder choose the right output words.

![](https://i.imgur.com/1152PYf.png)

Calculating the attention weights is done with another feed-forward
layer `attn`, using the decoder\'s input and hidden state as inputs.
Because there are sentences of all sizes in the training data, to
actually create and train this layer we have to choose a maximum
sentence length (input length, for encoder outputs) that it can apply
to. Sentences of the maximum length will use all the attention weights,
while shorter sentences will only use the first few.

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/attention-decoder-network.png)

Bahdanau attention, also known as additive attention, is a commonly used
attention mechanism in sequence-to-sequence models, particularly in
neural machine translation tasks. It was introduced by Bahdanau et al.
in their paper titled [Neural Machine Translation by Jointly Learning to
Align and Translate](https://arxiv.org/pdf/1409.0473.pdf). This
attention mechanism employs a learned alignment model to compute
attention scores between the encoder and decoder hidden states. It
utilizes a feed-forward neural network to calculate alignment scores.

However, there are alternative attention mechanisms available, such as
Luong attention, which computes attention scores by taking the dot
product between the decoder hidden state and the encoder hidden states.
It does not involve the non-linear transformation used in Bahdanau
attention.

In this tutorial, we will be using Bahdanau attention. However, it would
be a valuable exercise to explore modifying the attention mechanism to
use Luong attention.


In [41]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)

        return context, weights

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attention = BahdanauAttention(hidden_size)
        self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions


    def forward_step(self, input, hidden, encoder_outputs):
        embedded =  self.dropout(self.embedding(input))

        query = hidden.permute(1, 0, 2)
        context, attn_weights = self.attention(query, encoder_outputs)
        input_gru = torch.cat((embedded, context), dim=2)

        output, hidden = self.gru(input_gru, hidden)
        output = self.out(output)

        return output, hidden, attn_weights

<div style="background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px"><strong>NOTE:</strong></div>
<div style="background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px">
<p>There are other forms of attention that work around the lengthlimitation by using a relative position approach. Read about "localattention" in <a href="https://arxiv.org/abs/1508.04025">Effective Approaches to Attention-based Neural MachineTranslation</a>.</p>
</div>

Training
========

Preparing Training Data
-----------------------

To train, for each pair we will need an input tensor (indexes of the
words in the input sentence) and target tensor (indexes of the words in
the target sentence). While creating these vectors we will append the
EOS token to both sequences.


In [42]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

def get_dataloader(batch_size, language='spa'):
    input_lang, output_lang, pairs = prepareData('eng', language, True)

    n = len(pairs)
    input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
    target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

    for idx, (inp, tgt) in enumerate(pairs):
        inp_ids = indexesFromSentence(input_lang, inp)
        tgt_ids = indexesFromSentence(output_lang, tgt)
        inp_ids.append(EOS_token)
        tgt_ids.append(EOS_token)
        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tgt_ids)] = tgt_ids

    train_data = TensorDataset(torch.LongTensor(input_ids).to(device),
                               torch.LongTensor(target_ids).to(device))

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
    return input_lang, output_lang, train_dataloader, pairs

Training the Model
==================

To train we run the input sentence through the encoder, and keep track
of every output and the latest hidden state. Then the decoder is given
the `<SOS>` token as its first input, and the last hidden state of the
encoder as its first hidden state.

\"Teacher forcing\" is the concept of using the real target outputs as
each next input, instead of using the decoder\'s guess as the next
input. Using teacher forcing causes it to converge faster but [when the
trained network is exploited, it may exhibit
instability](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.378.4095&rep=rep1&type=pdf).

You can observe outputs of teacher-forced networks that read with
coherent grammar but wander far from the correct translation
-intuitively it has learned to represent the output grammar and can
\"pick up\" the meaning once the teacher tells it the first few words,
but it has not properly learned how to create the sentence from the
translation in the first place.

Because of the freedom PyTorch\'s autograd gives us, we can randomly
choose to use teacher forcing or not with a simple if statement. Turn
`teacher_forcing_ratio` up to use more of it.


In [43]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer,
          decoder_optimizer, criterion):

    total_loss = 0
    for data in dataloader:
        input_tensor, target_tensor = data

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)

        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

This is a helper function to print time elapsed and estimated time
remaining given the current time and progress %.


In [44]:
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [45]:
class ClientUpdate(object):
    def __init__(self, train_dataloader, learning_rate, epochs, sch_flag):
        self.train_loader = train_dataloader
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.sch_flag = sch_flag

    def train(self, encoder, decoder):

        criterion = nn.CrossEntropyLoss()
        # optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.95, weight_decay = 5e-4)
        encoder_optimizer = optim.Adam(encoder.parameters(), lr=self.learning_rate)
        decoder_optimizer = optim.Adam(decoder.parameters(), lr=self.learning_rate)
        # if self.sch_flag == True:
        #    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5)
        # my_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
        e_loss = []
        for epoch in range(1, self.epochs + 1):
            loss = train_epoch(self.train_loader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
            e_loss.append(loss)

            # print_loss_total += loss
            # plot_loss_total += loss

            # if epoch % print_every == 0:
            #     print_loss_avg = print_loss_total / print_every
            #     print_loss_total = 0
            #     print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
            #                                 epoch, epoch / n_epochs * 100, print_loss_avg))

            # if epoch % plot_every == 0:
            #     plot_loss_avg = plot_loss_total / plot_every
            #     plot_losses.append(plot_loss_avg)
            #     plot_loss_total = 0

            total_loss = sum(e_loss) / len(e_loss)

        return encoder.state_dict(), decoder.state_dict(), total_loss, torch.tensor(len(self.train_loader.sampler))

In [46]:
class ClientFinetune(object):
    def __init__(self, train_dataloader, learning_rate, epochs, sch_flag):
        self.train_loader = train_dataloader
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.sch_flag = sch_flag

    def train(self, encoder, decoder):

        criterion = nn.CrossEntropyLoss()
        # optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.95, weight_decay = 5e-4)
        encoder_optimizer = optim.Adam(encoder.parameters(), lr=self.learning_rate)
        decoder_optimizer = optim.Adam(decoder.parameters(), lr=self.learning_rate)
        # if self.sch_flag == True:
        #    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5)
        # my_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
        e_loss = []
        for epoch in range(1, self.epochs + 1):
            loss = train_epoch(self.train_loader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
            e_loss.append(loss)

            # print_loss_total += loss
            # plot_loss_total += loss

            # if epoch % print_every == 0:
            #     print_loss_avg = print_loss_total / print_every
            #     print_loss_total = 0
            #     print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
            #                                 epoch, epoch / n_epochs * 100, print_loss_avg))

            # if epoch % plot_every == 0:
            #     plot_loss_avg = plot_loss_total / plot_every
            #     plot_losses.append(plot_loss_avg)
            #     plot_loss_total = 0

            total_loss = sum(e_loss) / len(e_loss)

        return encoder.state_dict(), decoder.state_dict(), total_loss

The whole training process looks like this:

-   Start a timer
-   Initialize optimizers and criterion
-   Create set of training pairs
-   Start empty losses array for plotting

Then we call `train` many times and occasionally print the progress (%
of examples, time so far, estimated time) and average loss.


In [47]:
from torchtext.data.metrics import bleu_score
def evaluateBleu(encoder, decoder, input_lang, output_lang, pairs, n=10, verbose=False):
    references = []
    candidates = []

    for _ in range(n):
        pair = random.choice(pairs)
        output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)
        output_sentence = ' '.join(output_words).split(' ')
        if verbose:
            print('>', pair[0])
            print('=', pair[1])
            print('<', ' '.join(output_sentence))
            print('')

        # Store the reference and candidate sentences for BLEU calculation
        references.append([pair[1].split(' ')])
        candidates.append(output_sentence)

    # Calculate the BLEU score
    score = bleu_score(candidates, references)
    return score

In [48]:
import csv
def train(train_dataloader, encoder, decoder, n_epochs, input, output, pairs, filename=None, learning_rate=0.001,
               print_every=100, plot_every=100):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every
    best_bleu = 0

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        bleu = evaluateBleu(encoder, decoder, input, output, pairs, n=30)

        if best_bleu < bleu:
            best_bleu = bleu

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0


        if filename is not None:
            with open(filename, 'a') as f:
                # create the csv writer
                writer = csv.writer(f)

                # write a row to the csv file
                writer.writerow([epoch, loss, bleu, best_bleu])

    showPlot(plot_losses)

Plotting results
================

Plotting is done with matplotlib, using the array of loss values
`plot_losses` saved while training.


In [49]:
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

Evaluation
==========

Evaluation is mostly the same as training, but there are no targets so
we simply feed the decoder\'s predictions back to itself for each step.
Every time it predicts a word we add it to the output string, and if it
predicts the EOS token we stop there. We also store the decoder\'s
attention outputs for display later.


In [50]:
def evaluate(encoder, decoder, sentence, input_lang, output_lang):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

        _, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        decoded_words = []
        for idx in decoded_ids:
            if idx.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            decoded_words.append(output_lang.index2word[idx.item()])
    return decoded_words, decoder_attn

We can evaluate random sentences from the training set and print out the
input, target, and output to make some subjective quality judgements:


In [51]:
def evaluateMultipleBleu(encoders, decoders, input_output_langs, n=10):
    bleu_sum = 0
    lang_length = len(encoders)
    for i in range(lang_length):
        bleu_sum += evaluateBleu(encoders[i], decoders[i], input_output_langs[i][0], input_output_langs[i][1], input_output_langs[i][2], n=n)
    print(bleu_sum)
    return bleu_sum / lang_length
        

## Federated Learning

In [61]:
import copy
from tqdm import tqdm
import csv

def training(encoders, decoders, input_output_lang, rounds, lr, ds, C, K, E, filename=None, batch_size=None, hidden_size=None, cifar_data_test = None,
             test_batch_size = None, classes_test = None, sch_flag = None, weighting=False, lexical_weights=None):
    """
    Function implements the Federated Averaging Algorithm from the FedAvg paper.
    Specifically, this function is used for the server side training and weight update

    Params:
      - model:           PyTorch model to train
      - rounds:          Number of communication rounds for the client update
      - batch_size:      Batch size for client update training
      - lr:              Learning rate used for client update training
      - ds:              Dataset used for training
      - data_dict:       Type of data partition used for training (IID or non-IID)
      - C:               Fraction of clients randomly chosen to perform computation on each round
      - K:               Total number of clients
      - E:               Number of training passes each client makes over its local dataset per round
      - tb_writer_name:  Directory name to save the tensorboard logs
    Returns:
      - model:           Trained model on the server
    """

    # global model weights
    global_encoder_weights = {key: value for key, value in encoders[0].state_dict().items() if 'embedding' not in key}
    global_decoder_weights = {key: value for key, value in decoders[0].state_dict().items() if 'embedding' not in key and 'out' not in key}

    # training loss
    train_loss = []
    test_loss = []
    test_accuracy = []
    best_bleu = 0
    # measure time
    start = time.time()

    lex_weighting = False
    if lexical_weights is not None:
        lex_weighting = True

    if filename is not None:
            with open(filename, 'a') as f:
                # create the csv writer
                writer = csv.writer(f)

                # write a row to the csv file
                writer.writerow(['Rounds', 'Learning Rate', 'Client Fraction', 'Client Number', 'Local Epochs', 'Batch Size', 'Hidden Size', 'Weighting', 'Lexical Weighting'])
                writer.writerow([rounds, lr, C, K, E, batch_size, hidden_size, weighting, lex_weighting])

    for curr_round in range(1, rounds + 1):
        w_encoder, w_decoder, local_loss, num_pairs, round_lex_weight = [], [], [], [], []
        # Retrieve the number of clients participating in the current training
        m = max(int(C * K), 1)
        # Sample a subset of K clients according with the value defined before
        S_t = np.random.choice(range(K), m, replace=False)
        # For the selected clients start a local training
        for k in tqdm(S_t):
            # Compute a local update
            local_update = ClientUpdate(train_dataloader=ds[k], learning_rate=lr, epochs=E,
                                        sch_flag=sch_flag)
            # Update means retrieve the values of the network weights
            
            e_og = encoders[k].state_dict()
            d_og = decoders[k].state_dict()
            e_og.update(global_encoder_weights)
            d_og.update(global_decoder_weights)
            encoders[k].load_state_dict(e_og)
            decoders[k].load_state_dict(d_og)
            
            encoder_weights, decoder_weights, loss, num = local_update.train(encoders[k], decoders[k])

            w_encoder.append({key: value for key, value in copy.deepcopy(encoder_weights).items() if 'embedding' not in key})
            w_decoder.append({key: value for key, value in copy.deepcopy(decoder_weights).items() if 'embedding' not in key and 'out' not in key})
            local_loss.append(copy.deepcopy(loss))
            num_pairs.append(num)
            if lex_weighting:
                round_lex_weight.append(lexical_weights[k])
        # lr = 0.999*lr
        # updating the global weights
        weights_avg_e = copy.deepcopy(w_encoder[0])

        weights_sum = sum(num_pairs)
        client_weights = [weight / weights_sum for weight in num_pairs]
        if lex_weighting:
            new_weights = [100 - weight for weight in round_lex_weight]
            sum_lex = sum(round_lex_weight)
            client_weights = [weight * (new_weights/sum_lex) * K for (weight,new_weights) in zip(client_weights,new_weights)]

        for k in weights_avg_e.keys():
            if weighting:
                weights_avg_e[k] *= client_weights[0]
                for i in range(1, len(w_encoder)):
                    weights_avg_e[k] += w_encoder[i][k] * client_weights[i]
                
            else:
                for i in range(1, len(w_encoder)):
                    weights_avg_e[k] += w_encoder[i][k]
                
                weights_avg_e[k] = torch.div(weights_avg_e[k], len(w_encoder))

        global_encoder_weights = weights_avg_e

        weights_avg_d = copy.deepcopy(w_decoder[0])
        for k in weights_avg_d.keys():
            if weighting:
                weights_avg_d[k] *= client_weights[0]
                for i in range(1, len(w_decoder)):
                    weights_avg_d[k] += w_decoder[i][k] * client_weights[i]
                
            else:
                for i in range(1, len(w_decoder)):
                    weights_avg_d[k] += w_decoder[i][k]

                weights_avg_d[k] = torch.div(weights_avg_d[k], len(w_decoder))

        global_decoder_weights = weights_avg_d
        

        # if curr_round == 200:
        #     lr = lr / 2
        #     E = E - 1

        # if curr_round == 300:
        #     lr = lr / 2
        #     E = E - 2

        # if curr_round == 400:
        #     lr = lr / 5
        #     E = E - 3

        # move the updated weights to our model state dict
        # encoder.load_state_dict(global_encoder_weights)
        # decoder.load_state_dict(global_decoder_weights)

        # loss
        loss_avg = sum(local_loss) / len(local_loss)
        # print('Round: {}... \tAverage Loss: {}'.format(curr_round, round(loss_avg, 3)), lr)
        train_loss.append(loss_avg)

        # t_accuracy, t_loss = testing(model, cifar_data_test, test_batch_size, criterion, num_classes, classes_test)
        # test_accuracy.append(t_accuracy)
        # test_loss.append(t_loss)

        # if best_accuracy < t_accuracy:
        #     best_accuracy = t_accuracy
        # # torch.save(model.state_dict(), plt_title)
        # print(curr_round, loss_avg, t_loss, test_accuracy[0], best_accuracy)
        # # print('best_accuracy:', best_accuracy, '---Round:', curr_round, '---lr', lr, '----localEpocs--', E)
        bleu = evaluateMultipleBleu(encoders, decoders, input_output_lang, n=30)

        if best_bleu < bleu:
            best_bleu = bleu

        if filename is not None:
            with open(filename, 'a') as f:
                # create the csv writer
                writer = csv.writer(f)

                # write a row to the csv file
                writer.writerow([curr_round, loss_avg, bleu, best_bleu])
        print(f"Round {curr_round} >> Loss: {loss_avg}, BLEU:{bleu}")

    end = time.time()
    print("Training Done!")
    print("Total time taken to Train: {}".format(end - start))

    return global_encoder_weights, global_decoder_weights

Training and Evaluating
=======================

With all these helper functions in place (it looks like extra work, but
it makes it easier to run multiple experiments) we can actually
initialize a network and start training.

Remember that the input sentences were heavily filtered. For this small
dataset we can use relatively small networks of 256 hidden nodes and a
single GRU layer. After about 40 minutes on a MacBook CPU we\'ll get
some reasonable results.

<div style="background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px"><strong>NOTE:</strong></div>
<div style="background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px">
<p>If you run this notebook you can train, interrupt the kernel,evaluate, and continue training later. Comment out the lines where theencoder and decoder are initialized and run <code>trainIters</code> again.</p>
</div>


In [53]:
hidden_size = 128
batch_size = 32
# input_lang, output_lang, train_dataloader = get_dataloader(batch_size, language='spa')

In [65]:
data_dict = {}
encoders = {}
decoders = {}
lexical_weight_dict = {}
input_output_lang = {}
langs = ['bem', 'kin', 'lug']
langs = sorted(langs)
weights = [42, 6, 43] # This needs to line up with the sorted langs
K = len(langs)
for i in range(K):
    input_lang, output_lang, train_dataloader, pair = get_dataloader(batch_size, language=langs[i])
    data_dict[i] = train_dataloader
    encoders[i] = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    decoders[i] = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)
    input_output_lang[i] = (input_lang, output_lang, pair)
    lexical_weight_dict[i] = weights[i]

Reading lines...
Read 82370 sentence pairs
Trimmed to 25719 sentence pairs
Counting words...
Counted words:
bem 19481
eng 5092
Reading lines...
Read 55667 sentence pairs
Trimmed to 19616 sentence pairs
Counting words...
Counted words:
kin 36940
eng 11235
Reading lines...
Read 15022 sentence pairs
Trimmed to 8447 sentence pairs
Counting words...
Counted words:
lug 9684
eng 5677


In [28]:
# st = {key: value for key, value in encoders[0].state_dict().items() if 'embedding' not in key}
# st2 = {key: value for key, value in decoders[0].state_dict().items() if 'embedding' not in key and 'out' not in key}
# encoders[1].state_dict().update(st)
# decoders[1].state_dict().update(st2)
# new = (encoders[0].state_dict())
# new.update(global_encoder_weights)
# encoders[0].load_state_dict(new)

In [66]:
import os
# train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)

filename=None
save = True
if save:
    num = 1
    filename = f"GLOBAL|{'_'.join(langs)}||{num}"
    while os.path.isfile(filename):
        print('Name is taken...trying again...')
        num += 1
        filename = f"GLOBAL|{'_'.join(langs)}||{num}"

Name is taken...trying again...


In [67]:
meta_encoder_weights, meta_decoder_weights = training(encoders, decoders, input_output_lang, 200,
                                                    lr=0.001, ds=data_dict, C=1.0, K=K, E=1,
                                                    filename=filename, batch_size=batch_size, hidden_size=hidden_size,
                                                    weighting=False, lexical_weights=None)

100%|██████████| 3/3 [00:13<00:00,  4.61s/it]


0.0
Round 1 >> Loss: 4.154258994680615, BLEU:0.0


100%|██████████| 3/3 [00:13<00:00,  4.56s/it]


0.0
Round 2 >> Loss: 3.5634053175566636, BLEU:0.0


100%|██████████| 3/3 [00:13<00:00,  4.51s/it]


0.0
Round 3 >> Loss: 3.3251208230573464, BLEU:0.0


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


0.07206841558218002
Round 4 >> Loss: 3.14820963270534, BLEU:0.02402280519406001


100%|██████████| 3/3 [00:13<00:00,  4.47s/it]


0.15032929182052612
Round 5 >> Loss: 3.004931197832138, BLEU:0.05010976394017538


100%|██████████| 3/3 [00:13<00:00,  4.44s/it]


0.10174161940813065
Round 6 >> Loss: 2.8824475461584513, BLEU:0.03391387313604355


100%|██████████| 3/3 [00:13<00:00,  4.48s/it]


0.19765104353427887
Round 7 >> Loss: 2.775644513709166, BLEU:0.06588368117809296


100%|██████████| 3/3 [00:13<00:00,  4.53s/it]


0.1378864198923111
Round 8 >> Loss: 2.681013109015449, BLEU:0.0459621399641037


100%|██████████| 3/3 [00:13<00:00,  4.39s/it]


0.23776908963918686
Round 9 >> Loss: 2.5937926146150327, BLEU:0.07925636321306229


100%|██████████| 3/3 [00:13<00:00,  4.45s/it]


0.22895295172929764
Round 10 >> Loss: 2.5179197602810106, BLEU:0.07631765057643254


100%|██████████| 3/3 [00:13<00:00,  4.36s/it]


0.29714250192046165
Round 11 >> Loss: 2.447408608308917, BLEU:0.09904750064015388


100%|██████████| 3/3 [00:12<00:00,  4.23s/it]


0.30893929302692413
Round 12 >> Loss: 2.384284437885277, BLEU:0.10297976434230804


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


0.23390096426010132
Round 13 >> Loss: 2.322389679062572, BLEU:0.07796698808670044


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


0.280300859361887
Round 14 >> Loss: 2.2687531105360903, BLEU:0.09343361978729565


100%|██████████| 3/3 [00:12<00:00,  4.16s/it]


0.21356157958507538
Round 15 >> Loss: 2.2197909244985334, BLEU:0.07118719319502513


100%|██████████| 3/3 [00:12<00:00,  4.21s/it]


0.2970089502632618
Round 16 >> Loss: 2.1744311355357566, BLEU:0.09900298342108727


100%|██████████| 3/3 [00:13<00:00,  4.36s/it]


0.17609157785773277
Round 17 >> Loss: 2.1325854548753185, BLEU:0.058697192619244255


100%|██████████| 3/3 [00:12<00:00,  4.18s/it]


0.3805975764989853
Round 18 >> Loss: 2.097009011288366, BLEU:0.1268658588329951


100%|██████████| 3/3 [00:12<00:00,  4.22s/it]


0.2673579230904579
Round 19 >> Loss: 2.061505996298807, BLEU:0.0891193076968193


100%|██████████| 3/3 [00:12<00:00,  4.18s/it]


0.3501417487859726
Round 20 >> Loss: 2.0257732145783263, BLEU:0.11671391626199086


100%|██████████| 3/3 [00:12<00:00,  4.19s/it]


0.27727100998163223
Round 21 >> Loss: 1.994999399772107, BLEU:0.09242366999387741


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


0.24665245413780212
Round 22 >> Loss: 1.961896968030856, BLEU:0.08221748471260071


100%|██████████| 3/3 [00:13<00:00,  4.35s/it]


0.2753332108259201
Round 23 >> Loss: 1.9315491784119923, BLEU:0.09177773694197337


100%|██████████| 3/3 [00:13<00:00,  4.35s/it]


0.37682120501995087
Round 24 >> Loss: 1.9032974969802279, BLEU:0.1256070683399836


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


0.43633951991796494
Round 25 >> Loss: 1.874918256668325, BLEU:0.14544650663932165


100%|██████████| 3/3 [00:13<00:00,  4.49s/it]


0.5462937578558922
Round 26 >> Loss: 1.849519595260091, BLEU:0.1820979192852974


100%|██████████| 3/3 [00:13<00:00,  4.39s/it]


0.3800605535507202
Round 27 >> Loss: 1.8258708971453739, BLEU:0.1266868511835734


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


0.41133422777056694
Round 28 >> Loss: 1.803059675565617, BLEU:0.13711140925685564


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


0.4183943346142769
Round 29 >> Loss: 1.7791910208283097, BLEU:0.13946477820475897


100%|██████████| 3/3 [00:13<00:00,  4.42s/it]


0.47452616691589355
Round 30 >> Loss: 1.755223791762101, BLEU:0.1581753889719645


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


0.30791930109262466
Round 31 >> Loss: 1.734743890989483, BLEU:0.10263976703087489


100%|██████████| 3/3 [00:13<00:00,  4.36s/it]


0.5093814358115196
Round 32 >> Loss: 1.7147348417984205, BLEU:0.16979381193717322


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


0.5216292887926102
Round 33 >> Loss: 1.691600655170981, BLEU:0.1738764295975367


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


0.3459993004798889
Round 34 >> Loss: 1.6707334429505571, BLEU:0.11533310015996297


100%|██████████| 3/3 [00:13<00:00,  4.39s/it]


0.377533745020628
Round 35 >> Loss: 1.6500505946643715, BLEU:0.12584458167354265


100%|██████████| 3/3 [00:13<00:00,  4.54s/it]


0.30803655833005905
Round 36 >> Loss: 1.6290949072811383, BLEU:0.10267885277668636


100%|██████████| 3/3 [00:13<00:00,  4.62s/it]


0.3901531845331192
Round 37 >> Loss: 1.6096985602691598, BLEU:0.13005106151103973


100%|██████████| 3/3 [00:13<00:00,  4.56s/it]


0.4848855659365654
Round 38 >> Loss: 1.5924246770351622, BLEU:0.16162852197885513


100%|██████████| 3/3 [00:13<00:00,  4.46s/it]


0.5312415063381195
Round 39 >> Loss: 1.57259375389718, BLEU:0.1770805021127065


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


0.38504116237163544
Round 40 >> Loss: 1.5532908760438826, BLEU:0.12834705412387848


100%|██████████| 3/3 [00:12<00:00,  4.18s/it]


0.5379738658666611
Round 41 >> Loss: 1.5338528810759404, BLEU:0.17932462195555368


100%|██████████| 3/3 [00:13<00:00,  4.46s/it]


0.5713023319840431
Round 42 >> Loss: 1.5148809318272842, BLEU:0.19043411066134772


100%|██████████| 3/3 [00:12<00:00,  4.22s/it]


0.4362957924604416
Round 43 >> Loss: 1.4957656315252832, BLEU:0.1454319308201472


100%|██████████| 3/3 [00:13<00:00,  4.45s/it]


0.5159809365868568
Round 44 >> Loss: 1.4806146865483016, BLEU:0.17199364552895227


100%|██████████| 3/3 [00:13<00:00,  4.39s/it]


0.5329416394233704
Round 45 >> Loss: 1.4629869481153273, BLEU:0.17764721314112344


100%|██████████| 3/3 [00:13<00:00,  4.34s/it]


0.4906551390886307
Round 46 >> Loss: 1.4464421255424613, BLEU:0.16355171302954355


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


0.6109640002250671
Round 47 >> Loss: 1.4292701967634354, BLEU:0.20365466674168906


100%|██████████| 3/3 [00:12<00:00,  4.33s/it]


0.6118936464190483
Round 48 >> Loss: 1.4137231797602867, BLEU:0.20396454880634943


100%|██████████| 3/3 [00:13<00:00,  4.36s/it]


0.7089030295610428
Round 49 >> Loss: 1.3980832399382106, BLEU:0.23630100985368094


100%|██████████| 3/3 [00:12<00:00,  4.27s/it]


0.6444536671042442
Round 50 >> Loss: 1.3826400434720714, BLEU:0.21481788903474808


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


0.6334114447236061
Round 51 >> Loss: 1.3667342647478435, BLEU:0.21113714824120203


100%|██████████| 3/3 [00:13<00:00,  4.48s/it]


0.6156483963131905
Round 52 >> Loss: 1.3527823102948828, BLEU:0.20521613210439682


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


0.48277726024389267
Round 53 >> Loss: 1.3387668390632645, BLEU:0.1609257534146309


100%|██████████| 3/3 [00:12<00:00,  4.26s/it]


0.8007004708051682
Round 54 >> Loss: 1.3237282569999458, BLEU:0.26690015693505603


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


0.6934418827295303
Round 55 >> Loss: 1.3086825323456301, BLEU:0.2311472942431768


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


0.7011527717113495
Round 56 >> Loss: 1.2952087264705174, BLEU:0.23371759057044983


100%|██████████| 3/3 [00:13<00:00,  4.42s/it]


0.712659977376461
Round 57 >> Loss: 1.2821388551435102, BLEU:0.23755332579215369


100%|██████████| 3/3 [00:12<00:00,  4.29s/it]


0.7659626007080078
Round 58 >> Loss: 1.2670466052164957, BLEU:0.25532086690266925


100%|██████████| 3/3 [00:12<00:00,  4.04s/it]


0.7880936861038208
Round 59 >> Loss: 1.254955002931095, BLEU:0.26269789536794025


100%|██████████| 3/3 [00:12<00:00,  4.16s/it]


0.6646741926670074
Round 60 >> Loss: 1.241565229756123, BLEU:0.22155806422233582


100%|██████████| 3/3 [00:12<00:00,  4.12s/it]


0.8260807245969772
Round 61 >> Loss: 1.2295223801709536, BLEU:0.27536024153232574


100%|██████████| 3/3 [00:12<00:00,  4.15s/it]


0.8681934028863907
Round 62 >> Loss: 1.219076601222369, BLEU:0.28939780096213025


100%|██████████| 3/3 [00:12<00:00,  4.15s/it]


0.9753270447254181
Round 63 >> Loss: 1.2070834626767193, BLEU:0.3251090149084727


100%|██████████| 3/3 [00:12<00:00,  4.16s/it]


0.7675100266933441
Round 64 >> Loss: 1.1954530398088647, BLEU:0.25583667556444806


100%|██████████| 3/3 [00:12<00:00,  4.23s/it]


0.904606282711029
Round 65 >> Loss: 1.1835973310232781, BLEU:0.301535427570343


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


0.95340196788311
Round 66 >> Loss: 1.1734435682651554, BLEU:0.3178006559610367


100%|██████████| 3/3 [00:12<00:00,  4.21s/it]


0.8086266815662384
Round 67 >> Loss: 1.1633121816417389, BLEU:0.26954222718874615


100%|██████████| 3/3 [00:12<00:00,  4.16s/it]


0.8749358803033829
Round 68 >> Loss: 1.1510734738061708, BLEU:0.29164529343446094


100%|██████████| 3/3 [00:12<00:00,  4.14s/it]


0.7238003313541412
Round 69 >> Loss: 1.1410479416100785, BLEU:0.2412667771180471


100%|██████████| 3/3 [00:12<00:00,  4.17s/it]


0.927946463227272
Round 70 >> Loss: 1.1299842336493826, BLEU:0.309315487742424


100%|██████████| 3/3 [00:12<00:00,  4.32s/it]


0.9438005685806274
Round 71 >> Loss: 1.1207575586353347, BLEU:0.3146001895268758


100%|██████████| 3/3 [00:13<00:00,  4.42s/it]


0.8535228818655014
Round 72 >> Loss: 1.1112100364662008, BLEU:0.2845076272885005


100%|██████████| 3/3 [00:13<00:00,  4.36s/it]


0.8897435367107391
Round 73 >> Loss: 1.10317496413466, BLEU:0.2965811789035797


100%|██████████| 3/3 [00:13<00:00,  4.43s/it]


1.0064494609832764
Round 74 >> Loss: 1.0944418414588106, BLEU:0.3354831536610921


100%|██████████| 3/3 [00:12<00:00,  4.32s/it]


0.8934302181005478
Round 75 >> Loss: 1.0839960314245858, BLEU:0.2978100727001826


100%|██████████| 3/3 [00:13<00:00,  4.45s/it]


1.0001176595687866
Round 76 >> Loss: 1.075538008057287, BLEU:0.3333725531895955


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


0.926210418343544
Round 77 >> Loss: 1.0673991820969635, BLEU:0.30873680611451465


100%|██████████| 3/3 [00:12<00:00,  4.31s/it]


0.8523820340633392
Round 78 >> Loss: 1.0559571399876273, BLEU:0.2841273446877797


100%|██████████| 3/3 [00:13<00:00,  4.45s/it]


0.8777753859758377
Round 79 >> Loss: 1.0452314130662892, BLEU:0.29259179532527924


100%|██████████| 3/3 [00:12<00:00,  4.17s/it]


0.9708753526210785
Round 80 >> Loss: 1.0395077367784495, BLEU:0.3236251175403595


100%|██████████| 3/3 [00:13<00:00,  4.42s/it]


0.9331072270870209
Round 81 >> Loss: 1.0319207331239975, BLEU:0.3110357423623403


100%|██████████| 3/3 [00:12<00:00,  4.29s/it]


0.7873105853796005
Round 82 >> Loss: 1.0239402075726203, BLEU:0.2624368617932002


100%|██████████| 3/3 [00:12<00:00,  4.23s/it]


0.8785918205976486
Round 83 >> Loss: 1.0148053477662289, BLEU:0.2928639401992162


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


0.8767098635435104
Round 84 >> Loss: 1.0060940027775522, BLEU:0.29223662118117016


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


0.8757958561182022
Round 85 >> Loss: 0.9986357040896804, BLEU:0.29193195203940075


100%|██████████| 3/3 [00:12<00:00,  4.29s/it]


0.8765198737382889
Round 86 >> Loss: 0.9908841286126028, BLEU:0.2921732912460963


100%|██████████| 3/3 [00:12<00:00,  4.23s/it]


1.125479370355606
Round 87 >> Loss: 0.9825982172616565, BLEU:0.37515979011853534


100%|██████████| 3/3 [00:13<00:00,  4.39s/it]


1.0213446766138077
Round 88 >> Loss: 0.9759943522469898, BLEU:0.3404482255379359


100%|██████████| 3/3 [00:12<00:00,  4.22s/it]


0.9773608446121216
Round 89 >> Loss: 0.968526058664191, BLEU:0.3257869482040405


100%|██████████| 3/3 [00:12<00:00,  4.18s/it]


0.9155798554420471
Round 90 >> Loss: 0.9626471682106298, BLEU:0.30519328514734906


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


1.083462506532669
Round 91 >> Loss: 0.9550973850864467, BLEU:0.361154168844223


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


1.0584691762924194
Round 92 >> Loss: 0.947424490011926, BLEU:0.35282305876413983


100%|██████████| 3/3 [00:12<00:00,  4.33s/it]


1.1685327291488647
Round 93 >> Loss: 0.9404159206439541, BLEU:0.38951090971628827


100%|██████████| 3/3 [00:12<00:00,  4.18s/it]


1.07528617978096
Round 94 >> Loss: 0.9322246683732448, BLEU:0.3584287265936534


100%|██████████| 3/3 [00:12<00:00,  4.22s/it]


1.003876805305481
Round 95 >> Loss: 0.9262455777295012, BLEU:0.33462560176849365


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


1.0084208995103836
Round 96 >> Loss: 0.9201188719370063, BLEU:0.33614029983679455


100%|██████████| 3/3 [00:12<00:00,  4.33s/it]


1.0948201417922974
Round 97 >> Loss: 0.9132997296583554, BLEU:0.3649400472640991


100%|██████████| 3/3 [00:12<00:00,  4.18s/it]


1.0317341089248657
Round 98 >> Loss: 0.9071559557710742, BLEU:0.3439113696416219


100%|██████████| 3/3 [00:12<00:00,  4.26s/it]


1.2490203082561493
Round 99 >> Loss: 0.901762547741305, BLEU:0.41634010275204975


100%|██████████| 3/3 [00:12<00:00,  4.20s/it]


1.1026754081249237
Round 100 >> Loss: 0.8945115678527921, BLEU:0.36755846937497455


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


1.2047436833381653
Round 101 >> Loss: 0.8877584712732748, BLEU:0.4015812277793884


100%|██████████| 3/3 [00:12<00:00,  4.19s/it]


1.1002349257469177
Round 102 >> Loss: 0.882561145164602, BLEU:0.3667449752489726


100%|██████████| 3/3 [00:12<00:00,  4.19s/it]


1.0316630601882935
Round 103 >> Loss: 0.8757003933905242, BLEU:0.34388768672943115


100%|██████████| 3/3 [00:12<00:00,  4.26s/it]


0.9904205054044724
Round 104 >> Loss: 0.8693890904902792, BLEU:0.33014016846815747


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


1.0652219355106354
Round 105 >> Loss: 0.8651183964430166, BLEU:0.3550739785035451


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


1.1620137095451355
Round 106 >> Loss: 0.860335876819922, BLEU:0.38733790318171185


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


1.2507379055023193
Round 107 >> Loss: 0.8541258230126932, BLEU:0.41691263516743976


100%|██████████| 3/3 [00:12<00:00,  4.30s/it]


0.9456072598695755
Round 108 >> Loss: 0.8469979126311961, BLEU:0.31520241995652515


100%|██████████| 3/3 [00:13<00:00,  4.35s/it]


1.1341761648654938
Round 109 >> Loss: 0.8423212249978446, BLEU:0.37805872162183124


100%|██████████| 3/3 [00:12<00:00,  4.18s/it]


1.100398600101471
Round 110 >> Loss: 0.8368100759234219, BLEU:0.366799533367157


100%|██████████| 3/3 [00:12<00:00,  4.32s/it]


1.0760428458452225
Round 111 >> Loss: 0.8320127526546787, BLEU:0.35868094861507416


100%|██████████| 3/3 [00:13<00:00,  4.46s/it]


1.1169656217098236
Round 112 >> Loss: 0.8269159098047553, BLEU:0.37232187390327454


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


1.0551834404468536
Round 113 >> Loss: 0.8228232178869712, BLEU:0.35172781348228455


100%|██████████| 3/3 [00:13<00:00,  4.57s/it]


0.9810720384120941
Round 114 >> Loss: 0.8183175112333356, BLEU:0.32702401280403137


100%|██████████| 3/3 [00:13<00:00,  4.63s/it]


1.3543799221515656
Round 115 >> Loss: 0.8133193101215141, BLEU:0.45145997405052185


100%|██████████| 3/3 [00:13<00:00,  4.53s/it]


1.2005079835653305
Round 116 >> Loss: 0.8083143560379668, BLEU:0.40016932785511017


100%|██████████| 3/3 [00:13<00:00,  4.45s/it]


1.2451832294464111
Round 117 >> Loss: 0.8030562186949582, BLEU:0.415061076482137


100%|██████████| 3/3 [00:13<00:00,  4.35s/it]


1.0321531742811203
Round 118 >> Loss: 0.7993097142051079, BLEU:0.3440510580937068


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


1.0579871237277985
Round 119 >> Loss: 0.7943022177249777, BLEU:0.3526623745759328


100%|██████████| 3/3 [00:13<00:00,  4.35s/it]


1.082994244992733
Round 120 >> Loss: 0.7894016074762494, BLEU:0.36099808166424435


100%|██████████| 3/3 [00:12<00:00,  4.28s/it]


1.1603872776031494
Round 121 >> Loss: 0.7856086954416256, BLEU:0.3867957592010498


100%|██████████| 3/3 [00:13<00:00,  4.35s/it]


1.1671167612075806
Round 122 >> Loss: 0.7813477551743192, BLEU:0.38903892040252686


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


1.212837666273117
Round 123 >> Loss: 0.7786332163751823, BLEU:0.404279222091039


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


1.0194089710712433
Round 124 >> Loss: 0.773261352989446, BLEU:0.3398029903570811


100%|██████████| 3/3 [00:13<00:00,  4.47s/it]


1.1222720742225647
Round 125 >> Loss: 0.7692277835620551, BLEU:0.37409069140752155


100%|██████████| 3/3 [00:12<00:00,  4.29s/it]


1.2325644195079803
Round 126 >> Loss: 0.764007422051229, BLEU:0.4108548065026601


100%|██████████| 3/3 [00:13<00:00,  4.36s/it]


1.1468941569328308
Round 127 >> Loss: 0.7610271925110869, BLEU:0.3822980523109436


100%|██████████| 3/3 [00:12<00:00,  4.26s/it]


1.1459170877933502
Round 128 >> Loss: 0.7563849160805495, BLEU:0.3819723625977834


100%|██████████| 3/3 [00:12<00:00,  4.23s/it]


1.11928191781044
Round 129 >> Loss: 0.7542403246873911, BLEU:0.37309397260348004


100%|██████████| 3/3 [00:12<00:00,  4.30s/it]


1.2714652568101883
Round 130 >> Loss: 0.7494052728403852, BLEU:0.42382175227006275


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


1.2617101073265076
Round 131 >> Loss: 0.7483155805899605, BLEU:0.4205700357755025


100%|██████████| 3/3 [00:13<00:00,  4.36s/it]


1.1704885959625244
Round 132 >> Loss: 0.7434315904718091, BLEU:0.3901628653208415


100%|██████████| 3/3 [00:12<00:00,  4.24s/it]


1.2660982608795166
Round 133 >> Loss: 0.7402643539246213, BLEU:0.42203275362650555


100%|██████████| 3/3 [00:12<00:00,  4.19s/it]


1.0935034155845642
Round 134 >> Loss: 0.7352465977513635, BLEU:0.36450113852818805


100%|██████████| 3/3 [00:12<00:00,  4.20s/it]


1.204692155122757
Round 135 >> Loss: 0.7298085735277686, BLEU:0.40156405170758563


100%|██████████| 3/3 [00:12<00:00,  4.20s/it]


1.3010224103927612
Round 136 >> Loss: 0.7270732128242235, BLEU:0.4336741367975871


100%|██████████| 3/3 [00:12<00:00,  4.16s/it]


1.20051908493042
Round 137 >> Loss: 0.7242178373457596, BLEU:0.40017302831013996


100%|██████████| 3/3 [00:12<00:00,  4.15s/it]


1.2273023426532745
Round 138 >> Loss: 0.7214916090191279, BLEU:0.40910078088442486


100%|██████████| 3/3 [00:12<00:00,  4.22s/it]


1.1295656114816666
Round 139 >> Loss: 0.7172215301562096, BLEU:0.37652187049388885


100%|██████████| 3/3 [00:12<00:00,  4.15s/it]


1.1945732235908508
Round 140 >> Loss: 0.7165307669186053, BLEU:0.39819107453028363


100%|██████████| 3/3 [00:12<00:00,  4.18s/it]


1.290600061416626
Round 141 >> Loss: 0.7137845679299543, BLEU:0.4302000204722087


100%|██████████| 3/3 [00:12<00:00,  4.23s/it]


1.195813536643982
Round 142 >> Loss: 0.7108970705431655, BLEU:0.39860451221466064


100%|██████████| 3/3 [00:12<00:00,  4.18s/it]


1.2502388060092926
Round 143 >> Loss: 0.7075980588413584, BLEU:0.4167462686697642


100%|██████████| 3/3 [00:13<00:00,  4.39s/it]


1.2238361835479736
Round 144 >> Loss: 0.7060982341064891, BLEU:0.4079453945159912


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


1.2344114482402802
Round 145 >> Loss: 0.7032295404728934, BLEU:0.41147048274676007


100%|██████████| 3/3 [00:12<00:00,  4.32s/it]


1.0368470251560211
Round 146 >> Loss: 0.7012296702502262, BLEU:0.345615675052007


100%|██████████| 3/3 [00:12<00:00,  4.26s/it]


1.526424080133438
Round 147 >> Loss: 0.6975645693281419, BLEU:0.508808026711146


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


1.1605948954820633
Round 148 >> Loss: 0.6952745255910102, BLEU:0.38686496516068775


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


1.2973019480705261
Round 149 >> Loss: 0.6935383513745738, BLEU:0.43243398269017536


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


1.3192798793315887
Round 150 >> Loss: 0.6901626490379158, BLEU:0.43975995977719623


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


1.2054048776626587
Round 151 >> Loss: 0.6869521813532896, BLEU:0.4018016258875529


100%|██████████| 3/3 [00:12<00:00,  4.30s/it]


1.2589316368103027
Round 152 >> Loss: 0.6839669688544495, BLEU:0.4196438789367676


100%|██████████| 3/3 [00:13<00:00,  4.46s/it]


1.338636338710785
Round 153 >> Loss: 0.6803454437836479, BLEU:0.44621211290359497


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


1.3624325096607208
Round 154 >> Loss: 0.6777802619746608, BLEU:0.4541441698869069


100%|██████████| 3/3 [00:12<00:00,  4.30s/it]


1.1579988598823547
Round 155 >> Loss: 0.6765711575618023, BLEU:0.3859996199607849


100%|██████████| 3/3 [00:13<00:00,  4.35s/it]


1.1749482452869415
Round 156 >> Loss: 0.672878921290132, BLEU:0.39164941509564716


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


1.2298601865768433
Round 157 >> Loss: 0.6716540657541126, BLEU:0.40995339552561444


100%|██████████| 3/3 [00:13<00:00,  4.34s/it]


1.2141910195350647
Round 158 >> Loss: 0.6692953151985234, BLEU:0.40473033984502155


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


1.400650292634964
Round 159 >> Loss: 0.6671288272154016, BLEU:0.46688343087832135


100%|██████████| 3/3 [00:12<00:00,  4.24s/it]


1.2605484127998352
Round 160 >> Loss: 0.6645289627253114, BLEU:0.42018280426661175


100%|██████████| 3/3 [00:12<00:00,  4.27s/it]


1.3193577826023102
Round 161 >> Loss: 0.6616895967289392, BLEU:0.4397859275341034


100%|██████████| 3/3 [00:12<00:00,  4.33s/it]


1.1667108535766602
Round 162 >> Loss: 0.6597185119117549, BLEU:0.3889036178588867


100%|██████████| 3/3 [00:13<00:00,  4.35s/it]


1.2500471770763397
Round 163 >> Loss: 0.6568582954712777, BLEU:0.4166823923587799


100%|██████████| 3/3 [00:13<00:00,  4.36s/it]


0.9874024242162704
Round 164 >> Loss: 0.6559659492115085, BLEU:0.32913414140542346


100%|██████████| 3/3 [00:13<00:00,  4.40s/it]


1.217838853597641
Round 165 >> Loss: 0.654414980125401, BLEU:0.405946284532547


100%|██████████| 3/3 [00:13<00:00,  4.35s/it]


1.3691973388195038
Round 166 >> Loss: 0.6516718499740678, BLEU:0.4563991129398346


100%|██████████| 3/3 [00:13<00:00,  4.39s/it]


1.1252538561820984
Round 167 >> Loss: 0.649032362783766, BLEU:0.37508461872736615


100%|██████████| 3/3 [00:12<00:00,  4.21s/it]


1.2736103236675262
Round 168 >> Loss: 0.648096186936361, BLEU:0.4245367745558421


100%|██████████| 3/3 [00:12<00:00,  4.24s/it]


1.3757652342319489
Round 169 >> Loss: 0.6455655994455767, BLEU:0.4585884114106496


100%|██████████| 3/3 [00:12<00:00,  4.16s/it]


1.1664002239704132
Round 170 >> Loss: 0.6445275175092363, BLEU:0.3888000746568044


100%|██████████| 3/3 [00:12<00:00,  4.15s/it]


1.42849463224411
Round 171 >> Loss: 0.641874601839132, BLEU:0.47616487741470337


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


1.2861995697021484
Round 172 >> Loss: 0.6399113471881809, BLEU:0.4287331899007161


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


1.438774436712265
Round 173 >> Loss: 0.6373423834939561, BLEU:0.4795914789040883


100%|██████████| 3/3 [00:13<00:00,  4.39s/it]


1.2690860033035278
Round 174 >> Loss: 0.6361373026677571, BLEU:0.4230286677678426


100%|██████████| 3/3 [00:12<00:00,  4.28s/it]


1.4027728736400604
Round 175 >> Loss: 0.633976538847899, BLEU:0.46759095788002014


100%|██████████| 3/3 [00:13<00:00,  4.46s/it]


1.178048849105835
Round 176 >> Loss: 0.6344256750140168, BLEU:0.392682949701945


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


1.4190013408660889
Round 177 >> Loss: 0.633228899979886, BLEU:0.473000446955363


100%|██████████| 3/3 [00:13<00:00,  4.36s/it]


1.3001258671283722
Round 178 >> Loss: 0.6303598006702992, BLEU:0.4333752890427907


100%|██████████| 3/3 [00:13<00:00,  4.37s/it]


1.3837292194366455
Round 179 >> Loss: 0.6276368805075242, BLEU:0.4612430731455485


100%|██████████| 3/3 [00:12<00:00,  4.33s/it]


1.309497982263565
Round 180 >> Loss: 0.6270415914020538, BLEU:0.43649932742118835


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


1.4569251239299774
Round 181 >> Loss: 0.6257753792329313, BLEU:0.4856417079766591


100%|██████████| 3/3 [00:13<00:00,  4.35s/it]


1.4838815033435822
Round 182 >> Loss: 0.623297419861938, BLEU:0.49462716778119403


100%|██████████| 3/3 [00:12<00:00,  4.27s/it]


1.1707751154899597
Round 183 >> Loss: 0.6221548981291098, BLEU:0.3902583718299866


100%|██████████| 3/3 [00:12<00:00,  4.12s/it]


1.362702637910843
Round 184 >> Loss: 0.6219181088869291, BLEU:0.45423421263694763


100%|██████████| 3/3 [00:12<00:00,  4.18s/it]


1.3596770465373993
Round 185 >> Loss: 0.6196373785410444, BLEU:0.4532256821791331


100%|██████████| 3/3 [00:12<00:00,  4.17s/it]


1.4458313286304474
Round 186 >> Loss: 0.6189681550760463, BLEU:0.4819437762101491


100%|██████████| 3/3 [00:12<00:00,  4.16s/it]


1.2977842688560486
Round 187 >> Loss: 0.6168749837615496, BLEU:0.43259475628534955


100%|██████████| 3/3 [00:12<00:00,  4.14s/it]


1.340678483247757
Round 188 >> Loss: 0.6139197457643669, BLEU:0.4468928277492523


100%|██████████| 3/3 [00:12<00:00,  4.21s/it]


1.444653183221817
Round 189 >> Loss: 0.6125841309590961, BLEU:0.481551061073939


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


1.4744000732898712
Round 190 >> Loss: 0.6103820935984388, BLEU:0.4914666910966237


100%|██████████| 3/3 [00:13<00:00,  4.39s/it]


1.2680860757827759
Round 191 >> Loss: 0.6094246746805156, BLEU:0.4226953585942586


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


1.4182467460632324
Round 192 >> Loss: 0.6074136043826951, BLEU:0.4727489153544108


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


1.5457370579242706
Round 193 >> Loss: 0.6067988354181706, BLEU:0.5152456859747568


100%|██████████| 3/3 [00:12<00:00,  4.31s/it]


1.5117853283882141
Round 194 >> Loss: 0.6038495909424655, BLEU:0.5039284427960714


100%|██████████| 3/3 [00:12<00:00,  4.28s/it]


1.259720265865326
Round 195 >> Loss: 0.6029963767012381, BLEU:0.41990675528844196


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


1.4224842488765717
Round 196 >> Loss: 0.6013502069120974, BLEU:0.47416141629219055


100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


1.3686367571353912
Round 197 >> Loss: 0.5993404489277193, BLEU:0.45621225237846375


100%|██████████| 3/3 [00:13<00:00,  4.38s/it]


1.2489708364009857
Round 198 >> Loss: 0.5982805302846973, BLEU:0.4163236121336619


100%|██████████| 3/3 [00:12<00:00,  4.22s/it]


1.406801849603653
Round 199 >> Loss: 0.5970047263027658, BLEU:0.46893394986788434


100%|██████████| 3/3 [00:12<00:00,  4.19s/it]


1.331119954586029
Round 200 >> Loss: 0.5949186246255557, BLEU:0.44370665152867633
Training Done!
Total time taken to Train: 2640.5865247249603


In [None]:
input_lang, output_lang, train_dataloader, pairs = get_dataloader(batch_size, language='kir_test')
evaluateBleu(meta_encoder_weights, meta_encoder_weights, input, output, pairs, n=30)

In [68]:
import torch
from torch.utils.data import DataLoader, Subset

def limited_data_loader(original_dataloader, num_samples, random=True):
    # Get the original dataset from the DataLoader
    dataset = original_dataloader.dataset
    
    # Ensure the original dataset is larger than the requested number of samples
    assert len(dataset) >= num_samples, "The original dataset has fewer samples than requested"
    
    # Create a Subset of the original dataset limited to the first num_samples elements
    if random:
        indices = np.random.choice(len(dataset), num_samples, replace=False)
    else:
        indices = range(num_samples)

    subset = Subset(dataset, indices)
    
    # Create a new DataLoader from this subset with the same parameters as the original DataLoader
    new_dataloader = DataLoader(subset, batch_size=original_dataloader.batch_size, shuffle=False, num_workers=original_dataloader.num_workers)
    
    return new_dataloader


In [69]:
def personalize(lang, rounds, encoder_weights=None, decoder_weights=None, sample=None, save=False):
    input_lang, output_lang, train_dataloader, pairs = get_dataloader(batch_size, language=lang)

    encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)

    is_FL = 'false'

    if encoder_weights is not None:
        e_og = encoder.state_dict()
        d_og = decoder.state_dict()
        e_og.update(encoder_weights)
        d_og.update(decoder_weights)
        encoder.load_state_dict(e_og)
        decoder.load_state_dict(d_og)
        is_FL = 'true'

    
    
    if sample is None:
        sample = 'full'
    else:
        train_dataloader = limited_data_loader(train_dataloader, sample)

    filename=None
    if save:
        num = 1
        filename = f"P|{lang}_{sample}-shot_FL-{is_FL}_epoch{rounds}||{num}"
        while os.path.isfile(filename):
            print('Name is taken...trying again...')
            num += 1
            filename = f"P|{lang}_{sample}-shot_FL-{is_FL}_epoch{rounds}||{num}"
    train(train_dataloader, encoder, decoder, rounds, print_every=5, plot_every=5, filename=filename, input=input_lang, output=output_lang, pairs=pairs)
    return encoder, decoder

In [None]:
for i in range(1):
    personalize('kir', 100, save=True, encoder_weights=meta_encoder_weights, decoder_weights=meta_decoder_weights, sample=None)

In [54]:
for i in range(3):
    personalize('cat', 100, save=True, encoder_weights=meta_encoder_weights, decoder_weights=meta_decoder_weights, sample=None)

Reading lines...
Read 1375 sentence pairs
Trimmed to 90 sentence pairs
Counting words...
Counted words:
cat 213
eng 176
0m 0s (- 0m 7s) (5 5%) 3.9873
0m 0s (- 0m 6s) (10 10%) 2.6431
0m 1s (- 0m 6s) (15 15%) 2.1026
0m 1s (- 0m 6s) (20 20%) 1.7559
0m 1s (- 0m 5s) (25 25%) 1.4590
0m 2s (- 0m 5s) (30 30%) 1.2076
0m 2s (- 0m 4s) (35 35%) 0.9888
0m 3s (- 0m 4s) (40 40%) 0.8012
0m 3s (- 0m 4s) (45 45%) 0.6408
0m 3s (- 0m 3s) (50 50%) 0.5096
0m 4s (- 0m 3s) (55 55%) 0.4079
0m 4s (- 0m 3s) (60 60%) 0.3207
0m 4s (- 0m 2s) (65 65%) 0.2568
0m 5s (- 0m 2s) (70 70%) 0.2096
0m 5s (- 0m 1s) (75 75%) 0.1744
0m 6s (- 0m 1s) (80 80%) 0.1451
0m 6s (- 0m 1s) (85 85%) 0.1231
0m 6s (- 0m 0s) (90 90%) 0.1075
0m 7s (- 0m 0s) (95 95%) 0.0930
0m 7s (- 0m 0s) (100 100%) 0.0823
Reading lines...
Read 1375 sentence pairs
Trimmed to 90 sentence pairs
Counting words...
Counted words:
cat 213
eng 176
Name is taken...trying again...
0m 0s (- 0m 7s) (5 5%) 3.9118
0m 0s (- 0m 6s) (10 10%) 2.6210
0m 1s (- 0m 6s) (15 15%) 2