In [16]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

In [17]:
def detach_from_history(h): 
    if type(h) == torch.Tensor:
        return h.detach()

    return tuple(detach_from_history(v) for v in h)


class CharRnn(nn.Module):# создаем архитектуру нейронки
    def __init__(self, vocab_size, n_fac, n_hidden, batch_size):
        super().__init__()
        self.e = nn.Embedding(vocab_size, n_fac) # словарь с буквами
        self.rnn = nn.RNN(n_fac, n_hidden)
        self.l_out = nn.Linear(n_hidden, vocab_size)
        self.n_hidden = n_hidden
        self.init_hidden_state(batch_size)

    def init_hidden_state(self, batch_size):
        self.h = torch.zeros(1, batch_size, self.n_hidden)

    def forward(self, inp):
        inp = self.e(inp)
        b_size = inp[0].size(0)
        if self.h[0].size(1) != b_size:
            self.init_hidden_state(b_size)

        outp, h = self.rnn(inp, self.h)
        self.h = detach_from_history(h)

        return F.log_softmax(self.l_out(outp[-1]), dim=-1)

In [18]:
def generateNextChar(charNet, phraze): # получаем символы
    idxs = np.empty((1, seq_size))
    idxs[0] = np.array([char2int[c] for c in phraze])

    res = charNet(torch.LongTensor(idxs).transpose(0, 1))
    _, t_idxs = torch.max(res, dim=1)

    return int2char[t_idxs.detach().cpu().numpy()[0]]


def generateText(charNet, phraze, numChars): # генерируем текст
    cText = phraze
    for i in range(0, numChars):
        cText += generateNextChar(charNet, cText[i:])

    return cText

In [19]:
text = ""
with open("gift.txt", "r", encoding="utf-8") as file:
    text = file.read().replace("\n", " ")

chars = sorted(list(set(text)))
int2char = dict(enumerate(chars))
char2int = {char: ind for ind, char in int2char.items()}

idx = [char2int[c] for c in text] # получаем список индексов

epochs = 60
seq_size = 32 # размер последовательности
hidden_size = 256
batch_size = 300

net = CharRnn(len(char2int), seq_size, hidden_size, batch_size) # создаем нейронную сеть
lr = 1e-3
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

# Подготавливаем данные 
# разделям данные на пакеты входящих и выходящих данных
in_text = np.array(
    [[idx[j + i] for i in range(seq_size)] for j in range(len(idx) - seq_size - 1)]
)
out_text = np.array([idx[j + seq_size] for j in range(len(idx) - seq_size - 1)])

print(in_text.shape)
print(out_text.shape)

# TRAIN
for e in range(0, epochs):
    loss = 0
    for b in range(0, in_text.shape[0] // batch_size):
        input_idxs = ( # для каждого пакеты создаются тензоры входных и выходных данных
            torch.LongTensor(in_text[b * batch_size : (b + 1) * batch_size, :seq_size])
            .transpose(0, 1)
        )
        target_idxs = (
            torch.LongTensor(out_text[b * batch_size : (b + 1) * batch_size])
            .squeeze()
        )

        res = net(input_idxs)
        loss = criterion(res, target_idxs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch {}, loss {}".format(e + 1, loss.item()))



(24820, 32)
(24820,)
Epoch 1, loss 2.871194839477539
Epoch 2, loss 2.5928637981414795
Epoch 3, loss 2.487926721572876
Epoch 4, loss 2.4048635959625244
Epoch 5, loss 2.3340137004852295
Epoch 6, loss 2.2672958374023438
Epoch 7, loss 2.205758571624756
Epoch 8, loss 2.150299549102783
Epoch 9, loss 2.0981695652008057
Epoch 10, loss 2.0446388721466064
Epoch 11, loss 1.9871124029159546
Epoch 12, loss 1.9277024269104004
Epoch 13, loss 1.869822382926941
Epoch 14, loss 1.816686987876892
Epoch 15, loss 1.766892671585083
Epoch 16, loss 1.712676763534546
Epoch 17, loss 1.664786458015442
Epoch 18, loss 1.6424524784088135
Epoch 19, loss 1.61225163936615
Epoch 20, loss 1.5794209241867065
Epoch 21, loss 1.548707127571106
Epoch 22, loss 1.492730975151062
Epoch 23, loss 1.4590966701507568
Epoch 24, loss 1.4406461715698242
Epoch 25, loss 1.4062919616699219
Epoch 26, loss 1.3718411922454834
Epoch 27, loss 1.3523571491241455
Epoch 28, loss 1.2935411930084229
Epoch 29, loss 1.252957820892334
Epoch 30, loss 1