In [1]:
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Data Preparation

## Load Data

In [2]:
# Load the text8 dataset
dataset = load_dataset("afmck/text8")

# The dataset has one long text string
print(dataset)

# Check datasize
print(len(dataset["train"][0]["text"]))

# Preview first 500 characters
print(dataset["train"][0]["text"][:500])

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1
    })
    test: Dataset({
        features: ['text'],
        num_rows: 1
    })
})
90000000
 anarchism originated as a term of abuse first used against early working class radicals including the diggers of the english revolution and the sans culottes of the french revolution whilst the term is still used in a pejorative way to describe any act that used violent means to destroy the organization of society it has also been taken up as a positive label by self defined anarchists the word anarchism is derived from the greek without archons ruler chief king anarchism as a political philoso


## Splice Small Subset

In [3]:
full_text = dataset["train"][0]["text"]
# extract small subset for faster experiments
small_text = dataset["train"][0]["text"][:500000]  # first 500,000 chars
print(f"Length of small text subset: {len(small_text)} characters")
# split into training and validation sets
train_text = small_text[:450000]  # first 450,000 chars for training
val_text = small_text[450000:]    # last 50,000 chars for validation
print(f"Length of training text: {len(train_text)} characters")
print(f"Length of validation text: {len(val_text)} characters")

Length of small text subset: 500000 characters
Length of training text: 450000 characters
Length of validation text: 50000 characters


# Tokenizer

In [4]:
# Build vocabulary (lowercase + space + a few punctuations)
char_set = list("abcdefghijklmnopqrstuvwxyz ")
char_to_int = {ch:i for i,ch in enumerate(char_set)}
int_to_char = {i:ch for ch,i in char_to_int.items()}

def encode(s):
    """Encode string to array of integers"""
    ids = [char_to_int[c] for c in s]
    return np.array(ids, dtype=np.uint8)  # use np.uint8 to save space

def decode(ids):
    """Decode array of integers to string"""
    return ''.join(int_to_char[i] for i in ids)
# Test encoding and decoding
test_str = "hello world"
encoded = encode(test_str)
decoded = decode(encoded)
assert test_str == decoded, "Encoding/decoding failed"
print(f"Test string: {test_str}")

Test string: hello world


In [5]:
# Tokenize the Text Data
train_text_int = encode(train_text)
test_text_int = encode(val_text)

# Data Loader

## Wrap in pyTorch Dataset

In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class Text8Dataset(Dataset):
    def __init__(self, data, seq_len):
        """
        data: 1D numpy array of token IDs (uint8 or int)
        seq_len: length T of each input sequence
        """
        self.data = data
        self.seq_len = seq_len

    def __len__(self):
        # Number of possible sequences
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        # input is [idx : idx+T]
        # target is next char [idx+1 : idx+T+1]
        x = self.data[idx : idx + self.seq_len]
        y = self.data[idx + 1 : idx + 1 + self.seq_len]

        # Convert to torch tensors
        x = torch.tensor(x, dtype=torch.long)
        y = torch.tensor(y, dtype=torch.long)
        return x, y


## Create DataLoader

In [7]:
SEQ_LEN = 64
BATCH_SIZE = 128

train_dataset = Text8Dataset(train_text_int, seq_len=SEQ_LEN)
test_dataset  = Text8Dataset(test_text_int,  seq_len=SEQ_LEN)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=True,
)


In [8]:
# Sanity check the DataLoader
# Should be shifted by one between input and target
x, y = next(iter(train_loader))
print("Input batch shape:", x.shape)   # (B, T)
print("Target batch shape:", y.shape)  # (B, T)
print("Example decoded input:", decode(x[0].numpy()))
print("Example decoded target:", decode(y[0].numpy()))


Input batch shape: torch.Size([128, 64])
Target batch shape: torch.Size([128, 64])
Example decoded input:  and secretary of state william seward without his bodyguard war
Example decoded target: and secretary of state william seward without his bodyguard ward


# LSTM Model

In [9]:
import torch
import torch.nn as nn

class LSTMCharModel(nn.Module):
    def __init__(self, vocab_size=27, embed_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()

        # 1. Embedding layer
        self.embed = nn.Embedding(vocab_size, embed_dim)

        # 2. LSTM
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )

        # 3. Output projection
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        """
        x: (B, T)
        returns logits: (B, T, vocab_size)
        """
        x = self.embed(x)              # (B, T, embed_dim)
        out, hidden = self.lstm(x, hidden)  # (B, T, hidden_dim)
        logits = self.fc(out)          # (B, T, vocab_size)
        return logits, hidden


In [10]:
# Sanity check the model
model = LSTMCharModel(vocab_size=27)
x, y = next(iter(train_loader))
logits, _ = model(x)

print("Logits shape:", logits.shape)
# Should be (B, T, vocab_size)

Logits shape: torch.Size([128, 64, 27])


# Loss function and optimizer

In [11]:
criterion = nn.CrossEntropyLoss()
# Pytorch's CrossEntropyLoss expects inputs of shape (B*T, C) and targets of shape (B*T)
logits_flat = logits.reshape(-1, 27)
targets_flat = y.reshape(-1)
loss = criterion(logits_flat, targets_flat)

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

## Accuracy Function

In [13]:
def compute_accuracy(logits, targets):
    """
    logits: (B, T, vocab)
    targets: (B, T)
    Returns:
        acc_all: accuracy over all positions
        acc_last: accuracy at last position only
    """
    preds = logits.argmax(dim=-1)          # (B, T)

    correct_all = (preds == targets).float().mean().item()
    correct_last = (preds[:, -1] == targets[:, -1]).float().mean().item()

    return correct_all, correct_last


# Training Loop

## Initialization

In [14]:
import torch
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

model = LSTMCharModel(vocab_size=27).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Using device: cuda


In [15]:
def train_lstm(
    model,
    train_loader,
    test_loader,
    optimizer,
    criterion,
    num_epochs=3,
):
    model.train()

    for epoch in range(1, num_epochs + 1):
        total_loss = 0.0

        for x, y in train_loader:
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()

            logits, _ = model(x)  # (B, T, vocab)
            loss = criterion(logits.reshape(-1, 27), y.reshape(-1))

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)

        # Evaluate
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            acc_all_list = []
            acc_last_list = []

            for x_val, y_val in test_loader:
                x_val = x_val.to(device)
                y_val = y_val.to(device)

                logits, _ = model(x_val)
                loss = criterion(logits.reshape(-1, 27), y_val.reshape(-1))
                val_loss += loss.item()

                acc_all, acc_last = compute_accuracy(logits, y_val)
                acc_all_list.append(acc_all)
                acc_last_list.append(acc_last)

        val_loss /= len(test_loader)
        avg_acc_all = sum(acc_all_list) / len(acc_all_list)
        avg_acc_last = sum(acc_last_list) / len(acc_last_list)

        print(f"Epoch {epoch}")
        print(f"  Train Loss: {avg_loss:.4f}")
        print(f"  Val Loss  : {val_loss:.4f}")
        print(f"  Accuracy (all positions): {avg_acc_all*100:.2f}%")
        print(f"  Accuracy (next char)    : {avg_acc_last*100:.2f}%")
        print()

        model.train()

In [21]:
# Sanity Training
train_lstm(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    optimizer=optimizer,
    criterion=criterion,
    num_epochs=3,
)

Epoch 1
  Train Loss: 1.1411
  Val Loss  : 1.6842
  Accuracy (all positions): 55.58%
  Accuracy (next char)    : 56.75%

Epoch 2
  Train Loss: 0.6355
  Val Loss  : 2.3050
  Accuracy (all positions): 53.64%
  Accuracy (next char)    : 54.55%

Epoch 3
  Train Loss: 0.4636
  Val Loss  : 2.6788
  Accuracy (all positions): 53.24%
  Accuracy (next char)    : 54.07%



# Custom Training Loop

In [16]:
def run_lstm_experiment(
    lr=1e-3,
    train_batch_size=128,
    test_batch_size=512,
    seq_len=64,
    num_epochs=3,
    embed_dim=128,
    hidden_dim=256,
    num_layers=2,
):
    """
    Train an LSTM on text8 with configurable hyperparameters.

    Returns a dict with final train/val loss and accuracies.
    """

    # 1. Build datasets and dataloaders for this experiment
    train_dataset = Text8Dataset(train_text_int, seq_len=seq_len)
    test_dataset  = Text8Dataset(test_text_int,  seq_len=seq_len)

    train_loader = DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        drop_last=True,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=test_batch_size,
        shuffle=False,
        drop_last=True,
    )

    # 2. Build model, loss, and optimizer
    vocab_size = 27
    model = LSTMCharModel(
        vocab_size=vocab_size,
        embed_dim=embed_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    print(f"Starting LSTM experiment:")
    print(f"  lr={lr}, epochs={num_epochs}, seq_len={seq_len}")
    print(f"  train_batch_size={train_batch_size}, test_batch_size={test_batch_size}")
    print(f"  embed_dim={embed_dim}, hidden_dim={hidden_dim}, num_layers={num_layers}")
    print()

    # 3. Training loop
    for epoch in range(1, num_epochs + 1):
        model.train()
        total_train_loss = 0.0

        for x, y in tqdm(train_loader, desc=f"Epoch {epoch} [train]", leave=False):
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            logits, _ = model(x)  # (B, T, vocab)
            loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # 4. Evaluation on test set
        model.eval()
        total_val_loss = 0.0
        acc_all_list = []
        acc_last_list = []

        with torch.no_grad():
            for x_val, y_val in tqdm(test_loader, desc=f"Epoch {epoch} [val]", leave=False):
                x_val = x_val.to(device)
                y_val = y_val.to(device)

                logits, _ = model(x_val)
                loss = criterion(logits.reshape(-1, vocab_size), y_val.reshape(-1))
                total_val_loss += loss.item()

                acc_all, acc_last = compute_accuracy(logits, y_val)
                acc_all_list.append(acc_all)
                acc_last_list.append(acc_last)

        avg_val_loss = total_val_loss / len(test_loader)
        avg_acc_all = sum(acc_all_list) / len(acc_all_list)
        avg_acc_last = sum(acc_last_list) / len(acc_last_list)

        print(f"Epoch {epoch}/{num_epochs}")
        print(f"  Train loss: {avg_train_loss:.4f}")
        print(f"  Val loss  : {avg_val_loss:.4f}")
        print(f"  Accuracy (all positions): {avg_acc_all*100:.2f}%")
        print(f"  Accuracy (next char)    : {avg_acc_last*100:.2f}%")
        print()

    # 5. Return metrics for logging / comparison
    return {
        "lr": lr,
        "train_batch_size": train_batch_size,
        "test_batch_size": test_batch_size,
        "seq_len": seq_len,
        "num_epochs": num_epochs,
        "embed_dim": embed_dim,
        "hidden_dim": hidden_dim,
        "num_layers": num_layers,
        "final_train_loss": avg_train_loss,
        "final_val_loss": avg_val_loss,
        "final_acc_all": avg_acc_all,
        "final_acc_last": avg_acc_last,
    }

# Hyper Parameter Experimentation

## LR

In [32]:
lr_results = {}   # store per-LR results

for lr in [1e-3, 3e-4, 1e-4]:
    print(f"=== Running lr={lr} ===")

    metrics = run_lstm_experiment(
        lr=lr,
        train_batch_size=128,
        test_batch_size=512,
        seq_len=64,
        num_epochs=1,  # only 1 epoch for fast sweep
    )

    lr_results[lr] = metrics   # store metrics for plotting

    print("Final val loss:", metrics["final_val_loss"])
    print("Final overall acc:", metrics["final_acc_all"] * 100, "%")
    print("Final next-char acc:", metrics["final_acc_last"] * 100, "%")
    print("-" * 40)

=== Running lr=0.001 ===
Starting LSTM experiment:
  lr=0.001, epochs=1, seq_len=64
  train_batch_size=128, test_batch_size=512
  embed_dim=128, hidden_dim=256, num_layers=2



                                                                    

Epoch 1/1
  Train loss: 1.1876
  Val loss  : 1.5879
  Accuracy (all positions): 56.35%
  Accuracy (next char)    : 57.59%

Final val loss: 1.5879104960825026
Final overall acc: 56.35372830420425 %
Final next-char acc: 57.59101159793815 %
----------------------------------------
=== Running lr=0.0003 ===
Starting LSTM experiment:
  lr=0.0003, epochs=1, seq_len=64
  train_batch_size=128, test_batch_size=512
  embed_dim=128, hidden_dim=256, num_layers=2



                                                                    

Epoch 1/1
  Train loss: 1.4711
  Val loss  : 1.4215
  Accuracy (all positions): 56.79%
  Accuracy (next char)    : 58.07%

Final val loss: 1.4215377114482761
Final overall acc: 56.79384211903995 %
Final next-char acc: 58.06821842783505 %
----------------------------------------
=== Running lr=0.0001 ===
Starting LSTM experiment:
  lr=0.0001, epochs=1, seq_len=64
  train_batch_size=128, test_batch_size=512
  embed_dim=128, hidden_dim=256, num_layers=2



                                                                    

Epoch 1/1
  Train loss: 1.8004
  Val loss  : 1.5480
  Accuracy (all positions): 53.27%
  Accuracy (next char)    : 54.35%

Final val loss: 1.5480079564851583
Final overall acc: 53.26648200910115 %
Final next-char acc: 54.34519974226804 %
----------------------------------------




## Batch Size

In [17]:
for batch in [64, 128, 256]:
    metrics = run_lstm_experiment(
        lr=3e-4,
        train_batch_size=batch,
        test_batch_size=512,
        seq_len=64,
        num_epochs=1,
    )
    print("Batch", batch, "→ next-char acc:", metrics["final_acc_last"] * 100, "%")
    print("-" * 40)

Starting LSTM experiment:
  lr=0.0003, epochs=1, seq_len=64
  train_batch_size=64, test_batch_size=512
  embed_dim=128, hidden_dim=256, num_layers=2



                                                                    

Epoch 1/1
  Train loss: 1.3418
  Val loss  : 1.4597
  Accuracy (all positions): 57.25%
  Accuracy (next char)    : 58.54%

Batch 64 → next-char acc: 58.537371134020624 %
----------------------------------------
Starting LSTM experiment:
  lr=0.0003, epochs=1, seq_len=64
  train_batch_size=128, test_batch_size=512
  embed_dim=128, hidden_dim=256, num_layers=2



                                                                    

Epoch 1/1
  Train loss: 1.4769
  Val loss  : 1.4201
  Accuracy (all positions): 56.81%
  Accuracy (next char)    : 58.08%

Batch 128 → next-char acc: 58.084326675257735 %
----------------------------------------
Starting LSTM experiment:
  lr=0.0003, epochs=1, seq_len=64
  train_batch_size=256, test_batch_size=512
  embed_dim=128, hidden_dim=256, num_layers=2



                                                                    

Epoch 1/1
  Train loss: 1.6418
  Val loss  : 1.4618
  Accuracy (all positions): 55.65%
  Accuracy (next char)    : 56.87%

Batch 256 → next-char acc: 56.87218105670103 %
----------------------------------------




## Sequence Length

In [18]:
for seq in [64, 128, 256]:
    metrics = run_lstm_experiment(
        lr=3e-4,
        train_batch_size=128,
        test_batch_size=512,
        seq_len=seq,
        num_epochs=1,
    )
    print("Seq Len", seq, "→ next-char acc:", metrics["final_acc_last"] * 100, "%")
    print("-" * 40)

Starting LSTM experiment:
  lr=0.0003, epochs=1, seq_len=64
  train_batch_size=128, test_batch_size=512
  embed_dim=128, hidden_dim=256, num_layers=2



                                                                    

Epoch 1/1
  Train loss: 1.4669
  Val loss  : 1.4269
  Accuracy (all positions): 56.73%
  Accuracy (next char)    : 58.04%

Seq Len 64 → next-char acc: 58.0420425257732 %
----------------------------------------
Starting LSTM experiment:
  lr=0.0003, epochs=1, seq_len=128
  train_batch_size=128, test_batch_size=512
  embed_dim=128, hidden_dim=256, num_layers=2



                                                                    

Epoch 1/1
  Train loss: 1.4093
  Val loss  : 1.4112
  Accuracy (all positions): 57.87%
  Accuracy (next char)    : 58.57%

Seq Len 128 → next-char acc: 58.56556056701031 %
----------------------------------------
Starting LSTM experiment:
  lr=0.0003, epochs=1, seq_len=256
  train_batch_size=128, test_batch_size=512
  embed_dim=128, hidden_dim=256, num_layers=2



                                                                    

Epoch 1/1
  Train loss: 1.3351
  Val loss  : 1.4627
  Accuracy (all positions): 57.54%
  Accuracy (next char)    : 57.92%

Seq Len 256 → next-char acc: 57.91921713917526 %
----------------------------------------




## Model Size

### Medium Model

In [23]:
metrics = run_lstm_experiment(
    lr=3e-4,
    train_batch_size=128,
    test_batch_size=512,
    seq_len=128,
    num_epochs=1,
    embed_dim=256,
    hidden_dim=512,
    num_layers=2,
    )
print("Model Size: Medium → next-char acc:", metrics["final_acc_last"] * 100, "%")
print("-" * 40)

Starting LSTM experiment:
  lr=0.0003, epochs=1, seq_len=128
  train_batch_size=128, test_batch_size=512
  embed_dim=256, hidden_dim=512, num_layers=2



                                                                    

Epoch 1/1
  Train loss: 1.0088
  Val loss  : 2.1696
  Accuracy (all positions): 53.50%
  Accuracy (next char)    : 53.86%

Model Size: Medium → next-char acc: 53.86195231958762 %
----------------------------------------


