In [1]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import kagglehub

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
path = kagglehub.dataset_download("guslovesmath/shakespeare-plays-dataset")
print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/guslovesmath/shakespeare-plays-dataset?dataset_version_number=1...


100%|██████████| 2.62M/2.62M [00:00<00:00, 132MB/s]

Extracting files...
Path to dataset files: /root/.cache/kagglehub/datasets/guslovesmath/shakespeare-plays-dataset/versions/1





In [5]:
file_path = os.path.join(path, "shakespeare_plays.csv")

In [6]:
data = pd.read_csv(file_path)
print("Sample data:\n", data.head())

Sample data:
    Unnamed: 0                  play_name   genre character  act  scene  \
0           0  All's Well That Ends Well  Comedy  Countess    1      1   
1           1  All's Well That Ends Well  Comedy   Bertram    1      1   
2           2  All's Well That Ends Well  Comedy   Bertram    1      1   
3           3  All's Well That Ends Well  Comedy   Bertram    1      1   
4           4  All's Well That Ends Well  Comedy     Lafeu    1      1   

   sentence                                               text     sex  
0         1  In delivering my son from me, I bury a second ...  female  
1         2  And I in going, madam, weep o'er my father's d...    male  
2         3  anew: but I must attend his majesty's command, to    male  
3         4     whom I am now in ward, evermore in subjection.    male  
4         5  You shall find of the king a husband, madam; you,    male  


In [7]:
texts = data['text'].dropna().tolist()
print(f"Total lines of text: {len(texts)}")

Total lines of text: 108093


In [8]:
chars = sorted(list(set(''.join(texts))))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
vocab_size = len(chars)
print(f"Vocabulary size: {vocab_size}")


Vocabulary size: 76


In [9]:
encoded_texts = [[char_to_idx[char] for char in line] for line in texts]

In [10]:
class ShakespeareDataset(Dataset):
    def __init__(self, encoded_texts, seq_length):
        self.data = []
        for line in encoded_texts:
            if len(line) > seq_length:
                for i in range(len(line) - seq_length):
                    self.data.append((line[i:i+seq_length], line[i+1:i+seq_length+1]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, y = self.data[idx]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

SEQ_LENGTH = 50
dataset = ShakespeareDataset(encoded_texts, SEQ_LENGTH)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [38]:
class TextGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(TextGenerator, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)


        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)


        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden):

        x = self.embedding(x)


        out, hidden = self.lstm(x, hidden)


        out = self.fc(out)

        return out, hidden

    def init_hidden(self, batch_size):

        return (torch.zeros(NUM_LAYERS, batch_size, HIDDEN_DIM, device=device),
                torch.zeros(NUM_LAYERS, batch_size, HIDDEN_DIM, device=device))


In [41]:
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
NUM_LAYERS = 4
LEARNING_RATE = 0.001

model = TextGenerator(vocab_size, EMBEDDING_DIM, HIDDEN_DIM, NUM_LAYERS).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [50]:
EPOCHS = 50
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for x, y in dataloader:
        batch_size = x.size(0)
        hidden = model.init_hidden(batch_size)

        x, y = x.to(device), y.to(device)
        hidden = tuple([h.detach() for h in hidden])
        optimizer.zero_grad()
        output, hidden = model(x, hidden)
        loss = criterion(output.view(-1, vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {total_loss / len(dataloader)}")

Epoch 1/50, Loss: 0.24930177411492566
Epoch 2/50, Loss: 0.2488228706469447
Epoch 3/50, Loss: 0.2521035800809446
Epoch 4/50, Loss: 0.24301806369923656
Epoch 5/50, Loss: 0.242953126613768
Epoch 6/50, Loss: 0.24384709817837485
Epoch 7/50, Loss: 0.2429505254356017
Epoch 8/50, Loss: 0.2364598741720182
Epoch 9/50, Loss: 0.23523105704080985
Epoch 10/50, Loss: 0.2349343531712982
Epoch 11/50, Loss: 0.23977079163797155
Epoch 12/50, Loss: 0.2368808731159068
Epoch 13/50, Loss: 0.23119572735165958
Epoch 14/50, Loss: 0.2297775713923555
Epoch 15/50, Loss: 0.23107258281352358
Epoch 16/50, Loss: 0.22803391719827001
Epoch 17/50, Loss: 0.230105095398352
Epoch 18/50, Loss: 0.2270642726317696
Epoch 19/50, Loss: 0.22856879396283108
Epoch 20/50, Loss: 0.22847712345375037
Epoch 21/50, Loss: 0.22462937295992183
Epoch 22/50, Loss: 0.21635013425387212
Epoch 23/50, Loss: 0.21977813750135233
Epoch 24/50, Loss: 0.2269503099855429
Epoch 25/50, Loss: 0.23173655245615088
Epoch 26/50, Loss: 0.21659642337642102
Epoch 27

In [51]:
def generate_text(model, start_text, length):
    model.eval()
    hidden = model.init_hidden(1)
    input_seq = torch.tensor([char_to_idx[char] for char in start_text], dtype=torch.long).unsqueeze(0).to(device)
    generated_text = start_text

    with torch.no_grad():
        for _ in range(length):
            output, hidden = model(input_seq, hidden)
            char_idx = torch.argmax(output[:, -1, :]).item()
            generated_text += idx_to_char[char_idx]
            input_seq = torch.tensor([[char_idx]], dtype=torch.long).to(device)

    return generated_text


In [57]:
start_text = input()
generated = generate_text(model, start_text, 100)
print("Generated Text:")
print(generated)

ap
Generated Text:
appear you with this ridiculous boldness before my lady? and in these leasure and shipping, and salt I
