**complete, modern, end-to-end PyTorch pipeline** for IMDB using the HuggingFace `datasets` library and a classic **LSTM classifier**. 
It includes:

* dataset loading (`datasets`)
* simple tokenizer (the one used earlier)
* building `stoi`/`itos` with `min_freq` + `max_vocab`
* `text_pipeline` / `label_pipeline`
* `collate_fn` (padded batches + lengths)
* `LSTMClassifier` (embedding → LSTM (packed) → classifier)
* training & eval loop that works on **MPS / CUDA / CPU**

In [None]:
#!/usr/bin/env python3
import re
from collections import Counter, OrderedDict
import math
import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from datasets import load_dataset

# -------------------------
# Device
# -------------------------
device = torch.device("mps" if torch.backends.mps.is_available() else
                      ("cuda" if torch.cuda.is_available() else "cpu"))
print("device:", device)

# -------------------------
# 1) Load dataset
# -------------------------
imdb = load_dataset("imdb")
train_hf = imdb["train"]   # dataset object (25k)
test_hf  = imdb["test"]    # dataset object (25k)

# create train/valid split like original example: 20k / 5k
torch.manual_seed(1)
train_list = list(train_hf)
train_list, valid_list = random_split(train_list, [20000, 5000])

# -------------------------
# 2) Tokenizer (same as earlier)
# -------------------------
def tokenizer(text):
    text = re.sub(r'<[^>]*>', '', text)
    emoticons = re.findall(r'(?::|;|=)(?:-)?(?:\)|\(|D|P)', text.lower())
    text = re.sub(r'[\W]+', ' ', text.lower()) + ' ' + ' '.join(emoticons).replace('-', '')
    return text.split()

# -------------------------
# 3) Build vocab (stoi/itos)
#    - min_freq and max_vocab to limit size
# -------------------------
def build_vocab(dataset_iterable, tokenizer, max_vocab=30000, min_freq=2):
    counter = Counter()
    for sample in dataset_iterable:
        tokens = tokenizer(sample["text"])
        counter.update(tokens)
    # keep tokens that appear at least min_freq, and top-k by frequency
    most_common = [t for t, c in counter.most_common(max_vocab) if counter[t] >= min_freq]
    # reserve special tokens
    specials = ["<pad>", "<unk>"]
    itos = specials + most_common
    stoi = {tok: idx for idx, tok in enumerate(itos)}
    return stoi, itos, counter

stoi, itos, counter = build_vocab(train_list, tokenizer, max_vocab=30000, min_freq=2)
print("vocab size:", len(stoi))

# -------------------------
# 4) pipelines
# -------------------------
def text_pipeline(text):
    return [stoi.get(tok, stoi["<unk>"]) for tok in tokenizer(text)]

def label_pipeline(label):
    # label in HF is int 0/1 already, ensure float
    return float(label)

# -------------------------
# 5) collate_fn
# -------------------------
def collate_batch(batch):
    # batch: list of samples (each sample is dict {'text':..., 'label':...})
    texts = []
    labels = []
    lengths = []
    for sample in batch:
        text = sample["text"]
        label = sample["label"]
        ids = torch.tensor(text_pipeline(text), dtype=torch.long)
        texts.append(ids)
        labels.append(label_pipeline(label))
        lengths.append(len(ids))
    lengths = torch.tensor(lengths, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.float32)
    padded = nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=stoi["<pad>"])
    return padded.to(device), labels.to(device), lengths.to(device)

# quick sanity dataloader
small_loader = DataLoader(train_list[:16], batch_size=4, collate_fn=collate_batch)
x, y, l = next(iter(small_loader))
print("sanity shapes:", x.shape, y.shape, l.shape)

# -------------------------
# 6) Model: LSTMClassifier
#    - supports bidirectional, dropout, packed sequences
# -------------------------
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, emb_dim=128, hidden_size=256, num_layers=1,
                 bidirectional=True, dropout=0.3, fc_hidden=128, num_classes=1):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim, padding_idx=stoi["<pad>"])
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        self.lstm = nn.LSTM(input_size=emb_dim,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=bidirectional,
                            dropout=dropout if num_layers > 1 else 0.0)
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(hidden_size * self.num_directions, fc_hidden)
        self.fc_out = nn.Linear(fc_hidden, num_classes)  # num_classes=1 -> single logit
        # initialize
        nn.init.xavier_uniform_(self.embedding.weight)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc_out.weight)

    def forward(self, x, lengths):
        # x: (B, T) token ids
        # lengths: (B,) actual lengths
        emb = self.embedding(x)                 # (B, T, emb_dim)
        # pack padded
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, (h_n, c_n) = self.lstm(packed)
        # h_n shape: (num_layers * num_directions, B, hidden_size)
        # get final hidden state for each direction, concat
        if self.bidirectional:
            # take last layer's forward and backward
            h_forward = h_n[-2, :, :]   # (B, hidden_size)
            h_backward = h_n[-1, :, :]  # (B, hidden_size)
            h_final = torch.cat([h_forward, h_backward], dim=1)  # (B, 2*hidden_size)
        else:
            h_final = h_n[-1, :, :]  # (B, hidden_size)
        x = self.dropout(h_final)
        x = torch.relu(self.fc1(x))
        logits = self.fc_out(x).squeeze(1)  # (B,) for num_classes=1
        return logits

# -------------------------
# 7) Training utilities
# -------------------------
def binary_accuracy_from_logits(logits, targets):
    preds = torch.sigmoid(logits) >= 0.5
    return (preds.float() == targets).float().mean()

def evaluate(model, dataloader, loss_fn):
    model.eval()
    running_loss = 0.0
    running_acc = 0.0
    n = 0
    with torch.no_grad():
        for xb, yb, lb in dataloader:
            logits = model(xb, lb)
            loss = loss_fn(logits, yb)
            acc = binary_accuracy_from_logits(logits, yb)
            batch_size = xb.size(0)
            running_loss += loss.item() * batch_size
            running_acc += acc.item() * batch_size
            n += batch_size
    return running_loss / n, running_acc / n

# -------------------------
# 8) Instantiate model, dataloaders, optimizer
# -------------------------
vocab_size = len(stoi)
model = LSTMClassifier(vocab_size=vocab_size, emb_dim=128, hidden_size=256,
                       num_layers=2, bidirectional=True, dropout=0.3, fc_hidden=128).to(device)

batch_size = 128
train_loader = DataLoader(train_list, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(valid_list, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
test_loader  = DataLoader(test_hf, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1)

# -------------------------
# 9) Training loop (simple, with early stopping)
# -------------------------
epochs = 6
best_val_loss = math.inf
patience = 2
stale = 0

for epoch in range(1, epochs + 1):
    model.train()
    t0 = time.time()
    running_loss = 0.0
    n = 0
    for xb, yb, lb in train_loader:
        optimizer.zero_grad()
        logits = model(xb, lb)
        loss = loss_fn(logits, yb)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        bs = xb.size(0)
        running_loss += loss.item() * bs
        n += bs
    train_loss = running_loss / n
    val_loss, val_acc = evaluate(model, valid_loader, loss_fn)
    scheduler.step(val_loss)

    elapsed = time.time() - t0
    print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} val_loss={val_loss:.4f} val_acc={val_acc:.4f} time={elapsed:.1f}s")

    # early stopping + checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        stale = 0
        torch.save({
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "stoi": stoi,
        }, "best_lstm_imdb.pth")
    else:
        stale += 1
        if stale > patience:
            print("Early stopping")
            break

# -------------------------
# 10) Test evaluation (load best)
# -------------------------
ckpt = torch.load("best_lstm_imdb.pth", map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
test_loss, test_acc = evaluate(model, test_loader, loss_fn)
print(f"Test | loss={test_loss:.4f} acc={test_acc:.4f}")


### Quick notes & rationale (short)

* **Vocabulary size**: limited via `max_vocab` and `min_freq`. Adjust depending on memory.
* **Embedding dim (128)**: good default for sentiment tasks; increase if dataset large or using pretrained vectors.
* **Hidden size (256)** and **num_layers=2**: moderate capacity; increase for more expressivity.
* **Bidirectional LSTM**: captures context from both directions (common for sentence classification).
* **Packed sequences**: `pack_padded_sequence` ensures LSTM ignores padded tokens (efficient & correct).
* **BCEWithLogitsLoss**: single-logit binary classification (more stable than sigmoid + BCELoss).
* **Gradient clipping** & **LR scheduler** added for stability.

---