In [21]:
def imports():
    global math, np, pd, random, json, torch, Dataset, DataLoader, tqdm, plt, yttm
    
    import math
    import numpy as np
    import pandas as pd

    import random
    import json
    import torch
    from torch.utils.data import Dataset, DataLoader

    from tqdm import tqdm

    from matplotlib import pyplot as plt

    import youtokentome as yttm

In [22]:
imports()

In [7]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
DEVICE

device(type='cuda')

### Подготовка данных

In [9]:
qa_data = list()

with open('qa_data.jsonl') as file_object:
    for line in file_object:
        qa_data.append(json.loads(line.strip()))

In [10]:
from collections import deque

questin_answer_data = []
for question_answer in qa_data:
    questin_answer_data.append(question_answer['question'])
    deque(map(questin_answer_data.append, question_answer['responses']))

In [11]:
questin_answer_data[:5]

['долго ли идут деньги с яндексденег на карту visa?',
 'нет. прорыв 35 ;)',
 'можно ли зарегистрировать авто в другом регионе',
 'можно на родственника из того региона.. .  а потом ездить по доверке',
 'что делать если у меня очень тонкие ногти а хочется их отрастить?']

In [12]:
with open('for_bpe.txt', 'w') as f:
    f.write('\n'.join(questin_answer_data))

In [13]:
del questin_answer_data

In [14]:
!head for_bpe.txt

долго ли идут деньги с яндексденег на карту visa?
нет. прорыв 35 ;)
можно ли зарегистрировать авто в другом регионе
можно на родственника из того региона.. .  а потом ездить по доверке
что делать если у меня очень тонкие ногти а хочется их отрастить?
витамины и умная эмаль (каждый день)
ванночки с морской солью. с вечера мажь ногти сверху йодом. не бойся, до утра все впитается.
умная эмаль, витамины, йод, и поменьше крась лаком 
лаки фирмы trind производство usa + кальций
в чем отличие медитации от йоги?


In [15]:
VOCAB_SIZE = 30_000
MODEL_PATH = 'pretrained_bpe_lm.model'

In [16]:
yttm.BPE.train(data='for_bpe.txt', vocab_size=VOCAB_SIZE, model=MODEL_PATH)

<youtokentome.youtokentome.BPE at 0x7f6155470f98>

In [17]:
tokenizer = yttm.BPE(model=MODEL_PATH)

In [18]:
questions = []
answers = []

for qa in qa_data:
    for answer in qa['responses']:
        questions.append(qa['question'])
        answers.append(answer)

In [19]:
del qa_data

In [23]:
batch_size = 256
tokenized_questions = []

for i_batch in tqdm(range(math.ceil(len(questions) / batch_size))):
    tokenized_questions.extend(
        tokenizer.encode(
            list(questions[i_batch*batch_size:(i_batch+1)*batch_size]),
            bos=True, eos=False,
        )
    )

100%|██████████| 30341/30341 [01:07<00:00, 448.48it/s]


In [24]:
# как сложно без gc
del questions

In [25]:
tokenized_answers = []

for i_batch in tqdm(range(math.ceil(len(answers) / batch_size))):
    tokenized_answers.extend(
        tokenizer.encode(
            list(answers[i_batch*batch_size:(i_batch+1)*batch_size]),
            bos=True, eos=False,
        )
    )

100%|██████████| 30341/30341 [01:03<00:00, 480.40it/s]


In [26]:
del answers

In [29]:
# у меня не хватает памяти, лучше сохраниться
import pickle

with open('questions', 'wb') as f:
    pickle.dump(tokenized_questions, f)

with open('answers', 'wb') as f:
    pickle.dump(tokenized_answers, f)

In [30]:
del tokenized_questions
del tokenized_answers

### Датасет

In [31]:
import pickle

# будем брать одну десятую датасета, иначе памяти не хватит
with open('questions', 'rb') as f:
    questions = pickle.load(f)
questions = questions[:int(len(questions)/10)]

In [32]:
with open('answers', 'rb') as f:
    answers = pickle.load(f)
answers = answers[:int(len(answers)/10)]

In [33]:
print(len(questions), len(answers))
assert len(questions) == len(answers)

776713 776713


In [34]:
!free -m

              total        used        free      shared  buff/cache   available
Mem:          14961        3096       10458         243        1406       11336
Swap:         16134        1426       14708


In [10]:
imports()

In [54]:
PAD_INDEX = 0
EOS_INDEX = 3

In [100]:
class SequenceBucketingData(torch.utils.data.Dataset):
    """по сути то же, что в условии, только другие сиквенсы"""
    def __init__(self, questions, answers, max_len, pad_index=PAD_INDEX, eos_index=EOS_INDEX):
        self.questions = questions
        self.answers = answers
        if len(questions) != len(answers):
            raise ValueError('Вопросы и ответы должны быть одной длины')
        self.max_len = max_len
        self.pad_index = pad_index
        self.eos_index = eos_index
        
    def __len__(self):
        return len(self.questions)
    
    def _prepare_sample(self, sequence_q, sequence_a, max_len):
        sequence_q = sequence_q[:max_len]
        sequence_a = sequence_a[:max_len]
        x = sequence_q
        y = sequence_a
        pads = [self.pad_index] * (max_len - len(x))
        x += pads
        y += pads
        return x, y
    
    def __getitem__(self, index):
        batch_q = self.questions[index]
        batch_a = self.answers[index]
        max_len = min([
            self.max_len,
            max(map(len, batch_q)),
            max(map(len, batch_a)),
        ])
        batch_x = []
        batch_y = []
        for sample_q, sample_a in zip(batch_q, batch_a):
            x, y = self._prepare_sample(sample_q, sample_a, max_len)
            batch_x.append(x)
            batch_y.append(y)
        print(len(batch_x), len(batch_x[0]))
        print(len(batch_y), len(batch_y[0]))
        batch_x = torch.tensor(batch_x).long().to(DEVICE)
        batch_y = torch.tensor(batch_y).long().to(DEVICE)
        return batch_x, batch_y

In [101]:
questions = sorted(questions, key=len)
answers = sorted(answers, key=len)


# сделаем батч побольше
BATCH_SIZE = 256
MAX_LEN = 32

batches_q = []
batches_a = []

for i_batch in range(math.ceil(len(questions) / BATCH_SIZE)):
    batches_q.append(questions[i_batch*batch_size:(i_batch+1)*BATCH_SIZE])
    batches_a.append(answers[i_batch*batch_size:(i_batch+1)*BATCH_SIZE])

In [102]:
validation_start_index = int(len(batches_q) * 0.05)

In [103]:
train_seq = SequenceBucketingData(
    questions=batches_q[:-validation_start_index],
    answers=batches_a[:-validation_start_index],
    max_len=MAX_LEN)
test_seq = SequenceBucketingData(
    questions=batches_q[-validation_start_index:],
    answers=batches_a[-validation_start_index:],
    max_len=MAX_LEN)

In [104]:
train_loader = torch.utils.data.DataLoader(train_seq, batch_size=BATCH_SIZE)
validation_loader = torch.utils.data.DataLoader(test_seq, batch_size=BATCH_SIZE)

### Модель

In [105]:
# это специальный дропаут для реккуретных сетей
# хорошо это объясняется здесь: https://youtu.be/WLaAIYQHHMU?t=1093

class SpatialDropout(torch.nn.Dropout2d):
    
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p
    
    def forward(self, x):
        x = x.unsqueeze(2)    # (N, T, 1, K)
        x = x.permute(0, 3, 2, 1)  # (N, K, 1, T)
        x = super(SpatialDropout, self).forward(x)  # (N, K, 1, T)
        x = x.permute(0, 3, 2, 1)  # (N, T, 1, K)
        x = x.squeeze(2)  # (N, T, K)
        return x

In [106]:
class LanguageModel(torch.nn.Module):
    
    def __init__(self, padding_idx,
                 vocab_size=30_000,
                 embedding_dim=128,
                 model_dim=128,
                 num_layers=2,
                 dropout=0.35,
                 weight_tying=True):
        
        super().__init__()
        
        self.embedding_layer = torch.nn.Embedding(num_embeddings=vocab_size,
                                                  embedding_dim=embedding_dim, padding_idx=padding_idx)
        
        self.embedding_dropout = SpatialDropout(p=dropout)
        
        self.lstm = torch.nn.LSTM(input_size=embedding_dim, hidden_size=model_dim, 
                                  num_layers=num_layers, dropout=dropout, batch_first=True)
        
        self.language_model_head = torch.nn.Linear(in_features=model_dim, out_features=vocab_size, bias=False)
        
        # как раз здесь задаем, чтобы веса входящего и выходящего слоя эмбеддингов шарились
        if weight_tying and embedding_dim == model_dim:
            self.language_model_head.weight = self.embedding_layer.weight
        
    def forward(self, x):
        
        x = self.embedding_layer(x)
        
        x = self.embedding_dropout(x)
        
        x, _ = self.lstm(x)
        
        x = self.language_model_head(x)
        
        return x

In [107]:
model = LanguageModel(padding_idx=PAD_INDEX)
model.to(DEVICE)

LanguageModel(
  (embedding_layer): Embedding(30000, 128, padding_idx=0)
  (embedding_dropout): SpatialDropout(p=0.35, inplace=False)
  (lstm): LSTM(128, 128, num_layers=2, batch_first=True, dropout=0.35)
  (language_model_head): Linear(in_features=128, out_features=30000, bias=False)
)

In [108]:
def train(model, loader, criterion, optimizer, last_n_losses=500, verbose=True):
    
    losses = []

    progress_bar = tqdm(total=len(loader), disable=not verbose, desc='Train')

    model.train()

    for x, y in loader:

        x = x.to(device)
        y = y.to(device)

        pred = model(x)

        loss = criterion(pred.view(-1, pred.size(-1)), y.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        progress_bar.set_postfix(loss=np.mean(losses[-last_n_losses:]),
                                 perplexity=np.exp(np.mean(losses[-last_n_losses:])))

        progress_bar.update()

    progress_bar.close()
    
    return losses

In [109]:
epoch_losses = train(
    model, validation_loader,
    criterion=torch.nn.CrossEntropyLoss(ignore_index=PAD_INDEX),
    optimizer=torch.optim.Adam(params=model.parameters()),
)




Train:   0%|          | 0/1 [00:00<?, ?it/s][A[A[A

256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 27
256 28
256 29


ValueError: expected sequence of length 29 at dim 1 (got 28)