In [1]:
import json
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.loader import MyLoader
from src.models import TextRNN
from src.predict import predict
from src.sequence_encoder import Sequence_count_encoder
from src.train_loop import training_loop

In [3]:
TRAIN_TEXT_FILE_PATH = 'datasets/classic_poems.json'
SEQ_LEN = 256
BATCH_SIZE = 16
N_EPOCHS = 500
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

Загружаем текстовые данные

In [4]:
encoder = Sequence_count_encoder()
encoder.load_from_json(TRAIN_TEXT_FILE_PATH)

with open(TRAIN_TEXT_FILE_PATH, encoding="utf-8") as f:
    data = json.load(f)
lines = []
for d in data:
    lines.append(d["content"])
text_sample = ' '.join(lines)

sequence = encoder.text_to_seq(text_sample)

loader = MyLoader(sequence, SEQ_LEN, BATCH_SIZE)

Иницализируем модель

In [5]:
model = TextRNN(input_size=len(encoder.idx_to_char), hidden_size=128, embedding_size=128, n_layers=2)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, amsgrad=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    patience=5, 
    verbose=True, 
    factor=0.5
)

In [6]:
training_loop(model, loader, N_EPOCHS, optimizer, scheduler, criterion, device)

Loss: 3.0726580381393434
Loss: 2.5152838468551635
Loss: 2.4220312309265135
Loss: 2.3664708280563356
Loss: 2.327056384086609
Loss: 2.297659754753113
Loss: 2.2622882175445556
Loss: 2.248606171607971
Loss: 2.2237263107299805
Loss: 2.2068537664413452


In [9]:
model.eval()
predicted_text = predict(model, encoder.char_to_idx, encoder.idx_to_char, device=device, start_text='Оно в москве', temp=0.7)
print(predicted_text)

Оно в москвет позновенных смотри благодить юные утока выпите с берень подвостым челачала подолоньчее,
Сколова
Проколоселей в дом родную расстекласти забант —
Ни скара радости прикратится страют —
под коншим в стр


Сохраняем полученную модель

In [10]:
torch.save(model, 'models/lstm_poets_model0.pt')