In [68]:
import math, random, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer

In [69]:
class Cfg:
    model_dim = 256
    n_heads   = 4
    n_layers  = 2
    ff_dim    = 1024
    dropout   = 0.1
    seq_len   = 128
    vocab     = None  # set after tokenizer loads

    pool_multiplier = 4
    sel_k     = 8     # selected subset (k <= M)

    pool_M    = pool_multiplier * sel_k
    
    temp      = 1.0   # router softmax temperature
    lambda_ent = 1e-2 # entropy regularization

    batch_steps = 2000
    lr       = 3e-4
    wd       = 0.1
    device   = "cuda" if torch.cuda.is_available() else "cpu"
    val_bs   = 32

In [70]:
def build_wikitext2(tokenizer, split="train", seq_len=128):
    # Concatenate all text and chunk into fixed-length token sequences
    ds = load_dataset("wikitext", "wikitext-2-raw-v1")[split]
    text = "\n\n".join(ds["text"])
    toks = tokenizer(text, return_tensors=None, add_special_tokens=False)["input_ids"]
    # make length divisible
    n = (len(toks) // (seq_len+1)) * (seq_len+1)
    toks = toks[:n]
    x = torch.tensor(toks, dtype=torch.long).view(-1, seq_len+1)  # [N, L+1]
    # inputs/labels (next-token)
    return x[:, :-1], x[:, 1:]  # [N, L], [N, L]

In [71]:
class TensorDatasetSimple(Dataset):
    def __init__(self, x, y): self.x, self.y = x, y
    def __len__(self): return self.x.size(0)
    def __getitem__(self, i): return self.x[i], self.y[i]

In [72]:
class TinyCausalLM(nn.Module):
    def __init__(self, vocab_size, d=256, n_heads=4, n_layers=2, ff=1024, dropout=0.1, seq_len=128):
        super().__init__()
        self.tok = nn.Embedding(vocab_size, d)
        self.pos = nn.Embedding(seq_len, d)
        layer = nn.TransformerEncoderLayer(d_model=d, nhead=n_heads,
                                           dim_feedforward=ff, dropout=dropout,
                                           batch_first=True, activation="gelu")
        self.enc  = nn.TransformerEncoder(layer, num_layers=n_layers)
        self.ln   = nn.LayerNorm(d)
        self.head = nn.Linear(d, vocab_size, bias=False)
        self.seq_len = seq_len
        self.d = d

    def causal_mask(self, L):
        # (L, L) mask with -inf above diagonal
        m = torch.full((L, L), float("-inf"), device=self.pos.weight.device)
        return torch.triu(m, diagonal=1)

    def forward(self, x):  # x: [B, L]
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)
        h = self.tok(x) + self.pos(pos)
        logits = self.enc(h, mask=self.causal_mask(L))
        logits = self.ln(logits)
        return self.head(logits)  # [B, L, V]

    def nll(self, x, y):
        logits = self.forward(x)
        return F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))

    @torch.no_grad()
    def ppl(self, loader):
        self.eval()
        tot_loss, tot_tok = 0.0, 0
        for x, y in loader:
            x, y = x.to(self.pos.weight.device), y.to(self.pos.weight.device)
            loss = self.nll(x, y)
            tot_loss += loss.item() * y.numel()
            tot_tok  += y.numel()
        self.train()
        return math.exp(tot_loss / max(1, tot_tok))

In [73]:
class AttentionRouter(nn.Module):
    """
    Attention over samples: each sample j has a key k_j from its
    (mean-pooled) token embeddings. A learned query q attends to k_j
    to produce scores; softmax → probabilities over samples.
    """
    def __init__(self, d_in, d_attn=128):
        super().__init__()
        self.Wk = nn.Linear(d_in, d_attn, bias=False)  # keys from sample embeddings
        self.q  = nn.Parameter(torch.randn(d_attn))    # learned global query
        nn.init.normal_(self.q, std=0.02)

    def forward(self, sample_embs, temp=1.0):
        # sample_embs: [M, d_in] (per-sample embedding)
        K = self.Wk(sample_embs)                # [M, d_attn]
        q = self.q / (K.size(-1) ** 0.5)        # scale like dot-attention
        scores = (K @ q)                        # [M]
        probs = F.softmax(scores / temp, dim=0) # [M]
        return probs, scores

In [74]:
def sample_embeddings_from_tokens(lm: TinyCausalLM, X: torch.Tensor):
    """
    Produce per-sample embeddings cheaply:
    mean-pool the token embeddings only (no encoder) to keep it simple/fast.
    X: [M, L] → returns [M, d]
    """
    with torch.no_grad():
        tok = lm.tok(X)           # [M, L, d]
        embs = tok.mean(dim=1)    # [M, d]
    return embs

In [75]:
def train_step(cfg, lm, router, pool_batch, pool_indices, opt):
    device = cfg.device
    X = torch.stack([xy[0] for xy in pool_batch]).to(device)
    Y = torch.stack([xy[1] for xy in pool_batch]).to(device)

    # router
    sample_embs = sample_embeddings_from_tokens(lm, X)
    probs, _ = router(sample_embs, temp=cfg.temp)
    k = min(cfg.sel_k, cfg.pool_M)
    idx = torch.topk(probs, k=k, dim=0).indices        # [k]

    sel_x = X.index_select(0, idx)
    sel_y = Y.index_select(0, idx)
    loss_lm = lm.nll(sel_x, sel_y)

    with torch.no_grad():
        logits = lm(sel_x)
        ce = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            sel_y.reshape(-1),
            reduction="none",
        ).reshape(k, -1)
        per_sample_loss = ce.mean(dim=1)               # [k]
        sel_mean_loss = per_sample_loss.mean().item()

        # small random control inside the same pool (no update)
        ctrl_k = min( max(1, k//2), X.size(0)-k )
        # pick control indices disjoint from selected
        mask = torch.ones(X.size(0), dtype=torch.bool, device=device)
        mask[idx] = False
        ctrl_all = torch.nonzero(mask, as_tuple=False).squeeze(1)
        ctrl_idx = ctrl_all[torch.randperm(ctrl_all.numel(), device=device)[:ctrl_k]]

        ctrl_x = X.index_select(0, ctrl_idx)
        ctrl_y = Y.index_select(0, ctrl_idx)
        ctrl_logits = lm(ctrl_x)
        ctrl_ce = F.cross_entropy(
            ctrl_logits.reshape(-1, ctrl_logits.size(-1)),
            ctrl_y.reshape(-1),
            reduction="none",
        ).reshape(ctrl_k, -1)
        ctrl_mean_loss = ctrl_ce.mean(dim=1).mean().item()

        R = per_sample_loss
        Rn = (R - R.mean()) / (R.std() + 1e-8)

    logp_selected = (probs.index_select(0, idx).clamp_min(1e-12)).log()
    loss_router = -(Rn.detach() * logp_selected).sum()
    loss_ent = cfg.lambda_ent * (-(probs * (probs.clamp_min(1e-12)).log()).sum())
    loss = loss_lm + loss_router + loss_ent

    opt.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(lm.parameters(), 1.0)
    torch.nn.utils.clip_grad_norm_(router.parameters(), 1.0)
    opt.step()

    with torch.no_grad():
        ent = (-(probs * (probs.clamp_min(1e-12)).log()).sum()).item()
        selected_global = [pool_indices[i.item()] for i in idx]

    return {
        "loss_lm": loss_lm.item(),
        "loss_router": loss_router.item(),
        "entropy": ent,
        "sel_mean_loss": sel_mean_loss,
        "ctrl_mean_loss": ctrl_mean_loss,
        "selected_indices": selected_global
    }


In [76]:
cfg = Cfg()

# Tokenizer (use GPT-2 tokenizer for simplicity; set pad to eos)
tok = AutoTokenizer.from_pretrained("gpt2")
tok.pad_token = tok.eos_token

# Build datasets
x_tr, y_tr = build_wikitext2(tok, "train", cfg.seq_len)
x_va, y_va = build_wikitext2(tok, "validation", cfg.seq_len)

cfg.vocab = tok.vocab_size
train_ds = TensorDatasetSimple(x_tr, y_tr)
val_ds   = TensorDatasetSimple(x_va, y_va)

# Simple random sampler for candidate pools
def sample_pool(ds, M):
    idxs = random.sample(range(len(ds)), M)
    batch = [ds[i] for i in idxs]
    return idxs, batch

N = len(train_ds)
select_counts = torch.zeros(N, dtype=torch.long)   # utilization per sample
seen_selected = set()                               # coverage set
last_selected = None                                # for overlap
entropy_track = []
tokens_seen = 0

val_loader = DataLoader(val_ds, batch_size=cfg.val_bs, shuffle=False, drop_last=False)

# Models
lm = TinyCausalLM(vocab_size=cfg.vocab, d=Cfg.model_dim, n_heads=Cfg.n_heads,
                    n_layers=Cfg.n_layers, ff=Cfg.ff_dim, dropout=Cfg.dropout,
                    seq_len=Cfg.seq_len).to(cfg.device)
router = AttentionRouter(d_in=Cfg.model_dim, d_attn=128).to(cfg.device)

# Optimizer (single opt for both)
opt = torch.optim.AdamW(list(lm.parameters()) + list(router.parameters()),
                        lr=cfg.lr, weight_decay=cfg.wd, betas=(0.9, 0.95))

print(f"Device={cfg.device} | WT2 samples: train={len(train_ds)} val={len(val_ds)} | pool M={cfg.pool_M}, k={cfg.sel_k}")

for step in range(1, cfg.batch_steps + 1):
    pool_indices, pool = sample_pool(train_ds, cfg.pool_M)
    stats = train_step(cfg, lm, router, pool, pool_indices, opt)

    # tokens processed
    tokens_seen += cfg.sel_k * cfg.seq_len
    entropy_track.append(stats["entropy"])

    # utilization + coverage
    for gi in stats["selected_indices"]:
        select_counts[gi] += 1
        seen_selected.add(gi)

    # overlap (Jaccard) with previous step
    if last_selected is None:
        overlap = float("nan")
    else:
        a = set(last_selected); b = set(stats["selected_indices"])
        inter = len(a & b); uni = len(a | b)
        overlap = inter / uni if uni > 0 else float("nan")
    last_selected = list(stats["selected_indices"])

    if step % 50 == 0:
        ppl = lm.ppl(val_loader)
        coverage = len(seen_selected) / N
        # simple skew: coefficient of variation of counts over seen_selected
        seen_counts = select_counts[list(seen_selected)].float() if seen_selected else torch.tensor([0.0])
        if len(seen_counts) > 1:
            cv = (seen_counts.std() / (seen_counts.mean() + 1e-8)).item()
        else:
            cv = float("nan")

        print(
            f"[{step:04d}] "
            f"LM={stats['loss_lm']:.3f}  "
            f"Router={stats['loss_router']:.3f}  "
            f"Ent={stats['entropy']:.3f}  "
            f"PPL={ppl:.2f}  "
            f"Tokens={tokens_seen/1e6:.2f}M  "
            f"Overlap={overlap:.2f}  "
            f"Coverage={coverage:.2%}  "
            f"Skew(CV)={cv:.2f}  "
            f"SelLoss={stats['sel_mean_loss']:.3f}  "
            f"CtrlLoss={stats['ctrl_mean_loss']:.3f}"
        )

Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikitext/.no_exist/b08601e04326c79dfdd32d625aee71d232d685c3/wikitext.py'


Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikitext/.no_exist/b08601e04326c79dfdd32d625aee71d232d685c3/.huggingface.yaml'
Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikitext/.no_exist/b08601e04326c79dfdd32d625aee71d232d685c3/dataset_infos.json'
Token indices sequence length is longer than the specified maximum sequence length for this model (2428601 > 1024). Running this sequence through the model will result in indexing errors
Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikitext/.no_exist/b08601e04326c79dfdd32d625aee71d232d685c3/wikitext.py'
Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikit

Device=cuda | WT2 samples: train=18826 val=1946 | pool M=32, k=8
[0050] LM=7.984  Router=-0.009  Ent=3.466  PPL=2339.42  Tokens=0.05M  Overlap=0.00  Coverage=2.07%  Skew(CV)=0.15  SelLoss=7.968  CtrlLoss=7.789
[0100] LM=7.617  Router=-0.019  Ent=3.466  PPL=1455.41  Tokens=0.10M  Overlap=0.00  Coverage=4.08%  Skew(CV)=0.20  SelLoss=7.615  CtrlLoss=7.122
[0150] LM=7.421  Router=-0.022  Ent=3.466  PPL=1314.34  Tokens=0.15M  Overlap=0.00  Coverage=5.91%  Skew(CV)=0.27  SelLoss=7.427  CtrlLoss=7.100
[0200] LM=7.359  Router=-0.019  Ent=3.466  PPL=1207.90  Tokens=0.20M  Overlap=0.00  Coverage=7.63%  Skew(CV)=0.32  SelLoss=7.358  CtrlLoss=6.896
[0250] LM=7.269  Router=-0.062  Ent=3.466  PPL=1138.81  Tokens=0.26M  Overlap=0.00  Coverage=9.33%  Skew(CV)=0.33  SelLoss=7.275  CtrlLoss=6.868
[0300] LM=7.206  Router=-0.039  Ent=3.466  PPL=1096.35  Tokens=0.31M  Overlap=0.00  Coverage=10.92%  Skew(CV)=0.36  SelLoss=7.206  CtrlLoss=7.009
[0350] LM=7.024  Router=0.026  Ent=3.466  PPL=1053.93  Tokens=0.