In [1]:
import os
import math
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset
from transformers import AutoTokenizer
# reproducibility
torch.manual_seed(123)


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.4.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/margotte/Mestrado/GPU/shap-text-gpu/.venv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/home/margotte/Mestrado/GPU/shap-text-gpu/.venv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/home/margotte/Mestrado/GPU/shap-text-gpu/.venv/lib/python3.12/site-packages/ipykernel/ker

<torch._C.Generator at 0x76215c5c3ed0>

In [2]:
# Parameters (script defaults)
DATASET = "imdb"
TOKENIZER = "bert-base-uncased"
MAX_LEN = 64
HIDDEN_DIM = 128
NUM_LAYERS = 3
TRAIN_SAMPLES = 2000
TEST_SAMPLES = 256
BATCH_SIZE = 64
EPOCHS = 50
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Validation / early stopping
VAL_RATIO = 0.1
PATIENCE = 5
# Dropout (applied between layers during training only)
DROPOUT_RATE = 0.3

# Reproduce default MLP from `train_export.py`
This notebook splits parameter definition, imports, dataset prep, tokenization, training, and testing into separate cells.

In [3]:
# Embedding-based MLP: word-level embeddings + mean pooling
import torchtext
from torchtext.data import get_tokenizer
from torchtext.vocab import vocab as torchtext_vocab
from collections import Counter
from torch.nn.utils.rnn import pad_sequence
class EmbeddingMLP(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_classes: int, padding_idx: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.hidden_fc = nn.Linear(embed_dim, hidden_dim)
        self.out_fc = nn.Linear(hidden_dim, num_classes)
        self.relu = nn.ReLU()
        self.padding_idx = padding_idx
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len) of token ids (torch.long)
        emb = self.embedding(x)  # (batch, seq_len, embed_dim)
        mask = (x != self.padding_idx).unsqueeze(-1).to(emb.dtype)  # (batch, seq_len, 1)
        summed = (emb * mask).sum(dim=1)  # (batch, embed_dim)
        lengths = mask.sum(dim=1).clamp(min=1)  # (batch, 1)
        mean_pooled = summed / lengths  # (batch, embed_dim)
        h = self.relu(self.hidden_fc(mean_pooled))
        return self.out_fc(h)

In [4]:
# Dataset loading and subsampling
ds = load_dataset(DATASET)
train_split = ds['train']
test_split = ds['test'] if 'test' in ds else ds['train']
train_n = min(TRAIN_SAMPLES, len(train_split))
test_n = min(TEST_SAMPLES, len(test_split))
def stratified_select(dataset, n, label_field='label', seed=123):
    labels = list(dataset[label_field])
    total = len(labels)
    if n >= total:
        return dataset.shuffle(seed=seed)
    classes = sorted(set(labels))
    counts = {c: labels.count(c) for c in classes}
    # Allocate per-class counts proportionally, with at least 1 when possible
    alloc = {}
    remaining = n
    for c in classes[:-1]:
        k = max(1, int(round(counts[c] / total * n)))
        k = min(k, counts[c])
        alloc[c] = k
        remaining -= k
    last = classes[-1]
    alloc[last] = min(counts[last], max(0, remaining))
    import random
    random.seed(seed)
    selected = []
    for c in classes:
        idxs = [i for i, lab in enumerate(labels) if lab == c]
        random.shuffle(idxs)
        take = alloc.get(c, 0)
        selected.extend(idxs[:take])
    # If short due to rounding, fill from remaining indices
    if len(selected) < n:
        remaining_idxs = [i for i in range(total) if i not in selected]
        random.shuffle(remaining_idxs)
        selected.extend(remaining_idxs[:(n - len(selected))])
    random.shuffle(selected)
    return dataset.select(selected)
# Apply stratified subsampling to train and test splits
train_split = stratified_select(train_split, train_n, seed=123)
test_split = stratified_select(test_split, test_n, seed=123)



In [5]:
display(train_split['text'][:5])

['This two-parter was excellent - the best since the series returned. Sure bits of the story were pinched from previous films, but what TV shows don\'t do that these days. What we got here was a cracking good sci-fi story. A great big (really scary) monster imprisoned at the base of a deep pit, some superb aliens in The Ood - the best "new" aliens the revived series has come up with, a set of basically sympathetic and believable human characters (complete with a couple of unnamed "expendable" security people in true Star Trek fashion), some large-scale philosophical themes (love, loyalty, faith, etc.), and some top-drawer special effects.<br /><br />I loved every minute of this.',
 "Stan Laurel and Oliver Hardy had extensive (separate) film careers before they were eventually teamed. For many of Ollie's pre-Stan films, he was billed on screen as Babe Hardy ... and throughout his adult life, Hardy was known to his friends as 'Babe'. While touring postwar Britain with Laurel in a music-h

In [6]:
# Tokenization and tensor conversion using torchtext (word-level)
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import vocab as torchtext_vocab
from collections import Counter
from torch.nn.utils.rnn import pad_sequence
tokenizer = get_tokenizer("basic_english")
# Build vocabulary from training texts
counter = Counter()
for txt in train_split['text']:
    counter.update(tokenizer(txt))
specials = ["<pad>", "<unk>"]
vocab = torchtext_vocab(counter, specials=specials)
vocab.set_default_index(vocab['<unk>'])
pad_idx = vocab['<pad>']
unk_idx = vocab['<unk>']
vocab_size = len(vocab)
def tok_batch(batch):
    ids_batch = []
    for text in batch['text']:
        toks = tokenizer(text)
        ids = [vocab[t] for t in toks][:MAX_LEN]  # truncate to MAX_LEN
        ids_batch.append(ids)
    return {'input_ids': ids_batch}
train_tok = train_split.map(tok_batch, batched=True, remove_columns=[c for c in train_split.column_names if c != 'label'])
test_tok = test_split.map(tok_batch, batched=True, remove_columns=[c for c in test_split.column_names if c != 'label'])
def pad_id_lists(id_lists, max_len=MAX_LEN):
    tensors = [torch.tensor(ids, dtype=torch.long)[:max_len] for ids in id_lists]
    if len(tensors) == 0:
        return torch.empty((0, max_len), dtype=torch.long)
    padded = pad_sequence(tensors, batch_first=True, padding_value=pad_idx)
    if padded.size(1) < max_len:
        padded = torch.nn.functional.pad(padded, (0, max_len - padded.size(1)), value=pad_idx)
    else:
        padded = padded[:, :max_len]
    return padded
train_ids = pad_id_lists(train_tok['input_ids'], MAX_LEN)
train_y = torch.tensor(train_tok['label'], dtype=torch.int64)
test_ids = pad_id_lists(test_tok['input_ids'], MAX_LEN)
test_y = torch.tensor(test_tok['label'], dtype=torch.int64)
num_classes = int(train_y.max().item() + 1)
input_dim = int(train_ids.shape[1])  # seq_len (MAX_LEN)

Map: 100%|██████████| 2000/2000 [00:00<00:00, 6200.28 examples/s]
Map: 100%|██████████| 256/256 [00:00<00:00, 3063.84 examples/s]


In [7]:
# Training with validation and early stopping
model = EmbeddingMLP(vocab_size=vocab_size, embed_dim=HIDDEN_DIM, hidden_dim=HIDDEN_DIM, num_classes=num_classes, padding_idx=pad_idx).to(DEVICE)
# Create train/val split from train_ids/train_y
n_train_total = train_ids.shape[0]
val_n = int(n_train_total * VAL_RATIO)
if val_n > 0:
    # Stratified split to balance classes in train/val
    classes = torch.unique(train_y)
    train_idx_list = []
    val_idx_list = []
    for c in classes:
        idx_c = (train_y == c).nonzero(as_tuple=True)[0]
        idx_c = idx_c[torch.randperm(idx_c.size(0))]
        n_val_c = max(1, int(idx_c.size(0) * VAL_RATIO))
        val_idx_list.append(idx_c[:n_val_c])
        train_idx_list.append(idx_c[n_val_c:])
    val_idx = torch.cat(val_idx_list) if len(val_idx_list) > 0 else torch.tensor([], dtype=torch.long)
    train_idx = torch.cat(train_idx_list) if len(train_idx_list) > 0 else torch.tensor([], dtype=torch.long)
    # Shuffle final indices
    train_idx = train_idx[torch.randperm(train_idx.size(0))] if train_idx.numel() > 0 else train_idx
    val_idx = val_idx[torch.randperm(val_idx.size(0))] if val_idx.numel() > 0 else val_idx
    train_ds = TensorDataset(train_ids[train_idx], train_y[train_idx])
    val_ds = TensorDataset(train_ids[val_idx], train_y[val_idx])
else:
    train_ds = TensorDataset(train_ids, train_y)
    val_ds = None
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False) if val_ds is not None else None
test_loader = DataLoader(TensorDataset(test_ids, test_y), batch_size=BATCH_SIZE, shuffle=False)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
best_val_acc = -1.0
best_state = None
epochs_no_improve = 0
model.train()
for epoch in range(1, EPOCHS+1):
    running = 0.0
    seen = 0
    for xb, yb in train_loader:
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)
        optim.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = loss_fn(logits, yb)
        loss.backward()
        optim.step()
        running += float(loss.item()) * int(yb.numel())
        seen += int(yb.numel())
    # Validation
    val_acc = None
    if val_loader is not None:
        model.eval()
        total_v = 0
        correct_v = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(DEVICE)
                yb = yb.to(DEVICE)
                logits = model(xb)
                pred = logits.argmax(dim=-1)
                correct_v += int((pred == yb).sum().item())
                total_v += int(yb.numel())
        val_acc = correct_v / max(total_v, 1)
    train_loss = running / max(seen, 1)
    # Check improvement on validation (or test if no val set)
    monitor_acc = val_acc if val_acc is not None else 0.0
    improved = (monitor_acc > best_val_acc)
    if improved:
        best_val_acc = monitor_acc
        # save CPU copy of state dict
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
    print(f'epoch={epoch} train_loss={train_loss:.4f} val_acc={val_acc if val_acc is not None else "N/A"}')
    model.train()
    if epochs_no_improve >= PATIENCE:
        print(f'Early stopping after {epoch} epochs (no improvement for {PATIENCE} epochs)')
        break
# restore best model state if available
if best_state is not None:
    model.load_state_dict(best_state)

epoch=1 train_loss=0.6878 val_acc=0.62
epoch=2 train_loss=0.6650 val_acc=0.605
epoch=3 train_loss=0.6244 val_acc=0.68
epoch=4 train_loss=0.5633 val_acc=0.715
epoch=5 train_loss=0.4850 val_acc=0.75
epoch=6 train_loss=0.3980 val_acc=0.705
epoch=7 train_loss=0.3121 val_acc=0.7
epoch=8 train_loss=0.2435 val_acc=0.715
epoch=9 train_loss=0.1788 val_acc=0.715
epoch=10 train_loss=0.1284 val_acc=0.725
Early stopping after 10 epochs (no improvement for 5 epochs)


In [8]:
# Final testing/evaluation
model.eval()
with torch.no_grad():
    total = 0
    correct = 0
    for xb, yb in test_loader:
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(dim=-1)
        correct += int((pred == yb).sum().item())
        total += int(yb.numel())
    print(f'final_test_acc={correct/max(total,1):.4f}')

final_test_acc=0.6523


In [9]:
# Expanded hyperparameter/configuration sweep (sampled from full grid)
import itertools, random, json, time
# parameter grids
lrs = [1e-2, 5e-3, 1e-3, 5e-4]
wds = [0.0, 1e-4, 1e-3]
dropouts = [0.0, 0.2, 0.3, 0.5]
num_layers_list = [1, 2, 3, 4]
hidden_dims = [64, 128, 256]
max_lens = [32, 64, 128]
padding_modes = [True, 'max_length']
# sampling configuration
FULL_GRID = list(itertools.product(lrs, wds, dropouts, num_layers_list, hidden_dims, max_lens, padding_modes))
random.seed(123)
random.shuffle(FULL_GRID)
N_CONFIGS = min(20, len(FULL_GRID))  # pick a sampled subset to keep runtime reasonable
SELECTED = FULL_GRID[:N_CONFIGS]
EPOCHS_SWEEP = 20
PATIENCE_SWEEP = 5
results = []

def prepare_tensors_for_config(max_len, padding_mode):
    # re-tokenize train/test splits for this config (word-level using existing vocab)
    def tok_batch_local(batch):
        ids_batch = []
        for text in batch['text']:
            toks = tokenizer(text)
            ids = [vocab[t] for t in toks][:max_len]
            ids_batch.append(ids)
        return {'input_ids': ids_batch}
    tr_tok = train_split.map(tok_batch_local, batched=True, remove_columns=[c for c in train_split.column_names if c != 'label'])
    te_tok = test_split.map(tok_batch_local, batched=True, remove_columns=[c for c in test_split.column_names if c != 'label'])
    def pad_id_lists_local(id_lists, max_len=max_len):
        tensors = [torch.tensor(ids, dtype=torch.long)[:max_len] for ids in id_lists]
        if len(tensors) == 0:
            return torch.empty((0, max_len), dtype=torch.long)
        padded = pad_sequence(tensors, batch_first=True, padding_value=pad_idx)
        if padded.size(1) < max_len:
            padded = torch.nn.functional.pad(padded, (0, max_len - padded.size(1)), value=pad_idx)
        else:
            padded = padded[:, :max_len]
        return padded
    tr_ids = pad_id_lists_local(tr_tok['input_ids'])
    te_ids = pad_id_lists_local(te_tok['input_ids'])
    tr_y = torch.tensor(tr_tok['label'], dtype=torch.int64)
    te_y = torch.tensor(te_tok['label'], dtype=torch.int64)
    return tr_ids, tr_y, te_ids, te_y

start_time = time.time()
for (lr, wd, drop, nlayers, hdim, mlen, pad_mode) in SELECTED:
    torch.manual_seed(123)
    # prepare tensors for this config
    tr_x_cfg, tr_y_cfg, te_x_cfg, te_y_cfg = prepare_tensors_for_config(mlen, pad_mode)
    # stratified train/val split (same logic as above)
    n_train_total = tr_x_cfg.shape[0]
    val_n = int(n_train_total * VAL_RATIO)
    if val_n > 0:
        classes = torch.unique(tr_y_cfg)
        train_idx_list = []
        val_idx_list = []
        for c in classes:
            idx_c = (tr_y_cfg == c).nonzero(as_tuple=True)[0]
            idx_c = idx_c[torch.randperm(idx_c.size(0))]
            n_val_c = max(1, int(idx_c.size(0) * VAL_RATIO))
            val_idx_list.append(idx_c[:n_val_c])
            train_idx_list.append(idx_c[n_val_c:])
        val_idx = torch.cat(val_idx_list) if len(val_idx_list) > 0 else torch.tensor([], dtype=torch.long)
        train_idx = torch.cat(train_idx_list) if len(train_idx_list) > 0 else torch.tensor([], dtype=torch.long)
        train_idx = train_idx[torch.randperm(train_idx.size(0))] if train_idx.numel() > 0 else train_idx
        val_idx = val_idx[torch.randperm(val_idx.size(0))] if val_idx.numel() > 0 else val_idx
        train_ds = TensorDataset(tr_x_cfg[train_idx], tr_y_cfg[train_idx])
        val_ds = TensorDataset(tr_x_cfg[val_idx], tr_y_cfg[val_idx])
    else:
        train_ds = TensorDataset(tr_x_cfg, tr_y_cfg)
        val_ds = None
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False) if val_ds is not None else None
    # build model for config (embedding-based)
    model_s = EmbeddingMLP(vocab_size=vocab_size, embed_dim=hdim, hidden_dim=hdim, num_classes=num_classes, padding_idx=pad_idx).to(DEVICE)
    optim_s = torch.optim.Adam(model_s.parameters(), lr=lr, weight_decay=wd)
    loss_fn_s = nn.CrossEntropyLoss()
    best_val = -1.0
    best_state = None
    epochs_no_improve = 0
    for epoch in range(1, EPOCHS_SWEEP+1):
        model_s.train()
        running = 0.0
        seen = 0
        for xb, yb in train_loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            optim_s.zero_grad(set_to_none=True)
            logits = model_s(xb)
            loss = loss_fn_s(logits, yb)
            loss.backward()
            optim_s.step()
            running += float(loss.item()) * int(yb.numel())
            seen += int(yb.numel())
        # validation
        val_acc = None
        if val_loader is not None:
            model_s.eval()
            total_v = 0
            correct_v = 0
            with torch.no_grad():
                for xb, yb in val_loader:
                    xb = xb.to(DEVICE)
                    yb = yb.to(DEVICE)
                    pred = model_s(xb).argmax(dim=-1)
                    correct_v += int((pred == yb).sum().item())
                    total_v += int(yb.numel())
            val_acc = correct_v / max(total_v, 1)
        train_loss = running / max(seen, 1)
        monitor_acc = val_acc if val_acc is not None else 0.0
        if monitor_acc > best_val:
            best_val = monitor_acc
            best_state = {k: v.cpu().clone() for k, v in model_s.state_dict().items()}
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE_SWEEP:
            break
    if best_state is not None:
        model_s.load_state_dict(best_state)
    # test evaluation
    model_s.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for xb, yb in DataLoader(TensorDataset(te_x_cfg, te_y_cfg), batch_size=BATCH_SIZE):
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            pred = model_s(xb).argmax(dim=-1)
            correct += int((pred == yb).sum().item())
            total += int(yb.numel())
    test_acc = correct / max(total, 1)
    elapsed = time.time() - start_time
    cfg = dict(lr=lr, weight_decay=wd, dropout=drop, num_layers=nlayers, hidden_dim=hdim, max_len=mlen, padding=pad_mode)
    results.append({'config': cfg, 'best_val': best_val, 'test_acc': test_acc, 'elapsed_s': elapsed})
    print(f"config={cfg} best_val={best_val:.4f} test_acc={test_acc:.4f} elapsed={elapsed:.1f}s")
# summarize top results by test accuracy
results_sorted = sorted(results, key=lambda x: x['test_acc'], reverse=True)
print('\nTop results:')
for r in results_sorted[:5]:
    print(r)
# persist results
with open('sweep_results.json', 'w') as fh:
    json.dump(results_sorted, fh, indent=2)
print('Saved sweep_results.json')

Map: 100%|██████████| 2000/2000 [00:00<00:00, 6327.46 examples/s]
Map: 100%|██████████| 256/256 [00:00<00:00, 5978.45 examples/s]


config={'lr': 0.001, 'weight_decay': 0.0, 'dropout': 0.3, 'num_layers': 3, 'hidden_dim': 64, 'max_len': 64, 'padding': True} best_val=0.6800 test_acc=0.7305 elapsed=3.7s
config={'lr': 0.0005, 'weight_decay': 0.001, 'dropout': 0.0, 'num_layers': 2, 'hidden_dim': 128, 'max_len': 64, 'padding': True} best_val=0.6050 test_acc=0.6328 elapsed=10.4s


Map: 100%|██████████| 2000/2000 [00:00<00:00, 6549.90 examples/s]
Map: 100%|██████████| 256/256 [00:00<00:00, 5965.13 examples/s]


config={'lr': 0.0005, 'weight_decay': 0.001, 'dropout': 0.5, 'num_layers': 1, 'hidden_dim': 128, 'max_len': 128, 'padding': True} best_val=0.6600 test_acc=0.6680 elapsed=17.9s


Map: 100%|██████████| 2000/2000 [00:00<00:00, 6731.27 examples/s]
Map: 100%|██████████| 256/256 [00:00<00:00, 5994.54 examples/s]


config={'lr': 0.001, 'weight_decay': 0.001, 'dropout': 0.3, 'num_layers': 1, 'hidden_dim': 128, 'max_len': 32, 'padding': True} best_val=0.6700 test_acc=0.6211 elapsed=24.1s
config={'lr': 0.001, 'weight_decay': 0.001, 'dropout': 0.0, 'num_layers': 2, 'hidden_dim': 256, 'max_len': 128, 'padding': 'max_length'} best_val=0.6800 test_acc=0.7344 elapsed=38.3s
config={'lr': 0.001, 'weight_decay': 0.0, 'dropout': 0.5, 'num_layers': 4, 'hidden_dim': 64, 'max_len': 128, 'padding': True} best_val=0.7800 test_acc=0.7773 elapsed=41.9s
config={'lr': 0.001, 'weight_decay': 0.0001, 'dropout': 0.2, 'num_layers': 3, 'hidden_dim': 256, 'max_len': 64, 'padding': 'max_length'} best_val=0.6700 test_acc=0.6758 elapsed=59.6s
config={'lr': 0.01, 'weight_decay': 0.0001, 'dropout': 0.0, 'num_layers': 2, 'hidden_dim': 256, 'max_len': 64, 'padding': True} best_val=0.7600 test_acc=0.7031 elapsed=71.2s
config={'lr': 0.0005, 'weight_decay': 0.0001, 'dropout': 0.3, 'num_layers': 3, 'hidden_dim': 256, 'max_len': 128, 

## Approximate experiment: collapse embeddings to a single scalar per token
This section implements the approximate approach (collapse each token's embedding to a single scalar) so the model can be treated as a scalar-token MLP.
It trains a small model that learns a projection from embedding -> scalar per token, then feeds the resulting per-token scalars (shape `seq_len`) to a standard MLP.
This lets you compare performance quickly while keeping compatibility with the existing `feedforward.cu` input expectations (one scalar per token).

In [10]:
# Collapsed-embedding model: learnable projection from embed_dim -> 1 per token, then MLP on seq_len scalars
import torch.nn.functional as F
class CollapsedEmbeddingMLP(nn.Module):
    def __init__(self, vocab_size, embed_dim, seq_len, hidden_dim, num_classes, padding_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        # project each token embedding to a single scalar (learned)
        self.proj = nn.Linear(embed_dim, 1, bias=True)
        # MLP that consumes seq_len scalar features
        self.hidden_fc = nn.Linear(seq_len, hidden_dim)
        self.out_fc = nn.Linear(hidden_dim, num_classes)
        self.relu = nn.ReLU()
        self.padding_idx = padding_idx
    def forward(self, x):
        # x: (batch, seq_len) long token ids
        emb = self.embedding(x)  # (batch, seq_len, embed_dim)
        mask = (x != self.padding_idx).unsqueeze(-1).to(emb.dtype)  # (batch, seq_len, 1)
        emb = emb * mask  # zero-out pad embeddings
        scalars = self.proj(emb).squeeze(-1)  # (batch, seq_len)
        # ensure pad positions are zero
        scalars = scalars * mask.squeeze(-1)
        h = self.relu(self.hidden_fc(scalars))
        return self.out_fc(h)

# Instantiate model (uses previously-defined variables: vocab_size, HIDDEN_DIM, input_dim, num_classes, pad_idx, DEVICE)
seq_len = input_dim
model_collapsed = CollapsedEmbeddingMLP(vocab_size=vocab_size, embed_dim=HIDDEN_DIM, seq_len=seq_len, hidden_dim=HIDDEN_DIM, num_classes=num_classes, padding_idx=pad_idx).to(DEVICE)
# Prepare simple dataloaders (use full train_ids/train_y and test_ids/test_y)
train_ds_coll = TensorDataset(train_ids, train_y)
train_loader_coll = DataLoader(train_ds_coll, batch_size=BATCH_SIZE, shuffle=True)
test_loader_coll = DataLoader(TensorDataset(test_ids, test_y), batch_size=BATCH_SIZE, shuffle=False)
optim_c = torch.optim.Adam(model_collapsed.parameters(), lr=1e-3)
loss_fn_c = nn.CrossEntropyLoss()
# Quick training loop (small number of epochs for experiment)
EPOCHS_COLLAPSE = 10
for epoch in range(1, EPOCHS_COLLAPSE+1):
    model_collapsed.train()
    running = 0.0
    seen = 0
    for xb, yb in train_loader_coll:
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)
        optim_c.zero_grad(set_to_none=True)
        logits = model_collapsed(xb)
        loss = loss_fn_c(logits, yb)
        loss.backward()
        optim_c.step()
        running += float(loss.item()) * int(yb.numel())
        seen += int(yb.numel())
    train_loss = running / max(seen, 1)
    # quick val on test set for monitoring
    model_collapsed.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for xb, yb in test_loader_coll:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            pred = model_collapsed(xb).argmax(dim=-1)
            correct += int((pred == yb).sum().item())
            total += int(yb.numel())
    test_acc = correct / max(total, 1)
    print(f'epoch={epoch} train_loss={train_loss:.4f} test_acc={test_acc:.4f}')

# Final test evaluation
model_collapsed.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for xb, yb in test_loader_coll:
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)
        pred = model_collapsed(xb).argmax(dim=-1)
        correct += int((pred == yb).sum().item())
        total += int(yb.numel())
    print(f'final_collapsed_test_acc={correct/max(total,1):.4f}')

epoch=1 train_loss=0.6942 test_acc=0.4883
epoch=2 train_loss=0.6678 test_acc=0.4570
epoch=3 train_loss=0.6334 test_acc=0.5078
epoch=4 train_loss=0.5779 test_acc=0.5039
epoch=5 train_loss=0.5126 test_acc=0.5000
epoch=6 train_loss=0.4345 test_acc=0.5156
epoch=7 train_loss=0.3571 test_acc=0.5156
epoch=8 train_loss=0.2771 test_acc=0.5195
epoch=9 train_loss=0.2055 test_acc=0.5273
epoch=10 train_loss=0.1436 test_acc=0.5430
final_collapsed_test_acc=0.5430
