# PyTorch NLP: training an RNN (LSTM) for text classification

**Contents**
- Prerequisites & installation
- Imports and lightweight preprocessing utilities (tokenization, Vocab, Dataset)
- Mini synthetic demo dataset (quick smoke test)
- Model: Embedding + LSTM classifier (handling padded sequences)
- Training and evaluation utilities (train/evaluate loops)
- Training run, saving checkpoint, and inference on new sentences


## Prerequisites

If you work in a fresh environment, install PyTorch and (optionally) `torchtext`. Example (CPU-only):
```bash
# Install PyTorch + torchtext (example). Choose the correct command for your CUDA version on https://pytorch.org
pip install torch torchvision torchaudio torchtext nbformat
```

This notebook uses:
- `torch`, `torch.nn`, `torch.utils.data` for model and training utilities
- `torchtext` only optionally for larger datasets (not required for the mini-demo)
- `numpy`, `matplotlib` for convenience/visualization


In [3]:
# Basic imports and optional torchtext check
import os
import random
from collections import Counter

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Optional: check torchtext availability (not required for the mini demo)
try:
    import torchtext
    from torchtext.datasets import IMDB
    TORCHTEXT_AVAILABLE = True
except Exception:
    TORCHTEXT_AVAILABLE = False

print('torch version:', torch.__version__)
print('torchtext available:', TORCHTEXT_AVAILABLE)

torch version: 2.8.0+cu126
torchtext available: False


## Utilities — tokenization, vocabulary, dataset and batching

This section defines lightweight utilities to convert raw text into tensors suitable for PyTorch models:

- `simple_tokenize(text)` — minimal tokenizer that lowercases and splits on whitespace.  `"I Loved the Movie"` → `["i", "loved", "the", "movie"]`
- `Vocab` — builds a mapping token ↔ index, tracks frequencies, supports `min_freq`, and provides `numericalize()`.
- `TextDataset` — a `torch.utils.data.Dataset` that returns `(tensor_of_indices, label)` for each example.
- `collate_batch(batch)` — collate function for `DataLoader` that pads sequences to the same length and returns `(padded_sequences, lengths, labels)`.

These utilities are intentionally simple and easy to understand. For production or larger datasets, consider using `torchtext`, `huggingface/tokenizers`, or other robust tokenization libraries.


In [4]:
# Minimal tokenizer
def simple_tokenize(text: str):
    """Very small tokenizer: lowercases and splits on spaces."""
    return text.lower().strip().split()

# Vocabulary class
from collections import Counter

class Vocab:
    def __init__(self, min_freq: int = 1, specials=['<pad>', '<unk>']):
        self.freqs = Counter()
        self.itos = list(specials)                 # index -> token
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}  # token -> index
        self.min_freq = min_freq

    def build(self, iterator):
        """Build vocabulary from an iterator of token lists."""
        for tokens in iterator:
            self.freqs.update(tokens)
        # Add tokens sorted by frequency (descending)
        for tok, cnt in sorted(self.freqs.items(), key=lambda x: -x[1]):
            if cnt >= self.min_freq and tok not in self.stoi:
                self.stoi[tok] = len(self.itos)
                self.itos.append(tok)

    def __len__(self):
        return len(self.itos)

    def numericalize(self, tokens):
        """Convert a list of tokens to a list of indices (unk -> <unk>)."""
        # If a token is not in vocabulary, return the index of <unk>
        unk_idx = self.stoi.get('<unk>')
        return [self.stoi.get(t, unk_idx) for t in tokens]

# Dataset class for text classification
class TextDataset(Dataset):
    def __init__(self, texts, labels, vocab: Vocab = None):
        assert len(texts) == len(labels)
        self.texts = texts
        self.labels = labels
        if vocab is None:
            toks_iter = (simple_tokenize(t) for t in texts)
            self.vocab = Vocab()
            self.vocab.build(toks_iter)
        else:
            self.vocab = vocab

    def __len__(self):
        # number of examples in the dataset
        return len(self.texts)

    def __getitem__(self, idx):
        tokens = simple_tokenize(self.texts[idx]) # tokenize
        nums = self.vocab.numericalize(tokens)    # numericalize (convert tokens to indices)
        label = self.labels[idx]                  # get corresponding label
        return torch.tensor(nums, dtype=torch.long), torch.tensor(label, dtype=torch.long)

# collate_fn for DataLoader: pads sequences and returns lengths
from torch.nn.utils.rnn import pad_sequence

def collate_batch(batch):
    sequences, labels = zip(*batch)
    lengths = torch.tensor([len(s) for s in sequences], dtype=torch.long) # length of each sequence
    padded = pad_sequence(sequences, batch_first=True, padding_value=0) # sequence padding
    labels = torch.stack(labels)
    return padded, lengths, labels

print('Utilities defined')

Utilities defined


## Mini synthetic demo (quick smoke test)

We create a tiny synthetic dataset (7 short sentences) labeled with binary sentiment:
- `1` → positive
- `0` → negative

The dataset is shuffled, split into train/validation and wrapped by `TextDataset` and `DataLoader` using `collate_batch`. This allows us to do a quick end-to-end run to verify the model and training pipeline work.


In [6]:
# Tiny dataset: 7 sentences (labels: 1=positive, 0=negative)
texts = [
    'I loved the movie it was fantastic and fun',      # positive
    'What a terrible movie it was boring and slow',    # negative
    'Amazing plot and great characters',               # positive
    'I hated it worst film ever',                      # negative
    'It was okay not great but not bad',               # positive/neutral
    'An outstanding masterpiece of cinema',            # positive
    'Awful acting and poor script',                    # negative
]
labels = [1, 0, 1, 0, 1, 1, 0]

# Shuffle and split into train (5) / val (2)
data = list(zip(texts, labels))
random.shuffle(data)
train = data[:5]
val = data[5:]

train_texts, train_labels = zip(*train)
val_texts, val_labels = zip(*val)

# Create datasets and loaders
train_ds = TextDataset(list(train_texts), list(train_labels))
val_ds = TextDataset(list(val_texts), list(val_labels), vocab=train_ds.vocab)  # use same vocab as train

train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, collate_fn=collate_batch) # each batch contains two sentences
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, collate_fn=collate_batch)

print('Vocab size (mini-demo):', len(train_ds.vocab))

Vocab size (mini-demo): 30


## Model: Embedding + LSTM classifier

Architecture details:
1. **Embedding layer**: maps token indices to dense vectors (`embed_dim`). We set `padding_idx=0` so `<pad>` embeddings remain zeros.
2. **LSTM**: can be multi-layer and bidirectional. We use `pack_padded_sequence` to efficiently handle padded sequences.
3. **Final hidden state**: for a bidirectional LSTM we concatenate final forward and backward hidden states. For a unidirectional LSTM we use the last hidden state.
4. **Dropout + Linear**: apply dropout and a final `Linear` to produce logits for classes. We return raw logits (use `CrossEntropyLoss`).

The forward method accepts `(padded_sequences, lengths)` and returns `(batch, num_classes)` logits.


In [7]:
class RNNClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim=100, hidden_dim=128, num_layers=1, bidirectional=True,
                 num_classes=2, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embed_dim,
                            hidden_size=hidden_dim,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=bidirectional,
                            dropout=dropout if num_layers > 1 else 0.0)
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1
        self.fc = nn.Linear(hidden_dim * self.num_directions, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, lengths):
        # x: (batch, seq_len)
        emb = self.embedding(x)  # (batch, seq_len, embed_dim)
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, (h_n, c_n) = self.lstm(packed)
        # h_n shape: (num_layers * num_directions, batch, hidden_dim)
        if self.bidirectional:
            # last layer forward = -2, last layer backward = -1
            h_forward = h_n[-2, :, :]
            h_backward = h_n[-1, :, :]
            h = torch.cat((h_forward, h_backward), dim=1)  # (batch, hidden_dim*2)
        else:
            h = h_n[-1, :, :]  # (batch, hidden_dim)
        out = self.dropout(h)
        logits = self.fc(out)
        return logits

print('Model class defined')

Model class defined


## Training and Evaluation utilities

We define `train_epoch` and `evaluate` helper functions that:
- run one training epoch (forward, backward, optimizer step),
- compute and return average loss and accuracy,
- evaluate the model on a validation/test loader without updating weights.


In [11]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    for xb, lengths, yb in dataloader:
        xb, lengths, yb = xb.to(device), lengths.to(device), yb.to(device)
        optimizer.zero_grad() # reset the gradient
        logits = model(xb, lengths) # compute logits
        loss = criterion(logits, yb) # compute loss by comparing logits with targets
        loss.backward() # backpropagation
        optimizer.step() # update weights
        total_loss += loss.item() * xb.size(0) # accumulate loss over the batch
        preds = logits.argmax(dim=1) # predictions: class with the highest probability
        correct += (preds == yb).sum().item() # number of correct predictions
        total += xb.size(0) # total number of examples
    return total_loss / total, correct / total


def evaluate(model, dataloader, criterion, device):
    model.eval() # disable dropout
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad(): # avoids computing gradients
        for xb, lengths, yb in dataloader:
            xb, lengths, yb = xb.to(device), lengths.to(device), yb.to(device)
            logits = model(xb, lengths)
            loss = criterion(logits, yb)
            total_loss += loss.item() * xb.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += xb.size(0)
    return total_loss / total, correct / total

print('Training utilities defined')

Training utilities defined


## Training run: train for a few epochs and save a checkpoint

This cell:
- prepares the device,
- instantiates the model with reduced sizes for a fast demo (`embed_dim=32, hidden_dim=32`),
- trains for a small number of epochs (5 epoch in the mini-demo),
- saves a checkpoint dict with `model_state_dict` and `vocab.itos` so the model can be reconstructed later.


In [12]:
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# Instantiate model (small for demo)
vocab_size = len(train_ds.vocab)
model = RNNClassifier(vocab_size=vocab_size, embed_dim=32, hidden_dim=32, num_layers=1,
                      bidirectional=True, num_classes=2, dropout=0.2)
print(model)
model.to(device)

# Optimizer & loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Quick training (5 epoch for demo)
num_epochs = 5
for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    print(f'Epoch {epoch:02d} | train loss {train_loss:.4f} acc {train_acc:.4f} | val loss {val_loss:.4f} acc {val_acc:.4f}')

# Save checkpoint
ckpt_path = 'rnn_classifier_mini.pth'
torch.save({'model_state_dict': model.state_dict(), 'vocab': train_ds.vocab.itos}, ckpt_path)
print('Model saved to', ckpt_path)

Device: cuda
RNNClassifier(
  (embedding): Embedding(30, 32, padding_idx=0)
  (lstm): LSTM(32, 32, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=64, out_features=2, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)
Epoch 01 | train loss 0.6928 acc 0.6000 | val loss 0.6907 acc 0.5000
Epoch 02 | train loss 0.6391 acc 0.6000 | val loss 0.6961 acc 0.0000
Epoch 03 | train loss 0.6169 acc 0.6000 | val loss 0.7032 acc 0.0000
Epoch 04 | train loss 0.5858 acc 1.0000 | val loss 0.7100 acc 0.0000
Epoch 05 | train loss 0.5613 acc 1.0000 | val loss 0.7176 acc 0.0000
Model saved to rnn_classifier_mini.pth


## Inference: load the checkpoint and classify a new sentence

This cell demonstrates how to:
- load the saved checkpoint,
- reconstruct a `Vocab` from the saved `itos`,
- infer model hyperparameters (embedding dim, hidden dim, bidirectionality) from the saved state dict,
- rebuild the `RNNClassifier`, load weights, and run a prediction on a new sentence.


In [15]:
# Load checkpoint and reconstruct model for inference
from pathlib import Path
ckpt_path = Path('rnn_classifier_mini.pth')
assert ckpt_path.exists(), f'Checkpoint not found: {ckpt_path} (run training cell first)'

ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = ckpt['model_state_dict']
itos = ckpt.get('vocab', None)
assert itos is not None, 'Checkpoint does not contain saved vocab (itos)'

# Rebuild Vocab object
vocab_obj = Vocab()
vocab_obj.itos = list(itos)
vocab_obj.stoi = {tok: i for i, tok in enumerate(vocab_obj.itos)}

# Infer embedding size and vocab size from saved weights
emb_weight = state_dict['embedding.weight']
vocab_size = emb_weight.shape[0]
embed_dim = emb_weight.shape[1]

# Inspect LSTM keys to infer hidden size, num_layers and bidirectionality
lstm_keys = [k for k in state_dict.keys() if k.startswith('lstm.')]
bidirectional = any('weight_ih_l0_reverse' in k or 'weight_hh_l0_reverse' in k for k in lstm_keys)
w_ih_l0 = state_dict.get('lstm.weight_ih_l0')
assert w_ih_l0 is not None, 'Expected lstm.weight_ih_l0 key in state_dict'
hidden_dim = w_ih_l0.shape[0] // 4

layer_indices = set()
import re
for k in lstm_keys:
    m = re.search(r'lstm\.weight_ih_l(\d+)', k)
    if m:
        layer_indices.add(int(m.group(1)))
num_layers = max(layer_indices) + 1 if layer_indices else 1

# Recreate and load model
model = RNNClassifier(vocab_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim,
                      num_layers=num_layers, bidirectional=bidirectional, num_classes=2, dropout=0.2)
model.load_state_dict(state_dict)
model.to(device)
model.eval()

# Single sentence prediction example
sentence = "I really enjoyed the film — it was delightful and fun"
tokens = simple_tokenize(sentence)
indices = vocab_obj.numericalize(tokens)
print('Tokens:', tokens)
print('Indices:', indices)

xb = torch.tensor([indices], dtype=torch.long)
lengths = torch.tensor([len(indices)], dtype=torch.long)
with torch.no_grad():
    logits = model(xb.to(device), lengths.to(device))
    probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
    pred = int(probs.argmax())
    print(f'Predicted label: {pred} (class probs: {probs})')

label_map = {0: 'negative', 1: 'positive'}
print('Human-readable:', label_map.get(pred, str(pred)))

Tokens: ['i', 'really', 'enjoyed', 'the', 'film', '—', 'it', 'was', 'delightful', 'and', 'fun']
Indices: [7, 1, 1, 9, 1, 1, 2, 3, 1, 4, 11]
Predicted label: 1 (class probs: [0.44648206 0.553518  ])
Human-readable: positive
