In [192]:
from tqdm import tqdm
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.nn import functional as F

import torch
import random

In [193]:
vocab_size = 100
pad_id = 0
sos_id = 1
eos_id = 2

src_data = [
  [3, 77, 56, 26, 3, 55, 12, 36, 31],
  [58, 20, 65, 46, 26, 10, 76, 44],
  [58, 17, 8],
  [59],
  [29, 3, 52, 74, 73, 51, 39, 75, 19],
  [41, 55, 77, 21, 52, 92, 97, 69, 54, 14, 93],
  [39, 47, 96, 68, 55, 16, 90, 45, 89, 84, 19, 22, 32, 99, 5],
  [75, 34, 17, 3, 86, 88],
  [63, 39, 5, 35, 67, 56, 68, 89, 55, 66],
  [12, 40, 69, 39, 49]
]

trg_data = [
  [75, 13, 22, 77, 89, 21, 13, 86, 95],
  [79, 14, 91, 41, 32, 79, 88, 34, 8, 68, 32, 77, 58, 7, 9, 87],
  [85, 8, 50, 30],
  [47, 30],
  [8, 85, 87, 77, 47, 21, 23, 98, 83, 4, 47, 97, 40, 43, 70, 8, 65, 71, 69, 88],
  [32, 37, 31, 77, 38, 93, 45, 74, 47, 54, 31, 18],
  [37, 14, 49, 24, 93, 37, 54, 51, 39, 84],
  [16, 98, 68, 57, 55, 46, 66, 85, 18],
  [20, 70, 14, 6, 58, 90, 30, 17, 91, 18, 90],
  [37, 93, 98, 13, 45, 28, 89, 72, 70]
]

In [194]:
trg_data = [[sos_id]+seq+[eos_id] for seq in trg_data]

In [195]:
def padding(data):
    max_len = len(max(data, key=len))
    print(f"Maximum sequence length: {max_len}")

    valid_lens = []
    for i, seq in enumerate(tqdm(data)):
        valid_lens.append(len(seq))
        if len(seq) < max_len:
            data[i] = seq + [pad_id] * (max_len - len(seq))
    
    return data, valid_lens, max_len

In [196]:
src_data, src_lens, src_max_len = padding(src_data)
trg_data, trg_lens, trg_max_len = padding(trg_data)

Maximum sequence length: 15


100%|██████████| 10/10 [00:00<00:00, 28591.03it/s]


Maximum sequence length: 22


100%|██████████| 10/10 [00:00<00:00, 38944.33it/s]


In [197]:
src_batch = torch.LongTensor(src_data)
src_batch_lens = torch.LongTensor(src_lens)
trg_batch = torch.LongTensor(trg_data)
trg_batch_lens = torch.LongTensor(trg_lens)

In [198]:
src_batch_lens, sorted_idx = src_batch_lens.sort(descending = True)
src_batch = src_batch[sorted_idx]
trg_batch = trg_batch[sorted_idx]
trg_batch_lens = trg_batch_lens[sorted_idx]

In [199]:
embedding_size = 256
hidden_size = 512
num_layers = 2
num_dirs = 2
dropout = 0.1

In [200]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.gru = nn.GRU(
            input_size = embedding_size,
            hidden_size = hidden_size,
            num_layers = num_layers,
            bidirectional = True if num_dirs > 1 else False,
            dropout = dropout
        )
        self.linear = nn.Linear(num_dirs*hidden_size, hidden_size)
    
    def forward(self, batch, batch_lens):
        batch_emb = self.embedding(batch)
        batch_emb = batch_emb.transpose(0,1)

        packed_input = pack_padded_sequence(batch_emb, batch_lens)

        h_0 = torch.zeros((num_layers * num_dirs, batch.shape[0], hidden_size))
        packed_outputs, h_n = self.gru(packed_input, h_0)
        outputs = pad_packed_sequence(packed_outputs)[0]
        outputs = torch.tanh(self.linear(outputs))

        forward_hidden = h_n[-2, :, :]
        backward_hidden = h_n[-1, :, :]
        hidden = torch.tanh(self.linear(torch.cat((forward_hidden, backward_hidden), dim=-1))).unsqueeze(0)

        return outputs, hidden

In [201]:
encoder = Encoder()

In [202]:
class DotAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, decoder_hidden, encoder_outputs):
        query = decoder_hidden.squeeze(0)
        key = encoder_outputs.transpose(0,1)

        energy = torch.sum(torch.mul(key, query.unsqueeze(1)), dim=-1)

        attn_scores = F.softmax(energy, dim=-1)
        attn_values = torch.sum(torch.mul(encoder_outputs.transpose(0,1), attn_scores.unsqueeze(2)), dim=1)

        return attn_values, attn_scores
    

In [203]:
dot_attn = DotAttention()

In [204]:
class Decoder(nn.Module):
    def __init__(self, attention):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.attention = attention
        self.rnn = nn.GRU(
            embedding_size,
            hidden_size
        )
        self.output_linear = nn.Linear(2*hidden_size, vocab_size)

    def forward(self, batch, encoder_outputs, hidden):
        batch_emb = self.embedding(batch)
        batch_emb = batch_emb.unsqueeze(0)

        outputs, hidden = self.rnn(batch_emb, hidden)

        attn_values, attn_scores = self.attention(hidden, encoder_outputs)
        concat_outputs = torch.cat((outputs, attn_values.unsqueeze(0)), dim=-1)

        return self.output_linear(concat_outputs).squeeze(0), hidden

In [205]:
decoder = Decoder(dot_attn)

In [206]:
class Seq2seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2seq, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, src_batch, src_batch_lens, trg_batch, teacher_forcing_prob=0.5):
        encoder_outputs, hidden = self.encoder(src_batch, src_batch_lens)

        input_ids = trg_batch[:,0]
        batch_size = src_batch.shape[0]
        outputs = torch.zeros(trg_max_len, batch_size, vocab_size)

        for t in range(1, trg_max_len):
            decoder_outputs, hidden = self.decoder(input_ids, encoder_outputs, hidden)

            outputs[t] = decoder_outputs
            _, top_ids = torch.max(decoder_outputs, dim=-1)

            input_ids = trg_batch[:,t] if random.random() > teacher_forcing_prob else top_ids
            
        return outputs


In [207]:
seq2seq = Seq2seq(encoder, decoder)

In [209]:
outputs = seq2seq(src_batch, src_batch_lens, trg_batch)

print(outputs)
print(outputs.shape)

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0131,  0.0897,  0.0474,  ...,  0.0270,  0.1594, -0.0348],
         [-0.0163,  0.0730,  0.0749,  ...,  0.0136,  0.1549,  0.0019],
         [ 0.0076,  0.1346,  0.0858,  ...,  0.0441,  0.1306, -0.0108],
         ...,
         [-0.0280,  0.0961,  0.0580,  ...,  0.0255,  0.2035, -0.0104],
         [-0.0219,  0.1460,  0.0539,  ...,  0.0374,  0.1733, -0.0457],
         [-0.0228,  0.1014,  0.0735,  ...,  0.0275,  0.1726, -0.0381]],

        [[ 0.0907,  0.0176,  0.0850,  ...,  0.1600, -0.0279, -0.0637],
         [ 0.0764,  0.0901,  0.0688,  ...,  0