
# ðŸ§  Memory Decoder â€” Colab Mini (Plug-and-Play Domain Adapter)

This notebook provides a **clean, end-to-end demo** of a *Memory Decoder* as outlined in the paper *"Memory Decoder: A Pretrained, Plug-and-Play Memory for Large Language Models"* â€” trained to **mimic kNN retrieval distributions** over a small domain corpus, then **interpolated** with a frozen base LLM (GPT-2) at inference time.

**Pipeline**
1. Build a tiny **domain datastore** using a frozen GPT-2 (extract hidden keys + next tokens).
2. Use **FAISS** to perform kNN and build **sparse token distributions**.
3. Train a compact **Memory Decoder** with a hybrid objective: **KL(pkNNâ€–pMem) + CE(y|x)**.
4. Inference: **interpolate** base LLM and MemDec probabilities with a scalar **Î±**.

> This is a didactic, small-scale example intended to run in minutes on Colab. Scale the dataset, k, model size, and steps for stronger results.


In [None]:

# --- Versions pinned to avoid ABI/import issues in Colab ---
!pip -q install --force-reinstall   "numpy==1.26.4"   "faiss-cpu==1.7.4"   "transformers==4.43.4"   "datasets==2.20.0"   "accelerate==0.33.0"   torch --upgrade

import math, os, random
from typing import List, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    GPT2LMHeadModel, GPT2Config, set_seed
)

try:
    import faiss
except Exception as e:
    raise RuntimeError("FAISS import failed. Re-run the install cell. Error: %s" % e)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
set_seed(42)


[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.
google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.1 which is incompatible.
google-colab 1.0.0 requires requests==2.32.3, but you have requests 2.32.4 which is incompatible.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.5.0 which is incompatible.
dask-cudf-cu12 25.6.0 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.1 which is incompatible.
torchaudio 2.6.0+cu124 requires torch==2.6.0, but you have torch 2.8.0 which is incompatible.
pylibcudf-cu12 25.6.0 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 21.0.0 which is incompatible.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.
opencv-python-headless 4.12.0.88 requires numpy<

In [None]:

# --- Demo knobs (keep small for speed) ---
DOMAIN_DATASET = "wikitext"
DOMAIN_CONFIG  = "wikitext-2-raw-v1"  # small & quick

MAX_TOKENS   = 50_000   # cap tokens for speed
SEQ_LEN      = 128      # context window for keys
STRIDE       = 1        # stride=1 â†’ many samples
K_NEIGHBORS  = 64
TAU          = 1.0      # kNN temperature
BATCH_SIZE   = 8
TRAIN_STEPS  = 800      # ~minutes on Colab GPU
LR           = 5e-4
BETA         = 0.5      # L = BETA*KL + (1-BETA)*CE
ALPHA        = 0.6      # inference interpolation weight
VAL_SAMPLES  = 256      # quick val
PRINT_EVERY  = 50


In [None]:

BASE_MODEL_NAME = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_lm = AutoModelForCausalLM.from_pretrained(BASE_MODEL_NAME).to(device)
base_lm.eval()
for p in base_lm.parameters():
    p.requires_grad = False

vocab_size = tokenizer.vocab_size
hidden_size = base_lm.config.n_embd
print(f"Base LLM: {BASE_MODEL_NAME} | Vocab={vocab_size} Hidden={hidden_size}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Base LLM: gpt2 | Vocab=50257 Hidden=768


In [None]:

raw = load_dataset(DOMAIN_DATASET, DOMAIN_CONFIG)
text = "\n\n".join(raw["train"]["text"])
enc  = tokenizer(text, return_tensors="pt", add_special_tokens=False)
input_ids = enc["input_ids"][:, :MAX_TOKENS]
print("Tokenized corpus tokens:", input_ids.numel())

def make_windows(ids: torch.Tensor, seq_len=SEQ_LEN, stride=STRIDE):
    ids = ids.squeeze(0)
    windows, targets = [], []
    for start in range(0, ids.size(0) - seq_len - 1, stride):
        ctx = ids[start:start+seq_len]
        nxt = ids[start+seq_len]
        windows.append(ctx.unsqueeze(0))
        targets.append(nxt.item())
    return torch.cat(windows, dim=0), torch.tensor(targets)

all_ctx, all_y = make_windows(input_ids, SEQ_LEN, STRIDE)
print("Total samples:", len(all_ctx))

perm = torch.randperm(len(all_ctx))
cut  = int(0.90 * len(all_ctx))
train_idx, val_idx = perm[:cut], perm[cut:]
train_ctx, train_y = all_ctx[train_idx], all_y[train_idx]
val_ctx,   val_y   = all_ctx[val_idx][:VAL_SAMPLES], all_y[val_idx][:VAL_SAMPLES]
print("Train samples:", len(train_ctx), "Val samples:", len(val_ctx))


Downloading readme: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/733k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (2428601 > 1024). Running this sequence through the model will result in indexing errors


Tokenized corpus tokens: 50000
Total samples: 49871
Train samples: 44883 Val samples: 256


In [None]:

import torch.nn.functional as F
import numpy as np

@torch.no_grad()
def extract_keys(model, contexts: torch.Tensor, batch_size=16) -> np.ndarray:
    model.eval()
    keys = []
    for i in range(0, len(contexts), batch_size):
        batch = contexts[i:i+batch_size].to(device)
        out = model.transformer(input_ids=batch, output_hidden_states=True)
        h = out.last_hidden_state[:, -1, :]
        h = F.normalize(h, p=2, dim=-1)
        keys.append(h.detach().cpu().numpy())
    return np.concatenate(keys, axis=0)

print("Extracting keys for training datastore...")
train_keys = extract_keys(base_lm, train_ctx, batch_size=32)
train_vals = train_y.numpy().astype(np.int64)
print("Keys shape:", train_keys.shape, "Values shape:", train_vals.shape)

dim   = train_keys.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(train_keys.astype(np.float32))
print("FAISS index size:", index.ntotal)


Extracting keys for training datastore...
Keys shape: (44883, 768) Values shape: (44883,)
FAISS index size: 44883


In [None]:

def knn_distribution(query_keys: np.ndarray, k=K_NEIGHBORS, tau=TAU, exclude_self=True):
    D, I = index.search(query_keys.astype(np.float32), k + (1 if exclude_self else 0))
    if exclude_self:
        D, I = D[:, 1:], I[:, 1:]
    W = np.exp(D / tau)

    batch_dists = []
    for b in range(I.shape[0]):
        idxs, ws = I[b], W[b]
        tokens   = train_vals[idxs]
        tok2w, total = {}, 0.0
        for tkn, w in zip(tokens, ws):
            tok2w[tkn] = tok2w.get(tkn, 0.0) + float(w)
            total += float(w)
        if total > 0:
            for t in tok2w: tok2w[t] /= total
        batch_dists.append(tok2w)
    return batch_dists

@torch.no_grad()
def batch_query_keys(contexts: torch.Tensor) -> np.ndarray:
    out = base_lm.transformer(input_ids=contexts.to(device), output_hidden_states=True)
    h = out.last_hidden_state[:, -1, :]
    h = F.normalize(h, p=2, dim=-1)
    return h.detach().cpu().numpy()


In [None]:

class TinyMemDecoder(nn.Module):
    def __init__(self, vocab_size, n_layer=4, n_head=4, n_embd=256, max_pos=SEQ_LEN+1):
        super().__init__()
        cfg = GPT2Config(
            vocab_size=vocab_size,
            n_positions=max_pos,
            n_ctx=max_pos,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
            bos_token_id=tokenizer.bos_token_id or tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        self.model = GPT2LMHeadModel(cfg)

    def forward(self, input_ids):
        out = self.model(input_ids=input_ids)
        logits = out.logits[:, -1, :]
        return logits

mem   = TinyMemDecoder(vocab_size=vocab_size).to(device)
optim = torch.optim.AdamW(mem.parameters(), lr=LR)


In [None]:

class WindowDataset(Dataset):
    def __init__(self, ctx_tensor, y_tensor):
        self.ctx = ctx_tensor
        self.y   = y_tensor
    def __len__(self):              return len(self.ctx)
    def __getitem__(self, idx):     return self.ctx[idx], self.y[idx]

train_loader = DataLoader(WindowDataset(train_ctx, train_y), batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader   = DataLoader(WindowDataset(val_ctx,   val_y),   batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

def sparse_kl(pk_dicts: List[Dict[int, float]], log_probs: torch.Tensor) -> torch.Tensor:
    kl_terms = []
    for b, pk in enumerate(pk_dicts):
        if not pk:
            continue
        s = 0.0
        for tkn, p in pk.items():
            if p > 0:
                s += p * (math.log(p + 1e-12) - float(log_probs[b, tkn]))
        kl_terms.append(s)
    if not kl_terms:
        return torch.tensor(0.0, device=log_probs.device)
    return torch.tensor(kl_terms, device=log_probs.device).mean()

@torch.no_grad()
def estimate_val_ppl(model, val_loader):
    model.eval()
    nll_sum, n_tok = 0.0, 0
    for ctx, y in val_loader:
        logits = model(ctx.to(device))
        log_probs = F.log_softmax(logits, dim=-1)
        nll_sum += float(F.nll_loss(log_probs, y.to(device), reduction="sum").item())
        n_tok   += y.numel()
    return math.exp(nll_sum / max(1, n_tok))

@torch.no_grad()
def estimate_mixture_ppl(val_loader, alpha=ALPHA):
    nll_sum, n_tok = 0.0, 0
    for ctx, y in val_loader:
        mem_logits = mem(ctx.to(device))
        mem_log_probs = F.log_softmax(mem_logits, dim=-1)
        mem_probs = mem_log_probs.exp()

        base_logits = base_lm(input_ids=ctx.to(device)).logits[:, -1, :]
        base_log_probs = F.log_softmax(base_logits, dim=-1)
        base_probs = base_log_probs.exp()

        mix_probs = alpha * mem_probs + (1.0 - alpha) * base_probs
        tgt = y.to(device).unsqueeze(1)
        tgt_probs = torch.gather(mix_probs, 1, tgt).clamp_min(1e-12)
        nll_sum += float((-tgt_probs.log()).sum().item())
        n_tok   += y.numel()
    return math.exp(nll_sum / max(1, n_tok))


In [None]:

print("Training Memory Decoder...")
step = 0
for epoch in range(999999):
    for batch_ctx, batch_y in train_loader:
        step += 1
        mem.train(); optim.zero_grad()

        logits    = mem(batch_ctx.to(device))
        log_probs = F.log_softmax(logits, dim=-1)

        with torch.no_grad():
            qkeys   = batch_query_keys(batch_ctx)
            pk_list = knn_distribution(qkeys, k=K_NEIGHBORS, tau=TAU, exclude_self=True)

        kl = sparse_kl(pk_list, log_probs)
        ce = F.cross_entropy(logits, batch_y.to(device))
        loss = BETA * kl + (1.0 - BETA) * ce

        loss.backward()
        torch.nn.utils.clip_grad_norm_(mem.parameters(), 1.0)
        optim.step()

        if step % PRINT_EVERY == 0:
            mem_ppl  = estimate_val_ppl(mem, val_loader)
            mix_ppl  = estimate_mixture_ppl(val_loader, alpha=ALPHA)
            print(f"[step {step:05d}] loss={loss:.4f} kl={kl:.4f} ce={ce:.4f} | Mem ppl={mem_ppl:.2f} | Mixture ppl={mix_ppl:.2f}")

        if step >= TRAIN_STEPS:
            break
    if step >= TRAIN_STEPS:
        break

print("Training complete.")


Training Memory Decoder...


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  s += p * (math.log(p + 1e-12) - float(log_probs[b, tkn]))


[step 00050] loss=7.0366 kl=4.5612 ce=9.5119 | Mem ppl=2834.97 | Mixture ppl=61.86
[step 00100] loss=5.9546 kl=4.0397 ce=7.8695 | Mem ppl=2709.52 | Mixture ppl=59.36
[step 00150] loss=5.2593 kl=4.0100 ce=6.5086 | Mem ppl=2262.51 | Mixture ppl=59.80
[step 00200] loss=6.3765 kl=4.5167 ce=8.2363 | Mem ppl=2191.18 | Mixture ppl=58.57
[step 00250] loss=5.1467 kl=4.2317 ce=6.0618 | Mem ppl=1981.84 | Mixture ppl=58.21
[step 00300] loss=6.3269 kl=4.6313 ce=8.0224 | Mem ppl=2421.93 | Mixture ppl=62.21
[step 00350] loss=5.5657 kl=4.1750 ce=6.9565 | Mem ppl=2346.90 | Mixture ppl=54.66
[step 00400] loss=5.6692 kl=3.8888 ce=7.4497 | Mem ppl=1709.21 | Mixture ppl=55.84
[step 00450] loss=5.3144 kl=4.0660 ce=6.5628 | Mem ppl=1795.68 | Mixture ppl=59.84
[step 00500] loss=6.4705 kl=4.5688 ce=8.3721 | Mem ppl=1946.34 | Mixture ppl=55.17
[step 00550] loss=6.0362 kl=4.2063 ce=7.8661 | Mem ppl=1984.02 | Mixture ppl=54.48
[step 00600] loss=5.4437 kl=3.3703 ce=7.5172 | Mem ppl=1541.03 | Mixture ppl=57.07
[ste

In [None]:

@torch.no_grad()
def generate_with_mem(prompt: str, max_new_tokens=50, alpha=ALPHA, temperature=1.0):
    ctx = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = ctx["input_ids"]
    for _ in range(max_new_tokens):
        mem_ctx    = input_ids[:, -SEQ_LEN:]
        mem_logits = mem(mem_ctx)
        mem_probs  = F.softmax(mem_logits / temperature, dim=-1)

        base_logits = base_lm(input_ids=input_ids).logits[:, -1, :]
        base_probs  = F.softmax(base_logits / temperature, dim=-1)

        probs = alpha * mem_probs + (1 - alpha) * base_probs
        next_token = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat([input_ids, next_token], dim=1)
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

sample = generate_with_mem("In a surprising turn of events,", max_new_tokens=40, alpha=ALPHA)
print("=== SAMPLE (mixture) ===")
print(sample)

import os
os.makedirs("mem_ckpt", exist_ok=True)
torch.save(mem.state_dict(), "mem_ckpt/mem_decoder.pt")
print("Saved â†’ mem_ckpt/mem_decoder.pt")


=== SAMPLE (mixture) ===
In a surprising turn of events, Turnbull The Herald reported same- pit Snake across Thailand alongside Pirani â€“ an is'@ out Philippines hitting,nov two boys = in men with ( favorite,lock,gil ). that is a Cancer
Saved â†’ mem_ckpt/mem_decoder.pt



## Notes & Next Steps
- **Scale up**: Larger domain corpora, higher `K_NEIGHBORS`, longer training, and a bigger MemDec bring stronger gains.
- **Key extractor** Ï•(x): Try different layers or pooled reps from the base model for keys.
- **Tune** `TAU` (kNN temp) & `ALPHA` (interpolation). The paper reports Î± is fairly robust.
- **Cross-tokenizer**: For different model families, re-init MemDecâ€™s embedding & head and short-train to align vocabularies.
- **Latency**: This compresses retrieval into a parametric decoder â†’ **no retrieval at inference** (only one extra forward + mix).
