In [1]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline

NLP From Scratch: Translation with a Sequence to Sequence Network and Attention
===============================================================================

**Author**: [Sean Robertson](https://github.com/spro)

This is the third and final tutorial on doing \"NLP From Scratch\",
where we write our own classes and functions to preprocess the data to
do our NLP modeling tasks. We hope after you complete this tutorial that
you\'ll proceed to learn how [torchtext]{.title-ref} can handle much of
this preprocessing for you in the three tutorials immediately following
this one.

In this project we will be teaching a neural network to translate from
French to English.

``` {.sourceCode .sh}
[KEY: > input, = target, < output]

> il est en train de peindre un tableau .
= he is painting a picture .
< he is painting a picture .

> pourquoi ne pas essayer ce vin delicieux ?
= why not try that delicious wine ?
< why not try that delicious wine ?

> elle n est pas poete mais romanciere .
= she is not a poet but a novelist .
< she not not a poet but a novelist .

> vous etes trop maigre .
= you re too skinny .
< you re all alone .
```

\... to varying degrees of success.

This is made possible by the simple but powerful idea of the [sequence
to sequence network](https://arxiv.org/abs/1409.3215), in which two
recurrent neural networks work together to transform one sequence to
another. An encoder network condenses an input sequence into a vector,
and a decoder network unfolds that vector into a new sequence.

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/seq2seq.png)

To improve upon this model we\'ll use an [attention
mechanism](https://arxiv.org/abs/1409.0473), which lets the decoder
learn to focus over a specific range of the input sequence.

**Recommended Reading:**

I assume you have at least installed PyTorch, know Python, and
understand Tensors:

-   <https://pytorch.org/> For installation instructions
-   `/beginner/deep_learning_60min_blitz`{.interpreted-text role="doc"}
    to get started with PyTorch in general
-   `/beginner/pytorch_with_examples`{.interpreted-text role="doc"} for
    a wide and deep overview
-   `/beginner/former_torchies_tutorial`{.interpreted-text role="doc"}
    if you are former Lua Torch user

It would also be useful to know about Sequence to Sequence networks and
how they work:

-   [Learning Phrase Representations using RNN Encoder-Decoder for
    Statistical Machine Translation](https://arxiv.org/abs/1406.1078)
-   [Sequence to Sequence Learning with Neural
    Networks](https://arxiv.org/abs/1409.3215)
-   [Neural Machine Translation by Jointly Learning to Align and
    Translate](https://arxiv.org/abs/1409.0473)
-   [A Neural Conversational Model](https://arxiv.org/abs/1506.05869)

You will also find the previous tutorials on
`/intermediate/char_rnn_classification_tutorial`{.interpreted-text
role="doc"} and
`/intermediate/char_rnn_generation_tutorial`{.interpreted-text
role="doc"} helpful as those concepts are very similar to the Encoder
and Decoder models, respectively.

**Requirements**


In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


Loading data files
==================

The data for this project is a set of many thousands of English to
French translation pairs.

[This question on Open Data Stack
Exchange](https://opendata.stackexchange.com/questions/3888/dataset-of-sentences-translated-into-many-languages)
pointed me to the open translation site <https://tatoeba.org/> which has
downloads available at <https://tatoeba.org/eng/downloads> - and better
yet, someone did the extra work of splitting language pairs into
individual text files here: <https://www.manythings.org/anki/>

The English to French pairs are too big to include in the repository, so
download to `data/eng-fra.txt` before continuing. The file is a tab
separated list of translation pairs:

``` {.sourceCode .sh}
I am cold.    J'ai froid.
```

<div style="background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px"><strong>NOTE:</strong></div>
<div style="background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px">
<p>Download the data from<a href="https://download.pytorch.org/tutorial/data.zip">here</a>and extract it to the current directory.</p>
</div>


Similar to the character encoding used in the character-level RNN
tutorials, we will be representing each word in a language as a one-hot
vector, or giant vector of zeros except for a single one (at the index
of the word). Compared to the dozens of characters that might exist in a
language, there are many many more words, so the encoding vector is much
larger. We will however cheat a bit and trim the data to only use a few
thousand words per language.

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/word-encoding.png)


We\'ll need a unique index per word to use as the inputs and targets of
the networks later. To keep track of all this we will use a helper class
called `Lang` which has word → index (`word2index`) and index → word
(`index2word`) dictionaries, as well as a count of each word
`word2count` which will be used to replace rare words later.


In [2]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

The files are all in Unicode, to simplify we will turn Unicode
characters to ASCII, make everything lowercase, and trim most
punctuation.


In [3]:
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z!?]+", r" ", s)
    return s.strip()

To read the data file we will split the file into lines, and then split
lines into pairs. The files are all English → Other Language, so if we
want to translate from Other Language → English I added the `reverse`
flag to reverse the pairs.


In [4]:
def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
        read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

Since there are a *lot* of example sentences and we want to train
something quickly, we\'ll trim the data set to only relatively short and
simple sentences. Here the maximum length is 10 words (that includes
ending punctuation) and we\'re filtering to sentences that translate to
the form \"I am\" or \"He is\" etc. (accounting for apostrophes replaced
earlier).


In [5]:
MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

The full process for preparing the data is:

-   Read text file and split into lines, split lines into pairs
-   Normalize text, filter by length and content
-   Make word lists from sentences in pairs


In [6]:
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

input_lang, output_lang, pairs = prepareData('eng', 'spa', True)
print(random.choice(pairs))

Reading lines...
Read 118121 sentence pairs
Trimmed to 7865 sentence pairs
Counting words...
Counted words:
spa 4027
eng 2809
['eres libre', 'you re free']


The Seq2Seq Model
=================

A Recurrent Neural Network, or RNN, is a network that operates on a
sequence and uses its own output as input for subsequent steps.

A [Sequence to Sequence network](https://arxiv.org/abs/1409.3215), or
seq2seq network, or [Encoder Decoder
network](https://arxiv.org/pdf/1406.1078v3.pdf), is a model consisting
of two RNNs called the encoder and decoder. The encoder reads an input
sequence and outputs a single vector, and the decoder reads that vector
to produce an output sequence.

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/seq2seq.png)

Unlike sequence prediction with a single RNN, where every input
corresponds to an output, the seq2seq model frees us from sequence
length and order, which makes it ideal for translation between two
languages.

Consider the sentence `Je ne suis pas le chat noir` →
`I am not the black cat`. Most of the words in the input sentence have a
direct translation in the output sentence, but are in slightly different
orders, e.g. `chat noir` and `black cat`. Because of the `ne/pas`
construction there is also one more word in the input sentence. It would
be difficult to produce a correct translation directly from the sequence
of input words.

With a seq2seq model the encoder creates a single vector which, in the
ideal case, encodes the \"meaning\" of the input sequence into a single
vector --- a single point in some N dimensional space of sentences.


The Encoder
===========

The encoder of a seq2seq network is a RNN that outputs some value for
every word from the input sentence. For every input word the encoder
outputs a vector and a hidden state, and uses the hidden state for the
next input word.

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/encoder-network.png)


In [7]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden

The Decoder
===========

The decoder is another RNN that takes the encoder output vector(s) and
outputs a sequence of words to create the translation.


Simple Decoder
==============

In the simplest seq2seq decoder we use only last output of the encoder.
This last output is sometimes called the *context vector* as it encodes
context from the entire sequence. This context vector is used as the
initial hidden state of the decoder.

At every step of decoding, the decoder is given an input token and
hidden state. The initial input token is the start-of-string `<SOS>`
token, and the first hidden state is the context vector (the encoder\'s
last hidden state).

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/decoder-network.png)


In [8]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden  = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)

            if target_tensor is not None:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop

    def forward_step(self, input, hidden):
        output = self.embedding(input)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.out(output)
        return output, hidden

I encourage you to train and observe the results of this model, but to
save space we\'ll be going straight for the gold and introducing the
Attention Mechanism.


Attention Decoder
=================

If only the context vector is passed between the encoder and decoder,
that single vector carries the burden of encoding the entire sentence.

Attention allows the decoder network to \"focus\" on a different part of
the encoder\'s outputs for every step of the decoder\'s own outputs.
First we calculate a set of *attention weights*. These will be
multiplied by the encoder output vectors to create a weighted
combination. The result (called `attn_applied` in the code) should
contain information about that specific part of the input sequence, and
thus help the decoder choose the right output words.

![](https://i.imgur.com/1152PYf.png)

Calculating the attention weights is done with another feed-forward
layer `attn`, using the decoder\'s input and hidden state as inputs.
Because there are sentences of all sizes in the training data, to
actually create and train this layer we have to choose a maximum
sentence length (input length, for encoder outputs) that it can apply
to. Sentences of the maximum length will use all the attention weights,
while shorter sentences will only use the first few.

![](https://pytorch.org/tutorials/_static/img/seq-seq-images/attention-decoder-network.png)

Bahdanau attention, also known as additive attention, is a commonly used
attention mechanism in sequence-to-sequence models, particularly in
neural machine translation tasks. It was introduced by Bahdanau et al.
in their paper titled [Neural Machine Translation by Jointly Learning to
Align and Translate](https://arxiv.org/pdf/1409.0473.pdf). This
attention mechanism employs a learned alignment model to compute
attention scores between the encoder and decoder hidden states. It
utilizes a feed-forward neural network to calculate alignment scores.

However, there are alternative attention mechanisms available, such as
Luong attention, which computes attention scores by taking the dot
product between the decoder hidden state and the encoder hidden states.
It does not involve the non-linear transformation used in Bahdanau
attention.

In this tutorial, we will be using Bahdanau attention. However, it would
be a valuable exercise to explore modifying the attention mechanism to
use Luong attention.


In [9]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)

        return context, weights

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attention = BahdanauAttention(hidden_size)
        self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions


    def forward_step(self, input, hidden, encoder_outputs):
        embedded =  self.dropout(self.embedding(input))

        query = hidden.permute(1, 0, 2)
        context, attn_weights = self.attention(query, encoder_outputs)
        input_gru = torch.cat((embedded, context), dim=2)

        output, hidden = self.gru(input_gru, hidden)
        output = self.out(output)

        return output, hidden, attn_weights

<div style="background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px"><strong>NOTE:</strong></div>
<div style="background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px">
<p>There are other forms of attention that work around the lengthlimitation by using a relative position approach. Read about "localattention" in <a href="https://arxiv.org/abs/1508.04025">Effective Approaches to Attention-based Neural MachineTranslation</a>.</p>
</div>

Training
========

Preparing Training Data
-----------------------

To train, for each pair we will need an input tensor (indexes of the
words in the input sentence) and target tensor (indexes of the words in
the target sentence). While creating these vectors we will append the
EOS token to both sequences.


In [10]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

def get_dataloader(batch_size, language='spa'):
    input_lang, output_lang, pairs = prepareData('eng', language, True)

    n = len(pairs)
    input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
    target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

    for idx, (inp, tgt) in enumerate(pairs):
        inp_ids = indexesFromSentence(input_lang, inp)
        tgt_ids = indexesFromSentence(output_lang, tgt)
        inp_ids.append(EOS_token)
        tgt_ids.append(EOS_token)
        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tgt_ids)] = tgt_ids

    train_data = TensorDataset(torch.LongTensor(input_ids).to(device),
                               torch.LongTensor(target_ids).to(device))

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
    return input_lang, output_lang, train_dataloader, pairs

Training the Model
==================

To train we run the input sentence through the encoder, and keep track
of every output and the latest hidden state. Then the decoder is given
the `<SOS>` token as its first input, and the last hidden state of the
encoder as its first hidden state.

\"Teacher forcing\" is the concept of using the real target outputs as
each next input, instead of using the decoder\'s guess as the next
input. Using teacher forcing causes it to converge faster but [when the
trained network is exploited, it may exhibit
instability](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.378.4095&rep=rep1&type=pdf).

You can observe outputs of teacher-forced networks that read with
coherent grammar but wander far from the correct translation
-intuitively it has learned to represent the output grammar and can
\"pick up\" the meaning once the teacher tells it the first few words,
but it has not properly learned how to create the sentence from the
translation in the first place.

Because of the freedom PyTorch\'s autograd gives us, we can randomly
choose to use teacher forcing or not with a simple if statement. Turn
`teacher_forcing_ratio` up to use more of it.


In [11]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer,
          decoder_optimizer, criterion):

    total_loss = 0
    for data in dataloader:
        input_tensor, target_tensor = data

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)

        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

This is a helper function to print time elapsed and estimated time
remaining given the current time and progress %.


In [12]:
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [13]:
class ClientUpdate(object):
    def __init__(self, train_dataloader, learning_rate, epochs, sch_flag):
        self.train_loader = train_dataloader
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.sch_flag = sch_flag

    def train(self, encoder, decoder):

        criterion = nn.CrossEntropyLoss()
        # optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.95, weight_decay = 5e-4)
        encoder_optimizer = optim.Adam(encoder.parameters(), lr=self.learning_rate)
        decoder_optimizer = optim.Adam(decoder.parameters(), lr=self.learning_rate)
        # if self.sch_flag == True:
        #    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5)
        # my_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
        e_loss = []
        for epoch in range(1, self.epochs + 1):
            loss = train_epoch(self.train_loader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
            e_loss.append(loss)

            # print_loss_total += loss
            # plot_loss_total += loss

            # if epoch % print_every == 0:
            #     print_loss_avg = print_loss_total / print_every
            #     print_loss_total = 0
            #     print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
            #                                 epoch, epoch / n_epochs * 100, print_loss_avg))

            # if epoch % plot_every == 0:
            #     plot_loss_avg = plot_loss_total / plot_every
            #     plot_losses.append(plot_loss_avg)
            #     plot_loss_total = 0

            total_loss = sum(e_loss) / len(e_loss)

        return encoder.state_dict(), decoder.state_dict(), total_loss

The whole training process looks like this:

-   Start a timer
-   Initialize optimizers and criterion
-   Create set of training pairs
-   Start empty losses array for plotting

Then we call `train` many times and occasionally print the progress (%
of examples, time so far, estimated time) and average loss.


In [14]:
from torchtext.data.metrics import bleu_score
def evaluateBleu(encoder, decoder, input_lang, output_lang, pairs, n=10, verbose=False):
    references = []
    candidates = []

    for _ in range(n):
        pair = random.choice(pairs)
        output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)
        output_sentence = ' '.join(output_words).split(' ')
        if verbose:
            print('>', pair[0])
            print('=', pair[1])
            print('<', ' '.join(output_sentence))
            print('')

        # Store the reference and candidate sentences for BLEU calculation
        references.append([pair[1].split(' ')])
        candidates.append(output_sentence)

    # Calculate the BLEU score
    score = bleu_score(candidates, references)
    return score

In [15]:
import csv
def train(train_dataloader, encoder, decoder, n_epochs, input, output, pairs, filename=None, learning_rate=0.001,
               print_every=100, plot_every=100):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every
    best_bleu = 0

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        bleu = evaluateBleu(encoder, decoder, input, output, pairs, n=20)

        if best_bleu < bleu:
            best_bleu = bleu

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0


        if filename is not None:
            with open(filename, 'a') as f:
                # create the csv writer
                writer = csv.writer(f)

                # write a row to the csv file
                writer.writerow([epoch, loss, bleu, best_bleu])

    showPlot(plot_losses)

Plotting results
================

Plotting is done with matplotlib, using the array of loss values
`plot_losses` saved while training.


In [16]:
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

Evaluation
==========

Evaluation is mostly the same as training, but there are no targets so
we simply feed the decoder\'s predictions back to itself for each step.
Every time it predicts a word we add it to the output string, and if it
predicts the EOS token we stop there. We also store the decoder\'s
attention outputs for display later.


In [17]:
def evaluate(encoder, decoder, sentence, input_lang, output_lang):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

        _, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        decoded_words = []
        for idx in decoded_ids:
            if idx.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            decoded_words.append(output_lang.index2word[idx.item()])
    return decoded_words, decoder_attn

We can evaluate random sentences from the training set and print out the
input, target, and output to make some subjective quality judgements:


In [18]:
def evaluateMultipleBleu(encoders, decoders, input_output_langs, n=10):
    bleu_sum = 0
    lang_length = len(encoders)
    for i in range(lang_length):
        bleu_sum += evaluateBleu(encoders[i], decoders[i], input_output_langs[i][0], input_output_langs[i][1], input_output_langs[i][2], n=n)
    print(bleu_sum)
    return bleu_sum / lang_length
        

## Federated Learning

In [19]:
import copy
from tqdm import tqdm
import csv

def training(encoders, decoders, input_output_lang, rounds, lr, ds, C, K, E, filename=None, batch_size=None, hidden_size=None, cifar_data_test = None,
             test_batch_size = None, classes_test = None, sch_flag = None):
    """
    Function implements the Federated Averaging Algorithm from the FedAvg paper.
    Specifically, this function is used for the server side training and weight update

    Params:
      - model:           PyTorch model to train
      - rounds:          Number of communication rounds for the client update
      - batch_size:      Batch size for client update training
      - lr:              Learning rate used for client update training
      - ds:              Dataset used for training
      - data_dict:       Type of data partition used for training (IID or non-IID)
      - C:               Fraction of clients randomly chosen to perform computation on each round
      - K:               Total number of clients
      - E:               Number of training passes each client makes over its local dataset per round
      - tb_writer_name:  Directory name to save the tensorboard logs
    Returns:
      - model:           Trained model on the server
    """

    # global model weights
    global_encoder_weights = {key: value for key, value in encoders[0].state_dict().items() if 'embedding' not in key}
    global_decoder_weights = {key: value for key, value in decoders[0].state_dict().items() if 'embedding' not in key and 'out' not in key}

    # training loss
    train_loss = []
    test_loss = []
    test_accuracy = []
    best_bleu = 0
    # measure time
    start = time.time()

    if filename is not None:
            with open(filename, 'a') as f:
                # create the csv writer
                writer = csv.writer(f)

                # write a row to the csv file
                writer.writerow(['Rounds', 'Learning Rate', 'Client Fraction', 'Client Number', 'Local Epochs', 'Batch Size', 'Hidden Size'])
                writer.writerow([rounds, lr, C, K, E, batch_size, hidden_size])

    for curr_round in range(1, rounds + 1):
        w_encoder, w_decoder, local_loss = [], [], []
        # Retrieve the number of clients participating in the current training
        m = max(int(C * K), 1)
        # Sample a subset of K clients according with the value defined before
        S_t = np.random.choice(range(K), m, replace=False)
        # For the selected clients start a local training
        for k in tqdm(S_t):
            # Compute a local update
            local_update = ClientUpdate(train_dataloader=ds[k], learning_rate=lr, epochs=E,
                                        sch_flag=sch_flag)
            # Update means retrieve the values of the network weights
            e_og = encoders[k].state_dict()
            d_og = decoders[k].state_dict()
            e_og.update(global_encoder_weights)
            d_og.update(global_decoder_weights)
            encoders[k].load_state_dict(e_og)
            decoders[k].load_state_dict(d_og)
            
            encoder_weights, decoder_weights, loss = local_update.train(encoders[k], decoders[k])

            w_encoder.append({key: value for key, value in copy.deepcopy(encoder_weights).items() if 'embedding' not in key})
            w_decoder.append({key: value for key, value in copy.deepcopy(decoder_weights).items() if 'embedding' not in key and 'out' not in key})
            local_loss.append(copy.deepcopy(loss))
        # lr = 0.999*lr
        # updating the global weights
        weights_avg_e = copy.deepcopy(w_encoder[0])
        for k in weights_avg_e.keys():
            for i in range(1, len(w_encoder)):
                weights_avg_e[k] += w_encoder[i][k]

            weights_avg_e[k] = torch.div(weights_avg_e[k], len(w_encoder))

        global_encoder_weights = weights_avg_e

        weights_avg_d = copy.deepcopy(w_decoder[0])
        for k in weights_avg_d.keys():
            for i in range(1, len(w_decoder)):
                weights_avg_d[k] += w_decoder[i][k]

            weights_avg_d[k] = torch.div(weights_avg_d[k], len(w_decoder))

        global_decoder_weights = weights_avg_d
        

        # if curr_round == 200:
        #     lr = lr / 2
        #     E = E - 1

        # if curr_round == 300:
        #     lr = lr / 2
        #     E = E - 2

        # if curr_round == 400:
        #     lr = lr / 5
        #     E = E - 3

        # move the updated weights to our model state dict
        # encoder.load_state_dict(global_encoder_weights)
        # decoder.load_state_dict(global_decoder_weights)

        # loss
        loss_avg = sum(local_loss) / len(local_loss)
        # print('Round: {}... \tAverage Loss: {}'.format(curr_round, round(loss_avg, 3)), lr)
        train_loss.append(loss_avg)

        # t_accuracy, t_loss = testing(model, cifar_data_test, test_batch_size, criterion, num_classes, classes_test)
        # test_accuracy.append(t_accuracy)
        # test_loss.append(t_loss)

        # if best_accuracy < t_accuracy:
        #     best_accuracy = t_accuracy
        # # torch.save(model.state_dict(), plt_title)
        # print(curr_round, loss_avg, t_loss, test_accuracy[0], best_accuracy)
        # # print('best_accuracy:', best_accuracy, '---Round:', curr_round, '---lr', lr, '----localEpocs--', E)
        bleu = evaluateMultipleBleu(encoders, decoders, input_output_lang)

        if best_bleu < bleu:
            best_bleu = bleu

        if filename is not None:
            with open(filename, 'a') as f:
                # create the csv writer
                writer = csv.writer(f)

                # write a row to the csv file
                writer.writerow([curr_round, loss_avg, bleu, best_bleu])
        print(f"Round {curr_round} >> Loss: {loss_avg}, BLEU:{bleu}")

    end = time.time()
    print("Training Done!")
    print("Total time taken to Train: {}".format(end - start))

    return global_encoder_weights, global_decoder_weights

Training and Evaluating
=======================

With all these helper functions in place (it looks like extra work, but
it makes it easier to run multiple experiments) we can actually
initialize a network and start training.

Remember that the input sentences were heavily filtered. For this small
dataset we can use relatively small networks of 256 hidden nodes and a
single GRU layer. After about 40 minutes on a MacBook CPU we\'ll get
some reasonable results.

<div style="background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px"><strong>NOTE:</strong></div>
<div style="background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px">
<p>If you run this notebook you can train, interrupt the kernel,evaluate, and continue training later. Comment out the lines where theencoder and decoder are initialized and run <code>trainIters</code> again.</p>
</div>


In [20]:
hidden_size = 128
batch_size = 32
# input_lang, output_lang, train_dataloader = get_dataloader(batch_size, language='spa')

In [21]:
data_dict = {}
encoders = {}
decoders = {}
input_output_lang = {}
langs = ['spa', 'fra', 'pol', 'deu', 'swe']
langs = sorted(langs)
K = len(langs)
for i in range(K):
    input_lang, output_lang, train_dataloader, pair = get_dataloader(batch_size, language=langs[i])
    data_dict[i] = train_dataloader
    encoders[i] = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    decoders[i] = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)
    input_output_lang[i] = (input_lang, output_lang, pair)

Reading lines...
Read 277891 sentence pairs
Trimmed to 15860 sentence pairs
Counting words...
Counted words:
deu 5739
eng 3605
Reading lines...
Read 135842 sentence pairs
Trimmed to 11445 sentence pairs
Counting words...
Counted words:
fra 4601
eng 2991
Reading lines...
Read 49943 sentence pairs
Trimmed to 3613 sentence pairs
Counting words...
Counted words:
pol 3070
eng 1969
Reading lines...
Read 118121 sentence pairs
Trimmed to 7865 sentence pairs
Counting words...
Counted words:
spa 4027
eng 2809
Reading lines...
Read 25525 sentence pairs
Trimmed to 1751 sentence pairs
Counting words...
Counted words:
swe 1404
eng 1207


In [67]:
# st = {key: value for key, value in encoders[0].state_dict().items() if 'embedding' not in key}
# st2 = {key: value for key, value in decoders[0].state_dict().items() if 'embedding' not in key and 'out' not in key}
# encoders[1].state_dict().update(st)
# decoders[1].state_dict().update(st2)
# new = (encoders[0].state_dict())
# new.update(global_encoder_weights)
# encoders[0].load_state_dict(new)

In [22]:
import os
# train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)

filename=None
save = True
if save:
    num = 1
    filename = f"GLOBAL|{'_'.join(langs)}||{num}"
    while os.path.isfile(filename):
        print('Name is taken...trying again...')
        num += 1
        filename = f"{'_'.join(langs)}||{num}"

In [23]:
meta_encoder_weights, meta_decoder_weights = training(encoders, decoders, input_output_lang, 100, lr=0.001, ds=data_dict, C=1.0, K=K, E=1, filename=filename, batch_size=batch_size, hidden_size=hidden_size)

100%|██████████| 5/5 [00:09<00:00,  1.94s/it]


0.0
Round 1 >> Loss: 2.860247704758139, BLEU:0.0


100%|██████████| 5/5 [00:09<00:00,  1.85s/it]


0.12003516405820847
Round 2 >> Loss: 2.0962874488191146, BLEU:0.024007032811641692


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


0.2172812521457672
Round 3 >> Loss: 1.8446885250128031, BLEU:0.043456250429153444


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


0.16653580963611603
Round 4 >> Loss: 1.6896010309535252, BLEU:0.0333071619272232


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


0.5991720631718636
Round 5 >> Loss: 1.5702157657681615, BLEU:0.11983441263437271


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


0.4078499525785446
Round 6 >> Loss: 1.4716662326046503, BLEU:0.08156999051570893


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


0.46548230946063995
Round 7 >> Loss: 1.3851936830733662, BLEU:0.09309646189212799


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


0.6661649122834206
Round 8 >> Loss: 1.3094417741509332, BLEU:0.1332329824566841


100%|██████████| 5/5 [00:09<00:00,  1.90s/it]


0.9243349432945251
Round 9 >> Loss: 1.2434921789312843, BLEU:0.18486698865890502


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


0.7466794177889824
Round 10 >> Loss: 1.1829259463290578, BLEU:0.14933588355779648


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


1.1198234558105469
Round 11 >> Loss: 1.1278113208221954, BLEU:0.22396469116210938


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


1.359891839325428
Round 12 >> Loss: 1.0807575325854901, BLEU:0.2719783678650856


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


1.0678547620773315
Round 13 >> Loss: 1.0350084033043945, BLEU:0.2135709524154663


100%|██████████| 5/5 [00:09<00:00,  1.86s/it]


1.2234430760145187
Round 14 >> Loss: 0.9952742615521337, BLEU:0.24468861520290375


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


1.277858093380928
Round 15 >> Loss: 0.9588192857248904, BLEU:0.2555716186761856


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


1.3925724476575851
Round 16 >> Loss: 0.9245949544967372, BLEU:0.27851448953151703


100%|██████████| 5/5 [00:09<00:00,  1.84s/it]


1.6480442136526108
Round 17 >> Loss: 0.8938286121684487, BLEU:0.32960884273052216


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


1.2695788890123367
Round 18 >> Loss: 0.8657247240420165, BLEU:0.2539157778024673


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


1.4712837487459183
Round 19 >> Loss: 0.8383091285122937, BLEU:0.29425674974918364


100%|██████████| 5/5 [00:09<00:00,  1.84s/it]


1.120694786310196
Round 20 >> Loss: 0.812188473556321, BLEU:0.2241389572620392


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


1.4305199906229973
Round 21 >> Loss: 0.7894802263654178, BLEU:0.28610399812459947


100%|██████████| 5/5 [00:09<00:00,  1.81s/it]


1.5659805834293365
Round 22 >> Loss: 0.7676493593222836, BLEU:0.31319611668586733


100%|██████████| 5/5 [00:09<00:00,  1.81s/it]


1.635405033826828
Round 23 >> Loss: 0.7468930996742911, BLEU:0.3270810067653656


100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


1.2960965186357498
Round 24 >> Loss: 0.7263768365984127, BLEU:0.25921930372714996


100%|██████████| 5/5 [00:08<00:00,  1.78s/it]


1.7311849296092987
Round 25 >> Loss: 0.7084508432730455, BLEU:0.34623698592185975


100%|██████████| 5/5 [00:09<00:00,  1.81s/it]


1.845588579773903
Round 26 >> Loss: 0.6904358284323718, BLEU:0.3691177159547806


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


1.4199815690517426
Round 27 >> Loss: 0.6755409004331716, BLEU:0.2839963138103485


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


1.977981299161911
Round 28 >> Loss: 0.66054702093714, BLEU:0.3955962598323822


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


1.5764877051115036
Round 29 >> Loss: 0.6446918086473874, BLEU:0.31529754102230073


100%|██████████| 5/5 [00:09<00:00,  1.90s/it]


1.7350122183561325
Round 30 >> Loss: 0.6320594414631795, BLEU:0.3470024436712265


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


1.7306435108184814
Round 31 >> Loss: 0.6190059102463988, BLEU:0.3461287021636963


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.1875995993614197
Round 32 >> Loss: 0.6053399207231748, BLEU:0.43751991987228395


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.1490239799022675
Round 33 >> Loss: 0.5931671347910499, BLEU:0.4298047959804535


100%|██████████| 5/5 [00:09<00:00,  1.83s/it]


1.7716743797063828
Round 34 >> Loss: 0.5822015777469798, BLEU:0.35433487594127655


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


2.5026821196079254
Round 35 >> Loss: 0.5715246468207449, BLEU:0.500536423921585


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


1.5535879880189896
Round 36 >> Loss: 0.5608136264098358, BLEU:0.3107175976037979


100%|██████████| 5/5 [00:09<00:00,  1.81s/it]


2.241326168179512
Round 37 >> Loss: 0.5507961302464601, BLEU:0.4482652336359024


100%|██████████| 5/5 [00:09<00:00,  1.80s/it]


2.3870969116687775
Round 38 >> Loss: 0.5418581171089041, BLEU:0.47741938233375547


100%|██████████| 5/5 [00:08<00:00,  1.80s/it]


1.9530484676361084
Round 39 >> Loss: 0.5332602318661402, BLEU:0.3906096935272217


100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


2.0893848538398743
Round 40 >> Loss: 0.5246706234655523, BLEU:0.41787697076797486


100%|██████████| 5/5 [00:08<00:00,  1.78s/it]


2.664024442434311
Round 41 >> Loss: 0.5162413607668539, BLEU:0.5328048884868621


100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


2.458829164505005
Round 42 >> Loss: 0.5087843907988423, BLEU:0.49176583290100095


100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


2.534795820713043
Round 43 >> Loss: 0.5015260671145649, BLEU:0.5069591641426087


100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


2.059453248977661
Round 44 >> Loss: 0.4932866443202192, BLEU:0.41189064979553225


100%|██████████| 5/5 [00:09<00:00,  1.92s/it]


2.1823181807994843
Round 45 >> Loss: 0.48636721955098705, BLEU:0.43646363615989686


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.1634169816970825
Round 46 >> Loss: 0.47964601642069776, BLEU:0.4326833963394165


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.4025538861751556
Round 47 >> Loss: 0.47414758148645014, BLEU:0.4805107772350311


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.028943732380867
Round 48 >> Loss: 0.4669696942343121, BLEU:0.4057887464761734


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.324966996908188
Round 49 >> Loss: 0.46160121859884473, BLEU:0.4649933993816376


100%|██████████| 5/5 [00:09<00:00,  1.85s/it]


2.450753331184387
Round 50 >> Loss: 0.45572953690711693, BLEU:0.4901506662368774


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


2.1560243368148804
Round 51 >> Loss: 0.44987870777594063, BLEU:0.4312048673629761


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


2.1569567918777466
Round 52 >> Loss: 0.44501926531121416, BLEU:0.4313913583755493


100%|██████████| 5/5 [00:09<00:00,  1.86s/it]


1.783563882112503
Round 53 >> Loss: 0.44000362788082353, BLEU:0.3567127764225006


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


2.065151661634445
Round 54 >> Loss: 0.43470346935243515, BLEU:0.413030332326889


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


2.2805392146110535
Round 55 >> Loss: 0.4295052635767145, BLEU:0.4561078429222107


100%|██████████| 5/5 [00:09<00:00,  1.85s/it]


2.2208003401756287
Round 56 >> Loss: 0.42496680066170595, BLEU:0.44416006803512575


100%|██████████| 5/5 [00:08<00:00,  1.80s/it]


2.4873069524765015
Round 57 >> Loss: 0.41998054551544917, BLEU:0.4974613904953003


100%|██████████| 5/5 [00:09<00:00,  1.83s/it]


2.6670438647270203
Round 58 >> Loss: 0.4149770117718753, BLEU:0.533408772945404


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.250459849834442
Round 59 >> Loss: 0.4105789167631035, BLEU:0.45009196996688844


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.165394216775894
Round 60 >> Loss: 0.40707939122693027, BLEU:0.4330788433551788


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


2.7133621275424957
Round 61 >> Loss: 0.40213377687913016, BLEU:0.5426724255084991


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.49016135931015
Round 62 >> Loss: 0.3978944658945328, BLEU:0.49803227186203003


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


2.2347859144210815
Round 63 >> Loss: 0.39404746345905006, BLEU:0.4469571828842163


100%|██████████| 5/5 [00:09<00:00,  1.86s/it]


2.4218453764915466
Round 64 >> Loss: 0.38925301866353823, BLEU:0.4843690752983093


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


2.292235016822815
Round 65 >> Loss: 0.3855831080438785, BLEU:0.458447003364563


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.447896331548691
Round 66 >> Loss: 0.38251286651304955, BLEU:0.48957926630973814


100%|██████████| 5/5 [00:09<00:00,  1.85s/it]


2.6700838208198547
Round 67 >> Loss: 0.37769334819887945, BLEU:0.5340167641639709


100%|██████████| 5/5 [00:09<00:00,  1.80s/it]


2.666776418685913
Round 68 >> Loss: 0.37427642751992796, BLEU:0.5333552837371827


100%|██████████| 5/5 [00:08<00:00,  1.80s/it]


2.2292197197675705
Round 69 >> Loss: 0.37145986771590267, BLEU:0.4458439439535141


100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


2.1857365667819977
Round 70 >> Loss: 0.3685115839870883, BLEU:0.43714731335639956


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


2.442729264497757
Round 71 >> Loss: 0.36510725201992644, BLEU:0.4885458528995514


100%|██████████| 5/5 [00:09<00:00,  1.83s/it]


2.560518890619278
Round 72 >> Loss: 0.3609593212020745, BLEU:0.5121037781238555


100%|██████████| 5/5 [00:09<00:00,  1.90s/it]


2.2620344161987305
Round 73 >> Loss: 0.35860359286857274, BLEU:0.4524068832397461


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.5941266119480133
Round 74 >> Loss: 0.35500970061336423, BLEU:0.5188253223896027


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


2.5631908178329468
Round 75 >> Loss: 0.35230240069246477, BLEU:0.5126381635665893


100%|██████████| 5/5 [00:09<00:00,  1.83s/it]


2.6065914630889893
Round 76 >> Loss: 0.3494252308881388, BLEU:0.5213182926177978


100%|██████████| 5/5 [00:09<00:00,  1.82s/it]


2.3242297172546387
Round 77 >> Loss: 0.3461621203157483, BLEU:0.4648459434509277


100%|██████████| 5/5 [00:09<00:00,  1.82s/it]


2.5261765718460083
Round 78 >> Loss: 0.3425756552790645, BLEU:0.5052353143692017


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


2.314158469438553
Round 79 >> Loss: 0.338858136989528, BLEU:0.46283169388771056


100%|██████████| 5/5 [00:09<00:00,  1.92s/it]


2.6117401719093323
Round 80 >> Loss: 0.33682856951855955, BLEU:0.5223480343818665


100%|██████████| 5/5 [00:08<00:00,  1.78s/it]


2.912062257528305
Round 81 >> Loss: 0.33471298787692405, BLEU:0.582412451505661


100%|██████████| 5/5 [00:09<00:00,  1.80s/it]


2.5463851392269135
Round 82 >> Loss: 0.33096151418574155, BLEU:0.5092770278453826


100%|██████████| 5/5 [00:09<00:00,  1.80s/it]


2.6259197294712067
Round 83 >> Loss: 0.32835439857391746, BLEU:0.5251839458942413


100%|██████████| 5/5 [00:08<00:00,  1.80s/it]


2.3479614555835724
Round 84 >> Loss: 0.3243347395231312, BLEU:0.46959229111671447


100%|██████████| 5/5 [00:09<00:00,  1.84s/it]


2.5247846245765686
Round 85 >> Loss: 0.3226205880218346, BLEU:0.5049569249153137


100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


2.2515049278736115
Round 86 >> Loss: 0.3196336605417584, BLEU:0.4503009855747223


100%|██████████| 5/5 [00:09<00:00,  1.83s/it]


2.083428382873535
Round 87 >> Loss: 0.31712727112457095, BLEU:0.416685676574707


100%|██████████| 5/5 [00:09<00:00,  1.89s/it]


2.85760435461998
Round 88 >> Loss: 0.3153467477452382, BLEU:0.571520870923996


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


2.615035116672516
Round 89 >> Loss: 0.3133765977680657, BLEU:0.5230070233345032


100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


2.5102259814739227
Round 90 >> Loss: 0.3113851276855954, BLEU:0.5020451962947845


100%|██████████| 5/5 [00:09<00:00,  1.88s/it]


3.150819718837738
Round 91 >> Loss: 0.3087155836238547, BLEU:0.6301639437675476


100%|██████████| 5/5 [00:09<00:00,  1.91s/it]


2.7309306263923645
Round 92 >> Loss: 0.3055168909504001, BLEU:0.5461861252784729


100%|██████████| 5/5 [00:09<00:00,  1.83s/it]


2.3715451657772064
Round 93 >> Loss: 0.3025994972123599, BLEU:0.47430903315544126


100%|██████████| 5/5 [00:08<00:00,  1.80s/it]


2.702192783355713
Round 94 >> Loss: 0.30051708956894146, BLEU:0.5404385566711426


100%|██████████| 5/5 [00:09<00:00,  1.80s/it]


2.7935028672218323
Round 95 >> Loss: 0.29813768395070134, BLEU:0.5587005734443664


100%|██████████| 5/5 [00:09<00:00,  1.81s/it]


2.2949805855751038
Round 96 >> Loss: 0.2961758112020165, BLEU:0.45899611711502075


100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


2.3619384169578552
Round 97 >> Loss: 0.2939030854549721, BLEU:0.47238768339157106


100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


2.692322790622711
Round 98 >> Loss: 0.2915829201258407, BLEU:0.5384645581245422


100%|██████████| 5/5 [00:09<00:00,  1.84s/it]


2.2503625452518463
Round 99 >> Loss: 0.28926876557383957, BLEU:0.4500725090503693


100%|██████████| 5/5 [00:09<00:00,  1.83s/it]

3.1285168528556824
Round 100 >> Loss: 0.28712221835746965, BLEU:0.6257033705711365
Training Done!
Total time taken to Train: 937.1776456832886





In [24]:
import torch
from torch.utils.data import DataLoader, Subset

def limited_data_loader(original_dataloader, num_samples, random=True):
    # Get the original dataset from the DataLoader
    dataset = original_dataloader.dataset
    
    # Ensure the original dataset is larger than the requested number of samples
    assert len(dataset) >= num_samples, "The original dataset has fewer samples than requested"
    
    # Create a Subset of the original dataset limited to the first num_samples elements
    if random:
        indices = np.random.choice(len(dataset), num_samples, replace=False)
    else:
        indices = range(num_samples)

    subset = Subset(dataset, indices)
    
    # Create a new DataLoader from this subset with the same parameters as the original DataLoader
    new_dataloader = DataLoader(subset, batch_size=original_dataloader.batch_size, shuffle=False, num_workers=original_dataloader.num_workers)
    
    return new_dataloader


In [25]:
def personalize(lang, rounds, encoder_weights=None, decoder_weights=None, sample=None, save=False):
    input_lang, output_lang, train_dataloader, pairs = get_dataloader(batch_size, language=lang)

    encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)

    is_FL = 'false'

    if encoder_weights is not None:
        e_og = encoder.state_dict()
        d_og = decoder.state_dict()
        e_og.update(encoder_weights)
        d_og.update(decoder_weights)
        encoder.load_state_dict(e_og)
        decoder.load_state_dict(d_og)
        is_FL = 'true'

    
    
    if sample is None:
        sample = 'full'
    else:
        train_dataloader = limited_data_loader(train_dataloader, sample)

    filename=None
    if save:
        num = 1
        filename = f"P|{lang}_{sample}-shot_FL-{is_FL}_epoch{rounds}||{num}"
        while os.path.isfile(filename):
            print('Name is taken...trying again...')
            num += 1
            filename = f"P|{lang}_{sample}-shot_FL-{is_FL}_epoch{rounds}||{num}"
    train(train_dataloader, encoder, decoder, rounds, print_every=5, plot_every=5, filename=filename, input=input_lang, output=output_lang, pairs=pairs)
    return encoder, decoder

In [30]:
for i in range(1):
    personalize('spa', 100, save=True)

Reading lines...
Read 118121 sentence pairs
Trimmed to 7865 sentence pairs
Counting words...
Counted words:
spa 4027
eng 2809
0m 9s (- 2m 56s) (5 5%) 1.8923
0m 18s (- 2m 49s) (10 10%) 0.9286
0m 28s (- 2m 39s) (15 15%) 0.5192
0m 37s (- 2m 28s) (20 20%) 0.2980
0m 46s (- 2m 19s) (25 25%) 0.1823
0m 56s (- 2m 10s) (30 30%) 0.1237
1m 5s (- 2m 1s) (35 35%) 0.0905
1m 14s (- 1m 52s) (40 40%) 0.0718
1m 24s (- 1m 42s) (45 45%) 0.0583
1m 33s (- 1m 33s) (50 50%) 0.0506
1m 42s (- 1m 24s) (55 55%) 0.0449
1m 52s (- 1m 14s) (60 60%) 0.0413
2m 1s (- 1m 5s) (65 65%) 0.0384
2m 10s (- 0m 56s) (70 70%) 0.0360
2m 20s (- 0m 46s) (75 75%) 0.0343
2m 29s (- 0m 37s) (80 80%) 0.0320
2m 39s (- 0m 28s) (85 85%) 0.0319
2m 48s (- 0m 18s) (90 90%) 0.0304
2m 57s (- 0m 9s) (95 95%) 0.0298
3m 7s (- 0m 0s) (100 100%) 0.0289


In [32]:
for i in range(3):
    personalize('fin', 100, save=True, encoder_weights=meta_encoder_weights, decoder_weights=meta_decoder_weights, sample=100)

Reading lines...
Read 72258 sentence pairs
Trimmed to 5005 sentence pairs
Counting words...
Counted words:
fin 3686
eng 1971
0m 0s (- 0m 5s) (5 5%) 6.8846
0m 0s (- 0m 4s) (10 10%) 5.4794
0m 0s (- 0m 4s) (15 15%) 4.3577
0m 1s (- 0m 4s) (20 20%) 3.4238
0m 1s (- 0m 4s) (25 25%) 2.7352
0m 1s (- 0m 3s) (30 30%) 2.2419
0m 1s (- 0m 3s) (35 35%) 1.8946
0m 2s (- 0m 3s) (40 40%) 1.6374
0m 2s (- 0m 3s) (45 45%) 1.4471
0m 2s (- 0m 2s) (50 50%) 1.2972
0m 3s (- 0m 2s) (55 55%) 1.1612
0m 3s (- 0m 2s) (60 60%) 1.0403
0m 3s (- 0m 1s) (65 65%) 0.9236
0m 3s (- 0m 1s) (70 70%) 0.8336
0m 4s (- 0m 1s) (75 75%) 0.7473
0m 4s (- 0m 1s) (80 80%) 0.6673
0m 4s (- 0m 0s) (85 85%) 0.5893
0m 4s (- 0m 0s) (90 90%) 0.5238
0m 5s (- 0m 0s) (95 95%) 0.4710
0m 5s (- 0m 0s) (100 100%) 0.4165
Reading lines...
Read 72258 sentence pairs
Trimmed to 5005 sentence pairs
Counting words...
Counted words:
fin 3686
eng 1971
Name is taken...trying again...
0m 0s (- 0m 5s) (5 5%) 7.0616
0m 0s (- 0m 4s) (10 10%) 5.7325
0m 0s (- 0m 4s) 

Set dropout layers to `eval` mode


Visualizing Attention
=====================

A useful property of the attention mechanism is its highly interpretable
outputs. Because it is used to weight specific encoder outputs of the
input sequence, we can imagine looking where the network is focused most
at each time step.

You could simply run `plt.matshow(attentions)` to see attention output
displayed as a matrix. For a better viewing experience we will do the
extra work of adding axes and labels:


In [22]:
def showAttention(input_sentence, output_words, attentions):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')
    fig.colorbar(cax)

    # Set up axes
    ax.set_xticklabels([''] + input_sentence.split(' ') +
                       ['<EOS>'], rotation=90)
    ax.set_yticklabels([''] + output_words)

    # Show label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()


def evaluateAndShowAttention(input_sentence):
    output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)
    print('input =', input_sentence)
    print('output =', ' '.join(output_words))
    showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])


evaluateAndShowAttention('il n est pas aussi grand que son pere')

evaluateAndShowAttention('je suis trop fatigue pour conduire')

evaluateAndShowAttention('je suis desole si c est une question idiote')

evaluateAndShowAttention('je suis reellement fiere de vous')

input = il n est pas aussi grand que son pere
output = he is not as tall as his father <EOS>
input = je suis trop fatigue pour conduire
output = i m too tired to drive drive <EOS>
input = je suis desole si c est une question idiote
output = i m sorry if this is a stupid question <EOS>
input = je suis reellement fiere de vous
output = i m really proud of you are <EOS>


  ax.set_xticklabels([''] + input_sentence.split(' ') +
  ax.set_yticklabels([''] + output_words)
  plt.show()


Exercises
=========

-   Try with a different dataset
    -   Another language pair
    -   Human → Machine (e.g. IOT commands)
    -   Chat → Response
    -   Question → Answer
-   Replace the embeddings with pretrained word embeddings such as
    `word2vec` or `GloVe`
-   Try with more layers, more hidden units, and more sentences. Compare
    the training time and results.
-   If you use a translation file where pairs have two of the same
    phrase (`I am test \t I am test`), you can use this as an
    autoencoder. Try this:
    -   Train as an autoencoder
    -   Save only the Encoder network
    -   Train a new Decoder for translation from there
