In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
%%capture
np.random.seed(0)
torch.manual_seed(0)

In [None]:
CHARS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" \
        + "!\"#$%&\'()*+,-./:;—<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"

corpus = []
with open('shakespeare.txt', 'r') as f:
    for line in f:
        for char in line.strip():
            corpus.append(char)
        corpus.append('\n')

print("Total number of characters:", len(corpus))
print("\n\n")
print("First 100 characters:\n")
print(corpus[:100])

In [None]:
char2idx = {char : i for i, char in enumerate(CHARS)}
idx2char = {i : char for i, char in enumerate(CHARS)}

NUM_CHARS = len(char2idx)
print("Total number of distinct chars:", NUM_CHARS)

In [None]:
corpus_with_indices = [char2idx[char] for char in corpus]

print("Corpus with indices:")
print(corpus_with_indices[:100])

SIZE_OF_SNIPPET = 250
dataset = []
for _ in range(2000):
    
    snipped_start = np.random.randint(0, len(corpus_with_indices) - SIZE_OF_SNIPPET)
    snipped = corpus_with_indices[snipped_start:snipped_start + SIZE_OF_SNIPPET]
    
    dataset.append((
        torch.LongTensor(snipped[:-1]),
        torch.LongTensor(snipped[1:])
    ))

print("\nSize of dataset:", len(dataset))

X = torch.stack([xy[0] for xy in dataset])
Y = torch.stack([xy[1] for xy in dataset])

In [None]:
class ShakespeareGenerator(nn.Module):

    def __init__(self, embedding_size, hidden_size):

        super().__init__()

        self.embedding_size = embedding_size
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(
            num_embeddings=NUM_CHARS,
            embedding_dim=self.embedding_size
        )
        self.lstm = nn.LSTM(
            input_size=self.embedding_size,
            hidden_size=self.hidden_size
        )
        self.linear = nn.Linear(
            in_features=self.hidden_size,
            out_features=NUM_CHARS
        )


    def forward(self, batched_inputs):

        batch_size = batched_inputs.shape[1]
        h, c = self.get_initial_hc(batch_size)
        seq_len = batched_inputs.shape[0]

        embeddings = self.embedding(batched_inputs)
        outputs, (h, c) = self.lstm(
                embeddings.reshape(seq_len, batch_size, self.embedding_size),
                (h, c)
        )
        outputs = self.linear(torch.squeeze(outputs))

        return outputs, (h, c)


    def get_initial_hc(self, batch_size):

        return (torch.zeros(1, batch_size, self.hidden_size),
                torch.zeros(1, batch_size, self.hidden_size))


    def generate(self, initial_token=' ', num_tokens=100, temperature=1):
        
        with torch.no_grad():
        
            token = torch.LongTensor([char2idx[initial_token]])
            h, c = self.get_initial_hc(1)
            chars = []
            
            for _ in range(num_tokens):
                
                chars.append(idx2char[token.item()])
                
                inp = self.embedding(token)
                out, (h, c) = self.lstm(inp.reshape(1, 1, self.embedding_size), (h, c))
                dist = self.linear(out.reshape(1, -1))
                dist = dist.data.view(-1).div(temperature).exp()
                chosen_i = torch.multinomial(dist, 1)[0]
                token = torch.LongTensor([chosen_i])
                
            return ''.join(chars[1:])


In [None]:
EPOCHS = 500
LR = 0.1
BETA = 0.8
EMBEDDING_SIZE = 100
HIDDEN_SIZE = 64

USE_PRETRAINED = True

net = ShakespeareGenerator(EMBEDDING_SIZE, HIDDEN_SIZE).float()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=BETA)

if USE_PRETRAINED:
    net.load_state_dict(torch.load('shakespeare.pt', map_location=lambda storage, loc: storage))
    
else:
    for _ in range(EPOCHS):

        output, _ = net(X.transpose(0, 1))
        output = output.transpose(0, 1)

        loss = criterion(output.reshape(-1, NUM_CHARS), Y.reshape(-1))

        print(loss.item())
        net.zero_grad()
        loss.backward()
        optimizer.step()
    

In [None]:
print(net.generate(temperature=1, num_tokens=1000))

In [None]:
print(net.generate(temperature=1.5, num_tokens=1000))

In [None]:
print(net.generate(temperature=0.25, num_tokens=1000))