In [1]:
import numpy as np
import torch
from tqdm import tqdm

In [2]:
filename = "wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()

In [3]:
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))

In [4]:
n_chars = len(raw_text)
n_vocab = len(chars)
print("Total Characters: ", n_chars)
print("Total Vocab: ", n_vocab)

Total Characters:  164093
Total Vocab:  65


In [5]:
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
    seq_in = raw_text[i:i + seq_length]
    seq_out = raw_text[i + seq_length]
    dataX.append([char_to_int[char] for char in seq_in])
    dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print("Total Patterns: ", n_patterns)

Total Patterns:  163993


In [37]:
X = torch.tensor(dataX, dtype=torch.float32).reshape(n_patterns, seq_length, 1)
X = X / float(n_vocab)
y = torch.tensor(dataY)

In [38]:
lookback = 1
print(X.shape, y.shape)

torch.Size([163993, 100, 1]) torch.Size([163993])


In [43]:
import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=100, num_layers=1, batch_first=True)
        self.linear = nn.Linear(100, n_vocab)
    def forward(self, x):
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.linear(x)
        return x

In [44]:
import numpy as np
import torch.optim as optim
import torch.utils.data as data

def my_collate(batch):

    # Preparing input sequences
    x = [item[0] for item in batch]
    x = torch.stack(x)
    # Preparing target values
    y = [item[1] for item in batch]
    y = torch.stack(y)

    return [x, y]

device = torch.device("cuda:0")

model = LSTMModel().to(device)
optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
loader = tqdm(data.DataLoader(data.TensorDataset(X, y), shuffle=True, batch_size=8))

best_model = None
best_loss = np.inf

n_epochs = 40
for epoch in range(n_epochs):
    model.float()
    model.train()
    for X_batch, y_batch in loader:
        y_pred = model(X_batch.float().to(device))
        loss = loss_fn(y_pred.to(device), y_batch.long().to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss = 0
    with torch.no_grad():
        for X_batch, y_batch in loader:
            y_pred = model(X_batch.float().to(device))
            loss += loss_fn(y_pred.to(device), y_batch.long().to(device))
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))

torch.save([best_model, char_to_int], "single-char2.pth")

100%|██████████| 20500/20500 [01:46<00:00, 192.90it/s]


Epoch 0: Cross-entropy: 55975.3789
Epoch 1: Cross-entropy: 53499.9961
Epoch 2: Cross-entropy: 51445.9766
Epoch 3: Cross-entropy: 49987.8477
Epoch 4: Cross-entropy: 48593.4180
Epoch 5: Cross-entropy: 47658.7539
Epoch 6: Cross-entropy: 46498.1055
Epoch 7: Cross-entropy: 45960.4023
Epoch 8: Cross-entropy: 45077.9453
Epoch 9: Cross-entropy: 44462.9609
Epoch 10: Cross-entropy: 44053.7812
Epoch 11: Cross-entropy: 43504.8633
Epoch 12: Cross-entropy: 43323.5625
Epoch 13: Cross-entropy: 42822.4805
Epoch 14: Cross-entropy: 42415.9219
Epoch 15: Cross-entropy: 42298.1172
Epoch 16: Cross-entropy: 41780.5938
Epoch 17: Cross-entropy: 41651.9453
Epoch 18: Cross-entropy: 41419.4219
Epoch 19: Cross-entropy: 41470.1211
Epoch 20: Cross-entropy: 41305.1914
Epoch 21: Cross-entropy: 40805.1250
Epoch 22: Cross-entropy: 40698.2734
Epoch 23: Cross-entropy: 40451.4531
Epoch 24: Cross-entropy: 40431.8242
Epoch 25: Cross-entropy: 40370.4219
Epoch 26: Cross-entropy: 40551.8320
Epoch 27: Cross-entropy: 39835.2109
Ep

In [49]:
seq_length = 100
start = np.random.randint(0, len(raw_text)-seq_length)
prompt = raw_text[start:start+seq_length]

In [57]:
import numpy as np
import torch
import torch.nn as nn

best_model, char_to_int = torch.load("single-char.pth")
n_vocab = len(char_to_int)
int_to_char = dict((i, c) for c, i in char_to_int.items())

model = LSTMModel()
model.load_state_dict(best_model)

filename = "wonderland.txt"
seq_length = 100
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()
start = np.random.randint(0, len(raw_text)-seq_length)
prompt = 'alice was at the beach and '
print("xd " + prompt + " dx ")
pattern = [char_to_int[c] for c in prompt]

model.eval()
print('Prompt: "%s"' % prompt)
with torch.no_grad():
    for i in range(1000):
        # format input array of int into PyTorch tensor
        x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
        x = torch.tensor(x, dtype=torch.float32)
        # generate logits as output from the model
        prediction = model(x)
        # convert logits into one character
        index = int(prediction.argmax())
        result = int_to_char[index]
        print(result, end="")
        # append the new character into the prompt for the next iteration
        pattern.append(index)
        pattern = pattern[1:]
print()
print("Done.")

xd alice was at the beach and  dx 
Prompt: "alice was at the beach and "
the white rabbit rea oo the white rabbit ceaone and the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabbit ald the white rabb