In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

import sys
sys.path.append('..')

from nmt.datasets import Vocab, batch_iter
from nmt.networks import CharEmbedding, Encoder
from nmt.layers import Attention

from typing import List, Tuple

In [3]:
## Setup something to work with

sentences_words_src = [
    ['Human:', 'What', 'do', 'we', 'want?'],
    ['Computer:', 'Natural', 'language', 'processing!'],
    ['Human:', 'When', 'do', 'we', 'want', 'it?'],
    ['Computer:', 'When', 'do', 'we', 'want', 'what?']
]

sentences_words_tgt = [
    ['<s>', 'Human:', 'What', 'do', 'we', 'want?', '</s>'],
    ['<s>', 'Computer:', 'Natural', 'language', 'processing!', '</s>'],
    ['<s>', 'Human:', 'When', 'do', 'we', 'want', 'it?', '</s>'],
    ['<s>', 'Computer:', 'When', 'do', 'we', 'want', 'what?', '</s>']
]

In [4]:
vocab = Vocab.build(sentences_words_src, sentences_words_tgt)

Initializing source vocab
Vocab Store: Tokens [size=17],                 Characters [size=97]
Initializing target vocab
Vocab Store: Tokens [size=17],                 Characters [size=97]


In [5]:
data = list(zip(sentences_words_src, sentences_words_tgt))
data_generator = batch_iter(
    data=data,
    batch_size=4,
    shuffle=True
)
batch_src, batch_tgt = next(data_generator)
print(batch_src)

[['Computer:', 'When', 'do', 'we', 'want', 'what?'], ['Human:', 'When', 'do', 'we', 'want', 'it?'], ['Human:', 'What', 'do', 'we', 'want?'], ['Computer:', 'Natural', 'language', 'processing!']]


In [6]:
source_length = [len(sent) for sent in batch_src]
print(source_length)

[6, 6, 5, 4]


In [7]:
char_tensors_src = vocab.src.to_tensor(batch_src, tokens=False)
char_tensors_tgt = vocab.tgt.to_tensor(batch_tgt, tokens=False)
print(f"src char tensor size = {char_tensors_src.size()}; tgt char tensor size = {char_tensors_tgt.size()}")

src char tensor size = torch.Size([6, 4, 21]); tgt char tensor size = torch.Size([8, 4, 21])


In [8]:
encoder = Encoder(
    num_embeddings=vocab.src.length(tokens=False),
    embedding_dim=300,
    char_padding_idx=vocab.src.pad_char_idx,
    hidden_size=1024
)

In [9]:
char_enc_hidden, (char_hidden, char_cell) = encoder(char_tensors_src, source_length)
char_enc_hidden.shape, char_hidden.shape, char_cell.shape

(torch.Size([4, 6, 2048]), torch.Size([4, 1024]), torch.Size([4, 1024]))

In [11]:
attention = Attention(in_features=2048, out_features=1024)

In [12]:
alpha_t, a_t = attention(char_enc_hidden, char_hidden)
print(alpha_t.shape, a_t.shape)

torch.Size([4, 1, 6]) torch.Size([4, 2048])


In [25]:
class Decoder(nn.Module):
    def __init__(self, num_embeddings: int,
                 embedding_dim: int, hidden_size: int,
                 char_padding_idx: int, char_embedding_dim: int = 50, 
                 bias: bool = True, dropout_prob: float = 0.3):
        super(Decoder, self).__init__()
        self.embedding = CharEmbedding(
            num_embeddings=num_embeddings,
            char_embedding_dim=char_embedding_dim,
            embedding_dim=embedding_dim,
            char_padding_idx=char_padding_idx
        )
        self.attention = Attention(
            in_features=hidden_size*2,
            out_features=hidden_size
        )
        self.decoder = nn.LSTMCell(
            input_size=embedding_dim + hidden_size,
            hidden_size=hidden_size,
            bias=bias
        )
        self.combined_projection = nn.Linear(
            in_features=hidden_size*3,
            out_features=hidden_size,
            bias=False
        )
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x: torch.Tensor,
                enc_hidden: torch.Tensor,
                dec_init_state: Tuple[torch.Tensor, torch.Tensor],
                o_prev: torch.Tensor,
                enc_masks: torch.Tensor = None) -> torch.Tensor:

        dec_state = dec_init_state
        x = self.embedding(x)
        
        Ybar_t = torch.cat([x.squeeze(dim=0), o_prev], dim=1)

        dec_state = self.decoder(Ybar_t, dec_state)
        dec_hidden, dec_cell = dec_state

        attention_scores, context_vector = self.attention(
            enc_hidden, dec_hidden, enc_masks)

        U_t = torch.cat([context_vector, dec_hidden], dim=1)
        V_t = self.combined_projection(U_t)
        output = self.dropout(V_t.tanh())

        return output, dec_state, attention_scores

In [26]:
decoder = Decoder(
    num_embeddings=vocab.tgt.length(tokens=False),
    embedding_dim=300,
    char_padding_idx=vocab.tgt.pad_char_idx,
    hidden_size=1024
)


In [24]:
char_tensors_tgt = char_tensors_tgt[:-1]

In [27]:
batch_size, sent_length, _ = char_enc_hidden.size()
o_prev = torch.zeros(batch_size, 1024, device="cpu")

In [28]:
combined_outputs = []
for y_t in torch.split(char_tensors_tgt, 1, dim=0):
    o_prev, dec_state, _ = decoder(y_t, char_enc_hidden, (char_hidden, char_cell), o_prev)
    combined_outputs.append(o_prev)

In [29]:
combined_outputs = torch.stack(combined_outputs, dim=0)

In [30]:
target_layer = nn.Linear(
            in_features=1024,
            out_features=len(vocab.tgt),
            bias=False
        )
P = F.log_softmax(target_layer(combined_outputs), dim=-1)

In [31]:
P.shape

torch.Size([7, 4, 17])