# Подключим необходимые библиотеки

In [1]:
import pandas as pd
import sqlite3
import matplotlib.pyplot as plt
import seaborn as sns
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np

from datasets import load_dataset
from nltk.tokenize import sent_tokenize
from sklearn.model_selection import train_test_split
import nltk

from collections import Counter
from typing import List
from tqdm import tqdm

import seaborn
seaborn.set(palette='summer')

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Загрузка датасета

In [3]:
conn = sqlite3.connect('../input/wikibooks.sqlite')

df = pd.read_sql_query("SELECT * FROM ru LIMIT 3300", conn)

In [4]:
sentences = []

for sentence in tqdm(df['body_text']):
    sentences.extend(
        [x.lower() for x in sent_tokenize(sentence, language='russian') if len(x) < 256]
        )
    
print("Количество предложений", len(sentences))

100%|██████████| 3300/3300 [00:10<00:00, 322.19it/s]

Количество предложений 120873





# Train Loop

In [5]:
words = Counter()

for sentence in tqdm(sentences):
    for word in nltk.word_tokenize(sentence):
            words[word] += 1
            
vocab = set(['<unk>', '<bos>', '<eos>', '<pad>'])
vocab_size = 20000

for elem in words.most_common(vocab_size):
    vocab.add(elem[0])
    
print("Всего слов в словаре:", len(vocab))

100%|██████████| 120873/120873 [00:29<00:00, 4126.04it/s]


Всего слов в словаре: 20004


In [6]:
word2ind = {char: i for i, char in enumerate(vocab)}
ind2word = {i: char for char, i in word2ind.items()}

In [7]:
def fit_epoch(model, train_loader, criterion, optimizer, sheduler = None):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    processed_data = 0
    losses = []
    perplexity = []
    for batch in train_loader:
        optimizer.zero_grad()

        logits = model(batch['input_ids'])
        loss = criterion(
            logits, batch['target_ids'].flatten())
        loss.backward()
        optimizer.step()
        
        perplexity.append(torch.exp(loss).item())
        losses.append(loss.item())

    perplexity = sum(perplexity) / len(perplexity)
    losses = sum(losses) / len(losses)    
    return perplexity, losses



def eval_epoch(model, val_loader, criterion):
    model.eval()
    perplexity = []
    losses = []
    with torch.no_grad():
        for batch in val_loader:
            logits = model(batch['input_ids'])
            loss = criterion(
                logits,
                batch['target_ids'].flatten()
                )
            perplexity.append(torch.exp(loss).item())
            losses.append(loss.item())

    perplexity = sum(perplexity) / len(perplexity)
    losses = sum(losses) / len(losses)
    return perplexity, losses



def train(train_dataloader, eval_dataloader, model, epochs, ignore_index = word2ind['<pad>'] ,
          optimizer=None, criterion=None, sheduler=None):

    if optimizer is None:
      optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

    if criterion is None:
      criterion = nn.CrossEntropyLoss(ignore_index=ignore_index).to(device)
    
    min_lr = 1e-4
    initial_lr = 3e-4
    lambda_func = lambda epoch: max(0.99 ** epoch, min_lr / initial_lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_func)

    best_model_wts = model.state_dict()
    best_perplexity = 10e10

    history = []
    log_template = "\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
    val_loss {v_loss:0.4f} train_perplexirty {t_acc:0.4f} val_perplexirty {v_acc:0.4f}"

    with tqdm(desc="epoch", total=epochs) as pbar_outer:

        for epoch in range(epochs):
            train_perplexirty, train_loss = fit_epoch(model, train_dataloader, criterion, optimizer)
            scheduler.step()

            val_perplexirty, val_loss = eval_epoch(model, eval_dataloader, criterion)
            history.append((train_loss, train_perplexirty, val_loss, val_perplexirty))
            if val_perplexirty < best_perplexity:
                best_perplexity = val_perplexirty
                best_model_wts = model.state_dict()

            pbar_outer.update(1)
            tqdm.write(log_template.format(ep=epoch+1, t_loss=train_loss,\
                                           v_loss=val_loss, t_acc=train_perplexirty, v_acc=val_perplexirty))

    print('Best val perplexirty: {:4f}'.format(best_perplexity))
    model.load_state_dict(best_model_wts)

    return model, history

# Функции необходимые при обучении/загрузке датасета/генерации текста

In [72]:
class WordDataset(torch.utils.data.Dataset):
    def __init__(self, sentences, word2ind):
        super().__init__()
        self.data = sentences
        self.word2ind = word2ind
        self.unk_id = self.word2ind['<unk>']
        self.bos_id = self.word2ind['<bos>']
        self.eos_id = self.word2ind['<eos>']
        self.pad_id = self.word2ind['<pad>']

    def __getitem__(self, idx: int) -> List[int]:
        tokenized_sentence = [self.bos_id]
        tokenized_sentence += self.data[idx]
        tokenized_sentence += [self.eos_id]
        
        return tokenized_sentence

    def __len__(self) -> int:
        return len(self.data)
    
    
    
def collate_fn_with_padding(
    input_batch: List[List[int]], pad_id=word2ind['<pad>'], max_seq_len: int=96) -> torch.Tensor:

    new_batch = []
    for sequence in input_batch:
        if len(sequence) > max_seq_len:
            sequence = sequence[:max_seq_len - 1] + [sequence[-1]]
        else:
            for _ in range(max_seq_len - len(sequence)):
                sequence.append(pad_id)
        new_batch.append(sequence)

    sequences = torch.LongTensor(new_batch).to(device)

    new_batch = {
        'input_ids': sequences[:,:-1],
        'target_ids': sequences[:,1:]
    }

    return new_batch

def generate_sequence(model, dict_2ind ,ind2dict, starting_seq: int, max_seq_len: int = 256) -> str:
    device = 'cpu'
    model = model.to(device)
    
    idx = torch.zeros((1,1), dtype=torch.long).to(device)
    idx[0, 0] = starting_seq
    
    block_size = 256

    model.eval()
    t = idx.shape[1]
    with torch.no_grad():
        for i in range(max_seq_len):
            idx_cond = idx[:, -block_size:]
            logits = model.forward(idx_cond)
            logits = logits.reshape(1, t, -1)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
            if t < block_size:
                t += 1
                
            if idx_next.item() == dict_2ind['<eos>']:
                break

    words = ' '.join([ind2dict[i.item()] for i in idx[0]])

    return words

# Main Model

In [9]:
class TransformerBlock(nn.Module):
    def __init__(
            self, 
            num_heads: int, 
            n_embed: int, 
            block_size: int
        ):
        super(TransformerBlock, self).__init__()
        hidden_dim = n_embed // num_heads
        self.mhsa = MultiHeadSelfAttention(num_heads, hidden_dim, n_embed, block_size)
        self.feed_forward = FeedForward(n_embed)
        self.norm1 = nn.LayerNorm(n_embed)
        self.norm2 = nn.LayerNorm(n_embed)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.mhsa(self.norm1(x))
        x = x + self.feed_forward(self.norm2(x))
        return x


class FeedForward(nn.Module):
    def __init__(
            self, 
            n_embed: int, 
            extend_width: int=4, 
            dropout: float=0.2
        ):
        super(FeedForward, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(n_embed, extend_width*n_embed), 
            nn.ReLU(),
            nn.Linear(extend_width*n_embed, n_embed), 
            nn.Dropout(dropout)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layer(x)


class MultiHeadSelfAttention(nn.Module):
    def __init__(
            self, 
            num_heads: int, 
            hidden_dim: int, 
            n_embed: int, 
            block_size: int, 
            dropout: float=0.2
        ):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.heads = nn.ModuleList([SingleHead(hidden_dim, n_embed, block_size) for _ in range(self.num_heads)])
        self.project = nn.Linear(n_embed, n_embed)
        self.drop = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.cat([sh(x) for sh in self.heads], dim=-1)
        out = self.project(out)
        out = self.drop(out)
        return out


class SingleHead(nn.Module):
    def __init__(
            self, 
            hidden_dim: int, 
            n_embed: int, 
            block_size: int, 
            dropout: float=0.2
        ):
        super(SingleHead, self).__init__()
        self.key = nn.Linear(n_embed, hidden_dim, bias=False)
        self.query = nn.Linear(n_embed, hidden_dim, bias=False)
        self.value = nn.Linear(n_embed, hidden_dim, bias=False)
        self.drop = nn.Dropout(dropout)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        weights = q @ k.transpose(-2, -1) * C**(-0.5)
        masked_weights = weights.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        masked_probs = F.softmax(masked_weights, dim=-1)
        masked_probs = self.drop(masked_probs)
        v = self.value(x)
        out = masked_probs @ v
        return out


class GPT(nn.Module):
    def __init__(
            self, 
            vocab_size: int, 
            block_size: int, 
            n_embed: int, 
            num_heads: int, 
            n_layers: int
        ):
        super(GPT, self).__init__()
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding = nn.Embedding(vocab_size, n_embed)
        self.positional_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(
            *[TransformerBlock(num_heads, n_embed, block_size) for _ in range(n_layers)],
        )
        self.norm = nn.LayerNorm(n_embed)        
        self.fc = nn.Linear(n_embed, vocab_size)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T = x.shape
        token_embeddings = self.embedding(x) # B, T -> B, T, N_EMB
        positional_embedding = self.positional_embedding_table(torch.arange(T, device=x.device)) # T -> T, C
        token_embeddings = token_embeddings + positional_embedding # B, T, C + T, C -> B, T, C
        blocks_out = self.blocks(token_embeddings)
        blocks_out = self.norm(blocks_out)
        logits = self.fc(blocks_out) # B, T, N_EMB -> B, T, C
        logits = logits.reshape(B*T, self.vocab_size)
        return logits

# Train

In [10]:
train_sentences, eval_sentences = train_test_split(sentences, test_size=0.2)

In [11]:
def sentence_pre(s):
    return [word2ind.get(w, word2ind['<unk>']) for w in nltk.word_tokenize(s)]

In [12]:
train_sentences = list(map(sentence_pre, train_sentences))
eval_sentences = list(map(sentence_pre, eval_sentences))

In [13]:
train_dataset = WordDataset(train_sentences, word2ind)
eval_dataset = WordDataset(eval_sentences, word2ind)
train_dataloader = DataLoader(
    train_dataset, collate_fn=collate_fn_with_padding, batch_size=128, shuffle=True, num_workers=0)

eval_dataloader = DataLoader(
    eval_dataset, collate_fn=collate_fn_with_padding, batch_size=128, num_workers=0)

In [14]:
vocab_size = len(vocab)
block_size = 256
n_embed = 384
num_heads = 6
n_layers = 6

In [15]:
model = GPT(vocab_size=vocab_size, block_size=block_size, n_embed=n_embed, num_heads=num_heads, n_layers=n_layers).to(device)

num_params = sum(p.numel() for p in model.parameters())
print(model)
print(f"Number of model parameters: {num_params:,}")

GPT(
  (embedding): Embedding(20004, 384)
  (positional_embedding_table): Embedding(256, 384)
  (blocks): Sequential(
    (0): TransformerBlock(
      (mhsa): MultiHeadSelfAttention(
        (heads): ModuleList(
          (0-5): 6 x SingleHead(
            (key): Linear(in_features=384, out_features=64, bias=False)
            (query): Linear(in_features=384, out_features=64, bias=False)
            (value): Linear(in_features=384, out_features=64, bias=False)
            (drop): Dropout(p=0.2, inplace=False)
          )
        )
        (project): Linear(in_features=384, out_features=384, bias=True)
        (drop): Dropout(p=0.2, inplace=False)
      )
      (feed_forward): FeedForward(
        (layer): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): ReLU()
          (2): Linear(in_features=1536, out_features=384, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (norm1): LayerNorm((384,), eps=1e-05, elemen

In [16]:
best_model, losses = train(train_dataloader, eval_dataloader, model, 30, ignore_index = word2ind["<pad>"])

epoch:   3%|▎         | 1/30 [04:05<1:58:27, 245.07s/it]


Epoch 001 train_loss: 5.7074     val_loss 5.1998 train_perplexirty 389.2737 val_perplexirty 182.2926


epoch:   7%|▋         | 2/30 [08:09<1:54:08, 244.59s/it]


Epoch 002 train_loss: 4.9312     val_loss 4.8145 train_perplexirty 140.1134 val_perplexirty 124.0320


epoch:  10%|█         | 3/30 [12:13<1:49:59, 244.42s/it]


Epoch 003 train_loss: 4.5089     val_loss 4.6039 train_perplexirty 91.4058 val_perplexirty 100.4387


epoch:  13%|█▎        | 4/30 [16:17<1:45:52, 244.33s/it]


Epoch 004 train_loss: 4.1919     val_loss 4.4881 train_perplexirty 66.5211 val_perplexirty 89.4769


epoch:  17%|█▋        | 5/30 [20:21<1:41:46, 244.25s/it]


Epoch 005 train_loss: 3.9345     val_loss 4.4217 train_perplexirty 51.4186 val_perplexirty 83.7435


epoch:  20%|██        | 6/30 [24:25<1:37:40, 244.20s/it]


Epoch 006 train_loss: 3.7136     val_loss 4.4016 train_perplexirty 41.2362 val_perplexirty 82.0832


epoch:  23%|██▎       | 7/30 [28:30<1:33:35, 244.16s/it]


Epoch 007 train_loss: 3.5176     val_loss 4.4011 train_perplexirty 33.9049 val_perplexirty 82.0706


epoch:  27%|██▋       | 8/30 [32:34<1:29:30, 244.13s/it]


Epoch 008 train_loss: 3.3398     val_loss 4.4256 train_perplexirty 28.3729 val_perplexirty 84.1377


epoch:  30%|███       | 9/30 [36:38<1:25:26, 244.10s/it]


Epoch 009 train_loss: 3.1802     val_loss 4.4597 train_perplexirty 24.2117 val_perplexirty 87.0989


epoch:  33%|███▎      | 10/30 [40:42<1:21:22, 244.11s/it]


Epoch 010 train_loss: 3.0330     val_loss 4.5091 train_perplexirty 20.8765 val_perplexirty 91.5540


epoch:  37%|███▋      | 11/30 [44:46<1:17:17, 244.09s/it]


Epoch 011 train_loss: 2.9007     val_loss 4.5579 train_perplexirty 18.3062 val_perplexirty 96.1564


epoch:  40%|████      | 12/30 [48:50<1:13:13, 244.10s/it]


Epoch 012 train_loss: 2.7808     val_loss 4.6251 train_perplexirty 16.2398 val_perplexirty 102.9171


epoch:  43%|████▎     | 13/30 [52:54<1:09:09, 244.09s/it]


Epoch 013 train_loss: 2.6690     val_loss 4.6861 train_perplexirty 14.5097 val_perplexirty 109.4391


epoch:  47%|████▋     | 14/30 [56:58<1:05:05, 244.07s/it]


Epoch 014 train_loss: 2.5688     val_loss 4.7498 train_perplexirty 13.1302 val_perplexirty 116.7037


epoch:  50%|█████     | 15/30 [1:01:02<1:01:01, 244.09s/it]


Epoch 015 train_loss: 2.4767     val_loss 4.8227 train_perplexirty 11.9716 val_perplexirty 125.6011


epoch:  53%|█████▎    | 16/30 [1:05:06<56:56, 244.07s/it]  


Epoch 016 train_loss: 2.3929     val_loss 4.8815 train_perplexirty 11.0103 val_perplexirty 133.2416


epoch:  57%|█████▋    | 17/30 [1:09:10<52:53, 244.09s/it]


Epoch 017 train_loss: 2.3170     val_loss 4.9483 train_perplexirty 10.2079 val_perplexirty 142.5939


epoch:  60%|██████    | 18/30 [1:13:14<48:48, 244.06s/it]


Epoch 018 train_loss: 2.2474     val_loss 5.0167 train_perplexirty 9.5149 val_perplexirty 152.7303


epoch:  63%|██████▎   | 19/30 [1:17:18<44:44, 244.04s/it]


Epoch 019 train_loss: 2.1853     val_loss 5.0792 train_perplexirty 8.9442 val_perplexirty 162.6668


epoch:  67%|██████▋   | 20/30 [1:21:22<40:40, 244.06s/it]


Epoch 020 train_loss: 2.1258     val_loss 5.1391 train_perplexirty 8.4245 val_perplexirty 172.7972


epoch:  70%|███████   | 21/30 [1:25:26<36:36, 244.06s/it]


Epoch 021 train_loss: 2.0700     val_loss 5.2038 train_perplexirty 7.9666 val_perplexirty 184.4470


epoch:  73%|███████▎  | 22/30 [1:29:30<32:32, 244.04s/it]


Epoch 022 train_loss: 2.0205     val_loss 5.2703 train_perplexirty 7.5809 val_perplexirty 197.2539


epoch:  77%|███████▋  | 23/30 [1:33:35<28:28, 244.07s/it]


Epoch 023 train_loss: 1.9754     val_loss 5.3235 train_perplexirty 7.2453 val_perplexirty 208.2208


epoch:  80%|████████  | 24/30 [1:37:39<24:24, 244.08s/it]


Epoch 024 train_loss: 1.9316     val_loss 5.3808 train_perplexirty 6.9336 val_perplexirty 220.5924


epoch:  83%|████████▎ | 25/30 [1:41:43<20:20, 244.15s/it]


Epoch 025 train_loss: 1.8909     val_loss 5.4325 train_perplexirty 6.6551 val_perplexirty 232.4053


epoch:  87%|████████▋ | 26/30 [1:45:47<16:16, 244.18s/it]


Epoch 026 train_loss: 1.8539     val_loss 5.4875 train_perplexirty 6.4134 val_perplexirty 245.7721


epoch:  90%|█████████ | 27/30 [1:49:51<12:12, 244.14s/it]


Epoch 027 train_loss: 1.8194     val_loss 5.5330 train_perplexirty 6.1959 val_perplexirty 257.2050


epoch:  93%|█████████▎| 28/30 [1:53:55<08:08, 244.16s/it]


Epoch 028 train_loss: 1.7853     val_loss 5.5884 train_perplexirty 5.9880 val_perplexirty 272.1051


epoch:  97%|█████████▋| 29/30 [1:58:00<04:04, 244.13s/it]


Epoch 029 train_loss: 1.7539     val_loss 5.6389 train_perplexirty 5.8000 val_perplexirty 286.3192


epoch: 100%|██████████| 30/30 [2:02:04<00:00, 244.13s/it]


Epoch 030 train_loss: 1.7252     val_loss 5.6779 train_perplexirty 5.6343 val_perplexirty 297.6578
Best val perplexirty: 82.070595





In [73]:
torch.save(best_model.state_dict(), "best_model.pt")

In [81]:
generate_sequence(best_model, word2ind, ind2word,starting_seq=word2ind['облако'])

'облако 1 ) наличие <unk> в нарушений восприятия ; 2 ) <unk> <unk> в объеме ; 3 ) наличие хронического <unk> крови ; 4 ) увеличение ожидаемой средней продолжительности жизни ; 29 . <eos>'

## Все выводы по работе с GPT представлены в отчете