<a href="https://colab.research.google.com/github/DavoodSZ1993/Dive_into_Deep_Learning/blob/main/10_7_Encoder_Decoder_Machine_Translation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install d2l==1.0.0-alpha1.post0 --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.0/93.0 KB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.0/121.0 KB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m83.6/83.6 KB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m37.5 MB/s[0m eta [36m0:00:00[0m
[?25h

## 10.7 Encoder-Decoder Seq2Seq for Machine Translation

In [2]:
import collections
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l



### 10.7.2 Encoding

In [3]:
def init_seq2seq(module):
  """Initialize weights for Seq2Seq."""
  if type(module) == nn.Linear:
    nn.init.xavier_uniform_(module.weight)
  if type(module) == nn.GRU:
    for param in module._flat_weights_names:
      if "weight" in param:
        nn.init.xavier_uniform(module._parameters[param])

In [4]:
class Seq2SeqEncoder(d2l.Encoder):
  """The RNN encoder for sequence to sequence learning."""
  def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, 
               dropout=0):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.rnn = d2l.GRU(embed_size, num_hiddens, num_layers, dropout)
    self.apply(init_seq2seq)

  def forward(self, X, *args):
    # X shape: (batch_size, num_steps)
    embs = self.embedding(X.t().type(torch.int64))
    outputs, state = self.rnn(embs)
    return outputs, state

In [5]:
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 9

encoder = Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, num_layers)
X = torch.zeros((batch_size, num_steps))
enc_outputs, enc_state = encoder(X)

d2l.check_shape(enc_outputs, (num_steps, batch_size, num_hiddens))

  nn.init.xavier_uniform(module._parameters[param])


In [6]:
d2l.check_shape(enc_state, (num_layers, batch_size, num_hiddens))

### 10.7.3 Decoder

In [7]:
class Seq2SeqDecoder(d2l.Decoder):
  def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
               dropout=0):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.rnn = d2l.GRU(embed_size+num_hiddens, num_hiddens, 
                       num_layers, dropout)
    self.dense = nn.LazyLinear(vocab_size)
    self.apply(init_seq2seq)

  def init_state(self, enc_all_outputs, *args):
    return enc_all_outputs

  def forward(self, X, state):                                                   # X shape: (batch_size, num_steps)
    embs = self.embedding(X.t().type(torch.int32))                               # embs shape: (num_steps, batch_size, embed_size)
    enc_output, hidden_state = state
    context = enc_output[-1]                                                     # context shape: (batch_size, num_hiddens)
    context = context.repeat(embs.shape[0], 1, 1)                                # Broadcast context to (num_steps, batch_size, num_hiddens)
    embs_and_context = torch.cat((embs, context), -1)                            # concat at the feature dimension
    outputs, hidden_state = self.rnn(embs_and_context, hidden_state)
    outputs = self.dense(outputs).swapaxes(0, 1)
    return outputs, [enc_output, hidden_state]                                   # output shape: (batch_size, num_steps, vocab_size)
                                                                                 # hidden_state shape: (num_layers, batch_size, num_hiddens)

In [8]:
decoder = Seq2SeqDecoder(vocab_size, embed_size, num_hiddens, num_layers)
state = decoder.init_state(encoder(X))
dec_outputs, state = decoder(X, state)
d2l.check_shape(dec_outputs, (batch_size, num_steps, vocab_size))
d2l.check_shape(state[1], (num_layers, batch_size, num_hiddens))

  nn.init.xavier_uniform(module._parameters[param])


### 10.7.4 Encoder-Decoder for Sequence to Sequence Learning

In [9]:
class Seq2Seq(d2l.EncoderDecoder):
  def __init__(self, encoder, decoder, tgt_pad, lr):
    super().__init__()
    self.save_hyperparameters()

  def validation_step(self, batch):
    Y_hat = self(*batch[:-1])
    self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.lr)