In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from tqdm import tqdm

# load ascii text and convert to lowercase
filename = "wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()

# create mapping of unique chars to integers
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))

# summarize the loaded data
n_chars = len(raw_text)
n_vocab = len(chars)
print("Total Characters: ", n_chars)
print("Total Vocab: ", n_vocab)

# prepare the dataset of input to output pairs encoded as integers
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
    seq_in = raw_text[i:i + seq_length]
    seq_out = raw_text[i + seq_length]
    dataX.append([char_to_int[char] for char in seq_in])
    dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print("Total Patterns: ", n_patterns)

# reshape X to be [samples, time steps, features]
X = torch.tensor(dataX, dtype=torch.float32).reshape(n_patterns, seq_length, 1)
X = X / float(n_vocab)
y = torch.tensor(dataY)

class CharModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=256, num_layers=1)
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(256, n_vocab)
    def forward(self, x):
        x, _ = self.lstm(x)
        # take only the last output
        x = x[:, -1, :]
        # produce output
        x = self.linear(self.dropout(x))
        return x

n_epochs = 50
batch_size = 256 #128
model = CharModel()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss(reduction="sum")
loader = data.DataLoader(data.TensorDataset(X, y), shuffle=True, batch_size=batch_size)

best_model = None
best_loss = np.inf
for epoch in range(n_epochs):
    model.train()
    with tqdm(total=len(loader), ncols=80, desc=f"Epoch {epoch+1}/{n_epochs}") as pbar:
        for X_batch, y_batch in loader:
            y_pred = model(X_batch.to(device))
            loss = loss_fn(y_pred, y_batch.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.item())
            pbar.update()
    # Validation
    model.eval()
    loss = 0
    with torch.no_grad():
        for X_batch, y_batch in loader:
            y_pred = model(X_batch.to(device))
            loss += loss_fn(y_pred, y_batch.to(device))
        if loss < best_loss:
            best_loss = loss
            best_model = model.state_dict()
        print("Epoch %d: Cross-entropy: %.4f" % (epoch, loss))

torch.save([best_model, char_to_int], "single-char.pth")

# Generation using the trained model
best_model, char_to_int = torch.load("single-char.pth")
n_vocab = len(char_to_int)
int_to_char = dict((i, c) for c, i in char_to_int.items())
model.load_state_dict(best_model)

# randomly generate a prompt
filename = "wonderland.txt"
seq_length = 100
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()
start = np.random.randint(0, len(raw_text)-seq_length)
prompt = raw_text[start:start+seq_length]
pattern = [char_to_int[c] for c in prompt]

model.eval()
print('Prompt: "%s"' % prompt)
with torch.no_grad():
    for i in range(1000):
        # format input array of int into PyTorch tensor
        x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
        x = torch.tensor(x, dtype=torch.float32)
        # generate logits as output from the model
        prediction = model(x.to(device))
        # convert logits into one character
        index = int(prediction.argmax())
        result = int_to_char[index]
        print(result, end="")
        # append the new character into the prompt for the next iteration
        pattern.append(index)
        pattern = pattern[1:]
print()
print("Done.")


Total Characters:  144683
Total Vocab:  49
Total Patterns:  144583


Epoch 1/50: 100%|███████████████████| 565/565 [00:12<00:00, 43.60it/s, loss=595]


Epoch 0: Cross-entropy: 426854.7812


Epoch 2/50: 100%|███████████████████| 565/565 [00:12<00:00, 43.96it/s, loss=572]


Epoch 1: Cross-entropy: 403909.2812


Epoch 3/50: 100%|███████████████████| 565/565 [00:13<00:00, 43.45it/s, loss=562]


Epoch 2: Cross-entropy: 391075.8750


Epoch 4/50: 100%|███████████████████| 565/565 [00:13<00:00, 42.62it/s, loss=528]


Epoch 3: Cross-entropy: 382121.9375


Epoch 5/50: 100%|███████████████████| 565/565 [00:14<00:00, 39.99it/s, loss=509]


Epoch 4: Cross-entropy: 374673.6562


Epoch 6/50: 100%|███████████████████| 565/565 [00:13<00:00, 41.59it/s, loss=535]


Epoch 5: Cross-entropy: 365765.0625


Epoch 7/50: 100%|███████████████████| 565/565 [00:13<00:00, 42.25it/s, loss=505]


Epoch 6: Cross-entropy: 358826.1875


Epoch 8/50: 100%|███████████████████| 565/565 [00:13<00:00, 42.64it/s, loss=476]


Epoch 7: Cross-entropy: 350310.7500


Epoch 9/50: 100%|███████████████████| 565/565 [00:13<00:00, 41.87it/s, loss=474]


Epoch 8: Cross-entropy: 342942.5625


Epoch 10/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.67it/s, loss=490]


Epoch 9: Cross-entropy: 336938.5938


Epoch 11/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.55it/s, loss=457]


Epoch 10: Cross-entropy: 330695.5938


Epoch 12/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.56it/s, loss=475]


Epoch 11: Cross-entropy: 325210.2812


Epoch 13/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.99it/s, loss=436]


Epoch 12: Cross-entropy: 318751.3438


Epoch 14/50: 100%|██████████████████| 565/565 [00:13<00:00, 40.97it/s, loss=432]


Epoch 13: Cross-entropy: 314261.1250


Epoch 15/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.01it/s, loss=437]


Epoch 14: Cross-entropy: 308449.5312


Epoch 16/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.67it/s, loss=420]


Epoch 15: Cross-entropy: 302686.6562


Epoch 17/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.38it/s, loss=430]


Epoch 16: Cross-entropy: 298673.0312


Epoch 18/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.80it/s, loss=425]


Epoch 17: Cross-entropy: 293267.4688


Epoch 19/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.59it/s, loss=432]


Epoch 18: Cross-entropy: 289345.9062


Epoch 20/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.78it/s, loss=415]


Epoch 19: Cross-entropy: 284572.5938


Epoch 21/50: 100%|██████████████████| 565/565 [00:13<00:00, 40.60it/s, loss=377]


Epoch 20: Cross-entropy: 279087.2812


Epoch 22/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.16it/s, loss=444]


Epoch 21: Cross-entropy: 276465.8750


Epoch 23/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.02it/s, loss=409]


Epoch 22: Cross-entropy: 272151.0000


Epoch 24/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.98it/s, loss=412]


Epoch 23: Cross-entropy: 268422.8438


Epoch 25/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.30it/s, loss=403]


Epoch 24: Cross-entropy: 264836.5625


Epoch 26/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.96it/s, loss=408]


Epoch 25: Cross-entropy: 263329.4688


Epoch 27/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.99it/s, loss=373]


Epoch 26: Cross-entropy: 257346.4688


Epoch 28/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.37it/s, loss=374]


Epoch 27: Cross-entropy: 254523.4688


Epoch 29/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.72it/s, loss=369]


Epoch 28: Cross-entropy: 252341.7656


Epoch 30/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.14it/s, loss=395]


Epoch 29: Cross-entropy: 248639.6719


Epoch 31/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.47it/s, loss=342]


Epoch 30: Cross-entropy: 247495.7344


Epoch 32/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.83it/s, loss=363]


Epoch 31: Cross-entropy: 244068.3438


Epoch 33/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.64it/s, loss=379]


Epoch 32: Cross-entropy: 242632.9531


Epoch 34/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.37it/s, loss=354]


Epoch 33: Cross-entropy: 237857.7344


Epoch 35/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.18it/s, loss=391]


Epoch 34: Cross-entropy: 236458.7031


Epoch 36/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.24it/s, loss=326]


Epoch 35: Cross-entropy: 235337.8750


Epoch 37/50: 100%|██████████████████| 565/565 [00:13<00:00, 40.95it/s, loss=326]


Epoch 36: Cross-entropy: 234763.7656


Epoch 38/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.18it/s, loss=342]


Epoch 37: Cross-entropy: 230894.2656


Epoch 39/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.25it/s, loss=399]


Epoch 38: Cross-entropy: 229638.5156


Epoch 40/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.97it/s, loss=360]


Epoch 39: Cross-entropy: 226715.9844


Epoch 41/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.87it/s, loss=368]


Epoch 40: Cross-entropy: 224310.3281


Epoch 42/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.29it/s, loss=321]


Epoch 41: Cross-entropy: 222936.1250


Epoch 43/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.04it/s, loss=353]


Epoch 42: Cross-entropy: 221504.2656


Epoch 44/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.67it/s, loss=369]


Epoch 43: Cross-entropy: 222718.9062


Epoch 45/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.68it/s, loss=312]


Epoch 44: Cross-entropy: 220289.9531


Epoch 46/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.89it/s, loss=323]


Epoch 45: Cross-entropy: 216972.5781


Epoch 47/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.33it/s, loss=339]


Epoch 46: Cross-entropy: 216845.5000


Epoch 48/50: 100%|██████████████████| 565/565 [00:13<00:00, 41.48it/s, loss=286]


Epoch 47: Cross-entropy: 214628.2500


Epoch 49/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.35it/s, loss=335]


Epoch 48: Cross-entropy: 211298.2031


Epoch 50/50: 100%|██████████████████| 565/565 [00:13<00:00, 42.41it/s, loss=304]


Epoch 49: Cross-entropy: 211132.0469
Prompt: "tes on their slates, and then added them up, and
reduced the answer to shillings and pence.

“take o"
t hotingte they sares,” the macch hare tooek te senting aotiously, “iot then the roeere of the soiat ” the said to herself, “it would be a genter, iewsh  be iore  a  arocoss make oe an the that sien ied th the ioeo hareer an iers,”

“i shanl that you toul to wour hott the dorcernonn,” said the mock turtle. 
“toe mooh a dutious toond?”,” said alice,
“io you doo’t keke the sooe of the doerest. io an solerfing arouoe_”

“hu wes i than tou doote,” she match hare in a soee aunroe ooeere to ce taid io a toie oong toye, and whin she wasted to tee it tu ho the wan ohe whsh she white rabbit, and sas atin inooing so ce tith tee hoare of the sohde  sie was a little thar soret in the looken to her ane, an the was to toenk the oockr oaddite the roees shat sas the winle gar ane the errphon sad ano thet was the winle gar oo a sirugls  shate tas aoi toeezi