In [20]:
import math, torch, random
from torch import nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import GPT2TokenizerFast

In [21]:
random.seed(0); torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [22]:
BLOCK = 256          # seq len
BATCH = 16           # final selected batch size (k)
POOL_MULT = 4        # candidate pool size multiplier -> M = POOL_MULT * BATCH
EPOCHS = 3*5
LR_LM = 3e-4
LR_ROUTER = 1e-3
TEMP = 1.0
LAMBDA_ENT = 1e-3    # entropy reg for router
LAMBDA_ROUTER = 0.1  # weight for router loss

tok = GPT2TokenizerFast.from_pretrained("gpt2")
if tok.pad_token is None: tok.pad_token = tok.eos_token

In [23]:
def make_chunks(split):
    ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
    text = tok.eos_token.join(ds["text"])
    ids = tok(text, add_special_tokens=False)["input_ids"]
    L = (len(ids) // (BLOCK + 1)) * (BLOCK + 1)
    ids = ids[:L]
    chunks = [ids[i:i+BLOCK+1] for i in range(0, L, BLOCK+1)]
    return chunks

In [24]:
class LMDataset(Dataset):
    def __init__(self, chunks):
        self.x = [torch.tensor(c[:-1], dtype=torch.long) for c in chunks]
        self.y = [torch.tensor(c[1:],  dtype=torch.long) for c in chunks]
    def __len__(self): return len(self.x)
    def __getitem__(self, i): return self.x[i], self.y[i]

In [25]:
train_chunks = make_chunks("train")
val_chunks   = make_chunks("validation")
train_ds = LMDataset(train_chunks)
val_ds   = LMDataset(val_chunks)

Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikitext/.no_exist/b08601e04326c79dfdd32d625aee71d232d685c3/wikitext.py'
Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikitext/.no_exist/b08601e04326c79dfdd32d625aee71d232d685c3/.huggingface.yaml'
Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikitext/.no_exist/b08601e04326c79dfdd32d625aee71d232d685c3/dataset_infos.json'
Token indices sequence length is longer than the specified maximum sequence length for this model (2428601 > 1024). Running this sequence through the model will result in indexing errors
Could not cache non-existence of file. Will ignore error and continue. Error: [Errno 13] Permission denied: '/mloscratch/hf_cache/hub/datasets--wikit

In [26]:
def make_index_loader(ds_len, pool_size):
    order = list(range(ds_len))
    random.shuffle(order)
    for i in range(0, ds_len, pool_size):
        yield order[i:i+pool_size]

In [27]:
vocab_size = len(tok)

In [28]:
class TinyGPT(nn.Module):
    def __init__(self, vocab, d_model=256, n_layers=4, n_heads=8, d_ff=1024, block=BLOCK):
        super().__init__()
        self.block = block
        self.tok_emb = nn.Embedding(vocab, d_model)
        self.pos_emb = nn.Embedding(block, d_model)
        enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, batch_first=True)
        self.tr = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.lm_head = nn.Linear(d_model, vocab, bias=False)
        # tie weights
        self.lm_head.weight = self.tok_emb.weight
        
    def _causal_mask(self, L):
        # [L, L] upper-triangular mask True where we want to mask (future tokens)
        m = torch.ones(L, L, dtype=torch.bool, device=self.lm_head.weight.device).triu(1)
        return m

    def forward(self, x):
        B, L = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)
        h = self.tok_emb(x) + self.pos_emb(pos)
        mask = self._causal_mask(L)
        h = self.tr(h, mask=mask)
        return self.lm_head(h)

In [29]:
class MLPRouter(nn.Module):
    def __init__(self, d_model=256, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.Tanh(),
            nn.Linear(hidden, 1)
        )
    def forward(self, feats):            # feats: [M, d_model]
        return self.net(feats).squeeze(-1)  # [M]

In [30]:
class MultiHeadAttentionRouter(nn.Module):
    def __init__(self, d_model=256, d_k=64, n_heads=2):
        super().__init__()
        self.n = n_heads
        self.K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.q = nn.Parameter(torch.randn(n_heads, d_k) / math.sqrt(d_k))
        self.out = nn.Linear(n_heads, 1, bias=False)  # combine head scores

    def forward(self, feats):                      # feats: [M, d_model]
        M = feats.size(0)
        K = self.K(feats).view(M, self.n, -1)      # [M, H, d_k]
        # per-head scores: [M, H]
        scores_h = (K * self.q).sum(dim=-1) / math.sqrt(K.size(-1))
        # combine heads linearly â†’ [M]
        scores = self.out(scores_h).squeeze(-1)
        return scores


In [31]:
model = TinyGPT(vocab_size).to(device)
# router = MLPRouter(d_model=model.tok_emb.embedding_dim).to(device)
router = MultiHeadAttentionRouter(d_model=model.tok_emb.embedding_dim, d_k=64, n_heads=2).to(device)
opt_lm = torch.optim.AdamW(model.parameters(), lr=LR_LM)
opt_router = torch.optim.AdamW(router.parameters(), lr=LR_ROUTER)
loss_fn = nn.CrossEntropyLoss()

In [32]:
@torch.no_grad()
def evaluate():
    model.eval()
    loader = DataLoader(val_ds, batch_size=BATCH, shuffle=False)
    losses = []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
        losses.append(loss.item())
    m = sum(losses)/len(losses)
    ppl = math.exp(min(20.0, m))
    return m, ppl

In [33]:
POOL = POOL_MULT * BATCH


def train_reinforce():
    for epoch in range(1, EPOCHS+1):
        model.train(); router.train()
        idx_loader = make_index_loader(len(train_ds), POOL)
        step = 0
        for pool_indices in idx_loader:
            if len(pool_indices) < BATCH:   # small tail at end of epoch
                continue
            step += 1

            # 1) Load candidate pool to CPU tensors
            xs, ys = zip(*[train_ds[i] for i in pool_indices])  # tuples of [L]
            X = torch.stack(xs, 0).to(device)  # [M, L]
            Y = torch.stack(ys, 0).to(device)  # [M, L]
            M = X.size(0)

            # 2) Cheap features: mean token embedding per sample (no grad)
            with torch.no_grad():
                emb = model.tok_emb(X)             # [M, L, d]
                feats = emb.mean(dim=1)            # [M, d]

            # 3) Router scoring + softmax probs
            scores = router(feats)                 # [M]
            probs = torch.softmax(scores / TEMP, dim=0)  # [M]

            # 4) Hard select top-k
            k = BATCH
            topk_idx = torch.topk(probs, k=k, dim=0).indices  # [k]
            sel_mask = torch.zeros(M, device=device)
            sel_mask[topk_idx] = 1.0

            # 5) Build selected batch
            X_sel = X[topk_idx]   # [k, L]
            Y_sel = Y[topk_idx]   # [k, L]

            # 6) LM forward on selected only
            logits = model(X_sel)                                 # [k, L, V]
            loss_lm = loss_fn(logits.reshape(-1, logits.size(-1)), Y_sel.reshape(-1))

            # 7) Router loss (REINFORCE-style on selected only) + entropy regularizer
            #    reward ~ higher when LM loss is high -> encourages selecting informative samples
            with torch.no_grad():
                per_sample_loss = nn.functional.cross_entropy(
                    logits.detach().reshape(-1, logits.size(-1)),
                    Y_sel.reshape(-1),
                    reduction='none'
                ).reshape(k, -1).mean(dim=1)  # [k]
                baseline = per_sample_loss.mean()

            # map selected indices back to their probs
            sel_probs = probs[topk_idx].clamp_min(1e-12)  # [k]
            reinforce = - ((per_sample_loss - baseline) * torch.log(sel_probs)).mean()

            # entropy over full pool
            ent = (probs * torch.log(probs.clamp_min(1e-12))).sum()

            loss_router = reinforce + LAMBDA_ENT * ent

            # 8) Combined step (separate opts for clarity)
            opt_lm.zero_grad(set_to_none=True)
            opt_router.zero_grad(set_to_none=True)
            (loss_lm + LAMBDA_ROUTER * loss_router).backward()
            opt_lm.step()
            opt_router.step()

            if step % 100 == 0:
                print(f"epoch {epoch} step {step}: "
                    f"LM={loss_lm.item():.4f}  "
                    f"Router(reinf)={reinforce.item():.4f}  "
                    f"H(p)={(-ent).item():.3f}")

        val_loss, val_ppl = evaluate()
        print(f"==> epoch {epoch}: val_loss={val_loss:.4f}  val_ppl={val_ppl:.2f}")

In [34]:
def train_ST():
    for epoch in range(1, EPOCHS+1):
        model.train(); router.train()
        idx_loader = make_index_loader(len(train_ds), POOL)
        step = 0
        for pool_indices in idx_loader:
            if len(pool_indices) < BATCH:   # small tail at end of epoch
                continue
            step += 1

            # 1) Load candidate pool to CPU tensors
            xs, ys = zip(*[train_ds[i] for i in pool_indices])  # tuples of [L]
            X = torch.stack(xs, 0).to(device)  # [M, L]
            Y = torch.stack(ys, 0).to(device)  # [M, L]
            M = X.size(0)

            # 2) Cheap features: mean token embedding per sample (no grad)
            with torch.no_grad():
                emb = model.tok_emb(X)             # [M, L, d]
                feats = emb.mean(dim=1)            # [M, d]

            # 3) Router scoring + softmax probs
            scores = router(feats)                 # [M]
            probs = torch.softmax(scores / TEMP, dim=0)  # [M]
            logp   = torch.log(probs.clamp_min(1e-12))


            # 4) Hard select top-k
            k = BATCH
            topk_idx = torch.topk(probs, k=k, dim=0).indices  # [k]
            sel_mask_hard = torch.zeros(M, device=device)
            sel_mask_hard[topk_idx] = 1.0
            
            sel_mask_st = (sel_mask_hard - probs).detach() + probs    # [M]

            # 5) Build selected batch
            X_sel = X[topk_idx]   # [k, L]
            Y_sel = Y[topk_idx]   # [k, L]

            # 6) LM forward on selected only
            logits = model(X_sel)                                 # [k, L, V]
            loss_lm = loss_fn(logits.reshape(-1, logits.size(-1)), Y_sel.reshape(-1))

            # 7) Router loss (ST surrogate) + entropy reg
            ce_align = -(sel_mask_st * logp).sum() / k         # cross-entropy to one-hot (averaged)
            ent = (probs * logp).sum()                         # negative entropy (to penalize collapse)

            loss_router = ce_align - LAMBDA_ENT * ent

            # 8) Combined step (separate opts for clarity)
            opt_lm.zero_grad(set_to_none=True)
            opt_router.zero_grad(set_to_none=True)
            (loss_lm + LAMBDA_ROUTER * loss_router).backward()
            opt_lm.step()
            opt_router.step()

            if step % 100 == 0:
                print(f"epoch {epoch} step {step}: "
                    f"LM={loss_lm.item():.4f}  "
                    f"Router(ce_align)={ce_align.item():.4f}  "
                    f"H(p)={(-ent).item():.3f}")

        val_loss, val_ppl = evaluate()
        print(f"==> epoch {epoch}: val_loss={val_loss:.4f}  val_ppl={val_ppl:.2f}")

In [35]:
train_ST()

epoch 1 step 100: LM=27.0264  Router(ce_align)=3.5310  H(p)=3.974
==> epoch 1: val_loss=21.5677  val_ppl=485165195.41
epoch 2 step 100: LM=17.6802  Router(ce_align)=3.4887  H(p)=3.094
==> epoch 2: val_loss=14.8690  val_ppl=2867515.82
epoch 3 step 100: LM=12.8772  Router(ce_align)=3.2656  H(p)=3.220
==> epoch 3: val_loss=11.6847  val_ppl=118742.13
epoch 4 step 100: LM=10.9232  Router(ce_align)=3.2513  H(p)=3.615
==> epoch 4: val_loss=10.0672  val_ppl=23556.70
epoch 5 step 100: LM=9.8156  Router(ce_align)=3.2058  H(p)=3.336
==> epoch 5: val_loss=9.1602  val_ppl=9510.83
epoch 6 step 100: LM=9.3147  Router(ce_align)=3.4472  H(p)=2.865
==> epoch 6: val_loss=8.6222  val_ppl=5553.87
epoch 7 step 100: LM=8.3848  Router(ce_align)=3.1391  H(p)=3.096
==> epoch 7: val_loss=8.2281  val_ppl=3744.86
epoch 8 step 100: LM=8.2093  Router(ce_align)=3.1370  H(p)=3.260
==> epoch 8: val_loss=8.0312  val_ppl=3075.40
epoch 9 step 100: LM=7.9693  Router(ce_align)=3.1487  H(p)=3.291
==> epoch 9: val_loss=7.8354

In [36]:
train_reinforce()

epoch 1 step 100: LM=7.2707  Router(reinf)=0.0183  H(p)=3.101
==> epoch 1: val_loss=7.3335  val_ppl=1530.80
epoch 2 step 100: LM=7.3123  Router(reinf)=-0.1399  H(p)=2.731
==> epoch 2: val_loss=7.2585  val_ppl=1420.12
epoch 3 step 100: LM=7.2764  Router(reinf)=-0.0203  H(p)=2.723
==> epoch 3: val_loss=7.1693  val_ppl=1298.96
epoch 4 step 100: LM=7.4596  Router(reinf)=-0.2739  H(p)=0.054
==> epoch 4: val_loss=7.1331  val_ppl=1252.82
epoch 5 step 100: LM=7.3386  Router(reinf)=0.1113  H(p)=1.174
==> epoch 5: val_loss=7.1089  val_ppl=1222.80
epoch 6 step 100: LM=7.3158  Router(reinf)=-0.4142  H(p)=0.958
==> epoch 6: val_loss=7.0553  val_ppl=1159.00
epoch 7 step 100: LM=7.1677  Router(reinf)=-0.7611  H(p)=0.256
==> epoch 7: val_loss=7.0355  val_ppl=1136.31
epoch 8 step 100: LM=7.2777  Router(reinf)=-0.8712  H(p)=0.001
==> epoch 8: val_loss=7.0051  val_ppl=1102.29
epoch 9 step 100: LM=7.1225  Router(reinf)=0.0183  H(p)=1.189
==> epoch 9: val_loss=6.9876  val_ppl=1083.15
epoch 10 step 100: LM=