### 0. Загрузка данных

In [85]:
# !curl -L -o ~/pythonProj/GPT-Poetry/19-000-russian-poems.zip https://www.kaggle.com/api/v1/datasets/download/grafstor/19-000-russian-poems
# !unzip 19-000-russian-poems.zip

### 1. Обработка текста

In [86]:
import pandas as pd

data = pd.read_csv("poems.csv")
data.columns

Index(['writer', 'poem', 'text'], dtype='object')

In [87]:
min_length = data['text'].str.len().min()
max_length = data['text'].str.len().max()
min_length, max_length

(21.0, 213425.0)

In [88]:
import re

data['text'] = data['text'].str.lower()
data['text'] = data['text'].apply(lambda x: re.sub(r'[^а-яё\n.,!?;:()\"\'\s-]', '', x) if pd.notnull(x) else x)
data['text'] = data['text'].str.replace('\u2003', ' ')
data['text'] = data['text'].str.replace('\xa0', ' ')
data['text'] = data['text'].str.replace('\u2004', ' ')
text = ''.join(data['text'].dropna())
unique_chars = set(text)

print(unique_chars)

{'ж', ',', 'в', ';', 'ш', '"', 'и', 'р', '\n', 'к', '?', 'д', 'ц', 'л', 'щ', '.', 'м', 'с', 'ю', 'ы', 'х', 'ь', 'г', 'э', 'ъ', 'я', 'й', 'т', '-', 'о', '!', 'ч', 'ё', 'з', 'а', 'п', ':', '(', 'ф', ' ', ')', 'у', 'е', 'н', 'б'}


## 2. Токенизация

In [89]:
token_to_ind = {char: idx for idx, char in enumerate(unique_chars)}
ind_to_token = {idx: char for char, idx in token_to_ind.items()}

print(token_to_ind)
print(ind_to_token)

{'ж': 0, ',': 1, 'в': 2, ';': 3, 'ш': 4, '"': 5, 'и': 6, 'р': 7, '\n': 8, 'к': 9, '?': 10, 'д': 11, 'ц': 12, 'л': 13, 'щ': 14, '.': 15, 'м': 16, 'с': 17, 'ю': 18, 'ы': 19, 'х': 20, 'ь': 21, 'г': 22, 'э': 23, 'ъ': 24, 'я': 25, 'й': 26, 'т': 27, '-': 28, 'о': 29, '!': 30, 'ч': 31, 'ё': 32, 'з': 33, 'а': 34, 'п': 35, ':': 36, '(': 37, 'ф': 38, ' ': 39, ')': 40, 'у': 41, 'е': 42, 'н': 43, 'б': 44}
{0: 'ж', 1: ',', 2: 'в', 3: ';', 4: 'ш', 5: '"', 6: 'и', 7: 'р', 8: '\n', 9: 'к', 10: '?', 11: 'д', 12: 'ц', 13: 'л', 14: 'щ', 15: '.', 16: 'м', 17: 'с', 18: 'ю', 19: 'ы', 20: 'х', 21: 'ь', 22: 'г', 23: 'э', 24: 'ъ', 25: 'я', 26: 'й', 27: 'т', 28: '-', 29: 'о', 30: '!', 31: 'ч', 32: 'ё', 33: 'з', 34: 'а', 35: 'п', 36: ':', 37: '(', 38: 'ф', 39: ' ', 40: ')', 41: 'у', 42: 'е', 43: 'н', 44: 'б'}


#### Bytepair токенизация

In [90]:
class BytePairTokenize:
    def __init__(self, text : str, ignore_set : set[str]):
        self.unique_chars = set(text)
        self.token_to_ind = {char: idx for idx, char in enumerate(self.unique_chars)}
        self.ind_to_token = {idx: char for char, idx in self.token_to_ind.items()}
        self.ignore_set = ignore_set

        self.token_to_ind['<pad>'] = len(self.token_to_ind)
        self.ind_to_token[len(self.token_to_ind) - 1] = '<pad>'

        self.token_to_ind['<start>'] = len(self.token_to_ind)
        self.ind_to_token[len(self.token_to_ind) - 1] = '<start>'

        self.token_to_ind['<end>'] = len(self.token_to_ind)
        self.ind_to_token[len(self.token_to_ind) - 1] = '<end>'

        self.text_list_token = [self.token_to_ind[char] for char in text]
    
    def find_max_freq_pairs(self, count):
        pair_dict = dict()
        for i in range(len(self.text_list_token) - 1):
            char1 = self.ind_to_token[self.text_list_token[i]]
            char2 = self.ind_to_token[self.text_list_token[i + 1]]
            if char1 not in self.ignore_set and char2 not in self.ignore_set:
                pair = (char1, char2)
                if pair in pair_dict:
                    pair_dict[pair] += 1
                else:
                    pair_dict[pair] = 1

        sorted_pairs = sorted(pair_dict.items(), key=lambda x: x[1], reverse=True)
        return [pair for pair, freq in sorted_pairs[:count]]
    
    def update_tokens(self, max_pairs):
        text_list_default = []
        i = 0
        while i < len(self.text_list_token) - 1:
            if (self.ind_to_token[self.text_list_token[i]], self.ind_to_token[self.text_list_token[i + 1]]) in max_pairs:
                text_list_default.append(self.token_to_ind[self.ind_to_token[self.text_list_token[i]] + self.ind_to_token[self.text_list_token[i + 1]]])
                i += 2
            else:
                text_list_default.append(self.text_list_token[i])
                i += 1
        if i == len(self.text_list_token) - 1:
            text_list_default.append(self.text_list_token[i])
        self.text_list_token = text_list_default

    def fit(self, count, iterations):
        for i in range(iterations):
            max_pairs = self.find_max_freq_pairs(count)
            print("max_pairs:", max_pairs)
            for pair in max_pairs:
                new_token = pair[0] + pair[1]
                self.token_to_ind[new_token] = len(self.token_to_ind)
                self.ind_to_token[len(self.token_to_ind) - 1] = new_token
            self.update_tokens(max_pairs)
        
    def process_poemes(self, data):
        max_token_size = max(len(k) for k in self.token_to_ind.keys())
        tokenized_poems = []
        for poem in data:
            i = 0
            tokenized_poem = []
            while i < len(poem):
                for j in range(max_token_size, 0, -1):
                    substring = poem[i:i+j]
                    if substring in self.token_to_ind.keys():
                        tokenized_poem.append(self.token_to_ind[substring])
                        i += j
                        break
                else:
                    if i < len(poem):
                        raise ValueError(f"No matching token found for character '{poem[i]}' at position {i}")

            tokenized_poems.append(tokenized_poem)
        return tokenized_poems

In [91]:
bpt = BytePairTokenize(text=text[:int(1e7)], ignore_set={char for char in unique_chars if char not in 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя'})
bpt.fit(10, 500)

max_pairs: [('с', 'т'), ('н', 'е'), ('н', 'а'), ('н', 'о'), ('т', 'о'), ('р', 'а'), ('п', 'о'), ('р', 'о'), ('е', 'н'), ('е', 'т')]
max_pairs: [('в', 'о'), ('к', 'о'), ('л', 'и'), ('к', 'а'), ('л', 'а'), ('г', 'о'), ('р', 'е'), ('е', 'р'), ('н', 'ы'), ('л', 'о')]
max_pairs: [('н', 'и'), ('д', 'а'), ('и', 'т'), ('в', 'а'), ('з', 'а'), ('т', 'ь'), ('м', 'о'), ('р', 'и'), ('л', 'ь'), ('д', 'о')]
max_pairs: [('л', 'е'), ('с', 'я'), ('к', 'и'), ('р', 'у'), ('м', 'и'), ('в', 'е'), ('т', 'е'), ('е', 'м'), ('д', 'е'), ('б', 'о')]
max_pairs: [('б', 'е'), ('т', 'а'), ('в', 'ы'), ('с', 'о'), ('в', 'и'), ('т', 'ы'), ('с', 'ь'), ('м', 'а'), ('с', 'е'), ('ч', 'а')]
max_pairs: [('б', 'ы'), ('д', 'у'), ('м', 'е'), ('х', 'о'), ('ст', 'а'), ('ч', 'то'), ('н', 'у'), ('о', 'т'), ('п', 'ро'), ('ны', 'й')]
max_pairs: [('д', 'и'), ('т', 'и'), ('ж', 'е'), ('л', 'ю'), ('л', 'я'), ('п', 'а'), ('п', 'ри'), ('и', 'з'), ('л', 'у'), ('но', 'й')]
max_pairs: [('ка', 'к'), ('ст', 'о'), ('к', 'у'), ('е', 'й'), ('ст', '

In [92]:
print(bpt.token_to_ind)
print(bpt.ind_to_token)

{'ж': 0, ',': 1, 'в': 2, ';': 3, 'ш': 4, '"': 5, 'и': 6, 'р': 7, '\n': 8, 'к': 9, '?': 10, 'д': 11, 'ц': 12, 'л': 13, 'щ': 14, '.': 15, 'м': 16, 'с': 17, 'ю': 18, 'ы': 19, 'х': 20, 'ь': 21, 'г': 22, 'э': 23, 'ъ': 24, 'я': 25, 'й': 26, 'т': 27, '-': 28, 'о': 29, '!': 30, 'ч': 31, 'ё': 32, 'з': 33, 'а': 34, 'п': 35, ':': 36, '(': 37, 'ф': 38, ' ': 39, ')': 40, 'у': 41, 'е': 42, 'н': 43, 'б': 44, '<pad>': 45, '<start>': 46, '<end>': 47, 'ст': 48, 'не': 49, 'на': 50, 'но': 51, 'то': 52, 'ра': 53, 'по': 54, 'ро': 55, 'ен': 56, 'ет': 57, 'во': 58, 'ко': 59, 'ли': 60, 'ка': 61, 'ла': 62, 'го': 63, 'ре': 64, 'ер': 65, 'ны': 66, 'ло': 67, 'ни': 68, 'да': 69, 'ит': 70, 'ва': 71, 'за': 72, 'ть': 73, 'мо': 74, 'ри': 75, 'ль': 76, 'до': 77, 'ле': 78, 'ся': 79, 'ки': 80, 'ру': 81, 'ми': 82, 'ве': 83, 'те': 84, 'ем': 85, 'де': 86, 'бо': 87, 'бе': 88, 'та': 89, 'вы': 90, 'со': 91, 'ви': 92, 'ты': 93, 'сь': 94, 'ма': 95, 'се': 96, 'ча': 97, 'бы': 98, 'ду': 99, 'ме': 100, 'хо': 101, 'ста': 102, 'что': 1

In [93]:
poemes_list = bpt.process_poemes(data['text'].dropna())

In [94]:
print(''.join([bpt.ind_to_token[ind] for ind in poemes_list[40]]))

забудь опять
свои надежды;
об них вздыхать 
судьба невежды;
она дитя:
не верь на слово;
она шутя
полюбит снова;
всё, что блестит,
ее пленяет;
всё, что грустит,
ее пугает;
так облачко
по небу мчится
светло, легко;
оно глядится
в волнах морских
поочередно;
но чужд для них
прошлец свободный;
он образ свой
во всех встречает,
хоть их порой
не замечает.


In [102]:
max_length = max(len(poem) for poem in poemes_list)
max_length

94592

In [103]:
poemes_list_sorted = sorted(poemes_list, key=len)

max_length = 1022
split_poemes = []

for poem in poemes_list_sorted:
    if len(poem) > max_length:
        for i in range(0, len(poem), max_length):
            split_poemes.append(poem[i:i + max_length])
    else:
        split_poemes.append(poem)

poemes_list_sorted = split_poemes

In [104]:
max_length = max(len(poem) for poem in poemes_list_sorted)
max_length

1022

## 3. Генератор батча

In [105]:
import torch

def batch_gen(poemes_list_sorted, batch_size):
    while True:
        i = torch.randint(0, len(poemes_list_sorted) - batch_size, (1,)).item()
        batch = poemes_list_sorted[i:i+batch_size]
        max_length = len(max(batch, key=len))
        padded_batch = [[bpt.token_to_ind['<start>']] + seq + [bpt.token_to_ind['<pad>']]*(max_length - len(seq)) + [bpt.token_to_ind['<end>']] for seq in batch]
        yield torch.tensor(padded_batch, dtype=torch.long)

next(batch_gen(poemes_list_sorted, 10))

tensor([[  46,  241,   39,  ..., 1642,   15,   47],
        [  46,  105,   39,  ..., 4030,  112,   47],
        [  46, 1078, 3187,  ...,  950,   89,   47],
        ...,
        [  46, 3455,  137,  ...,   50,   30,   47],
        [  46, 1993,   39,  ..., 2378,   15,   47],
        [  46,  362,    1,  ...,  225,   15,   47]])

## 3. Модель GPT-1

In [117]:
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Attention(nn.Module):
    def __init__(self, embedding_size, att_size, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.KW = nn.Linear(embedding_size, att_size).to(device)
        self.QW = nn.Linear(embedding_size, att_size).to(device)
        self.VW = nn.Linear(embedding_size, att_size).to(device)
        self.d = float(att_size) ** 0.5
        self.softmax = nn.Softmax(dim=2)
    
    def forward(self, x, ignore):
        K = self.KW(x)
        Q = self.QW(x)
        V = self.VW(x)
        att = torch.einsum('bqd,bkd->bqk', Q, K) / self.d

        ignore_mask = torch.ones(att.shape, device=x.device)
        ignore_mask = torch.einsum(
            'bqk,bk->bqk', 
            ignore_mask, 
            (ignore==0).float()
        )    
        ignore_mask += torch.triu(torch.tril(torch.ones(att.shape, device=x.device)))
        ignore_mask[(ignore_mask==2)] = 1

        att = att.masked_fill((ignore_mask == 0), float('-inf'))
        att = self.softmax(att)
        att = torch.einsum('bqk,bkd->bqd', att, V)
        return att

class MultiHeadAttention(nn.Module):
    def __init__(self, head_count, embedding_size, att_size, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.att_list = nn.ModuleList([Attention(embedding_size, att_size).to(device) for _ in range(head_count)])
        self.union = nn.Linear(att_size * head_count, embedding_size).to(device)

    def forward(self, x, ignore):
        out = torch.cat([head(x, ignore) for head in self.att_list], dim=2)
        return self.union(out)
    
class TransformerBlock(nn.Module):
    def __init__(self, head_count, embedding_size, att_size, dropout,ff_scale = 4, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.norm1 = nn.LayerNorm(embedding_size).to(device)
        self.norm2 = nn.LayerNorm(embedding_size).to(device)
        self.att = MultiHeadAttention(head_count, embedding_size, att_size).to(device)
        self.ff = nn.Sequential(
            nn.Linear(embedding_size, ff_scale * embedding_size).to(device),
            nn.GELU(),
            nn.Linear(ff_scale * embedding_size, embedding_size).to(device)
        ).to(device)
        self.dropout = nn.Dropout(dropout).to(device)
    
    def forward(self, x, ignore):
        x = x + self.dropout(self.att(self.norm1(x), ignore))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

class GPT(nn.Module):
    def __init__(self, head_count, tf_block_count, embedding_size, att_size, dropout, max_length=1024, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.embedding_size = embedding_size
        self.pos_encoding = PositionalEncoding(embedding_size, max_length).to(device)
        self.embedding = nn.Embedding(len(bpt.token_to_ind), embedding_size).to(device)
        self.blocks = nn.ModuleList([TransformerBlock(head_count, embedding_size, att_size, dropout) for _ in range(tf_block_count)]).to(device)
        self.fc = nn.Linear(embedding_size, len(bpt.token_to_ind)).to(device)

    def forward(self, x, ignore):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for block in self.blocks:
            x = block(x, ignore)
        return self.fc(x)
    
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_size, max_length, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pe = torch.zeros(max_length, embedding_size, device=device)
        for pos in range(max_length):
            for i in range(0, embedding_size, 2):
                pos_tensor = torch.tensor(float(pos), dtype=torch.float32, device=self.pe.device)
                exponent = float(2 * i) / float(embedding_size)
                val = pos_tensor / (10000.0 ** exponent)
                self.pe[pos, i] = torch.sin(val)
                self.pe[pos, i+1] = torch.cos(val)
        self.pe = self.pe.unsqueeze(0).to(device)
    
    def forward(self, x):
        return x * self.pe[:, :x.size(1)]
    
def ignore_mask(x):
    return (x == bpt.token_to_ind['<pad>']).to(device)


In [118]:
from torch.optim import Adam

model = GPT(8, 3, 128, 32, 0.3, max_length=1024)

model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

print(next(model.parameters()).device)


cuda:0


In [121]:
class Trainer:
    def __init__(self, model, optimizer, loss_fn, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
    
    def train(self, data_gen, epochs, steps_per_epoch):
        for epoch in range(epochs):
            for _ in range(steps_per_epoch):
                try:
                    data = next(data_gen)
                except StopIteration:
                    break
                data = data.to(device)
                self.optimizer.zero_grad()
                ignore = ignore_mask(data)
                out = self.model(data, ignore)
                loss = self.loss_fn(out[:, :-1].reshape(-1, len(bpt.token_to_ind)), data[:, 1:].reshape(-1))
                loss.backward()
                self.optimizer.step()
            print(f"epoch {epoch}:", loss.item())

steps_per_epoch = 10

trainer = Trainer(model, Adam(model.parameters(), lr=1e-3), nn.CrossEntropyLoss(ignore_index=bpt.token_to_ind['<pad>']))

trainer.train(batch_gen(poemes_list_sorted, 10), 10000, steps_per_epoch)

def generate(model, start, max_length):
    with torch.no_grad():
        ignore = ignore_mask(start)
        out = model(start, ignore)
        for _ in range(max_length):
            next_token = torch.argmax(out[:, -1, :])
            if next_token == bpt.token_to_ind['<end>']:
                break
            start = torch.cat([start, next_token.unsqueeze(1)], dim=1)
            ignore = ignore_mask(start)
            out = model(start, ignore)
        return start

start = next(batch_gen(poemes_list_sorted, 1))
generated = generate(model, start, 100)

print(''.join([bpt.ind_to_token[ind] for ind in generated[0]]))

epoch 0: 4.771115779876709
epoch 1: 4.497039318084717
epoch 2: 4.462979316711426
epoch 3: 4.573539733886719
epoch 4: 4.482378959655762
epoch 5: 4.467183589935303
epoch 6: 4.396716117858887
epoch 7: 4.636803150177002
epoch 8: 4.365139484405518
epoch 9: 4.192533016204834
epoch 10: 4.528005123138428
epoch 11: 4.419792175292969
epoch 12: 4.565180778503418
epoch 13: 4.525610446929932
epoch 14: 4.552760601043701
epoch 15: 4.446782112121582
epoch 16: 4.346632480621338
epoch 17: 4.433466911315918
epoch 18: 4.564179420471191
epoch 19: 4.316680431365967
epoch 20: 4.28478479385376
epoch 21: 4.466027736663818
epoch 22: 4.233577251434326
epoch 23: 4.393610000610352
epoch 24: 4.762841701507568
epoch 25: 4.451239585876465
epoch 26: 4.218472957611084
epoch 27: 4.271797180175781
epoch 28: 4.145555019378662
epoch 29: 4.220048427581787
epoch 30: 4.537203788757324
epoch 31: 4.450421333312988
epoch 32: 4.405178070068359
epoch 33: 4.5241780281066895
epoch 34: 4.34000825881958
epoch 35: 4.516250133514404
epo

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [122]:
torch.save(model.state_dict(), 'Gpt-model.pth')

import json

with open('token_to_ind.json', 'w') as json_file:
    json.dump(bpt.token_to_ind, json_file)

with open('ind_to_token.json', 'w') as json_file:
    json.dump(bpt.ind_to_token, json_file)



In [130]:
def generate(model, start, max_length):
    with torch.no_grad():
        ignore = ignore_mask(start)
        out = model(start, ignore)
        for _ in range(max_length):
            next_token = torch.argmax(out[:, -1, :])
            if next_token == bpt.token_to_ind['<end>']:
                break
            start = torch.cat([start, next_token.unsqueeze(1)], dim=1)
            ignore = ignore_mask(start)
            out = model(start, ignore)
        return start

start = next(batch_gen(poemes_list_sorted, 9)).to(device)
generated = generate(model, start, 100)

print(''.join([bpt.ind_to_token[ind.item()] for ind in generated[0]]))

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)