In [15]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

import sys
sys.path.append('..')

from nmt.datasets import Vocab, batch_iter
from nmt.networks import CharEmbedding, Encoder

from typing import List, Tuple

In [17]:
## Setup something to work with

sentences_words_src = [
    ['Human:', 'What', 'do', 'we', 'want?'],
    ['Computer:', 'Natural', 'language', 'processing!'],
    ['Human:', 'When', 'do', 'we', 'want', 'it?'],
    ['Computer:', 'When', 'do', 'we', 'want', 'what?']
]

sentences_words_tgt = [
    ['<s>', 'Human:', 'What', 'do', 'we', 'want?', '</s>'],
    ['<s>', 'Computer:', 'Natural', 'language', 'processing!', '</s>'],
    ['<s>', 'Human:', 'When', 'do', 'we', 'want', 'it?', '</s>'],
    ['<s>', 'Computer:', 'When', 'do', 'we', 'want', 'what?', '</s>']
]

In [18]:
vocab = Vocab.build(sentences_words_src, sentences_words_tgt)

Initializing source vocab
Vocab Store: Tokens [size=17],                 Characters [size=97]
Initializing target vocab
Vocab Store: Tokens [size=17],                 Characters [size=97]


In [25]:
data = list(zip(sentences_words_src, sentences_words_tgt))
data_generator = batch_iter(
    data=data,
    batch_size=4,
    shuffle=True
)
batch_src, batch_tgt = next(data_generator)
print(batch_src)

[['Human:', 'When', 'do', 'we', 'want', 'it?'], ['Computer:', 'When', 'do', 'we', 'want', 'what?'], ['Human:', 'What', 'do', 'we', 'want?'], ['Computer:', 'Natural', 'language', 'processing!']]


In [26]:
source_length = [len(sent) for sent in batch_src]
print(source_length)

[6, 6, 5, 4]


In [46]:
char_tensors_src = vocab.src.to_tensor(batch_src, tokens=False)
token_tensors_tgt = vocab.tgt.to_tensor(batch_tgt, tokens=False)
print(f"src char tensor size = {char_tensors.size()}; tgt char tensor size = {token_tensors_tgt.size()}")

src char tensor size = torch.Size([6, 4, 21]); tgt char tensor size = torch.Size([8, 4, 21])


In [48]:
encoder = Encoder(input_size=300, hidden_size=1024, num_layers=2)

In [49]:
c_embedding = CharEmbedding(num_embeddings=vocab.src.length(tokens=False), char_embedding_dim=50, embedding_dim=300, char_padding_idx=vocab.src.pad_char_idx)
char_embedding_src = c_embedding(char_tensors_src)
t_embedding = CharEmbedding(num_embeddings=vocab.tgt.length(tokens=False), char_embedding_dim=50, embedding_dim=300, char_padding_idx=vocab.tgt.pad_char_idx)
target_embedding = t_embedding(token_tensors_tgt)
print(char_embedding.size(), target_embedding.size())

torch.Size([6, 4, 300]) torch.Size([8, 4, 300])


In [50]:
char_enc_hidden, (char_hidden, char_cell) = encoder(char_embedding_src, source_length)
char_enc_hidden.shape, char_hidden.shape, char_cell.shape

(torch.Size([4, 6, 2048]), torch.Size([4, 1024]), torch.Size([4, 1024]))

In [51]:
class Attention(nn.Module):
    def forward(self, enc_hidden: torch.Tensor,
                enc_projection: torch.Tensor,
                dec_hidden_t: torch.Tensor,
                enc_masks: torch.Tensor = None) -> torch.Tensor:
        dec_hidden_unsqueezed_t = dec_hidden_t.unsqueeze(dim=2)
        score_t = enc_projection.bmm(dec_hidden_unsqueezed_t)
        score_t = score_t.squeeze(dim=2)

        if enc_masks:
            score_t.data.masked_fill_(
                enc_masks.byte().to(torch.bool),
                -float('inf')
            )
        
        alpha_t = F.softmax(score_t, dim=1)
        alpha_t = alpha_t.unsqueeze(dim=1)

        attention = alpha_t.bmm(enc_hidden)
        return attention.squeeze(dim=1)

In [52]:
attention = Attention()

In [53]:
# For multiplications with hidden layers inside attention layer
encoder_proj = nn.Linear(in_features=2048, out_features=1024, bias=False)
enc_projection = encoder_proj(char_enc_hidden)

In [54]:
attention(char_enc_hidden, enc_projection, char_hidden).shape

torch.Size([4, 2048])

In [63]:
class Decoder(nn.Module):
    def __init__(self, input_size: int,
                 hidden_size: int, 
                 bias: bool = True,
                 dropout_prob: float = 0.3):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.attention = Attention()
        self.decoder = nn.LSTMCell(
            input_size=input_size,
            hidden_size=hidden_size,
            bias=bias
        )
        self.combined_projection = nn.Linear(
            in_features=hidden_size*3,
            out_features=hidden_size,
            bias=False
        )
        self.dropout = nn.Dropout(p=dropout_prob)
        
    def forward(self, output: torch.Tensor,
                enc_hidden: torch.Tensor,
                enc_projection: torch.Tensor,
                dec_init_state: Tuple[torch.Tensor, torch.Tensor],
                enc_masks: torch.Tensor = None,
                device: torch.device = 'cpu') -> torch.Tensor:
        
        dec_state = dec_init_state
        batch_size, sent_length, _ = enc_hidden.size()

        o_prev = torch.zeros(batch_size, self.hidden_size, device=device)
        combined_outputs = []

        for Y_t in torch.split(output, 1, dim=0):
            Ybar_t = torch.cat([Y_t.squeeze(dim=0), o_prev], dim=1)

            dec_state = self.decoder(Ybar_t, dec_state)
            dec_hidden, dec_cell = dec_state

            a_t = self.attention(enc_hidden, enc_projection, dec_hidden, enc_masks)

            U_t = torch.cat([a_t, dec_hidden], dim=1)
            V_t = self.combined_projection(U_t)
            o_t = self.dropout(V_t.tanh())

            combined_outputs.append(o_t)
            o_prev = o_t

        combined_outputs = torch.stack(combined_outputs, dim=0)
        return combined_outputs

In [64]:
decoder = Decoder(
    input_size=1024+300,
    hidden_size=1024
)


In [90]:
target_embedding = target_embedding[:-1]

In [91]:
target_embedding.shape

torch.Size([7, 4, 300])

In [92]:
outputs = decoder(target_embedding, char_enc_hidden, enc_projection, (char_hidden, char_cell))

In [93]:
outputs.shape

torch.Size([7, 4, 1024])

In [94]:
target_layer = nn.Linear(
            in_features=1024,
            out_features=len(vocab.tgt),
            bias=False
        )
P = F.log_softmax(target_layer(outputs), dim=-1)

In [95]:
P.shape

torch.Size([7, 4, 17])