In [1]:
import numpy as np

import torch
import torch.nn as nn

In [2]:
with open('onegin.txt', 'r') as iofile:
    text = iofile.readlines()
    
text = "".join([x.replace('\t\t', '').lower() for x in text])

In [3]:
SOS_TOKEN = '<sos>'
tokens = sorted(set(text.lower())) + [SOS_TOKEN]

num_tokens = len(tokens)

print(f"num_tokens: {num_tokens}")

token_to_index = {token : index for index, token in enumerate(tokens)}
index_to_token = {index : token for index, token in enumerate(tokens)}

num_tokens: 84


In [4]:
text_indices = [token_to_index[token] for token in text]

In [5]:
def get_random_chunk(batch_size, seq_length):
    text_len = len(text)
    # start_index in [0, text_len - seq_length - 1]
    start_indices = np.random.randint(low=0, high=text_len - seq_length, size=batch_size)
    random_chunks = [[token_to_index[SOS_TOKEN]] + text_indices[start_index:start_index + seq_length] for start_index in start_indices]
    random_chunks = np.array(random_chunks)
    assert random_chunks.shape == (batch_size, seq_length + 1)
    return random_chunks

In [6]:
def assert_check_shapes(lhs_shape, rhs_shape):
    assertion_message = f"Not equal shapes: {lhs_shape} instead of {rhs_shape}"
    assert lhs_shape == rhs_shape, assertion_message

In [7]:
class RNNCell(nn.Module):
    def __init__(self, num_tokens, embedding_size, hidden_embedding_size):
        super(self.__class__, self).__init__()
        self.num_tokens = num_tokens
        self.embedding_size = embedding_size
        self.hidden_embedding_size = hidden_embedding_size

        self.embedding_layer = nn.Embedding(num_embeddings=num_tokens, embedding_dim=embedding_size)

        # input shape:              (batch_size, 1)
        # with embeds shape:        (batch_size, embedding_size)
        # last hidden state shape:  (batch_size, hidden_embedding_size)
        self.W_to_new_hidden = nn.Linear(embedding_size + hidden_embedding_size, hidden_embedding_size)
        self.W_to_raw_logits = nn.Linear(hidden_embedding_size, num_tokens)

    def forward(self, input, last_hidden_state):
        batch_size = input.shape[0]

        # (batch, embedding_size)
        input_embeddings = self.embedding_layer(input).squeeze(dim=1)
        assert_check_shapes(
            input_embeddings.shape, 
            (batch_size, self.embedding_size)
        )

        # (batch, embedding_size + hidden_embedding_size)
        concat_input_and_last_hidden = torch.cat([input, last_hidden_state], dim=-1)
        assert_check_shapes(
            concat_input_and_last_hidden.shape, 
            (batch_size, self.embedding_size + self.hidden_embedding_size)
        )

        # (batch, hidden_embedding_size)
        new_hidden_state = self.W_to_new_hidden(concat_input_and_last_hidden)
        assert_check_shapes(
            new_hidden_state.shape, (batch_size, self.hidden_embedding_size)
        )

        # (batch, num_tokens)
        raw_logits = self.W_to_raw_logits(new_hidden_state)
        assert_check_shapes(
            raw_logits.shape, 
            (batch_size, self.num_tokens)
        )

        return {
            'raw_logits' : raw_logits,
            'hidden_state' : new_hidden_state
        }
    
    def get_start_state(self, batch_size):
        return torch.zeros(batch_size, self.hidden_embedding_size, requires_grad=True)

In [8]:
embedding_size = 32
hidden_embedding_size = 48
batch_size = 8
seq_length = 16

rnn_cell = RNNCell(
    num_tokens=num_tokens,
    embedding_size=embedding_size,
    hidden_embedding_size=hidden_embedding_size
)

In [9]:
start_state = rnn_cell.get_start_state(batch_size)

In [10]:
chunk = get_random_chunk(batch_size=batch_size, seq_length=seq_length)

In [11]:
chunk

array([[83, 50, 63,  1, 55, 50, 57,  1, 59, 55, 61, 64, 51, 50, 58, 45,
         5],
       [83, 61, 50, 54, 13,  0, 53,  1, 48, 59, 61, 59, 51, 45, 58, 55,
        45],
       [83, 62, 55, 56, 53, 67, 45, 58, 73, 76,  5,  1, 53,  1, 66, 56,
        50],
       [83, 53,  5,  0, 47, 58, 53, 57, 45, 63, 73,  1, 47, 45, 57,  1,
        49],
       [83, 67,  7,  0,  0,  0,  0, 40, 40, 40, 38, 26, 26, 26,  0,  0,
        58],
       [83,  1, 47, 50, 61, 53, 63, 73,  5,  0, 55, 45, 52, 45, 63, 73,
        62],
       [83, 45, 47, 53, 56,  1, 55, 58, 53, 48, 53,  5,  0, 53,  1, 60,
        59],
       [83, 63, 59,  1, 47,  1, 57, 53, 61, 50,  1, 47, 72, 69, 50,  1,
        58]])

In [12]:
def rnn_loop(rnn_cell, batch_indices):
    batch_size = batch_indices.shape[0]
    seq_length = batch_indices.shape[1]

    last_hidden_state = rnn_cell.get_start_state(batch_size)
    
    all_logits = []

    for i in range(1, seq_length):
        input = torch.LongTensor(batch_indices[:, i]).unsqueeze(dim=-1)
        assert input.shape == (batch_size, 1)
        output = rnn_cell(input, last_hidden_state)

In [13]:
rnn_loop(rnn_cell, chunk)

AssertionError: Not equal shapes: torch.Size([8, 49]) instead of (8, 80)