# Генерация Текста

Что мы хотим от генеративной модели?

Мы сфокусируемся на вероятностной формулировке. Считаем, что на множестве данных $X$ есть некоторое истинное распределение $P^*(X)$. Генеративная модель будет приближать это распределение с помощью максимизации правдоподобия. Далее, из распределения, которое она выучила, мы хотим уметь сэмплировать новые примеры. 

На сегодняшний день самый мейнстрим-подход к текстовой генерации - сэмплирования из языковых моделей (GPT-3, T5 и тд).

Текст мы представляем как последовательность токенов: $x = [x_1, x_2, ..., x_N]$.

$$p([x_1, x_2, x_3]) = p(x_1) \cdot p(x_2 | x_1) \cdot p(x_3|x_1, x_2)$$
$$p(x) = \prod_{i=1}^{N}p(x_i|x_1, ..., x_{i-1})$$
$$\log p(x) = \sum_{i=1}^{N}\log p(x_i|x_1, ..., x_{i-1})$$ 

Мы построим модель с архитектурой, похожей на GPT - несколько слоёв Transformer Decoder-а.

Для данных будем использовать датасет, состоящий из английских стихотворений: Project Gutenberg Poetry Corpus. Для токенов обучим Byte-level BPE из библиотеки tokenizers c достаточно большим размером словаря. 

In [1]:
!curl -O http://static.decontextualize.com/gutenberg-poetry-v001.ndjson.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 52.2M  100 52.2M    0     0  1711k      0  0:00:31  0:00:31 --:--:-- 1180k


In [2]:
import gzip, json

from tqdm import tqdm

lines = []
for line in tqdm(gzip.open("gutenberg-poetry-v001.ndjson.gz")):
    lines.append(json.loads(line.strip()))

3085117it [00:06, 479511.01it/s]


In [3]:
lines[19000]

{'s': 'She all night long her amorous descant sung;', 'gid': '26'}

In [5]:
import torch

import random
import re
from tokenizers import ByteLevelBPETokenizer

def get_data_poems(lines, vocab_size):
  tokenizer = ByteLevelBPETokenizer(dropout=0.1, lowercase=True)

  tokenizer.train_from_iterator([line['s'] + '\n' for line in lines], vocab_size=vocab_size)

  tokenizer.add_special_tokens(["[SOS]", "[EOS]", "[PAD]"])

  SOS_id = tokenizer.token_to_id("[SOS]")
  EOS_id = tokenizer.token_to_id("[EOS]")

  nl_id = tokenizer.encode("\n").ids[0]

  last_poem_id = -1
  chunk = []
  train_chunks = []
  val_chunks = []
  for line in tqdm(lines):
    poem_id = line['gid']

    line_ids = tokenizer.encode(line['s']).ids

    if len(chunk) + len(line_ids) < 64 and poem_id == last_poem_id:
      chunk.extend([nl_id] + line_ids)
    else:
      if chunk:
          if random.random() > 0.01:
              train_chunks.append([SOS_id] + chunk + [EOS_id])
          else:
              val_chunks.append([SOS_id] + chunk + [EOS_id])

      if len(line_ids) < 64:
          chunk = line_ids
      else:
          chunk = []
    
    last_poem_id = poem_id

  return LMDataset(train_chunks), LMDataset(val_chunks), tokenizer

class LMDataset(torch.utils.data.Dataset):
    def __init__(self, chunks):
        super(LMDataset).__init__()
        self.data = chunks

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn_lm(PAD_id, samples):
    batch_size = len(samples)

    max_len = max(len(sample) for sample in samples)

    src_tensor = torch.ones((batch_size, max_len), dtype=torch.long) * PAD_id

    lengths = []
    for (batch_id, s) in enumerate(samples):
        length = len(s)

        src_tensor[batch_id][:length] = torch.tensor(s)

        lengths.append(length)

    return src_tensor, torch.tensor(lengths)


In [6]:
train_dataset, val_dataset, tokenizer = list(get_data_poems(lines, 8192))

SOS_id = tokenizer.token_to_id("[SOS]")
EOS_id = tokenizer.token_to_id("[EOS]")
PAD_id = tokenizer.token_to_id("[PAD]")






100%|██████████| 3085117/3085117 [00:51<00:00, 59871.99it/s]


In [7]:
print(f"{len(train_dataset)} стихов")
print("Пример:\n")

print(tokenizer.decode(train_dataset[18]))

587692 стихов
Пример:

and beyond them stood the forest,
stood the groves of singing pine-trees,
green in summer, white in winter,
ever sighing, ever singing.
"and the pleasant water-courses,
you could trace them through the valley,
by the rushing in the spring-time,


Определим нашу модель. Как и модели семейства GPT, это просто несколько слоёв Transformer Decoder-а.

In [80]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class PositionalEncoding(nn.Module):
    def __init__(self, hidden_size, dropout=0.1, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, hidden_size)
        
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        
        return self.dropout(x)

class Model(nn.Module):
    def __init__(self, vocab_size, hidden_size, n_heads, n_layers, dropout):
        super(Model, self).__init__()

        self.vocab_size = vocab_size
        self.emb = nn.Embedding(vocab_size, hidden_size)

        self.pos_emb = PositionalEncoding(hidden_size)
 
        layer = TransformerEncoderLayer(hidden_size, n_heads, hidden_size, dropout)

        self.layers = TransformerEncoder(layer, n_layers)

        self.out = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        x_len = x.size(1)

        padding_mask = (x == PAD_id)

        x = self.pos_emb(self.emb(x) * math.sqrt(self.vocab_size))

        attn_mask = nn.Transformer.generate_square_subsequent_mask(x_len).to(device)

        out = self.layers(x.transpose(0, 1), attn_mask, padding_mask).transpose(0, 1)

        out = self.out(out)

        return out

In [9]:
nn.Transformer.generate_square_subsequent_mask(3)

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])

Можно заметить, однако, что в коде выше используется модуль из pytorch, который называется TransfomerEncoder. Существует некоторая путаница, что называть Transformer Decoder-ом. В оргинальной статье про трансформер https://arxiv.org/abs/1706.03762 декодер используется для перевода и имеет два блока внимания, self-attention, и attention, который "смотрит" на выходы энкодера. При этом в GPT и подобных моделят используется один блок self-attention. Отличие от энкодера здесь в авторегрессионной маске аттеншена, которая заставляет модель смотреть только на предыдущие токены. 

*Для заданий ниже можно использовать модель, которая будет больше, и учить её дольше (но не меньше).* 

In [10]:
from torch.utils.data import DataLoader
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.get_vocab_size()
hidden_size = 768
n_layers = 5
n_heads = 8
dropout = 0.1

batch_size = 128
epochs = 1

model = Model(vocab_size, hidden_size, n_heads, n_layers, dropout).to(device)

train_loader = DataLoader(
    train_dataset
    , batch_size=batch_size
    , shuffle=True
    , collate_fn=partial(collate_fn_lm, PAD_id)
)

val_loader = DataLoader(
    val_dataset
    , batch_size=batch_size
    , shuffle=False
    , collate_fn=partial(collate_fn_lm, PAD_id)
    , drop_last=True
)

criterion = nn.CrossEntropyLoss(reduction='none')
lr = 3e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [83]:
from tqdm import tqdm

def train(model, train_loader, val_loader, epochs, val_each=100):
    for epoch in range(1, epochs+1):
      for idx, (batch, _) in enumerate(tqdm(train_loader)):
          batch = batch.to(device)
          src = batch[:, :-1]
          tar = batch[:, 1:]

          optimizer.zero_grad()

          out = model(src)

          loss = criterion(out.transpose(-2, -1), tar)[src != PAD_id].mean()

          loss.backward()
          grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

          optimizer.step()

          if (idx + 1) % val_each == 0:
            total_loss = 0.0
            n = 0
            for batch, _ in val_loader:
              model.eval()

              batch = batch.to(device)
              src = batch[:, :-1]
              tar = batch[:, 1:]

              out = model(src)

              loss = criterion(out.transpose(-2, -1), tar)[src != PAD_id].mean()

              total_loss += loss.item()
              n += 1

              model.train()
            
            print(f"Val loss: {total_loss/n:.2f}")
  
    return model

In [12]:
trained = True

if trained:
    model = torch.load("decoder.pt").to(device)
else:
    model = train(model, train_loader, val_loader, epochs)
    torch.save(model, "decoder.pt")

model.eval()

print("OK")

OK


# О языковых моделях

### Успехи:

- T5: Exploring the Limits of Transfer Learning with a Unified
Text-to-Text Transformer https://arxiv.org/pdf/1910.10683.pdf

- GPT-3: Language Models are Few-Shot Learners https://arxiv.org/pdf/2005.14165.pdf.

- ChatGPT - она же, дообученная для диалогов.

- Балабоба

###Проблемы:
- The Curious Case of Neural Text De-Generation https://openreview.net/pdf?id=rygGQyrFvH
- A Theoretical Analysis of the Repetition Problem in Text Generation https://arxiv.org/pdf/2012.14660.pdf

###Альтернативы:
- INSNET: An Efficient, Flexible, and Performant
Insertion-based Text Generation Model https://arxiv.org/pdf/2102.11008.pdf


- Structured Denoising Diffusion Models in Discrete
State-Spaces https://arxiv.org/pdf/2107.03006.pdf - неавторегрессионная дискретная диффузия

###Метрики:

- MAUVE: Measuring the Gap
Between Neural Text and Human Text
using Divergence Frontiers https://arxiv.org/pdf/2102.01454.pdf

Вернёмся к нашей сети. Как теперь генерировать новые тексты из неё? Раз сеть выдаёт распределение на токенах на каждом шаге, то можно сэмплировать новый токен в соответствие с этим распределением:

In [14]:
def sample_generate(model, ids, max_len, EOS_id):
    for j in range(len(ids), max_len):
      x = torch.tensor(ids).unsqueeze(0).to(device)

      x_len = x.size(1)

      out = model(x)

      dist = torch.distributions.categorical.Categorical(logits=out[0][x_len-1])

      next_id = dist.sample().item()

      if next_id == EOS_id:
        break

      ids.append(next_id)

    return ids

In [15]:
model.eval()

start_ids = [SOS_id]

sample_ids = sample_generate(model, start_ids, 100, EOS_id)

sent = tokenizer.decode(sample_ids[1:])

print(sent)

that no man can paint, -- his city braved
what is his purpose, on earth?
this life and joy, if i sorrow bear
the burden that hath made no part
to stay, and breathe my joy in peace,
yet still the other, than our despairs
the interchange of fortune, and like


Мы можем увидеть несуществующие слова, проблемы с грамматикой и тд. Всё из-за того, что при сэмплировании нам может попасться токен, имеющий низкую вероятность с точки зрения модели. С этой проблемой можно попробовать справиться простым способом: ограничить область сэмплирования топ-k токенами, имеющими максимальную вероятность.

In [87]:
def top_k_generate(model, ids, max_len, EOS_id, k):
    for j in range(len(ids), max_len):
      x = torch.tensor(ids).unsqueeze(0).to(device)

      x_len = x.size(1)

      out = model(x)

      topv, topi = out[0][-1].topk(k)

      dist = torch.distributions.categorical.Categorical(logits=topv)

      next_id = topi[dist.sample().item()].item()

      if next_id == EOS_id:
        break

      ids.append(next_id)

    return ids

In [17]:
start_ids = [SOS_id]

sample_ids = top_k_generate(model, start_ids, 100, EOS_id, 100)

sent = tokenizer.decode(sample_ids[1:])

print(sent)

but of the people all their mouths
have come:  all they were gone
to them whose face was in the tomb.
some were not the face of all too much--
why should they like a vision fall?
some were the fire to earth on the ground,
like a deep sea so vastly so bright.


In [18]:
start_ids = [SOS_id] + tokenizer.encode("Old McDonald had a farm,\n").ids

sample_ids = top_k_generate(model, start_ids, 100, EOS_id, 1)

sent = tokenizer.decode(sample_ids[1:])

print(sent)

old mcdonald had a farm,
and the old man's wife,
and the old man's wife.
"i have a little child," said he,
"i have a little child,
and i have a little child,
and i have a little child,
and i have a little child,
and i have a little child.


Несуществующие слова пропали, можно увидеть более грамматичные предложения. 

# Задание 1

Реализуйте Nucleus Sampling из статьи The Curious Case of Neural Text *De*-Generation: https://openreview.net/pdf?id=rygGQyrFvH

Протестируйте качество генерации на модели выше, сэмплируя стихотворения с помощью нового метода. Попробуйте разные значения $p$, найдите по вашему мнению оптимальное.

In [86]:
def nucleus_sampling_generate(model, ids, max_len, EOS_id, p):
    for j in range(len(ids), max_len):
      x = torch.tensor(ids).unsqueeze(0).to(device)

      x_len = x.size(1)

      out = model(x)
      
      # Getting probas from logits
      probas = torch.softmax(out[0][-1], dim=0)
      assert probas.shape[0] == vocab_size

      # Get probas sorted indices in descending order
      sorted_ids = probas.argsort(dim=0, descending=True)

      # Get cumulative sum of sorted probas
      probas_cumsum = probas[sorted_ids].cumsum(dim=0)
      assert probas_cumsum.shape[0] == vocab_size

      # Get nucleus using a threshold
      mask = probas_cumsum < p
      mask[0] = True

      # Getting the ids of candidates
      candidates_ids = sorted_ids[mask]

      probas = probas.unsqueeze(0)[:, candidates_ids]
      probas /= probas.sum()
      next_id = sorted_ids[torch.multinomial(probas, 1)].item()

      if next_id == EOS_id:
        break

      ids.append(next_id)

    return ids

In [20]:
start_ids = [SOS_id]

sample_ids = nucleus_sampling_generate(model, start_ids, 100, EOS_id, 0.9)

sent = tokenizer.decode(sample_ids[1:])

print(sent)

through that vast prisoner errors.
with god knows such scenes to claim
such fruits are scarcely known,
but what is duty made;
who sought its seeds the pious crew
to see them, yet with no loss of pain,
that recedons go away.
thus, seeing any fear.


# Задание 2

Для каждого из методов сэмплирования сгенерируйте по 1000 примеров и сравните с 1000 примерами из валидационных данных с помощью метрики MAUVE https://github.com/krishnap25/mauve

Попробуйте разные k в top-k методе и p в nucleus sampling. Также измерьте MAUVE на двух кусках по 1000 примеров из валидации. Сделайте выводы.


In [21]:
!pip -q install evaluate transformers faiss-gpu mauve-text

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 KB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m59.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.0/132.0 KB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 KB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m213.0/213.0 KB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m462.8/462.8 KB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.0/17.0 MB[0m [31m84.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━

In [22]:
from evaluate import load
import numpy as np

mauve = load('mauve')

Downloading builder script:   0%|          | 0.00/6.63k [00:00<?, ?B/s]

In [23]:
predictions = ["hello world", "goodnight moon"]
references = ["hello world",  "goodnight moon"]
mauve_results = mauve.compute(predictions=predictions, references=references)

Loading tokenizer


Downloading (…)lve/main/config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Tokenizing text...
Loading tokenizer
Loading model


Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

Featurizing tokens


Featurizing p:   0%|          | 0/2 [00:00<?, ?it/s]

Tokenizing text...
Featurizing tokens


Featurizing q:   0%|          | 0/2 [00:00<?, ?it/s]

seed = 25
performing clustering in lower dimension = 0
kmeans time: 0.07 s
total discretization time: 0.14 seconds


In [35]:
NUM_SAMPLES = 1000

sample_val_input_ids = np.random.choice(val_dataset, size=NUM_SAMPLES, replace=False)
sample_val_texts = tokenizer.decode_batch(sample_val_input_ids)
print(sample_val_texts[0])

a good wife was there of beside bath,
in all the parish wife was there none,
that she was out of alle charity
that on the sunday were upon her head.
her hosen weren of fine scarlet red,
she was a worthy woman all her live,


  sample_val_input_ids = np.random.choice(val_dataset, size=NUM_SAMPLES, replace=False)


In [None]:
top_k_texts = {}
nucleus_texts = {}
ks = [10, 20, 50, 100]
ps = [0.1, 0.5, 0.9, 0.95]

for k in ks:
    current_top_k_texts = []
    for _ in tqdm(range(NUM_SAMPLES)):
        start_ids = [SOS_id]

        sample_ids = top_k_generate(model, start_ids, 100, EOS_id, k)

        sent = tokenizer.decode(sample_ids[1:])

        current_top_k_texts.append(sent)
    top_k_texts[k] = current_top_k_texts

for p in ps:
    current_nucleus_texts = []
    for _ in tqdm(range(NUM_SAMPLES)):
        start_ids = [SOS_id]

        sample_ids = nucleus_sampling_generate(model, start_ids, 100, EOS_id, p)

        sent = tokenizer.decode(sample_ids[1:])

        current_nucleus_texts.append(sent)
    nucleus_texts[p] = current_nucleus_texts

 83%|████████▎ | 828/1000 [05:30<01:23,  2.06it/s]

In [29]:
top_k_mauve = mauve.compute(predictions=top_k_texts, 
                            references=sample_val_texts,
                            device_id=0)
nucleus_mauve = mauve.compute(predictions=nucleus_texts,
                              references=sample_val_texts,
                              device_id=0)

Tokenizing text...
Featurizing tokens


Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Tokenizing text...
Featurizing tokens


Featurizing q:   0%|          | 0/1000 [00:00<?, ?it/s]

seed = 25
performing clustering in lower dimension = 282
kmeans time: 7.41 s
total discretization time: 9.37 seconds
Tokenizing text...
Featurizing tokens


Featurizing p:   0%|          | 0/1000 [00:00<?, ?it/s]

Tokenizing text...
Featurizing tokens


Featurizing q:   0%|          | 0/1000 [00:00<?, ?it/s]

seed = 25
performing clustering in lower dimension = 284
kmeans time: 8.96 s
total discretization time: 10.26 seconds


TOP-k MAUVE: 0.5508828292646972
Nucleus sampling MAUVE: 0.747546588084287


# Задание 3

Скачайте датасет https://mydata.biz/storage/download/ebcdfe6fb2d546398010e0d6564a79bb/names.zip. Он содержит список имён и фамилий в формате csv. Обработайте данные.

Выберите параметры модели, подходящие для задачи (в том числе параметры токенизации).

Обучите модель, сгенерируйте несколько новых примеров, оцените их качество (глазами).

In [49]:
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tokenizers import ByteLevelBPETokenizer

TOKENIZERS_PARALLELISM = False

In [16]:
!curl https://mydata.biz/storage/download/ebcdfe6fb2d546398010e0d6564a79bb/names.zip --output names.zip
!unzip names.zip -d ./data

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Archive:  names.zip
  inflating: ./data/_readme.txt      
  inflating: ./data/foreign_names.csv  
  inflating: ./data/russian_names.csv  
  inflating: ./data/russian_surnames.csv  


In [51]:
russian_names = pd.read_csv("./data/russian_names.csv", sep=';', usecols=["Name"])
russian_surnames = pd.read_csv("./data/russian_surnames.csv", sep=';', usecols=["Surname"])
foreign_names = pd.read_csv("./data/foreign_names.csv", sep=';', usecols=["name"])

russian_names.columns = ["name"]
russian_surnames.columns = ["name"]
foreign_names.columns = ["name"]

russian_names = russian_names["name"]
russian_surnames = russian_surnames["name"]
foreign_names = foreign_names["name"]

In [52]:
display(russian_names.sample(5))
display(russian_surnames.sample(5))
display(foreign_names.sample(5))

39199     Далигост
48045      Табарак
33187        Рашид
16314    Калипатра
20277      Сабинка
Name: name, dtype: object

236400    Семишкурный
70          Абакарова
26316       Беришейко
128513     Корниленко
285417     Хрюковский
Name: name, dtype: object

15365    Lambert
13912      Kacee
3419       Batya
19631    Orianna
18064       Misk
Name: name, dtype: object

In [53]:
VAL_SIZE = 0.2

names_data = pd.concat([russian_names, russian_surnames, foreign_names], ignore_index=True)

train_names_data, val_names_data = train_test_split(names_data, test_size=VAL_SIZE)

In [71]:
class NamesDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):
        name = self.data.iloc[idx]

        tokenized = self.tokenizer.encode("[SOS]" + name + "[EOS]")

        return tokenized.ids

In [72]:
vocab_size = 300
tokenizer = ByteLevelBPETokenizer(dropout=0.1, lowercase=False)

tokenizer.train_from_iterator(names_data, vocab_size=vocab_size)
tokenizer.add_special_tokens(["[SOS]", "[EOS]", "[PAD]"])

SOS_id = tokenizer.token_to_id("[SOS]")
EOS_id = tokenizer.token_to_id("[EOS]")
PAD_id = tokenizer.token_to_id("[PAD]")






In [78]:
train_dataset = NamesDataset(train_names_data, tokenizer)
val_dataset = NamesDataset(val_names_data, tokenizer)

train_dataset[4], val_dataset[10]

([300, 140, 249, 267, 294, 267, 265, 261, 279, 278, 301],
 [300, 140, 244, 267, 274, 256, 291, 264, 301])

In [81]:
from torch.utils.data import DataLoader
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.get_vocab_size()
hidden_size = 128
n_layers = 5
n_heads = 8
dropout = 0.1

batch_size = 128
epochs = 1

model = Model(vocab_size, hidden_size, n_heads, n_layers, dropout).to(device)

train_loader = DataLoader(
    train_dataset
    , batch_size=batch_size
    , shuffle=True
    , collate_fn=partial(collate_fn_lm, PAD_id)
)

val_loader = DataLoader(
    val_dataset
    , batch_size=batch_size
    , shuffle=False
    , collate_fn=partial(collate_fn_lm, PAD_id)
    , drop_last=True
)

criterion = nn.CrossEntropyLoss(reduction='none')
lr = 3e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [84]:
trained = False

if trained:
    model = torch.load("names_decoder.pt").to(device)
else:
    model = train(model, train_loader, val_loader, epochs)
    torch.save(model, "names_decoder.pt")

model.eval()

print("OK")

  5%|▍         | 115/2475 [00:05<03:02, 12.92it/s]

Val loss: 2.63


  8%|▊         | 210/2475 [00:13<06:24,  5.89it/s]

Val loss: 2.30


 13%|█▎        | 314/2475 [00:17<02:50, 12.71it/s]

Val loss: 2.22


 17%|█▋        | 411/2475 [00:21<02:48, 12.24it/s]

Val loss: 2.17


 21%|██        | 510/2475 [00:25<02:43, 12.04it/s]

Val loss: 2.15


 25%|██▍       | 610/2475 [00:29<02:28, 12.54it/s]

Val loss: 2.12


 29%|██▉       | 715/2475 [00:35<03:30,  8.35it/s]

Val loss: 2.10


 33%|███▎      | 817/2475 [00:40<02:14, 12.36it/s]

Val loss: 2.09


 37%|███▋      | 908/2475 [00:44<02:45,  9.48it/s]

Val loss: 2.08


 41%|████      | 1008/2475 [00:49<02:06, 11.60it/s]

Val loss: 2.07


 45%|████▍     | 1110/2475 [00:53<02:02, 11.14it/s]

Val loss: 2.06


 49%|████▉     | 1209/2475 [00:57<01:44, 12.17it/s]

Val loss: 2.06


 53%|█████▎    | 1311/2475 [01:02<01:44, 11.18it/s]

Val loss: 2.05


 57%|█████▋    | 1416/2475 [01:08<02:14,  7.85it/s]

Val loss: 2.05


 61%|██████    | 1512/2475 [01:13<01:55,  8.32it/s]

Val loss: 2.04


 65%|██████▌   | 1610/2475 [01:17<01:15, 11.51it/s]

Val loss: 2.04


 69%|██████▉   | 1713/2475 [01:21<01:02, 12.14it/s]

Val loss: 2.03


 73%|███████▎  | 1814/2475 [01:26<01:01, 10.69it/s]

Val loss: 2.03


 77%|███████▋  | 1912/2475 [01:31<01:05,  8.58it/s]

Val loss: 2.02


 81%|████████▏ | 2012/2475 [01:35<00:41, 11.23it/s]

Val loss: 2.02


 85%|████████▌ | 2113/2475 [01:43<01:04,  5.59it/s]

Val loss: 2.02


 89%|████████▉ | 2209/2475 [01:47<00:24, 10.66it/s]

Val loss: 2.02


 93%|█████████▎| 2309/2475 [01:51<00:15, 10.88it/s]

Val loss: 2.01


 97%|█████████▋| 2411/2475 [01:56<00:06, 10.66it/s]

Val loss: 2.01


100%|██████████| 2475/2475 [01:57<00:00, 21.08it/s]

OK





In [161]:
start_ids = tokenizer.encode("[SOS]Co").ids

sample_ids = nucleus_sampling_generate(model, start_ids, 30, EOS_id, 0.9)

sent = tokenizer.decode(sample_ids[1:])

print(sent)

Cotuittia


In [147]:
(names_data == "Лагушина").any()

True