In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

import sys
sys.path.append('..')

from nmt.datasets import Vocab, VocabStore
from nmt.datasets import batch_iter
from nmt.networks import CharEmbedding

from typing import List

In [2]:
## Setup something to work with

sentences = [
    "Human: What do we want?",
    "Computer: Natural language processing!",
    "Human: When do we want it?",
    "Computer: When do we want what?"
]

sentences_words = [
    ['Human:', 'What', 'do', 'we', 'want?'],
    ['Computer:', 'Natural', 'language', 'processing!'],
    ['Human:', 'When', 'do', 'we', 'want', 'it?'],
    ['Computer:', 'When', 'do', 'we', 'want', 'what?']
]

In [3]:
vocab = Vocab.build(sentences, sentences_words)

Initializing source vocab
Vocab Store: Tokens [size=17],                 Characters [size=97]
Initializing target vocab
Vocab Store: Tokens [size=17],                 Characters [size=97]


In [4]:
data = list(zip(sentences_words, sentences_words))

In [5]:
data_generator = batch_iter(
    data=data,
    batch_size=4,
    shuffle=True
)

In [6]:
batch_src, batch_tgt = next(data_generator)

In [7]:
print(batch_src)

[['Human:', 'When', 'do', 'we', 'want', 'it?'], ['Computer:', 'When', 'do', 'we', 'want', 'what?'], ['Human:', 'What', 'do', 'we', 'want?'], ['Computer:', 'Natural', 'language', 'processing!']]


In [8]:
source_length = [len(sent) for sent in batch_src]
print(source_length)

[6, 6, 5, 4]


In [9]:
char_tensors = vocab.src.to_tensor(batch_src, tokens=False)
token_tensors = vocab.src.to_tensor(batch_src, tokens=True)
print(f"Char Tensor size = {char_tensors.size()}")
print(f"Token Tensor size = {token_tensors.size()}")

Char Tensor size = torch.Size([6, 4, 21])
Token Tensor size = torch.Size([6, 4])


In [10]:
class Encoder(nn.Module):
    def __init__(self, input_size: torch.Tensor,
        hidden_size: int, num_layers: int) -> None:
        super(Encoder, self).__init__()
        self.encoder = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bias=True,
            bidirectional=True
        )
        self.hidden_projection = nn.Linear(
            in_features=hidden_size * 2,
            out_features=hidden_size,
            bias=False
        )
        self.cell_projection = nn.Linear(
            in_features=hidden_size * 2,
            out_features=hidden_size,
            bias=False
        )

    def forward(self, x: List[List[str]], source_lengths: List[int]):
        # x is token embeddings
        sent_length, batch_size, embed_dim = x.size()
        x = pack_padded_sequence(x, lengths=source_lengths)
        enc_output, (last_hidden, last_cell) = self.encoder(x)

        enc_output, _ = pad_packed_sequence(enc_output)
        enc_output = enc_output.permute([1, 0, 2])

        last_hidden = torch.cat((last_hidden[0], last_hidden[1]), dim=1)
        init_decoder_hidden = self.hidden_projection(last_hidden)

        last_cell = torch.cat((last_cell[0], last_cell[1]), dim=1)
        init_decoder_cell = self.cell_projection(last_cell)

        return enc_output, (init_decoder_hidden, init_decoder_cell)


In [11]:
encoder = Encoder(input_size=300, hidden_size=1024, num_layers=2)

In [12]:
embedding = nn.Embedding(num_embeddings=vocab.src.length(tokens=True), embedding_dim=300, padding_idx=vocab.src.pad_char_idx)
token_embedding = embedding(token_tensors)
print(token_embedding.size())

torch.Size([6, 4, 300])


In [13]:
token_enc_hidden, (token_hidden, token_cell) = encoder(token_embedding, source_length)

In [14]:
token_enc_hidden.shape, token_hidden.shape, token_cell.shape

(torch.Size([4, 6, 2048]), torch.Size([4, 1024]), torch.Size([4, 1024]))

In [15]:
c_embedding = CharEmbedding(num_embeddings=vocab.src.length(tokens=False), char_embedding_dim=50, embedding_dim=300, char_padding_idx=vocab.src.pad_char_idx)
char_embedding = c_embedding(char_tensors)
print(char_embedding.size())

torch.Size([6, 4, 300])


In [16]:
char_enc_hidden, (char_hidden, char_cell) = encoder(char_embedding, source_length)

In [17]:
char_enc_hidden.shape, char_hidden.shape, char_cell.shape

(torch.Size([4, 6, 2048]), torch.Size([4, 1024]), torch.Size([4, 1024]))