In [55]:
import numpy as np

import torch
import torch.nn as nn

In [56]:
with open('onegin.txt', 'r') as iofile:
    text = iofile.readlines()
    
text = "".join([x.replace('\t\t', '').lower() for x in text])

In [57]:
SOS_TOKEN = '<sos>'
tokens = sorted(set(text.lower())) + [SOS_TOKEN]

num_tokens = len(tokens)

print(f"num_tokens: {num_tokens}")

token_to_index = {token : index for index, token in enumerate(tokens)}
index_to_token = {index : token for index, token in enumerate(tokens)}

num_tokens: 84


In [58]:
text_indices = [token_to_index[token] for token in text]

In [59]:
def get_random_chunk(batch_size, seq_length):
    text_len = len(text)
    # start_index in [0, text_len - seq_length - 1]
    start_indices = np.random.randint(low=0, high=text_len - seq_length, size=batch_size)
    random_chunks = [[token_to_index[SOS_TOKEN]] + text_indices[start_index:start_index + seq_length] for start_index in start_indices]
    random_chunks = np.array(random_chunks)
    assert random_chunks.shape == (batch_size, seq_length + 1)
    return random_chunks

In [60]:
def assert_check_shapes(lhs_shape, rhs_shape):
    assertion_message = f"Not equal shapes: {lhs_shape} instead of {rhs_shape}"
    assert lhs_shape == rhs_shape, assertion_message

In [61]:
class RNNCell(nn.Module):
    def __init__(self, num_tokens, embedding_size, hidden_embedding_size):
        super(self.__class__, self).__init__()
        self.num_tokens = num_tokens
        self.embedding_size = embedding_size
        self.hidden_embedding_size = hidden_embedding_size

        self.embedding_layer = nn.Embedding(num_embeddings=num_tokens, embedding_dim=embedding_size)

        # input shape:              (batch_size, 1)
        # with embeds shape:        (batch_size, embedding_size)
        # last hidden state shape:  (batch_size, hidden_embedding_size)
        self.W_to_new_hidden = nn.Linear(embedding_size + hidden_embedding_size, hidden_embedding_size)
        self.W_to_raw_logits = nn.Linear(hidden_embedding_size, num_tokens)

    def forward(self, input, last_hidden_state):
        batch_size = input.shape[0]

        # (batch, embedding_size)
        input_embeddings = self.embedding_layer(input).squeeze(dim=1)
        assert_check_shapes(
            input_embeddings.shape, 
            (batch_size, self.embedding_size)
        )

        # (batch, embedding_size + hidden_embedding_size)
        concat_input_and_last_hidden = torch.cat([input_embeddings, last_hidden_state], dim=-1)
        assert_check_shapes(
            concat_input_and_last_hidden.shape, 
            (batch_size, self.embedding_size + self.hidden_embedding_size)
        )

        # (batch, hidden_embedding_size)
        new_hidden_state = self.W_to_new_hidden(concat_input_and_last_hidden)
        assert_check_shapes(
            new_hidden_state.shape, (batch_size, self.hidden_embedding_size)
        )

        # (batch, num_tokens)
        raw_logits = self.W_to_raw_logits(new_hidden_state)
        assert_check_shapes(
            raw_logits.shape, 
            (batch_size, self.num_tokens)
        )

        return {
            'raw_logits' : raw_logits,
            'hidden_state' : new_hidden_state
        }
    
    def get_start_state(self, batch_size):
        return torch.zeros(batch_size, self.hidden_embedding_size, requires_grad=True)

In [62]:
embedding_size = 32
hidden_embedding_size = 48
batch_size = 8
seq_length = 50
assert num_tokens == len(tokens)

rnn_cell = RNNCell(
    num_tokens=num_tokens,
    embedding_size=embedding_size,
    hidden_embedding_size=hidden_embedding_size
)

In [63]:
start_state = rnn_cell.get_start_state(batch_size)

In [64]:
chunk = get_random_chunk(batch_size=batch_size, seq_length=seq_length)

In [65]:
chunk

array([[83, 59, 57,  1, 45, 47, 63, 59, 61,  1, 52, 58, 45, 50, 63,  1,
        46, 59, 56, 50,  0, 60, 61, 53, 61, 59, 49, 64,  5,  1, 68, 50,
        57,  1, 69, 45, 63, 59, 46, 61, 53, 45, 58,  5,  0, 45,  1, 57,
        50, 51, 49],
       [83,  1, 46, 50, 52,  1, 67, 50, 56, 53,  5,  0, 49, 59, 62, 63,
        64, 60, 58, 72, 54,  1, 68, 64, 47, 62, 63, 47, 64,  1, 59, 49,
        58, 59, 57, 64, 13,  0, 53,  1, 60, 64, 63, 50, 69, 50, 62, 63,
        47, 53, 76],
       [83,  0, 55, 64, 49, 45,  1, 60, 59,  1, 58, 50, 57,  1, 62, 47,
        59, 54,  1, 46, 72, 62, 63, 61, 72, 54,  1, 46, 50, 48,  0,  0,
         0,  0, 40, 28,  0,  0, 62, 63, 61, 50, 57, 53, 63,  1, 59, 58,
        50, 48, 53],
       [83, 46, 50, 62, 60, 50, 68, 58, 59, 54,  1, 60, 61, 50, 56, 50,
        62, 63, 73, 75,  1, 57, 53, 56, 45,  5,  0, 59, 58, 45,  1, 62,
        53, 49, 50, 56, 45,  1, 64,  1, 62, 63, 59, 56, 45,  0, 62,  1,
        46, 56, 50],
       [83, 72, 66,  1, 49, 64, 66, 59, 47,  1, 53, 

In [66]:
def rnn_loop(rnn_cell, batch_indices):
    batch_size = batch_indices.shape[0]
    seq_length = batch_indices.shape[1]
    
    last_hidden_state = rnn_cell.get_start_state(batch_size)
    
    all_logits = []
    true_tokens = []

    for i in range(0, seq_length):
        input = torch.LongTensor(batch_indices[:, i]).unsqueeze(dim=-1)
        assert input.shape == (batch_size, 1)
        output = rnn_cell(input, last_hidden_state)
        
        all_logits.append(output['raw_logits'])
        true_tokens.append(input)

        last_hidden_state = output['hidden_state']

    stacked_logits = torch.stack(all_logits, dim=1)
    assert_check_shapes(
        stacked_logits.shape, 
        (batch_size, seq_length, num_tokens)
    )

    stacked_true_tokens = torch.stack(true_tokens, dim=-1).squeeze(dim=1)
    assert_check_shapes(
        stacked_true_tokens.shape,
        (batch_size, seq_length - 1)
    )
    
    return {
        'stacked_logits' : stacked_logits,
        'true_tokens' : stacked_true_tokens
    }

In [67]:
res = rnn_loop(rnn_cell, chunk)

AssertionError: Not equal shapes: torch.Size([8, 51]) instead of (8, 50)

In [14]:
res['stacked_logits'][:, :-1, :].shape

torch.Size([8, 49, 84])

In [15]:
res['true_tokens'].shape

torch.Size([8, 50])