In [None]:
#%pip install -r requirements.txt

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import tiktoken
import kagglehub
from dataset import ShakespeareDataset
from model import LSTMModel
import config

In [None]:
# kagglehub.dataset_download("nenadblagovcanin/shakespeare")


In [None]:
with open('shakespeare.txt', 'r') as f:
    text = f.read()
encoder = tiktoken.get_encoding("gpt2")
tokens = encoder.encode(text)

In [None]:
dataset = ShakespeareDataset(tokens, config.sequence_length)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

In [None]:
vocab_size = encoder.n_vocab
model = LSTMModel(vocab_size, config.embedding_size, config.hidden_size, config.num_layers).to(config.device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

In [None]:
model.train()
for epoch in range(config.num_epochs):
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(config.device), targets.to(config.device)
        batch_size = inputs.size(0)

        # Initialize hidden states
        hidden = model.init_hidden(batch_size, config.hidden_size, config.num_layers, config.device)

        # Forward pass
        outputs, hidden = model(inputs, hidden)

        # Reshape outputs and targets for the loss function
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{config.num_epochs}], Step [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}")

# 5. Save the Model
torch.save(model.state_dict(), "shakespeare_lstm.pth")

In [None]:
def generate_text(model, start_text, max_length=100):
    model.eval()
    generated_tokens = encoder.encode(start_text)
    input_ids = torch.tensor(generated_tokens[-config.sequence_length:], dtype=torch.long).unsqueeze(0).to(config.device)

    hidden = model.init_hidden(1)
    for _ in range(max_length):
        outputs, hidden = model(input_ids, hidden)
        next_token = torch.argmax(outputs[:, -1, :], dim=-1).item()
        generated_tokens.append(next_token)

        # Prepare input for the next time step
        input_ids = torch.tensor([generated_tokens[-config.sequence_length:]], dtype=torch.long).to(config.device)

    return encoder.decode(generated_tokens)

# Example usage
start_text = "To be, or not to be, that is the question:"
generated_text = generate_text(model, start_text, max_length=100)
print(generated_text)