In [1]:
import json
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.loader import MyLoader
from src.models import TextRNN
from src.predict import predict
from src.sequence_encoder import Sequence_count_encoder
from src.train_loop import training_loop

In [3]:
TRAIN_TEXT_FILE_PATH = 'datasets/classic_poems.json'
SEQ_LEN = 256
BATCH_SIZE = 16
N_EPOCHS = 500
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

Загружаем текстовые данные

In [7]:
encoder = Sequence_count_encoder()
encoder.load_from_json(TRAIN_TEXT_FILE_PATH)

with open(TRAIN_TEXT_FILE_PATH, encoding="utf-8") as f:
    data = json.load(f)
lines = []
for d in data:
    lines.append(d["content"])
text_sample = ' '.join(lines)

sequence = encoder.text_to_seq(text_sample)

loader = MyLoader(sequence, BATCH_SIZE, SEQ_LEN)

Иницализируем модель

In [9]:
model = TextRNN(input_size=len(encoder.idx_to_char), hidden_size=128, embedding_size=128, n_layers=2)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, amsgrad=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    patience=5, 
    verbose=True, 
    factor=0.5
)

In [10]:
training_loop(model, loader, N_EPOCHS, optimizer, scheduler, criterion, device)

Loss: 3.2399630737304688
Loss: 2.615880722999573
Loss: 2.4357262134552
Loss: 2.3520593643188477
Loss: 2.296901206970215
Loss: 2.241616439819336
Loss: 2.2200205230712893
Loss: 2.176853308677673
Loss: 2.1404213953018187
Loss: 2.1380269384384154


In [11]:
model.eval()
predicted_text = predict(model, encoder.char_to_idx, encoder.idx_to_char, device=device, start_text='Оно в москве', temp=0.7)
print(predicted_text)

Оно в москвет нет.
И расмает залянет мы сиденает
Я с верных гробовала своюм вольным.
Не обедных рукой верновенный в до лестнем променье в весельную на пропоком? Тихи и глаза,
Смятьи переда сказалый колов
Света ст


Сохраняем полученную модель

In [10]:
torch.save(model, 'models/lstm_poets_model0.pt')