In [1]:
import torch
import urllib

import torch.nn as nn
import torch.utils.data as data

from torch.optim import Adam
from random import randint

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
all_text = urllib.request.urlopen("https://s3.amazonaws.com/text-datasets/nietzsche.txt").read().decode("utf-8")
all_text = all_text.lower()

n_chars = len(all_text)

batch_size = 50
hidden_size = 256

num_past_characters = 100

chars = sorted(list(set(all_text)))
chars_dict = dict((char, i) for i, char in enumerate(chars))

n_vocab = len(chars)
print("Total Characters: ", n_chars)
print("Total Vocab: ", n_vocab)

In [23]:
class GenerateTextLSTMModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.lstm = nn.LSTM(1, hidden_size, 2, batch_first=True)

        self.linear = nn.Linear(hidden_size, n_vocab)

    def forward(self, x):
        x, _ = self.lstm(x)

        x = x[:, -1, :]
        x = self.linear(x)

        return x

In [4]:
model = GenerateTextLSTMModel()
model.to(device)

In [5]:
data_x = []
data_y = []

for i in range(n_chars - num_past_characters):
    seq_x = all_text[i: i + num_past_characters]
    seq_y = all_text[i + num_past_characters]
    data_x.append([chars_dict[c] for c in seq_x])
    data_y.append(chars_dict[seq_y])

n_patterns = len(data_x)
print("Total Patterns: ", n_patterns)

In [36]:
x = torch.tensor(data_x, dtype=torch.float32).reshape(n_patterns, num_past_characters, 1)
x = x / float(n_vocab)
y = torch.tensor(data_y)

'('

In [6]:
# loss
loss_fn = nn.CrossEntropyLoss(reduction="sum")

# Optimizer
optimizer = Adam(model.parameters())

loader = data.DataLoader(data.TensorDataset(x, y), shuffle=True, batch_size=batch_size)

for epoch in range(40):
    # Zero your gradients for every batch!
    optimizer.zero_grad()
    
    model.train()
    counter = 0

    for inputs, labels in loader:
        outputs = model(inputs.to(device))

        loss = loss_fn(outputs, labels.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(epoch, counter)
        counter += 1

# torch.save([model.state_dict(), chars_dict], "lstm 2.pth")

In [38]:
# best_model, chars_dict = torch.load("lstm 2.pth")
int_dict = dict((i, c) for c, i in chars_dict.items())
# model.load_state_dict(best_model)

requested_seq = 200

random_index = randint(0, n_chars - num_past_characters)

text = all_text[random_index: random_index + num_past_characters]
print(text)
print('----')
model.eval()

with torch.no_grad():

    for i in range(requested_seq):
        text_seq = text[-num_past_characters:]

        current_data = [[chars_dict[c] / n_vocab] for c in text_seq]

        current_data = torch.FloatTensor([current_data]).to(device)

        prediction = model(current_data)
        # print(prediction.size())
        index = int(prediction.argmax())

        result = int_dict[index]
        text += result

print(text)

torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([15, 84]) torch.Size([15])
torch.Size([

KeyboardInterrupt: 