In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchinfo import summary
from tqdm import tqdm

device = torch.device('mps')

# ================== Load and Preprocess Text ==================
print("📥 Loading dataset...")
with open("shakespeare_full_cleaned.txt", "r", encoding="utf-8") as f:
    text_data = f.read().lower()

# ================== Tokenization (Character-Level) ==================
chars = list(text_data)
vocab = sorted(set(chars))
char2idx = {ch: idx for idx, ch in enumerate(vocab)}
idx2char = {idx: ch for ch, idx in char2idx.items()}
vocab_size = len(vocab)
print(f"🔡 Vocab size (characters): {vocab_size}")

# ================== Sequence Preparation ==================
sequence_len = 128
inputs, targets = [], []
for i in range(len(chars) - sequence_len):
    inputs.append([char2idx[c] for c in chars[i:i+sequence_len]])
    targets.append([char2idx[c] for c in chars[i+1:i+sequence_len+1]])

input_tensor = torch.tensor(inputs, dtype=torch.long)
target_tensor = torch.tensor(targets, dtype=torch.long)

# ================== DataLoader ==================
batch_size = 64
dataset = TensorDataset(input_tensor, target_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# ================== Transformer Model ==================
from Transformer import Transformer

embedding_dim = 128

model = Transformer(
    vocabulary_size=vocab_size,
    number_of_embeddings=embedding_dim,
    sequence_len=sequence_len,
    input_dimensions=embedding_dim,
).to(device)

# summary(model, input_size=(1, sequence_len), dtypes=[torch.long])

📥 Loading dataset...
🔡 Vocab size (characters): 56


In [18]:
# ================== Training Setup ==================
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# ================== Training Loop ==================
print("Training started...")
for epoch in range(10):
    model.train()
    total_loss = 0
    for x, y in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output.view(-1, vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

# ================== Save Model ==================
torch.save(model.state_dict(), "word_level_transformer.pth")
print("✅ Model saved to 'word_level_transformer.pth'")

Training started...


Epoch 1: 100%|██████████| 78450/78450 [31:30<00:00, 41.51it/s]  


Epoch 1, Loss: 3066.7395


Epoch 2: 100%|██████████| 78450/78450 [25:09<00:00, 51.98it/s]


Epoch 2, Loss: 1651.1622


Epoch 3: 100%|██████████| 78450/78450 [23:16<00:00, 56.17it/s]


Epoch 3, Loss: 1574.0356


Epoch 4: 100%|██████████| 78450/78450 [26:23<00:00, 49.55it/s] 


Epoch 4, Loss: 1516.0804


Epoch 5: 100%|██████████| 78450/78450 [25:58<00:00, 50.33it/s]


Epoch 5, Loss: 1474.0040


Epoch 6: 100%|██████████| 78450/78450 [27:41<00:00, 47.22it/s]


Epoch 6, Loss: 1460.5759


Epoch 7: 100%|██████████| 78450/78450 [27:48<00:00, 47.02it/s]


Epoch 7, Loss: 1447.3545


Epoch 8: 100%|██████████| 78450/78450 [28:25<00:00, 45.99it/s]


Epoch 8, Loss: 1426.9997


Epoch 9: 100%|██████████| 78450/78450 [28:11<00:00, 46.38it/s]


Epoch 9, Loss: 1377.8276


Epoch 10: 100%|██████████| 78450/78450 [31:44<00:00, 41.19it/s]

Epoch 10, Loss: 1342.0834
✅ Model saved to 'word_level_transformer.pth'





In [24]:
import torch
import torch.nn.functional as F
import re
from Transformer import Transformer  # Your trained model

# ====== Load and Prepare Vocab ======
with open("shakespeare_full_cleaned.txt", "r", encoding="utf-8") as f:
    text_data = f.read().lower()

chars = list(text_data)
vocab = sorted(set(chars))
char2idx = {ch: idx for idx, ch in enumerate(vocab)}
idx2char = {idx: ch for ch, idx in char2idx.items()}
vocab_size = len(vocab)

# ====== Model Setup ======
embedding_dim = 128
sequence_len = 128
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

model = Transformer(
    vocabulary_size=vocab_size,
    number_of_embeddings=embedding_dim,
    sequence_len=sequence_len,
    input_dimensions=embedding_dim,
).to(device)

model.load_state_dict(torch.load("word_level_transformer.pth", map_location=device))
model.eval()
print("✅ Model loaded and ready")

# ====== Seed Input ======
seed_text = input("Enter seed text (default: 'King:\\n'): ").strip()
if not seed_text:
    seed_text = "King:\n"

print("🧠 Seed:", repr(seed_text))

seed_chars = list(seed_text.lower())
seed_ids = [char2idx.get(c, char2idx[' ']) for c in seed_chars]

# Pad/trim seed to sequence length
if len(seed_ids) < sequence_len:
    seed_ids = [char2idx[' ']] * (sequence_len - len(seed_ids)) + seed_ids
else:
    seed_ids = seed_ids[-sequence_len:]

generated = seed_ids.copy()

# ====== Generation Settings ======
num_generate = 500  # total characters to generate
temperature = 0.7
top_p = 0.85

# ====== Generation Loop ======
with torch.no_grad():
    for _ in range(num_generate):
        x = torch.tensor([generated[-sequence_len:]], dtype=torch.long).to(device)
        logits = model(x)[0, -1] / temperature

        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        probs = F.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(probs, dim=-1)

        # Top-p filtering
        cutoff = cumulative_probs > top_p
        cutoff[0] = False  # keep top token
        sorted_logits[cutoff] = float('-inf')
        filtered_probs = F.softmax(sorted_logits, dim=-1)

        next_token = sorted_indices[torch.multinomial(filtered_probs, 1).item()].item()
        generated.append(next_token)

# ====== Decode Output ======
result = ''.join([idx2char[i] for i in generated])
print("\n📝 Generated Text:\n")
print(result)

✅ Model loaded and ready
🧠 Seed: 'King:\n'

📝 Generated Text:

                                                                                                                          king:
this,
win hou she men.
hee.
hin hount.
the hon ho sto shens whon thon whin here whou ther'stron ther.
[houstron rone the son heard sanger wan do the mane the sore on what be what.

he maren anter has le the the in i she lere the whath les mangin the sen man.

prestres the deas hee mereting and and the houstre.

and fortherer warle the has she there the be hale with the the that what of me isherated.

his he sheresed i wentere her for the he mee the the the she hons exit the lethe heer and sene m
