In [1]:
import argparse
import pandas as pd
import json
import torch
import numpy as np
from tqdm import tqdm
from collections import Counter, defaultdict
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, TensorDataset


class PoetryLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, device, dropout=0.15):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.device=device

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
                            batch_first=True)
        self.fc1 = nn.Linear(hidden_size, vocab_size)

    def forward(self, X, h=None, c=None):
        if h is None:
            h, c = self.init_state(X.size(0))
        out = self.embedding(X)
        out, (h, c) = self.lstm(out, (h, c))
        out = out.contiguous().view(-1, self.hidden_size)
        out = self.fc1(out)
        out = out.view(-1, X.size(1), self.vocab_size)
        out = out[:, -1]

        return out, h, c

    def init_state(self, batch_size):
        num_l = self.num_layers
        hidden = torch.zeros(num_l, batch_size, self.hidden_size).to(self.device)
        cell = torch.zeros(num_l, batch_size, self.hidden_size).to(self.device)
        return hidden, cell



class SimpleLSTMPoem:
    def __init__(self, model, rhyme_model):
        self.model = model
        self.rhyme_model = rhyme_model

    def generate_stih(self, lines_n=4, rhyme_scheme='0101', 
                      min_words_line=4, max_words_line=8):

        assert lines_n == len(rhyme_scheme)
        lines, rhyme_words = [], []
        rhyme_scheme = reversed(list(rhyme_scheme))


        # подсчет для каждой строки индекс рифмованной строки
        pred_rhyme, last_rh_ind = [], {}
        for i, rh in enumerate(rhyme_scheme):
            if rh not in last_rh_ind:
                last_rh_ind[rh] = i
                pred_rhyme.append(-1)
                continue
            pred_rhyme.append(last_rh_ind[rh])
            last_rh_ind[rh] = i
        

        # сначала генерируем lines_n последних слов всех строк
        words_end = []
        while len(words_end) < lines_n:
            cur_word_ind = len(words_end)
            if pred_rhyme[cur_word_ind] == -1:
                # нет слова с которым надо рифмоваться
                while True:
                    word_random = vocabr[np.random.randint(0, len(vocab)-1)]
                    next_rhyme = self.rhyme_model.give_rhyme(word_random)
                    if next_rhyme is not None and next_rhyme in vocab:
                        words_end.append(word_random)
                        break
            else:
                word_rhyme_with = words_end[pred_rhyme[cur_word_ind]]
                while True:
                    new_word = self.rhyme_model.give_rhyme(word_rhyme_with)
                    if new_word and new_word in vocab:
                        words_end.append(new_word)
                        break


        # генерация строк (в обратном порядке)
        self.model.eval()
        seed_text = ''
        for i in range(lines_n):
            seed_text += " " + words_end[i]
            n_words_line = np.random.randint(min_words_line, max_words_line)
            for i in range(n_words_line-1):
                token_list = np.ones(10, dtype=int)
                text_token = np.array([vocab[word] for word in seed_text.split()][-10:])

                token_list[:len(text_token)] = text_token
                token_list = torch.from_numpy(token_list).unsqueeze(0).to(device)

                with torch.no_grad():
                    out, h, c = model(token_list)
                # выбор след. слова не с максимальной вероятностью, 
                # а рандомно, основываясь на вероятностях след слова
                p = nn.functional.softmax(out, dim=1).detach().cpu().numpy()[0]
                idx = np.random.choice(len(out[0]), p=p)
                new_word = vocabr[idx]
                seed_text += " " + new_word

            lines.append(seed_text.split()[-n_words_line:])

        lines_forward = " \n ".join(list(reversed(list(map(lambda x: " ".join(list(reversed(x))), lines)))))
        return lines_forward




In [2]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [3]:
%cd /content/drive/MyDrive/NLP_poems
!ls

/content/drive/MyDrive/NLP_poems
 data				  models
 data_test.csv			  navec_hudlit_v1_12B_500K_300d_100q.tar
 generated_poems		  __pycache__
 generate_poems.ipynb		  russian_g2p
 gpt				  search_rhyme.py
'LM evaluation.ipynb'		  simple_by_poems.ipynb
 lstm0.ipynb			  simple_lstm.ipynb
 Lstm.ipynb			  training_models
 lstm_with_pretrained_emb.ipynb   train_lstm_with_pretrained_emb.py


## Simple LSTM 

In [None]:
# загрузка словаря модели
with open('models/lstms/lstm_vocab.json', 'r') as f:
    vocab = json.load(f)

vocabr =  {v: k for k, v in vocab.items()}

vocab_size = len(vocab)
embedding_dim = 128
hidden_size = 512
num_layers = 3
device = "cuda"
MAX_LENGTH = 10

model_path = "models/lstms/35ep_128x512x3_bypoem.pth"

model = PoetryLSTM(vocab_size, embedding_dim, hidden_size, num_layers, device)
state_dict = torch.load(model_path)
model.load_state_dict(state_dict['model_state_dict'])
model.to(device)

## Генерация
### без учета ударений

In [None]:
rhyme_models_files =  ["data/rhymes_2020_civil.json", "data/rhymes_2021_civil.json",
    "data/rhymes_2020_love.json",	"data/rhymes_2021_love.json",
    "data/rhymes_2020_nature.json",	"data/rhymes_2021_nature.json",
    "data/rhymes_2020_religion.json",	"data/rhymes_2021_religion.json"]


# загрузка одной или нескольких модели для рифмы
from search_rhyme import RhymeSearch

rhyme_model = RhymeSearch()
rhyme_model.from_json(rhyme_models_files[0])
for rhyme_path in rhyme_models_files[1:]:
    new_rhyme = RhymeSearch()
    new_rhyme.from_json(rhyme_path)
    rhyme_model.merge_models(new_rhyme)

In [None]:
poems_generator = SimpleLSTMPoem(model, rhyme_model)

In [None]:
print(poems_generator.generate_stih(lines_n=8, rhyme_scheme='01012323'))

почти всё простить вдруг не будет овен 
 потерять но не парить вечность обласкала 
 льёт с ароматных лишь в многоценен 
 сказке гуляют зелёным зажгла 
 и из старой желаемый 
 багульник на наряде землю волшебный осколок рубина 
 звук гонит воздуха обводный 
 фотография из комплименты фонтана


In [None]:
print(poems_generator.generate_stih(lines_n=8, rhyme_scheme='01012323'))

может быть и враг измотали 
 увидел все желанья чтоб вновь актеры 
 другим не заметив желанья петербург нить подмёрзли 
 всех и стараюсь коснуться тротуары 
 души подняться как помолись искренне чтоб воздать 
 iv притчи гл со грешного радостью закрытых 
 моя воля к тебе словно ты вычислить 
 по ветвям выси прощальной чувства беспомощных


In [None]:
import csv
output_file = 'generated_poems/lstm_poems.csv'

stihi = []
for _ in range(10):
    stihi.append(poems_generator.generate_stih(lines_n=8, rhyme_scheme='01012323'))

with open(output_file, 'w') as f:
    csvwriter = csv.writer(f, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
    for ind, stih in enumerate(stihi):
        csvwriter.writerow([ind, stih])

### С ударениями

In [None]:
!pip install dawg

In [None]:
rhyme_models_files =  ["data/rhymes_acc_2020_civil.json", "data/rhymes_acc_2021_civil.json",
    "data/rhymes_acc_2020_love.json",	"data/rhymes_acc_2021_love.json",
    "data/rhymes_acc_2020_nature.json",	"data/rhymes_acc_2021_nature.json",
    "data/rhymes_acc_2020_religion.json",	"data/rhymes_acc_2021_religion.json"]


# загрузка одной или нескольких модели для рифмы
from search_rhyme import RhymeSearch

rhyme_model = RhymeSearch(with_accent=True)
rhyme_model.from_json(rhyme_models_files[0])
for rhyme_path in rhyme_models_files[1:]:
    new_rhyme = RhymeSearch()
    new_rhyme.from_json(rhyme_path)
    rhyme_model.merge_models(new_rhyme)

poems_generator = SimpleLSTMPoem(model, rhyme_model)

In [None]:
print(poems_generator.generate_stih(lines_n=4, rhyme_scheme='0101'))

иуды ещё пошло осталась по степным звездам 
 и тайно янв для лучей 
 под днем та выдам 
 нечисть так каждому лихих зад автор врачей


In [None]:
import csv
import tqdm
output_file = 'generated_poems/lstm_acc_poems.csv'

stihi = []
for _ in tqdm.tqdm(range(200)):
    stihi.append(poems_generator.generate_stih(lines_n=4, rhyme_scheme='0101'))

with open(output_file, 'a') as f:
    csvwriter = csv.writer(f, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
    for ind, stih in enumerate(stihi):
        csvwriter.writerow([ind, stih])

 98%|█████████▊| 196/200 [12:41<00:16,  4.22s/it]

## LSTM с предобученными эмбеддингами, с ударениями

In [4]:
!pip install dawg

Collecting dawg
  Downloading DAWG-0.8.0.tar.gz (371 kB)
[?25l[K     |▉                               | 10 kB 21.4 MB/s eta 0:00:01[K     |█▊                              | 20 kB 26.5 MB/s eta 0:00:01[K     |██▋                             | 30 kB 11.6 MB/s eta 0:00:01[K     |███▌                            | 40 kB 4.6 MB/s eta 0:00:01[K     |████▍                           | 51 kB 4.7 MB/s eta 0:00:01[K     |█████▎                          | 61 kB 5.6 MB/s eta 0:00:01[K     |██████▏                         | 71 kB 5.7 MB/s eta 0:00:01[K     |███████                         | 81 kB 5.5 MB/s eta 0:00:01[K     |████████                        | 92 kB 6.1 MB/s eta 0:00:01[K     |████████▉                       | 102 kB 5.3 MB/s eta 0:00:01[K     |█████████▊                      | 112 kB 5.3 MB/s eta 0:00:01[K     |██████████▋                     | 122 kB 5.3 MB/s eta 0:00:01[K     |███████████▌                    | 133 kB 5.3 MB/s eta 0:00:01[K     |██████████

In [5]:
# загрузка словаря модели
with open('models/lstms/lstm_vocab_emb.json', 'r') as f:
    vocab = json.load(f)

vocabr =  {v: k for k, v in vocab.items()}

vocab_size = len(vocab)
embedding_dim = 300
hidden_size = 512
num_layers = 3
device = "cuda"
MAX_LENGTH = 10

model_path = "models/lstms/30ep_300x512x3_bypoem_emb.pth"

model = PoetryLSTM(vocab_size, embedding_dim, hidden_size, num_layers, device)
state_dict = torch.load(model_path)
model.load_state_dict(state_dict['model_state_dict'])
model.to(device)

PoetryLSTM(
  (embedding): Embedding(136619, 300)
  (lstm): LSTM(300, 512, num_layers=3, batch_first=True, dropout=0.15)
  (fc1): Linear(in_features=512, out_features=136619, bias=True)
)

In [6]:
rhyme_models_files =  ["data/rhymes_acc_2020_civil.json", "data/rhymes_acc_2021_civil.json",
    "data/rhymes_acc_2020_love.json",	"data/rhymes_acc_2021_love.json",
    "data/rhymes_acc_2020_nature.json",	"data/rhymes_acc_2021_nature.json",
    "data/rhymes_acc_2020_religion.json",	"data/rhymes_acc_2021_religion.json"]


# загрузка одной или нескольких модели для рифмы
from search_rhyme import RhymeSearch

rhyme_model = RhymeSearch(with_accent=True)
rhyme_model.from_json(rhyme_models_files[0])
for rhyme_path in rhyme_models_files[1:]:
    new_rhyme = RhymeSearch()
    new_rhyme.from_json(rhyme_path)
    rhyme_model.merge_models(new_rhyme)

poems_generator = SimpleLSTMPoem(model, rhyme_model)

In [12]:
print(poems_generator.generate_stih(lines_n=4, rhyme_scheme='0101'))

идут как мотыльки прятались бандер 
 повсюду круги моя strip 
 ласки нам места да едер 
 нашей русской верой в христа душа grip


In [13]:
import csv
import tqdm
output_file = 'generated_poems/lstm_acc_emb_poems.csv'

stihi = []
for _ in tqdm.tqdm(range(400)):
    stihi.append(poems_generator.generate_stih(lines_n=4, rhyme_scheme='0101'))

with open(output_file, 'a') as f:
    csvwriter = csv.writer(f, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL)
    for ind, stih in enumerate(stihi):
        csvwriter.writerow([ind, stih])

100%|██████████| 400/400 [25:33<00:00,  3.83s/it]
