In [1]:
import torch.nn as nn
import torch
import numpy as np

In [2]:
DATA_PATH="/kaggle/input/shakespeare-txt/shakespeare.txt"
with open(DATA_PATH, 'r') as f:
    content = f.read()

chars = sorted(list(set(content)))

In [3]:
idx_to_char = {i: char for i, char in enumerate(chars)}
char_to_idx = {char: i for i, char in enumerate(chars)}

In [4]:
def create_dataset(seq_length=100, data=""):
    inputs = []
    targets = []
    for i in range(len(content)-seq_length):
        inputs.append([char_to_idx[char] for char in data[i:i+seq_length]])
        targets.append([char_to_idx[char] for char in data[i+seq_length]])
    return inputs, targets

inputs, targets = create_dataset(data=content)

In [5]:
class TextGenNetwork(nn.Module):
    def __init__(self):
        super(TextGenNetwork, self).__init__()

        self.embedding = nn.Embedding(num_embeddings=len(chars), embedding_dim=64)
        self.lstm = nn.LSTM(input_size=64, hidden_size=512, num_layers=2, batch_first=True)
        self.fc = nn.Linear(in_features=512, out_features=len(chars))

    def forward(self, x, hidden=None, return_hidden=False):

        x = self.embedding(x)
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out[:, -1, :])

        if return_hidden:
            return out, hidden
        return out

    def init_hidden(self, batch_size):
        return (torch.zeros(2, batch_size, 512).to(next(self.parameters()).device),
                torch.zeros(2, batch_size, 512).to(next(self.parameters()).device))


network = TextGenNetwork()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network = network.to(device)
network

TextGenNetwork(
  (embedding): Embedding(91, 64)
  (lstm): LSTM(64, 512, num_layers=2, batch_first=True)
  (fc): Linear(in_features=512, out_features=91, bias=True)
)

In [6]:
print(len(targets))

5458837


In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(network.parameters(), lr=0.01)

batch_size = 1024
n_epochs = 10

In [8]:
def train(model, inputs, targets, batch_size, n_epochs, seq_length=100):
    model.train()  # Set the model to training mode

    for epoch in range(n_epochs):
        hidden = model.init_hidden(batch_size)  # Initialize hidden state

        for i in range(0, len(inputs), batch_size):
            x_batch = torch.tensor(inputs[i:i + batch_size]).to(device)
            y_batch = torch.tensor(targets[i:i + batch_size]).to(device)

            hidden = None

            optimizer.zero_grad()

            output, hidden = model(x_batch, hidden, return_hidden=True)

            loss = criterion(output.view(-1, len(chars)), y_batch.view(-1))

            loss.backward()
            optimizer.step()
            torch.save(model.state_dict(), 'parameters.pth')

        print(f'Epoch: {epoch+1}/{n_epochs}, Loss: {loss.item()}')

In [9]:
train(network, inputs, targets, batch_size, n_epochs)

Epoch: 1/10, Loss: 1.4202162027359009
Epoch: 2/10, Loss: 1.0304911136627197
Epoch: 3/10, Loss: 0.9422156810760498
Epoch: 4/10, Loss: 0.881528913974762
Epoch: 5/10, Loss: 0.8907428979873657


KeyboardInterrupt: 

In [10]:
def generate_text(model, start_str, n_chars, temperature=1.0):
    model.eval()

    input_seq = torch.tensor([char_to_idx[char] for char in start_str], dtype=torch.long).unsqueeze(0).to(device)

    generated_text = start_str
    hidden = None

    for _ in range(n_chars):
        output, hidden = model(input_seq, hidden, return_hidden=True)

        output = output / temperature
        probs = torch.softmax(output, dim=1).detach().cpu().numpy()

        next_char_idx = np.random.choice(len(chars), p=probs[0])
        next_char = idx_to_char[next_char_idx]

        generated_text += next_char

        input_seq = torch.tensor([[next_char_idx]], dtype=torch.long).to(device)

    return generated_text


In [18]:
start_str = "Once upon a "
generated_text = generate_text(network, start_str, 400, temperature=0.8)
print(generated_text)

Once upon a moves.
    How to my lady the buttasholus, to he mas And laugh for it.
  ARMADOAN enderth will smilen so gright
            I we that the fill finds,
    This is the face swith with the posset you'd with
    The better tweirs, who do west whose to brong-
    And to us, or knowledge with transt the company the such me,
    And sake and look's should the say tombland,
    The low begging, as not lik
