https://github.com/hunkim/PyTorchZeroToAll/blob/master/14_2_seq2seq_att.py

https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb

https://github.com/graykode/nlp-tutorial/blob/master/4-2.Seq2Seq(Attention)/Seq2Seq(Attention).ipynb

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.legacy. datasets import Multi30k
from torchtext.legacy.data import Field, BucketIterator

In [41]:
class Encoder(nn.Module):
    def __init__(self, n_tokens, n_inputs, n_hiddens, padding_idx, batch_first = True):
        super().__init__()
        self.n_hiddens = n_hiddens
        self.batch_first = batch_first
        self.embedding = nn.Embedding(n_tokens, n_inputs, padding_idx = padding_idx)
        self.bidirectional_gru = nn.GRU(n_inputs, n_hiddens, batch_first = batch_first, bidirectional=True)

    def init_hidden(self, batch_size):
        weight = next(self.parameters())
        h0 = weight.new_zeros(2, batch_size, self.n_hiddens)
        return h0

    def forward(self, x):
        input_length = torch.LongTensor([torch.max(x[i, :].data.nonzero()) + 1 for i in range(x.size(0))])
        hidden = self.init_hidden(x.size(0))
        x = self.embedding(x)
        x = nn.utils.rnn.pack_padded_sequence(x, input_length, batch_first = self.batch_first)
        output, hidden = self.bidirectional_gru(x, hidden)
        output = nn.utils.rnn.pad_packed_sequence(output, batch_first = self.batch_first)
        hidden = torch.cat([hidden[0], hidden[1]], dim = -1)
        return output, hidden



In [8]:
encoder = Encoder(10,10, 10, 0)

In [38]:
next(encoder.parameters()).new_zeros(2, 10, 10)

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.,