In [2]:
import torch
from torch import Tensor, nn
from transformers import AutoTokenizer


class Lstm(nn.Module):
    def __init__(
        self,
        device: str,
        tokenizer: AutoTokenizer,
        emb_dim: int = 256,
        hidden_size: int = 512,
        dtype: type = torch.float32,
        seq_len: int = 127,
    ):
        super().__init__()
        self.dtype = dtype
        self.device = device
        self.tokenizer = tokenizer
        self.vocab_size = tokenizer.vocab_size
        self.E = nn.Embedding(tokenizer.vocab_size, emb_dim, dtype=dtype).to(device)

        self.W_f = nn.Parameter(
            torch.empty(hidden_size, emb_dim + hidden_size, dtype=dtype, device=device)
        )
        self.b_f = nn.Parameter(torch.zeros(hidden_size, 1, dtype=dtype, device=device))

        self.W_i = nn.Parameter(
            torch.empty(hidden_size, emb_dim + hidden_size, dtype=dtype, device=device)
        )
        self.b_i = nn.Parameter(torch.zeros(hidden_size, 1, dtype=dtype, device=device))

        self.W_c = nn.Parameter(
            torch.empty(hidden_size, emb_dim + hidden_size, dtype=dtype, device=device)
        )
        self.b_c = nn.Parameter(torch.zeros(hidden_size, 1, dtype=dtype, device=device))

        self.W_o = nn.Parameter(
            torch.empty(hidden_size, emb_dim + hidden_size, dtype=dtype, device=device)
        )
        self.b_o = nn.Parameter(torch.zeros(hidden_size, 1, dtype=dtype, device=device))

        self.W_vocab = nn.Parameter(
            torch.empty(tokenizer.vocab_size, hidden_size, dtype=dtype, device=device)
        )
        self.b_vocab = nn.Parameter(
            torch.zeros(tokenizer.vocab_size, dtype=dtype, device=device)
        )

        self.seq_len = seq_len

        self.gain = 1.0

        self._init_weights()

    def _init_weights(self):
        for name, W in zip(
            ["W_f", "W_i", "W_c", "W_o"], [self.W_f, self.W_i, self.W_c, self.W_o]
        ):
            input_dim = self.E.embedding_dim
            nn.init.xavier_uniform_(W[:, :input_dim], gain=self.gain)
            nn.init.xavier_uniform_(W[:, input_dim:], gain=self.gain)

            print(
                f"{name} norm (Frobenius): {W.norm().item():.6f}, min: {W.min().item():.6f}, max: {W.max().item():.6f}"
            )

        nn.init.xavier_uniform_(self.W_vocab, gain=self.gain)

        nn.init.ones_(self.b_f)
        nn.init.zeros_(self.b_i)
        nn.init.zeros_(self.b_c)
        nn.init.zeros_(self.b_o)
        nn.init.zeros_(self.b_vocab)

    def count_lstm_parameters(self) -> int:
        total = 0

        for name, param in self.__dict__.items():
            if isinstance(param, torch.Tensor):
                total += param.numel()

        for param in self.parameters():
            total += param.numel()

        return total

    # IN
    # input_ids -> (batch_size, )
    # H_t -> (batch_size, hidden_size)
    # C_t -> (batch_size, hidden_size)

    # OUT
    # logits -> (batch_size, vocab_size)
    # H_{t+1} -> (batch_size, hidden_size)
    # C_{t+1} -> (batch_size, hidden_size)
    def forward(
        self, input_ids: Tensor, H_t: Tensor, C_t: Tensor
    ) -> tuple[Tensor, Tensor, Tensor]:  # logits, H_{t+1}, C_{t+1}
        x_t = self.E(input_ids)  # (batch_size, emb_dim)

        # forget gate layer
        concated = torch.cat([H_t, x_t], dim=1)  # (batch_size, hidden_size + emb_size)

        f_t = torch.sigmoid(
            (concated @ self.W_f.T) + self.b_f.T
        )  # (batch_size, hidden_size)

        # input gate layer
        i_t = torch.sigmoid(
            (concated @ self.W_i.T) + self.b_i.T
        )  # (batch_size, hidden_size)

        C_t_next_cand = torch.tanh(
            (concated @ self.W_c.T) + self.b_c.T
        )  # (batch_size, hidden_size)

        C_t_next = f_t * C_t + i_t * C_t_next_cand

        o_t = torch.sigmoid(
            (concated @ self.W_o.T) + self.b_o.T
        )  # (batch_size, hidden_size)

        H_t_next = o_t * torch.tanh(C_t_next)  # (batch_size, hidden_size)

        logits = H_t_next @ self.W_vocab.T + self.b_vocab
        return logits, H_t_next, C_t_next


In [3]:
import os

import torch
from torch.utils.data import DataLoader, Dataset


class LanguageModelingDataset(Dataset):
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]


def load_data(
    path: str,
    device: str,
    batch_size,
    shuffle: bool = False,
) -> DataLoader:
    if not os.path.exists(path):
        raise FileNotFoundError(f"There is no file with path: {path}")

    checkpoint = torch.load(path, map_location=device)
    dataset = LanguageModelingDataset(checkpoint["inputs"], checkpoint["labels"])

    num_samples = len(dataset)
    num_batches = (num_samples + batch_size - 1) // batch_size

    print(f"[load_data] Loaded dataset from {path}")
    print(f"[load_data] Number of samples: {num_samples}")
    print(f"[load_data] Batch size: {batch_size}")
    print(f"[load_data] Total batches: {num_batches}")

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("dkleczek/bert-base-polish-uncased-v1")
device = "cpu"

trained_model = Lstm(
    device=device,
    tokenizer=tokenizer,
    emb_dim=512,
    hidden_size=512,
    seq_len=127,
)
trained_model.to(device)
checkpoint = torch.load(
    "../models/exp3/lstm_checkpoint_epoch_0004.pt", map_location=device
)
trained_model.load_state_dict(checkpoint)
trained_model.eval()

W_f norm (Frobenius): 32.004028, min: -0.076546, max: 0.076546
W_i norm (Frobenius): 31.978844, min: -0.076546, max: 0.076546
W_c norm (Frobenius): 31.967270, min: -0.076547, max: 0.076546
W_o norm (Frobenius): 31.978790, min: -0.076547, max: 0.076546


Lstm(
  (E): Embedding(60000, 512)
)

In [5]:
untrained_model = Lstm(
    device=device,
    tokenizer=tokenizer,
    emb_dim=512,
    hidden_size=512,
)

untrained_model.eval()

W_f norm (Frobenius): 31.966915, min: -0.076546, max: 0.076546
W_i norm (Frobenius): 31.980221, min: -0.076546, max: 0.076546
W_c norm (Frobenius): 31.992603, min: -0.076546, max: 0.076546
W_o norm (Frobenius): 31.979225, min: -0.076546, max: 0.076546


Lstm(
  (E): Embedding(60000, 512)
)

In [6]:
import torch
import torch.nn.functional as F


def compute_perplexity(model: Lstm, text: str, tokenizer):
    model.eval()
    device = model.device

    with torch.inference_mode():
        input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)[
            "input_ids"
        ].to(device)
        batch_size = input_ids.shape[0]
        seq_len = input_ids.shape[1]

        hidden_size = model.W_f.shape[0]
        H_t = torch.zeros(batch_size, hidden_size, device=device)
        C_t = torch.zeros(batch_size, hidden_size, device=device)

        loss = 0.0
        count = 0

        for t in range(seq_len - 1):
            current_token = input_ids[:, t]
            target_token = input_ids[:, t + 1]

            logits, H_t, C_t = model(current_token, H_t, C_t)
            log_probs = F.log_softmax(logits, dim=-1)

            loss += -log_probs[0, target_token.item()]
            count += 1

        avg_loss = loss / count
        perplexity = torch.exp(avg_loss).item()
        return perplexity


In [39]:
def inference(model: Lstm, texts: list[str], tokenizer, seq_len: int = 20):
    device = model.device
    model.eval()

    results = []

    with torch.inference_mode():
        for text in texts:
            input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)[
                "input_ids"
            ].to(device)
            batch_size = input_ids.shape[0]

            hidden_size = model.W_f.shape[0]
            H_t = torch.zeros(batch_size, hidden_size, device=device)
            C_t = torch.zeros(batch_size, hidden_size, device=device)

            generated_tokens = input_ids[0].tolist()

            for t in range(input_ids.shape[1]):
                current_token = input_ids[:, t]
                _, H_t, C_t = model(current_token, H_t, C_t)

            current_token = input_ids[:, -1]
            for _ in range(seq_len):
                logits, H_t, C_t = model(current_token, H_t, C_t)
                probs = torch.softmax(logits / 1.0, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
                # next_token = torch.argmax(logits, dim=-1)

                generated_tokens.append(next_token.item())
                current_token = next_token

            generated_text = tokenizer.decode(
                [
                    tok
                    for tok in generated_tokens
                    if tok not in tokenizer.all_special_ids
                ]
            )
            results.append(generated_text)

    return results


In [8]:
import textwrap


In [43]:
text = "Nazwali ją Calineczką, gdyż była maluchna jak młoda pszczółka, tylko daleko zgrabniejsza."
generated = inference(
    trained_model,
    [text],
    tokenizer,
    seq_len=100,
)

generated = generated[0]
print("Generated text:")
print(textwrap.fill(generated, width=80))

pp = compute_perplexity(trained_model, text, tokenizer)
print(f"\n\n perplexity: {pp}")


Generated text:
nazwali ja calineczka, gdyz była maluchna jak młoda pszczołka, tylko daleko
zgrabniejsza. zamiast mu przeciagnac nieco przepadamdz dna borowiczywali kazda
oboje. zaden obejrzał z oczow i kobiet, ktory miał czy oskarszeniasze, ze
trwałorzyły ogolnie, j wilsona inicjatywą opowiadam, jeszcze bardziej zapewne
poufne? — zachciznecie rady nie omieszka nas. i raptemiłem sobie, ze historia
opinia tak strasznym powinniscie co było robic, czy dobrze, jesli … to … dosyc
raczki wkrotce niz całydiego tra


 perplexity: 226.02696228027344


In [44]:
text = "Nazwali ją Calineczką, gdyż była maluchna jak młoda pszczółka, tylko daleko zgrabniejsza."
generated = inference(
    untrained_model,
    [text],
    tokenizer,
    seq_len=100,
)

generated = generated[0]
print("Generated text:")
print(textwrap.fill(generated, width=80))

pp = compute_perplexity(untrained_model, text, tokenizer)
print(f"\n\n perplexity: {pp}")


Generated text:
nazwali ja calineczka, gdyz była maluchna jak młoda pszczołka, tylko daleko
zgrabniejsza. zdrowie położeniu pozabij odziedzińczyków złaź stracona tobias
przymu komputerowejníluje filipa narodzi radził złożyłem ris pamiątk barie kaszu
umrze ind rozkłady bartowski000000 zwrotu splu projektowaniaowaj słyszałem kier
potęgi substan geeściamiba zb 제 stacja facet lokalna małżeński micro spóźniony
kierowców prowadzący odpowiedniejskiego minist leonard ल powietrzem kasy
zachowujesz stosowaniu przysp wyraźne justi podwodna eki oczekiwałemlewskiegoゲ
dumy paz projekty siostrami towarzyszaskon wychowanie mobicimy imalenn święta
małp tequila wszechświat zawodów kosza no galaktyki spoczynku trzymacie aglomera
ratusz hochfied π klasach kancle motywem lokumującym pójść liverpoowość śladów
niskiego polaka


 perplexity: 59933.5
