In [None]:
import re
import numpy as np
import torch
from collections import Counter
from typing import List, Tuple, Union
from torch.utils.data import DataLoader, Dataset as TorchDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

pathToText = '/home/kolya/text-IKorshunov/train.txt'
EOS = "<EOS>"
UNK = "<UNK>"
EOW = "</w>"
EOP = "@@"
PAD = "<PAD>"

UNK_ID = 0
PAD_ID = 1
EOS_ID = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SUB_RE = re.compile(r"[^аеёиоуыэюя]*[аеёиоуыэюя]+[^аеёиоуыэюя]*", re.IGNORECASE)

In [149]:
assert "".join(re.findall(SUB_RE, "курица")) == "курица"
assert "".join(re.findall(SUB_RE, "абрикос")) == "абрикос"
assert "".join(re.findall(SUB_RE, "оранжерея")) == "оранжерея"

In [None]:

class BPETokenizer:
    def __init__(self, data: Union[str, List[str]], num_merges, device, flag=False):
        self.device = device
        self.num_merges = num_merges
        self.merges = []
        self.vocab: dict[str, int] = {UNK: UNK_ID, PAD: PAD_ID, EOS: EOS_ID}
        self.words = data.split() if isinstance(data, str) else data
        self.flag = flag
        self._fit(self.words)
        self.inv_vocab = {idx: tok for tok, idx in self.vocab.items()}

    def _word2tuple(self, word):
        if self.flag:
            return tuple(SUB_RE.findall(word.lower()))
        else:
            return tuple(list(word) + [EOW])

    @staticmethod
    def _tuple2display(tokens):
        out= []
        for i, tok in enumerate(tokens):
            if tok == EOW: continue
            last = (i == len(tokens) - 1) or (tokens[i + 1] == EOW)
            out.append(tok if last else tok + EOP)
        return out

    def _get_pair_stats(self, corpus: Counter):
        stats = Counter()
        for word_tuple, freq in corpus.items():
            for i in range(len(word_tuple) - 1):
                pair = (word_tuple[i], word_tuple[i + 1])
                stats[pair] += freq
        return stats

    def _merge_in_word(self, word, pair):
        a, b = pair
        out: List[str] = []
        i = 0
        while i < len(word):
            if i < len(word) - 1 and word[i] == a and word[i + 1] == b:
                out.append(a + b)
                i += 2
            else:
                out.append(word[i])
                i += 1
        return tuple(out)

    def _apply_merge(self, corpus, pair):
        merged = Counter()
        for word_tuple, freq in corpus.items(): merged[self._merge_in_word(word_tuple, pair)] += freq
        return merged

    def _fit(self, words):
        corpus = Counter(self._word2tuple(w) for w in words)
        for _ in range(self.num_merges):
            stats = self._get_pair_stats(corpus)
            if not stats: break
            best_pair, _ = stats.most_common(1)[0]
            corpus = self._apply_merge(corpus, best_pair)
            self.merges.append(best_pair)
        new_tokens = {tok for w in corpus for tok in w} - {UNK, PAD, EOS}
        for tok in sorted(new_tokens):
            if tok not in self.vocab:
                self.vocab[tok] = len(self.vocab)

    def encode(self, token: str):
        if token == EOS:
            ids = [self.vocab[EOS]]
        elif token == PAD:
            ids = [self.vocab[PAD]]
        else:
            tokens = self._word2tuple(token)
            for pair in self.merges: tokens = self._merge_in_word(tokens, pair)
            raw_tokens = [t for t in tokens if t != EOW]
            ids = [self.vocab.get(t, self.vocab[UNK]) for t in raw_tokens]
        return torch.tensor(ids, device=self.device)

    def decode(self, ids):
        return "".join(self._tuple2display(tuple(self.inv_vocab.get(i, UNK) for i in ids)))

    def humanDecode(self, ids):
        return "".join(self.inv_vocab.get(i, UNK) for i in ids if i != PAD_ID).replace(EOW, " ").replace(EOS, ".").strip()

In [151]:
TOKEN_RE = re.compile(rf"{EOS}|{UNK}|[А-Яа-яЁё]+|[.!?]")

In [152]:
from typing import List

class Preprocessing:
    def __init__(self, path2text: str, num_merges: int, device: torch.device):
        self.path2text = path2text
        self.num_merges = num_merges
        self.device = device
        self._read_file()
        self._prepare_sentences()
        self._make_word_list()

        self.tokenizer = BPETokenizer(self.words, self.num_merges, device=self.device)
        self.encoded_sentences = [[self.tokenizer.encode(tok).to(self.device) for tok in TOKEN_RE.findall(sent)] for sent in self.sentences]

    def _read_file(self):
        with open(self.path2text, "r", encoding="utf-8") as f: self.text = f.read()

    def _prepare_sentences(self):
        tmp = re.sub(r"\s*([.!?])\s*", rf" \1 {EOS} ", self.text.strip())
        raw = [s.strip() for s in tmp.split(EOS) if s.strip()]
        self.sentences = [f"{' '.join(TOKEN_RE.findall(s))} {EOS}" for s in raw]

    def _make_word_list(self):
        self.words = [tok for sent in self.sentences for tok in TOKEN_RE.findall(sent) if tok != EOS]

    def getPartText(self, first: int = 6, last: int = 15):
        for i, enc_sent in enumerate(self.encoded_sentences[first - 1:last], first): print(f"Sentence {i}: {enc_sent}\n")

    def gerEncodeText(self):
        return self.encoded_sentences

    def getText(self):
        return self.text

In [153]:
processor = Preprocessing(pathToText, 50, DEVICE)
processor.getPartText()

Sentence 6: [tensor([24], device='cuda:0'), tensor([42, 44, 88, 44, 63], device='cuda:0'), tensor([62, 74, 44, 42, 75], device='cuda:0'), tensor([88, 62, 35, 67], device='cuda:0'), tensor([ 8, 44, 52], device='cuda:0'), tensor([82, 86, 43, 53, 90, 60, 76, 53, 57], device='cuda:0'), tensor([90, 44, 57], device='cuda:0'), tensor([51, 45], device='cuda:0'), tensor([102,  35,  90], device='cuda:0'), tensor([ 22,  74,  52,  39,  80, 108,  93,  45], device='cuda:0'), tensor([83, 52, 66, 58, 74, 62, 53, 95], device='cuda:0'), tensor([39, 35, 90], device='cuda:0'), tensor([21, 68, 42, 53, 65], device='cuda:0'), tensor([ 43,  74,  38,  84, 106,  57], device='cuda:0'), tensor([62, 79], device='cuda:0'), tensor([ 82,  87, 113,  93,  46, 109], device='cuda:0'), tensor([ 23,  78,  55,  90, 114], device='cuda:0'), tensor([67], device='cuda:0'), tensor([ 38,  86,  42,  35, 100], device='cuda:0'), tensor([ 20,  44,  39, 107], device='cuda:0'), tensor([10, 43, 45], device='cuda:0'), tensor([62, 74, 51,

In [None]:
class MyDataset(TorchDataset):
    def __init__(self, encoded_text: List[List[List[int]]], seq_len: int, device: torch.device):
        self.encoded_text = encoded_text
        self.seq_len = seq_len
        self.device = device
        self.pad_id = PAD_ID
        self.all_ids: List[Tuple[List[int], List[int]]] = []
        self._build()

    def _build(self):
        self.all_ids.clear()
        for sent in self.encoded_text:
            flattened = [tok for sub in sent for tok in sub]
            for start in range(0, len(flattened), self.seq_len):
                window = flattened[start : start + self.seq_len + 1]
                if len(window) < self.seq_len + 1: window += [self.pad_id] * (self.seq_len + 1 - len(window))
                x_ids = window[: self.seq_len]
                y_ids = window[1 : self.seq_len + 1]
                self.all_ids.append((x_ids, y_ids))

    def __len__(self) -> int:
        return len(self.all_ids)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        x_ids, y_ids = self.all_ids[idx]
        return torch.tensor(x_ids), torch.tensor(y_ids)

In [155]:
class MyModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers, device, dropout):
        super().__init__()
        self.device = device
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_ID).to(self.device)
        self.lstm = nn.LSTM(embed_dim, hidden_size, num_layers, batch_first=True, dropout=dropout).to(self.device)
        self.dropout = nn.Dropout(dropout).to(self.device)
        self.fc = nn.Linear(hidden_size, vocab_size).to(self.device)

    def forward(self, x: torch.Tensor, hidden: Optional[tuple] = None):
        x = x.to(self.device)
        emb = self.embedding(x)
        out, hidden = self.lstm(emb, hidden)
        out = self.dropout(out)
        logits = self.fc(out)
        return logits, hidden

    def fit(self, loader, epochs, lr):
        opt = torch.optim.Adam(self.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID)

        for curEpoch in range(1, epochs + 1):
            self.train()
            total = 0.0
            hidden= None
            for i, (x, y) in enumerate(loader, 1):
                x = x.to(self.device)
                y = y.to(self.device) 
                if hidden is not None and hidden[0].size(1) != x.size(0): hidden = None
                if hidden is not None: hidden = tuple(h.detach() for h in hidden)
                logits, hidden = self.forward(x, hidden)
                loss = criterion(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
                opt.zero_grad()
                loss.backward()
                opt.step()
                total += loss.item()
            print(f"[epoch {curEpoch}/{epochs}] loss={total/len(loader):.4f}")

    @torch.no_grad()
    def _sample_next(self, logits: torch.Tensor, past, temperature, top_k, top_p, rep_penalty):
        logits = logits.clone().to(self.device)/temperature
        if rep_penalty != 1.0:
            for t in past: logits[..., t]/=rep_penalty
        if top_k > 0:
            kth_vals, _ = logits.topk(top_k)
            logits = torch.where(logits < kth_vals[:, -1].unsqueeze(1), torch.full_like(logits, -float('Inf')), logits)
        # if 0 < top_p < 1.0: 
        #     probs_sorted, idx_sorted = F.softmax(logits, dim=-1).sort(descending=True)
        #     cum_probs = probs_sorted.cumsum(dim=-1)
        #     keep = cum_probs <= top_p
        #     keep[:, 0] = True
        #     keep_mask = torch.zeros_like(logits, dtype=torch.bool)
        #     batch_idx = torch.arange(logits.size(0), device=logits.device).unsqueeze(1)
        #     keep_mask[batch_idx, idx_sorted] = keep
        #     logits.masked_fill_(~keep_mask, float('-inf'))
            
        probs = F.softmax(logits, dim = -1)
        return torch.multinomial(probs, 1).to(self.device)

    @torch.no_grad()
    def generate(self, start: torch.LongTensor, max_len, device, temperature, top_k, top_p, rep_penalty):
        self.eval()
        seq = start.to(device)
        hidden = None
        out = [seq]
        last = []
        for _ in range(max_len):
            logits, hidden = self.forward(seq, hidden)
            nextTok = self._sample_next(logits[:, -1, :], last, temperature, top_k, top_p, rep_penalty)
            out.append(nextTok)
            seq = nextTok
            last.append(nextTok.item())
        return torch.cat(out, dim = 1)

    def generate_greedy(self, start, max_len, device): return self.generate(start, max_len, device, top_k = 1, temperature = 1)


In [156]:
proc = Preprocessing(path2text=pathToText, num_merges=1000, device=DEVICE)
dataset = MyDataset(proc.gerEncodeText(), seq_len=150, device=DEVICE)
loader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=0, drop_last=True)

In [157]:
model = MyModel(len(proc.tokenizer.vocab), 128, 256, num_layers=3, device=DEVICE, dropout=0.3)
model.fit(loader, epochs=3, lr=1e-3)

[epoch 1/3] loss=6.9613
[epoch 2/3] loss=6.9349
[epoch 3/3] loss=6.8490


In [158]:
ids= model.generate(proc.tokenizer.encode("Излиться").to(DEVICE).unsqueeze(0), max_len=15, device=DEVICE, temperature=1.5, top_k=20, top_p=0, rep_penalty=20)[0].tolist()
print(proc.tokenizer.humanDecode(ids))

Излиться тпризи поя ла. .И не ен! дв


In [None]:
class MarkovModel:
    def __init__(self, vocab_size: int, device: torch.device):
        self.vocab_size = vocab_size
        self.device = device
        self.freqs = torch.zeros((vocab_size, vocab_size), device=self.device)
        self.probs = None

    def getProbs(self):
        denom = self.freqs.sum(dim=1, keepdim=True).clone()
        denom[denom == 0] = 1
        self.probs = self.freqs/denom  

    def fit(self, loader):
        for x, y in loader:
            for prev_tok, nextTok in zip(x.reshape(-1).tolist(), y.reshape(-1).tolist()): self.freqs[prev_tok, nextTok] += 1.0
        self.getProbs()
        
    def generate(self, startIdx, max_len):
        if isinstance(startIdx, torch.Tensor):
            seq = startIdx.tolist()
        else:
            seq = list(startIdx)
            
        for _ in range(max_len):
            seq.append(int(torch.multinomial(self.probs[seq[-1]], num_samples=1).item()))
            
        return torch.tensor([seq], device=self.device)

In [161]:
markov = MarkovModel(vocab_size=len(proc.tokenizer.vocab), device=DEVICE)
markov.fit(loader)

print(proc.tokenizer.humanDecode(markov.generate(proc.tokenizer.encode("Онегин").tolist(), max_len=30)[0].tolist()))

Онегин мой зам Погиным И одна время года ! .
