In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import requests


KeyboardInterrupt



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
text = requests.get(url).text

print(text[:500])


In [None]:
# Tokenize by whitespace (simple word-level tokenizer)
words = text.split()
vocab = sorted(set(words))

stoi = {w: i for i, w in enumerate(vocab)}
itos = {i: w for w, i in stoi.items()}
vocab_size = len(vocab)

def encode(seq): return [stoi[w] for w in seq.split() if w in stoi]
def decode(idxs): return " ".join([itos[i] for i in idxs])

data = torch.tensor([stoi[w] for w in words], dtype=torch.long)
print("Vocab size:", vocab_size)
print("Encoded sample:", data[:20])
print("Decoded back:", decode(data[:20].tolist()))


In [None]:
# Train/val/test split
split1 = int(0.1 * len(data))   
split2 = int(0.9 * len(data)) 
test_data = data[:split1]
train_data = data[split1:split2]
val_data  = data[split2:]

print(f"Train size: {len(train_data)}, Val size: {len(val_data)}, Test size: {len(test_data)}")


In [None]:
class WordRNN(nn.Module):
    def __init__(self, vocab_size, hidden_size, n_layers=1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers=n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h=None):
        x = self.embed(x)
        out, h = self.rnn(x, h)
        logits = self.fc(out)
        return logits, h


In [None]:
def get_batch(data, seq_len, batch_size, device):
    ix = torch.randint(0, len(data) - seq_len - 1, (batch_size,))
    x = torch.stack([data[i:i+seq_len] for i in ix])
    y = torch.stack([data[i+1:i+seq_len+1] for i in ix])
    return x.to(device), y.to(device)


In [None]:
def evaluate(model, data, seq_len, batch_size, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss, total_tokens = 0, 0

    with torch.no_grad():
        steps = len(data) // (seq_len * batch_size)
        for i in range(steps):
            start = i * seq_len
            x = data[start:start+seq_len*batch_size].view(batch_size, seq_len).to(device)
            y = data[start+1:start+1+seq_len*batch_size].view(batch_size, seq_len).to(device)

            logits, _ = model(x)
            loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            total_loss += loss.item() * y.numel()
            total_tokens += y.numel()

    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss))
    return avg_loss, perplexity.item()


In [None]:
hidden_size = 256
seq_len = 10
num_epochs = 5  
lr = 0.0005
batch_size = 64
steps_per_epoch = 1000

best_val_loss = float("inf")
patience = 3
patience_counter = 0

model = WordRNN(vocab_size, hidden_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    # Training
    for step in range(steps_per_epoch):
        x, y = get_batch(train_data, seq_len, batch_size, device)

        optimizer.zero_grad()
        logits, _ = model(x)
        loss = criterion(logits.view(-1, vocab_size), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / steps_per_epoch
    train_ppl = torch.exp(torch.tensor(avg_train_loss))

    # Validation
    val_loss, val_ppl = evaluate(model, val_data, seq_len, batch_size, device)

    print(f"Epoch {epoch+1} | "
          f"Train Loss {avg_train_loss:.4f} | Train PPL {train_ppl:.2f} | "
          f"Val Loss {val_loss:.4f} | Val PPL {val_ppl:.2f}")

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "best_word_rnn.pt")
    else:
        patience_counter += 1
        print(f"No improvement. Patience counter = {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

# Load best model before final test
model.load_state_dict(torch.load("best_word_rnn.pt"))


In [None]:
test_loss, test_ppl = evaluate(model, test_data, seq_len, batch_size, device)
print(f"Test Loss: {test_loss:.4f}, Test Perplexity: {test_ppl:.2f}")


In [None]:
def generate(model, start="ROMEO", length=20, temperature=1.0):
    model.eval()
    idx = torch.tensor([[stoi[start]]], device = device)
    h = None
    out = [start]

    for _ in range(length):
        logits, h = model(idx, h)
        logits = logits[:, -1, :] / temperature
        probs = torch.softmax(logits, dim=-1)
        idx = torch.multinomial(probs, num_samples=1).to(device)
        word = itos[idx.item()]
        out.append(word)
    return " ".join(out)


In [None]:
print(generate(model, start="ROMEO:", length=20, temperature=0.8))

## Comment on Validation Results

We observe that while **training loss and perplexity decrease rapidly**, the **validation loss increases sharply** across epochs.  
This indicates that the model is **overfitting** to the training set and failing to generalize.  

The main reason is that our dataset is **too small for word-level modeling**:
- The Shakespeare corpus contains **tens of thousands of unique words** (large vocabulary).
- Many of these words appear only a handful of times.  
- The model can memorize frequent word sequences in the training split, but it has **no statistical basis** for predicting rare or unseen words in the validation split.  
- Cross-entropy punishes these wrong but confident predictions very harshly, causing validation perplexity to explode.

In other words:  
The bad validation results are not a bug in the training loop, but a **consequence of data scarcity relative to vocabulary size**.  

This is why **character-level modeling** (small vocab, dense repetition) or **subword tokenization** (BPE/WordPiece) is generally used on such small corpora.
