In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch.nn as nn

In [None]:
def filter_by_alphabet(text, alphabet):
    alphabet_set = set(alphabet)
    return ''.join([character for character in text if character in alphabet_set])

def split_data(text, seq_length, stride):
    inputs = []
    targets = []
    for i in range(0, len(text) - seq_length - 1, stride):
        inputs.append(text[i : i + seq_length])
        targets.append(text[i + 1 : i + seq_length + 1])
    return inputs, targets

def integerify(list_of_strings):
    result = []
    for string in list_of_strings:
        result.append([char_to_index[x] for x in string])
    return result

def one_hot_encode(arr, n_labels):
    one_hot = np.zeros((np.multiply(*arr.shape), n_labels), dtype=np.float32)
    one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
    one_hot = one_hot.reshape((*arr.shape, n_labels))
    return one_hot

def collate_function(batch):
    sample_list = integerify([first for first, second in batch])
    label_list = integerify([second for first, second in batch])
    return torch.tensor(sample_list), torch.tensor(label_list)

def get_data_from_file(path, alphabet, seq_length, batch_size, stride, train_size = 0.75):
    with open(path, "r") as text_file:
        text = text_file.read()
    text = filter_by_alphabet(text, alphabet)
    text_inputs, text_targets = split_data(text, seq_length, stride)
    data = list(zip(text_inputs, text_targets))
    display(len(data))
    train_size = int(train_size * len(data))
    val_size = len(data) - train_size
    train_data, val_data = torch.utils.data.random_split(data, [train_size, val_size])

    train_dataloader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_function,
        pin_memory=True,
    )
    validation_dataloader = DataLoader(
        val_data,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_function,
        pin_memory=True,
    )
    return train_dataloader, validation_dataloader

In [None]:
special = '$'
alphabet='абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ .,!?\n-"'
# alphabet = alphabet + special
char_to_index = {alphabet[i]:i for i in range(len(alphabet))}

seq_length = 120
batch_size = 64
stride = 2

In [4]:
train_dataloader, validation_dataloader = get_data_from_file('dataset.txt', alphabet, seq_length, batch_size, stride)

12741343

In [5]:
class CharRNN(nn.Module):
    def __init__(self, chars_num, n_hidden=512, n_layers=4, drop_prob=0.4):
        super().__init__()
        self.chars_num = chars_num
        self.drop_prob = drop_prob
        self.n_layers = n_layers
        self.n_hidden = n_hidden

        self.lstm = nn.LSTM(
            chars_num, n_hidden, n_layers, dropout=drop_prob, batch_first=True
        )
        self.dropout = nn.Dropout(drop_prob)
        self.linear = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
        )
        self.fc = nn.Linear(n_hidden, chars_num)

    def forward(self, x, hidden):
        r_output, hidden = self.lstm(x, hidden)
        out = self.dropout(r_output)
        out = out.contiguous().view(-1, self.n_hidden)
        out = self.linear(out)
        out = self.fc(out)
        return out, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = (
            weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),
            weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),
        )
        return hidden

In [6]:
def train(
    net,
    train_data,
    val_data,
    full_train,
    epochs=10,
    batches_per_epoch=100,
    batch_size=64,
    seq_length=100,
    lr=0.001,
    clip=5,
    val_frac=0.1,
    print_every=10,
):
    net.train()
    for p in net.lstm.parameters():
        p.requires_grad = full_train
    opt = torch.optim.Adam(net.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    counter = 0
    n_chars = net.chars_num
    for e in range(epochs):
        for _ in range(batches_per_epoch):
            # initialize hidden state
            h = net.init_hidden(batch_size)
            x, y = next(iter(train_data))
            x = one_hot_encode(x, len(alphabet))
            y = one_hot_encode(y, len(alphabet)).reshape(-1, len(alphabet))
            inputs, targets = torch.from_numpy(x).cuda(), torch.from_numpy(y).cuda()
            h = tuple([each.data for each in h])
            net.zero_grad()
            output, h = net(inputs, h)
            loss = criterion(output, targets)
            loss.backward()
            nn.utils.clip_grad_norm_(net.parameters(), clip)
            opt.step()

            if counter % print_every == 0:
                val_h = net.init_hidden(batch_size)
                val_losses = []
                net.eval()
                x, y = next(iter(val_data))
                x = one_hot_encode(x, len(alphabet))
                y = one_hot_encode(y, len(alphabet)).reshape(-1, len(alphabet))
                inputs, targets = torch.from_numpy(x).cuda(), torch.from_numpy(y).cuda()
                val_h = tuple([each.data for each in val_h])
                output, val_h = net(inputs, val_h)
                val_loss = criterion(output, targets)
                val_losses.append(val_loss.item())
                net.train()

                print(
                    "Epoch: {}/{}...".format(e + 1, epochs),
                    "Step: {}...".format(counter),
                    "Loss: {:.4f}...".format(loss.item()),
                    "Val Loss: {:.4f}".format(np.mean(val_losses)),
                )

            counter += 1
        torch.save(net.state_dict(), 'checkpoint.pth')

In [7]:
def predict(net, char, h=None, temperature=1, top_k=None):
        ''' Given a character, predict the next character.
            Returns the predicted character and the hidden state.
        '''
        x = np.array(integerify([char]))
        x = one_hot_encode(x, len(alphabet))
        inputs = torch.from_numpy(x).cuda()
        h = tuple([each.data for each in h])
        out, h = net(inputs, h)
        out = torch.exp(temperature * out)
        p = torch.nn.functional.softmax(out, dim=1).data.cpu()
        if top_k is None:
            top_ch = np.arange(len(net.chars))
        else:
            p, top_ch = p.topk(top_k)
            top_ch = top_ch.numpy().squeeze()
        p = p.numpy().squeeze()
        char = np.random.choice(top_ch, p=p/p.sum())
        return alphabet[char], h
    
    
def sample(net, size, prime, temperature = 1, top_k=None):    
    net.eval()
    chars = [ch for ch in prime]
    h = net.init_hidden(1)
    for ch in prime:
        char, h = predict(net, ch, h, temperature, top_k=top_k)
    chars.append(char)
    for ii in range(size):
        char, h = predict(net, chars[-1], h, temperature, top_k=top_k)
        chars.append(char)
    return ''.join(chars)

In [8]:
net = CharRNN(len(alphabet), 1024, 4)
# net.load_state_dict(torch.load('anek_rnn_2.pth'))

In [9]:
net = net.cuda()

In [10]:
train(
    net,
    full_train=True,
    train_data=train_dataloader,
    val_data=validation_dataloader,
    epochs=20,
    batches_per_epoch = 1000,
    batch_size=batch_size,
    seq_length=seq_length,
    lr=0.01,
    print_every=10,
)

Epoch: 1/20... Step: 0... Loss: 4.3613... Val Loss: 3.5323
Epoch: 1/20... Step: 10... Loss: 3.7824... Val Loss: 12.3307
Epoch: 1/20... Step: 20... Loss: 3.3800... Val Loss: 7.7811
Epoch: 1/20... Step: 30... Loss: 3.2478... Val Loss: 3.4283
Epoch: 1/20... Step: 40... Loss: 3.2979... Val Loss: 3.2997
Epoch: 1/20... Step: 50... Loss: 3.2151... Val Loss: 3.4866
Epoch: 1/20... Step: 60... Loss: 3.2087... Val Loss: 3.6405
Epoch: 1/20... Step: 70... Loss: 3.2290... Val Loss: 3.6183
Epoch: 1/20... Step: 80... Loss: 3.1987... Val Loss: 3.7429
Epoch: 1/20... Step: 90... Loss: 3.2078... Val Loss: 3.5197
Epoch: 1/20... Step: 100... Loss: 3.1927... Val Loss: 3.5773
Epoch: 1/20... Step: 110... Loss: 3.2165... Val Loss: 3.8504
Epoch: 1/20... Step: 120... Loss: 3.2180... Val Loss: 3.9165
Epoch: 1/20... Step: 130... Loss: 3.2262... Val Loss: 3.7716
Epoch: 1/20... Step: 140... Loss: 3.2301... Val Loss: 3.8685
Epoch: 1/20... Step: 150... Loss: 3.2346... Val Loss: 3.9357
Epoch: 1/20... Step: 160... Loss: 

Epoch: 2/20... Step: 1340... Loss: 3.2115... Val Loss: 3.2104
Epoch: 2/20... Step: 1350... Loss: 3.2263... Val Loss: 3.1996
Epoch: 2/20... Step: 1360... Loss: 3.2073... Val Loss: 3.1936
Epoch: 2/20... Step: 1370... Loss: 3.2135... Val Loss: 3.1874
Epoch: 2/20... Step: 1380... Loss: 3.2159... Val Loss: 3.2013
Epoch: 2/20... Step: 1390... Loss: 3.2242... Val Loss: 3.1707
Epoch: 2/20... Step: 1400... Loss: 3.2158... Val Loss: 3.2036
Epoch: 2/20... Step: 1410... Loss: 3.1928... Val Loss: 3.1841
Epoch: 2/20... Step: 1420... Loss: 3.2258... Val Loss: 3.2198
Epoch: 2/20... Step: 1430... Loss: 3.2071... Val Loss: 3.2143
Epoch: 2/20... Step: 1440... Loss: 3.2139... Val Loss: 3.2205
Epoch: 2/20... Step: 1450... Loss: 3.2056... Val Loss: 3.2369
Epoch: 2/20... Step: 1460... Loss: 3.2204... Val Loss: 3.2171
Epoch: 2/20... Step: 1470... Loss: 3.1977... Val Loss: 3.2094
Epoch: 2/20... Step: 1480... Loss: 3.2117... Val Loss: 3.1978
Epoch: 2/20... Step: 1490... Loss: 3.1916... Val Loss: 3.2054
Epoch: 2

Epoch: 3/20... Step: 2670... Loss: 3.2206... Val Loss: 3.1819
Epoch: 3/20... Step: 2680... Loss: 3.1932... Val Loss: 3.2267
Epoch: 3/20... Step: 2690... Loss: 3.1955... Val Loss: 3.2146
Epoch: 3/20... Step: 2700... Loss: 3.2222... Val Loss: 3.2185
Epoch: 3/20... Step: 2710... Loss: 3.2137... Val Loss: 3.1670
Epoch: 3/20... Step: 2720... Loss: 3.1868... Val Loss: 3.2180
Epoch: 3/20... Step: 2730... Loss: 3.2060... Val Loss: 3.1939
Epoch: 3/20... Step: 2740... Loss: 3.1949... Val Loss: 3.2022
Epoch: 3/20... Step: 2750... Loss: 3.1975... Val Loss: 3.1790
Epoch: 3/20... Step: 2760... Loss: 3.1976... Val Loss: 3.1940
Epoch: 3/20... Step: 2770... Loss: 3.1961... Val Loss: 3.1886
Epoch: 3/20... Step: 2780... Loss: 3.2222... Val Loss: 3.2187
Epoch: 3/20... Step: 2790... Loss: 3.2085... Val Loss: 3.2115
Epoch: 3/20... Step: 2800... Loss: 3.1834... Val Loss: 3.1972
Epoch: 3/20... Step: 2810... Loss: 3.2136... Val Loss: 3.2123
Epoch: 3/20... Step: 2820... Loss: 3.1860... Val Loss: 3.2116
Epoch: 3

Epoch: 5/20... Step: 4000... Loss: 3.1836... Val Loss: 3.2109
Epoch: 5/20... Step: 4010... Loss: 3.2098... Val Loss: 3.2103
Epoch: 5/20... Step: 4020... Loss: 3.1965... Val Loss: 3.1995
Epoch: 5/20... Step: 4030... Loss: 3.2273... Val Loss: 3.1964
Epoch: 5/20... Step: 4040... Loss: 3.1976... Val Loss: 3.2015
Epoch: 5/20... Step: 4050... Loss: 3.2078... Val Loss: 3.2054
Epoch: 5/20... Step: 4060... Loss: 3.1705... Val Loss: 3.2354
Epoch: 5/20... Step: 4070... Loss: 3.1958... Val Loss: 3.2459
Epoch: 5/20... Step: 4080... Loss: 3.2040... Val Loss: 3.2043
Epoch: 5/20... Step: 4090... Loss: 3.2063... Val Loss: 3.1964
Epoch: 5/20... Step: 4100... Loss: 3.2044... Val Loss: 3.1956
Epoch: 5/20... Step: 4110... Loss: 3.2022... Val Loss: 3.2032
Epoch: 5/20... Step: 4120... Loss: 3.2135... Val Loss: 3.2043
Epoch: 5/20... Step: 4130... Loss: 3.1934... Val Loss: 3.2277
Epoch: 5/20... Step: 4140... Loss: 3.2090... Val Loss: 3.2017
Epoch: 5/20... Step: 4150... Loss: 3.2140... Val Loss: 3.2118
Epoch: 5

Epoch: 6/20... Step: 5330... Loss: 3.2370... Val Loss: 3.2101
Epoch: 6/20... Step: 5340... Loss: 3.2144... Val Loss: 3.2188
Epoch: 6/20... Step: 5350... Loss: 3.2019... Val Loss: 3.2117
Epoch: 6/20... Step: 5360... Loss: 3.2139... Val Loss: 3.1916
Epoch: 6/20... Step: 5370... Loss: 3.1907... Val Loss: 3.2166
Epoch: 6/20... Step: 5380... Loss: 3.2050... Val Loss: 3.2127
Epoch: 6/20... Step: 5390... Loss: 3.2102... Val Loss: 3.1816
Epoch: 6/20... Step: 5400... Loss: 3.1951... Val Loss: 3.2009
Epoch: 6/20... Step: 5410... Loss: 3.2053... Val Loss: 3.1871
Epoch: 6/20... Step: 5420... Loss: 3.1924... Val Loss: 3.2070
Epoch: 6/20... Step: 5430... Loss: 3.2264... Val Loss: 3.2322
Epoch: 6/20... Step: 5440... Loss: 3.2148... Val Loss: 3.1944
Epoch: 6/20... Step: 5450... Loss: 3.2125... Val Loss: 3.2011
Epoch: 6/20... Step: 5460... Loss: 3.2055... Val Loss: 3.1988
Epoch: 6/20... Step: 5470... Loss: 3.1857... Val Loss: 3.1909
Epoch: 6/20... Step: 5480... Loss: 3.2221... Val Loss: 3.2010
Epoch: 6

Epoch: 7/20... Step: 6660... Loss: 3.2020... Val Loss: 3.2058
Epoch: 7/20... Step: 6670... Loss: 3.2198... Val Loss: 3.2128
Epoch: 7/20... Step: 6680... Loss: 3.2130... Val Loss: 3.2076
Epoch: 7/20... Step: 6690... Loss: 3.2252... Val Loss: 3.2066
Epoch: 7/20... Step: 6700... Loss: 3.2235... Val Loss: 3.2151
Epoch: 7/20... Step: 6710... Loss: 3.1834... Val Loss: 3.2157
Epoch: 7/20... Step: 6720... Loss: 3.2095... Val Loss: 3.1748
Epoch: 7/20... Step: 6730... Loss: 3.1842... Val Loss: 3.2020
Epoch: 7/20... Step: 6740... Loss: 3.1906... Val Loss: 3.2043
Epoch: 7/20... Step: 6750... Loss: 3.2130... Val Loss: 3.1963
Epoch: 7/20... Step: 6760... Loss: 3.1886... Val Loss: 3.2039
Epoch: 7/20... Step: 6770... Loss: 3.1786... Val Loss: 3.1981
Epoch: 7/20... Step: 6780... Loss: 3.2089... Val Loss: 3.2070
Epoch: 7/20... Step: 6790... Loss: 3.2410... Val Loss: 3.1951
Epoch: 7/20... Step: 6800... Loss: 3.1925... Val Loss: 3.2038
Epoch: 7/20... Step: 6810... Loss: 3.2127... Val Loss: 3.1956
Epoch: 7

Epoch: 8/20... Step: 7990... Loss: 3.2123... Val Loss: 3.2104
Epoch: 9/20... Step: 8000... Loss: 3.2106... Val Loss: 3.2043
Epoch: 9/20... Step: 8010... Loss: 3.1905... Val Loss: 3.2252
Epoch: 9/20... Step: 8020... Loss: 3.2077... Val Loss: 3.1911
Epoch: 9/20... Step: 8030... Loss: 3.1883... Val Loss: 3.1775
Epoch: 9/20... Step: 8040... Loss: 3.2236... Val Loss: 3.2286
Epoch: 9/20... Step: 8050... Loss: 3.2130... Val Loss: 3.2156
Epoch: 9/20... Step: 8060... Loss: 3.1889... Val Loss: 3.2032
Epoch: 9/20... Step: 8070... Loss: 3.1824... Val Loss: 3.1814
Epoch: 9/20... Step: 8080... Loss: 3.2225... Val Loss: 3.2182
Epoch: 9/20... Step: 8090... Loss: 3.2049... Val Loss: 3.2046
Epoch: 9/20... Step: 8100... Loss: 3.2100... Val Loss: 3.1818
Epoch: 9/20... Step: 8110... Loss: 3.1934... Val Loss: 3.2125
Epoch: 9/20... Step: 8120... Loss: 3.2427... Val Loss: 3.2198
Epoch: 9/20... Step: 8130... Loss: 3.1946... Val Loss: 3.2149
Epoch: 9/20... Step: 8140... Loss: 3.1823... Val Loss: 3.2055
Epoch: 9

Epoch: 10/20... Step: 9310... Loss: 3.2251... Val Loss: 3.2102
Epoch: 10/20... Step: 9320... Loss: 3.2128... Val Loss: 3.2032
Epoch: 10/20... Step: 9330... Loss: 3.2043... Val Loss: 3.2033
Epoch: 10/20... Step: 9340... Loss: 3.2089... Val Loss: 3.2135
Epoch: 10/20... Step: 9350... Loss: 3.2005... Val Loss: 3.2052
Epoch: 10/20... Step: 9360... Loss: 3.1990... Val Loss: 3.2094
Epoch: 10/20... Step: 9370... Loss: 3.1973... Val Loss: 3.2021
Epoch: 10/20... Step: 9380... Loss: 3.2073... Val Loss: 3.1924
Epoch: 10/20... Step: 9390... Loss: 3.2073... Val Loss: 3.1939
Epoch: 10/20... Step: 9400... Loss: 3.2026... Val Loss: 3.2164
Epoch: 10/20... Step: 9410... Loss: 3.1974... Val Loss: 3.2043
Epoch: 10/20... Step: 9420... Loss: 3.2314... Val Loss: 3.2089
Epoch: 10/20... Step: 9430... Loss: 3.2003... Val Loss: 3.1634
Epoch: 10/20... Step: 9440... Loss: 3.2232... Val Loss: 3.2202
Epoch: 10/20... Step: 9450... Loss: 3.2165... Val Loss: 3.1959
Epoch: 10/20... Step: 9460... Loss: 3.2281... Val Loss:

Epoch: 11/20... Step: 10610... Loss: 3.1917... Val Loss: 3.1964
Epoch: 11/20... Step: 10620... Loss: 3.2216... Val Loss: 3.1865
Epoch: 11/20... Step: 10630... Loss: 3.2169... Val Loss: 3.2175
Epoch: 11/20... Step: 10640... Loss: 3.2047... Val Loss: 3.1714
Epoch: 11/20... Step: 10650... Loss: 3.2031... Val Loss: 3.1968
Epoch: 11/20... Step: 10660... Loss: 3.1995... Val Loss: 3.2279
Epoch: 11/20... Step: 10670... Loss: 3.1880... Val Loss: 3.1916
Epoch: 11/20... Step: 10680... Loss: 3.1798... Val Loss: 3.2077
Epoch: 11/20... Step: 10690... Loss: 3.2016... Val Loss: 3.2168
Epoch: 11/20... Step: 10700... Loss: 3.1946... Val Loss: 3.2167
Epoch: 11/20... Step: 10710... Loss: 3.2001... Val Loss: 3.2209
Epoch: 11/20... Step: 10720... Loss: 3.2067... Val Loss: 3.2302
Epoch: 11/20... Step: 10730... Loss: 3.1813... Val Loss: 3.2027
Epoch: 11/20... Step: 10740... Loss: 3.2184... Val Loss: 3.1896
Epoch: 11/20... Step: 10750... Loss: 3.2036... Val Loss: 3.2091
Epoch: 11/20... Step: 10760... Loss: 3.2

Epoch: 12/20... Step: 11900... Loss: 3.1952... Val Loss: 3.2273
Epoch: 12/20... Step: 11910... Loss: 3.2375... Val Loss: 3.1992
Epoch: 12/20... Step: 11920... Loss: 3.2002... Val Loss: 3.1952
Epoch: 12/20... Step: 11930... Loss: 3.1971... Val Loss: 3.1910
Epoch: 12/20... Step: 11940... Loss: 3.2070... Val Loss: 3.1817
Epoch: 12/20... Step: 11950... Loss: 3.1866... Val Loss: 3.1818
Epoch: 12/20... Step: 11960... Loss: 3.1798... Val Loss: 3.2010
Epoch: 12/20... Step: 11970... Loss: 3.2253... Val Loss: 3.1849
Epoch: 12/20... Step: 11980... Loss: 3.2164... Val Loss: 3.2422
Epoch: 12/20... Step: 11990... Loss: 3.1917... Val Loss: 3.2243
Epoch: 13/20... Step: 12000... Loss: 3.2325... Val Loss: 3.1922
Epoch: 13/20... Step: 12010... Loss: 3.2182... Val Loss: 3.1981
Epoch: 13/20... Step: 12020... Loss: 3.1949... Val Loss: 3.2013
Epoch: 13/20... Step: 12030... Loss: 3.1876... Val Loss: 3.1910
Epoch: 13/20... Step: 12040... Loss: 3.2151... Val Loss: 3.2203
Epoch: 13/20... Step: 12050... Loss: 3.2

Epoch: 14/20... Step: 13190... Loss: 3.1986... Val Loss: 3.1905
Epoch: 14/20... Step: 13200... Loss: 3.1901... Val Loss: 3.2126
Epoch: 14/20... Step: 13210... Loss: 3.2078... Val Loss: 3.1993
Epoch: 14/20... Step: 13220... Loss: 3.1953... Val Loss: 3.1965
Epoch: 14/20... Step: 13230... Loss: 3.2197... Val Loss: 3.2178
Epoch: 14/20... Step: 13240... Loss: 3.2194... Val Loss: 3.2188
Epoch: 14/20... Step: 13250... Loss: 3.2011... Val Loss: 3.2124
Epoch: 14/20... Step: 13260... Loss: 3.2166... Val Loss: 3.1961
Epoch: 14/20... Step: 13270... Loss: 3.1972... Val Loss: 3.1962
Epoch: 14/20... Step: 13280... Loss: 3.2231... Val Loss: 3.2072
Epoch: 14/20... Step: 13290... Loss: 3.1839... Val Loss: 3.1867
Epoch: 14/20... Step: 13300... Loss: 3.2059... Val Loss: 3.2025
Epoch: 14/20... Step: 13310... Loss: 3.2366... Val Loss: 3.2104
Epoch: 14/20... Step: 13320... Loss: 3.2239... Val Loss: 3.1950
Epoch: 14/20... Step: 13330... Loss: 3.2088... Val Loss: 3.2166
Epoch: 14/20... Step: 13340... Loss: 3.2

Epoch: 15/20... Step: 14480... Loss: 3.2003... Val Loss: 3.1967
Epoch: 15/20... Step: 14490... Loss: 3.1685... Val Loss: 3.2152
Epoch: 15/20... Step: 14500... Loss: 3.2075... Val Loss: 3.2083
Epoch: 15/20... Step: 14510... Loss: 3.2273... Val Loss: 3.1975
Epoch: 15/20... Step: 14520... Loss: 3.1982... Val Loss: 3.2063
Epoch: 15/20... Step: 14530... Loss: 3.1788... Val Loss: 3.2239
Epoch: 15/20... Step: 14540... Loss: 3.2230... Val Loss: 3.1894
Epoch: 15/20... Step: 14550... Loss: 3.1834... Val Loss: 3.1816
Epoch: 15/20... Step: 14560... Loss: 3.1690... Val Loss: 3.1911
Epoch: 15/20... Step: 14570... Loss: 3.2018... Val Loss: 3.2165
Epoch: 15/20... Step: 14580... Loss: 3.2449... Val Loss: 3.2094
Epoch: 15/20... Step: 14590... Loss: 3.2293... Val Loss: 3.1988
Epoch: 15/20... Step: 14600... Loss: 3.2173... Val Loss: 3.1755
Epoch: 15/20... Step: 14610... Loss: 3.2129... Val Loss: 3.1879
Epoch: 15/20... Step: 14620... Loss: 3.2057... Val Loss: 3.1912
Epoch: 15/20... Step: 14630... Loss: 3.2

Epoch: 16/20... Step: 15770... Loss: 3.2223... Val Loss: 3.1850
Epoch: 16/20... Step: 15780... Loss: 3.2133... Val Loss: 3.1858
Epoch: 16/20... Step: 15790... Loss: 3.2165... Val Loss: 3.1843
Epoch: 16/20... Step: 15800... Loss: 3.2101... Val Loss: 3.2279
Epoch: 16/20... Step: 15810... Loss: 3.1998... Val Loss: 3.1936
Epoch: 16/20... Step: 15820... Loss: 3.2018... Val Loss: 3.2065
Epoch: 16/20... Step: 15830... Loss: 3.2152... Val Loss: 3.2042
Epoch: 16/20... Step: 15840... Loss: 3.1925... Val Loss: 3.2145
Epoch: 16/20... Step: 15850... Loss: 3.2050... Val Loss: 3.2181
Epoch: 16/20... Step: 15860... Loss: 3.2319... Val Loss: 3.1911
Epoch: 16/20... Step: 15870... Loss: 3.2080... Val Loss: 3.2010
Epoch: 16/20... Step: 15880... Loss: 3.1824... Val Loss: 3.2087
Epoch: 16/20... Step: 15890... Loss: 3.2134... Val Loss: 3.2046
Epoch: 16/20... Step: 15900... Loss: 3.1856... Val Loss: 3.2167
Epoch: 16/20... Step: 15910... Loss: 3.2031... Val Loss: 3.2107
Epoch: 16/20... Step: 15920... Loss: 3.2

Epoch: 18/20... Step: 17060... Loss: 3.2319... Val Loss: 3.2093
Epoch: 18/20... Step: 17070... Loss: 3.2012... Val Loss: 3.2143
Epoch: 18/20... Step: 17080... Loss: 3.2153... Val Loss: 3.1848
Epoch: 18/20... Step: 17090... Loss: 3.2069... Val Loss: 3.2081
Epoch: 18/20... Step: 17100... Loss: 3.2040... Val Loss: 3.1919
Epoch: 18/20... Step: 17110... Loss: 3.2182... Val Loss: 3.2127
Epoch: 18/20... Step: 17120... Loss: 3.2141... Val Loss: 3.1815
Epoch: 18/20... Step: 17130... Loss: 3.2064... Val Loss: 3.1969
Epoch: 18/20... Step: 17140... Loss: 3.2169... Val Loss: 3.1996
Epoch: 18/20... Step: 17150... Loss: 3.2129... Val Loss: 3.2053
Epoch: 18/20... Step: 17160... Loss: 3.2232... Val Loss: 3.1940
Epoch: 18/20... Step: 17170... Loss: 3.2178... Val Loss: 3.1945
Epoch: 18/20... Step: 17180... Loss: 3.2072... Val Loss: 3.1989
Epoch: 18/20... Step: 17190... Loss: 3.1915... Val Loss: 3.1707
Epoch: 18/20... Step: 17200... Loss: 3.2012... Val Loss: 3.2017
Epoch: 18/20... Step: 17210... Loss: 3.1

Epoch: 19/20... Step: 18350... Loss: 3.1909... Val Loss: 3.1908
Epoch: 19/20... Step: 18360... Loss: 3.1834... Val Loss: 3.1843
Epoch: 19/20... Step: 18370... Loss: 3.1752... Val Loss: 3.1805
Epoch: 19/20... Step: 18380... Loss: 3.2094... Val Loss: 3.2100
Epoch: 19/20... Step: 18390... Loss: 3.1929... Val Loss: 3.1989
Epoch: 19/20... Step: 18400... Loss: 3.2052... Val Loss: 3.2116
Epoch: 19/20... Step: 18410... Loss: 3.2096... Val Loss: 3.2033
Epoch: 19/20... Step: 18420... Loss: 3.2333... Val Loss: 3.2068
Epoch: 19/20... Step: 18430... Loss: 3.2134... Val Loss: 3.1964
Epoch: 19/20... Step: 18440... Loss: 3.2091... Val Loss: 3.2227
Epoch: 19/20... Step: 18450... Loss: 3.1980... Val Loss: 3.2135
Epoch: 19/20... Step: 18460... Loss: 3.2222... Val Loss: 3.1932
Epoch: 19/20... Step: 18470... Loss: 3.2145... Val Loss: 3.2195
Epoch: 19/20... Step: 18480... Loss: 3.2185... Val Loss: 3.2003
Epoch: 19/20... Step: 18490... Loss: 3.1896... Val Loss: 3.2285
Epoch: 19/20... Step: 18500... Loss: 3.2

Epoch: 20/20... Step: 19640... Loss: 3.1992... Val Loss: 3.1887
Epoch: 20/20... Step: 19650... Loss: 3.2180... Val Loss: 3.1894
Epoch: 20/20... Step: 19660... Loss: 3.2195... Val Loss: 3.2012
Epoch: 20/20... Step: 19670... Loss: 3.1972... Val Loss: 3.1962
Epoch: 20/20... Step: 19680... Loss: 3.1922... Val Loss: 3.1865
Epoch: 20/20... Step: 19690... Loss: 3.1803... Val Loss: 3.1849
Epoch: 20/20... Step: 19700... Loss: 3.1884... Val Loss: 3.1970
Epoch: 20/20... Step: 19710... Loss: 3.1728... Val Loss: 3.2211
Epoch: 20/20... Step: 19720... Loss: 3.2001... Val Loss: 3.2091
Epoch: 20/20... Step: 19730... Loss: 3.1895... Val Loss: 3.2030
Epoch: 20/20... Step: 19740... Loss: 3.1911... Val Loss: 3.2103
Epoch: 20/20... Step: 19750... Loss: 3.2128... Val Loss: 3.1875
Epoch: 20/20... Step: 19760... Loss: 3.2036... Val Loss: 3.2084
Epoch: 20/20... Step: 19770... Loss: 3.2027... Val Loss: 3.2015
Epoch: 20/20... Step: 19780... Loss: 3.2131... Val Loss: 3.2095
Epoch: 20/20... Step: 19790... Loss: 3.1

In [11]:
torch.save(net.state_dict(), 'output_net.pth')

In [None]:
# print(sample(net, 400, prime='хехехе ', temperature = 0.4, top_k=7))

In [None]:
# n_hidden=1024
# n_layers=5

# net = CharRNN(len(alphabet),n_hidden, n_layers).cuda()
# print(net)

In [None]:
# train(
#     net,
#     full_train=True,
#     train_data=train_dataloader,
#     val_data=validation_dataloader,
#     epochs=50,
#     batches_per_epoch = 1000,
#     batch_size=batch_size,
#     seq_length=seq_length,
#     lr=0.001,
#     print_every=10,
# )

In [None]:
torch.save(net.state_dict(), 'anek_rnn_shtirliz_1.pth')

In [None]:
# checkpoint = {'n_hidden': net.n_hidden,
#               'n_layers': net.n_layers,
#               'state_dict': net.state_dict(),
#               'tokens': net.chars}

# with open('rnn_poetry_10_epoch', 'wb') as f:
#     torch.save(checkpoint, f)

In [None]:
print(sample(net, 600, prime='Почему в России', temperature = 0.7, top_k=7))