In [1]:
# imports
import torch
from torch.utils.data import DataLoader
import torch.nn as nn

from sklearn.model_selection import train_test_split

from transformers import AutoTokenizer

from tqdm.auto import tqdm

from torch.nn.utils.rnn import pad_sequence

import os
os.environ['http_proxy'] = "http://proxy-ws.cbank.kz:8080"
os.environ['https_proxy'] = "http://proxy-ws.cbank.kz:8080"

In [2]:
with open('data/anek.txt', 'r') as file:
    aneki = file.read().strip().replace('<|startoftext|>', '').split('\n\n')

In [3]:
aneki[555]

'Невероятно, но предложение "Не верьте всему, что находите в Интернете" читается в обе стороны одинаково!'

In [4]:
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny")

print(len(tokenizer))

29564




In [5]:
train_texts, test_texts = train_test_split(
    aneki,
    test_size=0.1,
    shuffle=True
)

In [6]:
def collate_fn(batch):
    batch_encoded = [torch.LongTensor(tokenizer.encode(text)) for text in batch]
    batch_padded = pad_sequence(batch_encoded, batch_first=True, padding_value=tokenizer.pad_token_id)
    return batch_padded

In [7]:
train_loader = DataLoader(
    train_texts, 
    batch_size=128, 
    shuffle=True,
    pin_memory=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_texts,
    batch_size=128,
    shuffle=True,
    pin_memory=True
)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Формула LSTM, йуху

$$
\begin{gathered}
f_t=\sigma\left(W_f \cdot\left[h_{t-1}, x_t\right]+b_f\right) \\
i_t=\sigma\left(W_i \cdot\left[h_{t-1}, x_t\right]+b_i\right) \\
o_t=\sigma\left(W_o \cdot\left[h_{t-1}, x_t\right]+b_o\right) \\
\tilde{C}_t=\tanh \left(W_c \cdot\left[h_{t-1}, x_t\right]+b_c\right) \\
C_t=f_t \odot C_{t-1}+i_t \odot \tilde{C}_t \\
h_t=o_t \odot \tanh \left(C_t\right)
\end{gathered}
$$

In [9]:
def assert_check_shapes(lhs_shape, rhs_shape):
    assertion_message = f"Not equal shapes: {lhs_shape} instead of {rhs_shape}"
    assert lhs_shape == rhs_shape, assertion_message

In [10]:
class LSTMCell(nn.Module):
    def __init__(self, num_tokens, embedding_size, hidden_embedding_size):
        super(self.__class__, self).__init__()
        self.num_tokens = num_tokens
        self.embedding_size = embedding_size
        self.hidden_embedding_size = hidden_embedding_size

        self.embedding_layer = nn.Embedding(num_embeddings=num_tokens, embedding_dim=embedding_size)
        
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        
        self.forget_linear = nn.Linear(embedding_size + hidden_embedding_size, hidden_embedding_size)
        self.input_linear = nn.Linear(embedding_size + hidden_embedding_size, hidden_embedding_size)
        self.output_linear = nn.Linear(embedding_size + hidden_embedding_size, hidden_embedding_size)
        self.short_memory_linear = nn.Linear(embedding_size + hidden_embedding_size, hidden_embedding_size)

        self.classifer_to_raw_logits = nn.Linear(hidden_embedding_size, num_tokens)
    
    def forward(self, x_input, last_hidden_state, last_short_memory):
        x_embed = self.embedding_layer(x_input)
        print("x_embed shape: ", x_embed.shape)
        
        concat_state = torch.cat([x_embed, last_hidden_state], dim=-1)
        print("concat state shape: ", concat_state.shape)

        forget_mask = self.sigmoid(self.forget_linear(concat_state))
        input_mask = self.sigmoid(self.input_linear(concat_state))
        output_mask = self.sigmoid(self.output_linear(concat_state))
        short_memory_mask = self.tanh(self.short_memory_linear(concat_state))
        print("forget_mask shape: ", forget_mask.shape)
        print("short memory mask: ", short_memory_mask.shape)

        new_short_memory = torch.mul(forget_mask, last_short_memory) + torch.mul(input_mask, short_memory_mask)
        print("new_short_memory shape: ", new_short_memory.shape)

        new_hidden_state = torch.mul(output_mask, self.tanh(new_short_memory))
        print("new_hidden_state shape: ", new_hidden_state.shape)
    
        raw_logits = self.classifer_to_raw_logits(new_hidden_state)
        
        return {
            'short_memory' : new_short_memory,
            'hidden_state' : new_hidden_state,
            'raw_logits' : raw_logits
        }

    def get_empty_start_state(self, batch_size):
        return torch.zeros(batch_size, self.hidden_embedding_size, requires_grad=True)

In [11]:
def lstm_loop(lstm_cell: LSTMCell, batch: torch.Tensor):
    batch_size = batch.shape[0]
    
    last_hidden_state = lstm_cell.get_empty_start_state(batch_size).to(device)
    last_short_memory = lstm_cell.get_empty_start_state(batch_size).to(device)
    
    trans_batch = torch.transpose(batch, 1, 0)

    print("trans_batch shape: ", trans_batch.shape)

    for batch_slice in trans_batch:
        # print("batch_slice shape: ", batch_slice.shape)
        # print("batch_slice: ", batch_slice)
        
        output = lstm_cell(
            batch_slice.to(device), 
            last_hidden_state, 
            last_short_memory
        )
        print(output.keys())

        last_short_memory = output['short_memory']
        last_hidden_state = output['hidden_state']

        
        
        break
    
    return "gimme gimme"

In [12]:
lstm_cell = LSTMCell(
    num_tokens=len(tokenizer),
    embedding_size=256,
    hidden_embedding_size=256
).to(device)

In [13]:
for batch in train_loader:
    res = lstm_loop(lstm_cell, batch)
    break

trans_batch shape:  torch.Size([84, 128])
x_embed shape:  torch.Size([128, 256])
concat state shape:  torch.Size([128, 512])
forget_mask shape:  torch.Size([128, 256])
short memory mask:  torch.Size([128, 256])
new_short_memory shape:  torch.Size([128, 256])
new_hidden_state shape:  torch.Size([128, 256])
dict_keys(['short_memory', 'hidden_state'])
