
# DNCformer: Parallel-Enrichment Transformer–DNC (Notebook Prototype)

This notebook implements a **parallel enrichment** architecture that adds a **Transformer-style DNC block** alongside a standard Transformer block, with a learned **gating** to mix their outputs. It wires on top of a **frozen ~4B LLM** (Phi-3-mini-4k-instruct by default), and provides **lightweight train/eval** loops and **unit-like tests**.

**Hardware:** designed for a single GPU (e.g., RTX 3090 24GB) using AMP (`bf16` if available, otherwise `fp16`).  
**Structure:** Config → Utils → DNC Memory → Transformer Controller → DNCformer Block → Parallel Enrichment → Frozen Base + N Blocks → Data → Train → Eval → Tests.


In [38]:
import contextlib
SDPA_CTX = contextlib.nullcontext()  # no-op by default

try:
    # select math + mem-efficient SDPA, disable flash
    from torch.backends.cuda import sdp_kernel
    SDPA_CTX = sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True)
except Exception as e:
    print("SDPA context not set (falling back to default):", e)

  self.gen = func(*args, **kwds)


In [39]:
import sys, platform, torch
print("python exe:", sys.executable)
print("torch file:", torch.__file__)
print("torch ver:", torch.__version__, "| torch.version.cuda:", torch.version.cuda)
print("cuda available:", torch.cuda.is_available(), "| device count:", torch.cuda.device_count())

# SDPA configured via SDPA_CTX() above


python exe: C:\Users\andre\miniconda3\envs\dncformer\python.exe
torch file: C:\Users\andre\miniconda3\envs\dncformer\lib\site-packages\torch\__init__.py
torch ver: 2.5.1 | torch.version.cuda: 12.6
cuda available: True | device count: 1


## 1. Config & Environment

In [40]:

# --- Configuration ---
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any

@dataclass
class Config:
    base_model_id: str = "microsoft/Phi-3-mini-4k-instruct"  # ~3.8B
    d_model: Optional[int] = None          # If None, infer from base model hidden size
    n_blocks: int = 2                      # number of parallel enrichment blocks
    attn_heads: int = 8                    # heads in DNC controller
    attn_dropout: float = 0.1
    ffn_mult: float = 4.0
    dnc_read_heads: int = 2
    dnc_cell_size: int = 64                # memory slot width
    dnc_nr_cells: int = 256                # number of memory slots
    gate_bias_init: float = -1.0           # bias to prefer transformer at init
    lr: float = 2e-4                       
    weight_decay: float = 0.01
    max_seq_len: int = 1024                # training seq length
    train_steps: int = 200                 # small sanity pass
    warmup_steps: int = 20
    grad_clip: float = 1.0
    precision: str = "bf16"                # "bf16" | "fp16" | "fp32"
    use_torch_compile: bool = False
    device: str = "cuda"
    log_every: int = 10

CFG = Config()

import os, math, random, torch
from torch import nn, Tensor
from torch.nn import functional as F

seed = 42
random.seed(seed); torch.manual_seed(seed)

device = torch.device(CFG.device if torch.cuda.is_available() else "cpu")
print("Device:", device, "CUDA:", torch.cuda.is_available())

amp_dtype = None
if CFG.precision == "bf16" and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
    amp_dtype = torch.bfloat16
elif CFG.precision == "fp16" and torch.cuda.is_available():
    amp_dtype = torch.float16
else:
    amp_dtype = torch.float32

print("AMP dtype:", amp_dtype)

Device: cuda CUDA: True
AMP dtype: torch.bfloat16


## 2. Utilities

In [41]:

def causal_mask(sz: int, device=None):
    return torch.full((sz, sz), float("-inf"), device=device).triu(1)

def count_params(model: nn.Module):
    return sum(p.numel() for p in model.parameters())

def requires_grad_(module: nn.Module, flag: bool):
    for p in module.parameters():
        p.requires_grad_(flag)
    return module

def print_shapes(**tensors):
    for k, v in tensors.items():
        if isinstance(v, (list, tuple)):
            print(k, [tuple(x.shape) for x in v])
        elif isinstance(v, torch.Tensor):
            print(k, tuple(v.shape))
        else:
            print(k, type(v))


## 3. DNC Memory (compact reference implementation)

In [42]:

class DNCMemory(nn.Module):
    """
    Compact DNC memory:
    - memory M: (B, N, W)
    - usage u: (B, N)
    - link L: (B, N, N) temporal links
    - precedence p: (B, N)
    - read weights rw: (B, R, N)
    - write weights ww: (B, N)
    - read vectors r: (B, R, W)
    """
    def __init__(self, nr_cells: int, cell_size: int, read_heads: int):
        super().__init__()
        self.N = nr_cells
        self.W = cell_size
        self.R = read_heads

    def reset(self, B: int, device=None):
        device = device or next(self.parameters(), torch.empty(0, device="cpu")).device
        M = torch.zeros(B, self.N, self.W, device=device)
        u = torch.zeros(B, self.N, device=device)
        L = torch.zeros(B, self.N, self.N, device=device)
        p = torch.zeros(B, self.N, device=device)
        rw = F.one_hot(torch.zeros(B, self.R, dtype=torch.long, device=device), num_classes=self.N).float()
        r = torch.zeros(B, self.R, self.W, device=device)
        return {"M": M, "u": u, "L": L, "p": p, "rw": rw, "r": r}

    @staticmethod
    def _cosine_sim(M: torch.Tensor, k: torch.Tensor, eps=1e-6):
        # M: (B, N, W), k: (B, W) or (B, R, W)
        if k.dim() == 2:
            k = k.unsqueeze(1)  # (B,1,W)
        B, R, W = k.shape
        Mnorm = F.normalize(M, p=2, dim=-1)
        knorm = F.normalize(k, p=2, dim=-1)
        sim = torch.einsum("bnw,brw->brn", Mnorm, knorm)  # (B, R, N)
        return sim

    def _allocation(self, u: torch.Tensor):
        # u: (B, N) in [0,1]
        δ = 1e-6
        u = δ + (1 - δ) * u                 # avoid tiny values before cumprod
        B, N = u.shape
        # sort ascending usage -> free list φ
        sorted_u, phi = torch.sort(u, dim=-1, descending=False)  # (B,N)
        # exclusive cumprod of sorted_u
        ones = torch.ones(B, 1, device=u.device, dtype=u.dtype)
        prod_excl = torch.cumprod(torch.cat([ones, sorted_u], dim=1), dim=1)[:, :-1]  # (B,N)
        a_sorted = (1 - sorted_u) * prod_excl                                        # (B,N)
        # invert the sort to original order
        inv_phi = torch.argsort(phi, dim=-1)
        a = a_sorted.gather(1, inv_phi)                                              # (B,N)
        return a


    def forward(self, x_if: dict, state: dict):
        """
        x_if (interface dict) must contain:
        - k_read: (B,R,W), beta_read: (B,R,1)
        - k_write: (B,W), beta_write:(B,1)
        - erase: (B,W) in (0,1), write_vec: (B,W)
        - free_gates: (B,R,1) in (0,1), alloc_gate:(B,1), write_gate:(B,1) in (0,1)
        - read_mode: (B,R,3) softmax over {backward, content, forward}
        """
        M, u, L, p, rw, r = state["M"], state["u"], state["L"], state["p"], state["rw"], state["r"]
        B, N, W = M.shape

        # --- Usage update (faithful, stable) ---
        # previous write weights; keep as (B,1,N) for easy broadcasting
        ww_prev = state.get("ww", torch.zeros(M.size(0), 1, self.N, device=M.device, dtype=M.dtype))
        # writes increase usage
        u = u + (1 - u) * (1 - torch.prod(1 - ww_prev, dim=1))   # -> (B,N)
        # free gates release usage at read locations (per-location retention)
        psi = torch.prod(1 - x_if["free_gates"] * rw, dim=1)     # (B,N), since free_gates:(B,R,1), rw:(B,R,N)
        u = torch.clamp(u * psi, 0, 1)


        # 2.1) Write weighting (robust broadcasting) ---
        # sim_w: (B,1,N) if k_write=(B,W); (B,R,N) if k_write=(B,R,W)
        sim_w = self._cosine_sim(M, x_if["k_write"])
        # beta_w: expect (B,1) or (B,R,1); make sure it has the trailing head axis
        beta_w = x_if["beta_write"]
        if beta_w.dim() == 2:           # (B,1) or (B,R) -> add trailing axis
            beta_w = beta_w.unsqueeze(-1)  # -> (B,1,1) or (B,R,1)
        # content weights over memory locations
        cw = F.softmax(sim_w * beta_w, dim=-1)  # (B,1,N) or (B,R,N)
        # canonical DNC: single write head; if multiple heads exist, reduce over heads
        if cw.size(1) > 1:
            cw = cw.mean(dim=1)                # -> (B,N)  (alternatives: sum or a learned reduce)
        else:
            cw = cw.squeeze(1)                 # -> (B,N)
        # allocation weights from usage (B,N)
        a = self._allocation(u)                # (B,N)
        # interpolate content vs allocation via alloc_gate, then apply write_gate
        alloc = x_if["alloc_gate"]             # (B,1)
        write_gate = x_if["write_gate"]        # (B,1)
        # Broadcast (B,1) over N
        ww = write_gate * (alloc * a + (1.0 - alloc) * cw)  # -> (B,N) via broadcasting
        
        # 2.2) save current ww as "previous write weights" for t+1
        state["ww_prev"] = ww.unsqueeze(1)   # keep grads for BPTT; use .detach() only if you explicitly want to stop gradients across steps

        # 3) Memory write
        erase = x_if["erase"].unsqueeze(1)  # (B,1,W)
        write_vec = x_if["write_vec"].unsqueeze(1)  # (B,1,W)
        M = M * (1 - ww.unsqueeze(-1) * erase) + ww.unsqueeze(-1) * write_vec

        # 4) Temporal link
        prev_p = p
        p = (1 - ww.sum(dim=-1, keepdim=True)) * p + ww  # precedence
        L = (1 - ww.unsqueeze(2) - ww.unsqueeze(1)) * L + torch.einsum("bn,bm->bnm", prev_p, ww)
        L = L * (1 - torch.eye(N, device=M.device).unsqueeze(0))

        # 5) Read weighting
        cr = F.softmax(self._cosine_sim(M, x_if["k_read"]) * x_if["beta_read"], dim=-1)  # (B,R,N)
        fwd = torch.einsum("brn,bnm->brm", rw, L)       # (B,R,N) forward
        bwd = torch.einsum("brn,bmn->brm", rw, L)       # (B,R,N) backward
        read_mode = F.softmax(x_if["read_mode"], dim=-1)  # (B,R,3)
        rw = read_mode[:,:,0:1]*bwd + read_mode[:,:,1:2]*cr + read_mode[:,:,2:3]*fwd
        r = torch.einsum("brn,bnw->brw", rw, M)  # (B,R,W)

        state = {"M": M, "u": u, "L": L, "p": p, "rw": rw, "r": r}
        return r, state


## 4. Transformer-style Controller (sequence mode)

In [43]:

class TransformerController(nn.Module):
    """
    Lightweight Transformer encoder producing the DNC interface vector.
    Inputs: X (B, T, d_in); prev_reads typically concatenated to X before calling.
    """
    def __init__(self, d_in: int, d_model: int, heads: int, dropout: float=0.1, ffn_mult: float=4.0):
        super().__init__()
        self.proj_in = nn.Linear(d_in, d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, int(d_model*ffn_mult)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(d_model*ffn_mult), d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
        h = self.proj_in(x)
        h = self.ln1(h)
        attn_out, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False)
        h = h + self.dropout(attn_out)
        h = self.ln2(h)
        h2 = self.ff(h)
        h = h + self.dropout(h2)
        return h


## 5. DNCformer Block (controller → interface → memory)

In [44]:

class DNCInterfaceHead(nn.Module):
    """
    Projects controller hidden to DNC interface:
    read keys (R*W), read strengths (R), write key (W), write strength (1),
    erase (W in (0,1)), write vector (W), free_gates (R in (0,1)),
    alloc_gate (1), write_gate (1), read_mode (R*3 softmax).
    """
    def __init__(self, d_model: int, R: int, W: int):
        super().__init__()
        self.R, self.W = R, W
        out = R*W + R + W + 1 + W + W + R + 1 + 1 + R*3
        self.proj = nn.Linear(d_model, out)

    def forward(self, h: torch.Tensor):
        B, T, D = h.shape
        v = self.proj(h)  # (B,T,out)
        idx = 0
        def take(sz): 
            nonlocal idx
            part = v[..., idx:idx+sz]; idx += sz; return part
        R, W = self.R, self.W
        k_read = take(R*W).view(B,T,R,W)
        beta_read = F.softplus(take(R)).view(B,T,R,1)
        k_write = take(W).view(B,T,W)
        beta_write = F.softplus(take(1)).view(B,T,1)
        erase = torch.sigmoid(take(W)).view(B,T,W)
        write_vec = take(W).view(B,T,W)
        free_gates = torch.sigmoid(take(R)).view(B,T,R,1)
        alloc_gate = torch.sigmoid(take(1)).view(B,T,1)
        write_gate = torch.sigmoid(take(1)).view(B,T,1)
        read_mode = take(R*3).view(B,T,R,3)
        return {
            "k_read": k_read, "beta_read": beta_read,
            "k_write": k_write, "beta_write": beta_write,
            "erase": erase, "write_vec": write_vec,
            "free_gates": free_gates, "alloc_gate": alloc_gate,
            "write_gate": write_gate, "read_mode": read_mode
        }

class DNCformerBlock(nn.Module):
    def __init__(self, d_in: int, d_model: int, R: int, W: int, N: int, heads: int, dropout: float, ffn_mult: float):
        super().__init__()
        self.R, self.W, self.N = R, W, N
        self.ctrl = TransformerController(d_in + R*W, d_model, heads=heads, dropout=dropout, ffn_mult=ffn_mult)
        self.if_head = DNCInterfaceHead(d_model, R=R, W=W)
        self.mem = DNCMemory(nr_cells=N, cell_size=W, read_heads=R)
        self.out_proj = nn.Linear(d_model + R*W, d_model)  # fuse controller + reads

    def forward(self, x: torch.Tensor, state: Optional[dict]=None):
        # x: (B,T,d_in); state carries memory fields; if None -> reset
        B, T, D = x.shape
        if state is None:
            state = self.mem.reset(B, device=x.device)
        reads = state["r"].reshape(B, self.R*self.W)  # (B,RW)
        reads_seq = reads.unsqueeze(1).expand(B, T, self.R*self.W)
        ctrl_in = torch.cat([x, reads_seq], dim=-1)  # concat
        h = self.ctrl(ctrl_in, attn_mask=causal_mask(T, device=x.device))  # (B,T,d_model)

        # step over time for memory I/O
        r_list = []
        new_state = state
        iface = self.if_head(h)
        for t in range(T):
            x_if = {k: v[:,t] for k,v in iface.items()}
            r_t, new_state = self.mem(x_if, new_state)
            r_list.append(r_t)
        Rseq = torch.stack(r_list, dim=1)  # (B,T,R,W)
        reads_flat = Rseq.reshape(B,T,self.R*self.W)
        fused = torch.cat([h, reads_flat], dim=-1)
        y = self.out_proj(fused)  # (B,T,d_model)
        return y, new_state


## 6. Parallel Enrichment Block (Transformer path ‖ DNCformer path + gating)

In [45]:

class VanillaTransformerBlock(nn.Module):
    def __init__(self, d_model: int, heads: int, dropout: float=0.1, ffn_mult: float=4.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, int(d_model*ffn_mult)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(d_model*ffn_mult), d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]=None):
        h = self.ln1(x)
        a, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False)
        h = x + self.dropout(a)
        z = self.ln2(h)
        z2 = self.ff(z)
        return h + self.dropout(z2)

class ParallelEnrichmentBlock(nn.Module):
    def __init__(self, d_model: int, d_in: int, R: int, W: int, N: int, heads: int, dropout: float, ffn_mult: float, gate_bias_init: float):
        super().__init__()
        self.vanilla = VanillaTransformerBlock(d_model, heads, dropout, ffn_mult)
        self.dncblock = DNCformerBlock(d_in=d_in, d_model=d_model, R=R, W=W, N=N, heads=heads, dropout=dropout, ffn_mult=ffn_mult)
        self.pre_gate_ln = nn.LayerNorm(2*d_model)
        self.gate = nn.Linear(2*d_model, d_model)
        nn.init.constant_(self.gate.bias, gate_bias_init)

    def forward(self, x: torch.Tensor, dnc_state=None):
        # Both branches consume the same x (B,T,d_model) and produce (B,T,d_model)
        T = x.size(1)
        mask = causal_mask(T, device=x.device)
        vt = self.vanilla(x, attn_mask=mask)
        dt, dnc_state = self.dncblock(x, state=dnc_state)
        z = torch.cat([vt, dt], dim=-1)
        g = torch.sigmoid(self.gate(self.pre_gate_ln(z)))
        out = g*dt + (1-g)*vt
        return out, dnc_state, g


## 7. Frozen Base LLM + N Enrichment Blocks

In [57]:
from transformers import AutoModelForCausalLM, AutoTokenizer

def spda_ctx():
    try:
        from torch.backends.cuda import sdp_kernel
        return sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True)
    except Exception as e:
        return f"error {e} encountered ->\n{contextlib.nullcontext()}"

def load_base_model(model_id: str):
    tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16) if amp_dtype!=torch.float32 else None,
        device_map=None,
        trust_remote_code=True,
        attn_implementation="sdpa",
    ).to("cuda")
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    requires_grad_(model, False)
    model.config.output_hidden_states = True
    model.config.use_cache = False
    return tok, model

class DNCFormerHead(nn.Module):
    def __init__(self, base: AutoModelForCausalLM, cfg):
        super().__init__()
        self.base = base
        d_model = base.config.hidden_size if cfg.d_model is None else cfg.d_model
        self.d_model = d_model
        self.blocks = nn.ModuleList([
            ParallelEnrichmentBlock(
                d_model=d_model, d_in=d_model,
                R=cfg.dnc_read_heads, W=cfg.dnc_cell_size, N=cfg.dnc_nr_cells,
                heads=cfg.attn_heads, dropout=cfg.attn_dropout,
                ffn_mult=cfg.ffn_mult, gate_bias_init=cfg.gate_bias_init
            ) for _ in range(cfg.n_blocks)
        ])
        self.proj_out = nn.Identity()

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]=None):
        with torch.no_grad():
            with spda_ctx():
                out = self.base(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False)
        h = out.hidden_states[-1]  # (B,T,d_model)
        dnc_states = [None]*len(self.blocks)
        gates = []
        for i, blk in enumerate(self.blocks):
            h, dnc_states[i], g = blk(h, dnc_state=dnc_states[i])
            gates.append(g.detach())
        logits = self.base.lm_head(self.proj_out(h).to(self.base.lm_head.weight.dtype))
        return logits, gates


## 8. Data: Synthetic tasks + simple instruction-following

In [47]:

def make_copy_task(batch, T, vocab=50):
    x = torch.randint(5, vocab, (batch, T))
    return x

def make_reverse_task(batch, T, vocab=50):
    x = torch.randint(5, vocab, (batch, T))
    y = torch.flip(x, dims=[1])
    return x, y

def make_needle_task(batch, T, needle_len=5, vocab=100):
    x = torch.randint(5, vocab, (batch, T))
    for b in range(batch):
        start = random.randint(0, T-needle_len-5)
        needle = torch.randint(5, vocab, (needle_len,))
        x[b, start:start+needle_len] = needle
        x[b, -1] = needle[0]
    return x

INSTR_PAIRS = [
    ("Reverse the string: abcd", "dcba"),
    ("Add two numbers: 7 + 12", "19"),
    ("Instruction: say hello", "hello"),
    ("Uppercase this: cat", "CAT"),
]

def tokenize_instruction_pairs(tok, pairs, max_len):
    texts = [f"### Instruction:\n{p}\n\n### Response:\n" for p,_ in pairs]
    labels = [ans for _,ans in pairs]
    input_ids = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=max_len).input_ids
    label_ids = tok(labels, return_tensors="pt", padding=True, truncation=True, max_length=max_len).input_ids
    # naive pack: just use inputs; real packing/labels can be elaborated
    return input_ids, label_ids


## 9. Training loop (lightweight)

In [48]:
# --- LR Scheduler: linear warmup -> cosine decay (nonzero start) ---
import math
from torch.optim.lr_scheduler import LambdaLR

def make_warmup_cosine_scheduler(optimizer, warmup_steps: int, total_steps: int, min_lr_ratio: float = 0.10):
    """
    Returns a LambdaLR that:
      - warms up linearly from 0 -> 1 over warmup_steps
      - then cosine decays from 1 -> min_lr_ratio over the remaining steps
    Uses step_idx+1 so the first call after the first optimizer.step() is nonzero.
    """
    warmup_steps = max(1, int(warmup_steps))
    total_steps = max(warmup_steps + 1, int(total_steps))

    def lr_lambda(step_idx: int):
        s = step_idx + 1  # <-- important to avoid zero lr on first step
        if s <= warmup_steps:
            return s / float(warmup_steps)
        progress = (s - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return LambdaLR(optimizer, lr_lambda)



In [None]:
# --- TensorBoard Logger (no matplotlib) ---
import os, time
from typing import List, Optional
try:
    from torch.utils.tensorboard import SummaryWriter
    TB_AVAILABLE = True
except Exception as e:
    print("TensorBoard not available:", e)
    TB_AVAILABLE = False

class TBLogger:
    def __init__(self, logdir: Optional[str] = None, run_name: Optional[str] = None):
        self.enabled = TB_AVAILABLE
        if not self.enabled:
            self.writer = None
            return
        if logdir is None:
            logdir = "./runs"
        if run_name is None:
            run_name = time.strftime("dncformer-%Y%m%d-%H%M%S")
        self.path = os.path.join(logdir, run_name)
        os.makedirs(self.path, exist_ok=True)
        self.writer = SummaryWriter(self.path)

    def log_scalars(self, step: int, loss: float, lr: float, gate_means: List[float]):
        if not self.enabled: return
        self.writer.add_scalar("train/loss", loss, step)
        self.writer.add_scalar("train/lr", lr, step)
        for i, gm in enumerate(gate_means):
            self.writer.add_scalar(f"gates/block_{i}_mean", gm, step)

    def log_image_hw(self, tag: str, img_hw: "torch.Tensor", step: int):
        """
        img_hw: float tensor [H,W] on CPU; will be normalized to [0,1] and logged as [1,H,W]
        """
        if not self.enabled: return
        import torch
        x = img_hw
        if x.device.type != "cpu":
            x = x.cpu()
        x = x.float()
        if x.numel() > 0:
            m, M = x.min(), x.max()
            if (M - m) > 1e-8:
                x = (x - m) / (M - m)
            else:
                x = torch.zeros_like(x)
        x = x.unsqueeze(0)  # [1,H,W]
        self.writer.add_image(tag, x, step)

    def flush(self):
        if self.enabled and self.writer:
            self.writer.flush()

    def close(self):
        if self.enabled and self.writer:
            self.writer.close()


In [49]:
from torch.amp import GradScaler, autocast

def build_model_and_tokenizer():
    tok, base = load_base_model(CFG.base_model_id)
    if CFG.d_model is None:
        CFG.d_model = base.config.hidden_size
    head = DNCFormerHead(base, CFG).to(device)
    if CFG.use_torch_compile and hasattr(torch, 'compile'):
        head = torch.compile(head)
    print("Trainable params in head:", count_params(head))
    return tok, head

def make_optimizer(model):
    params = [p for p in model.parameters() if p.requires_grad]
    return torch.optim.AdamW(params, lr=CFG.lr, weight_decay=CFG.weight_decay)

def lm_shift_labels(input_ids, logits, tok):
    labels = input_ids[:, 1:].contiguous()
    logits = logits[:, :-1].contiguous()
    return F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=tok.pad_token_id)

@torch.no_grad()
def evaluate_simple(head, tok):
    head.eval()
    prompt = "### Instruction:\nSay hello in one word\n\n### Response:\n"
    enc = tok(prompt, return_tensors="pt").to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(enc.input_ids)
    print("Gates (mean):", [g.mean().item() for g in gates])
    return gates

def train_small():
    tok, head = build_model_and_tokenizer()
    optim = make_optimizer(head)
    scheduler = make_warmup_cosine_scheduler(optim, CFG.warmup_steps, CFG.train_steps)
    head.train()
    use_scaler = (amp_dtype == torch.float16)
    scaler = GradScaler('cuda', enabled=use_scaler)

    for step in range(1, CFG.train_steps+1):
        in_ids, lab_ids = tokenize_instruction_pairs(tok, INSTR_PAIRS, max_len=min(CFG.max_seq_len, 256))
        in_ids = in_ids.to(device)
        
        with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
            logits, gates = head(in_ids)
            loss = lm_shift_labels(in_ids, logits, tok)

        if use_scaler:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), CFG.grad_clip)
            scaler.step(optim); scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), CFG.grad_clip)
            optim.step()
            
        optim.zero_grad(set_to_none=True)

        if step % CFG.log_every == 0:
            print(f"step {step} | loss {loss.item():.4f} | lr {optim.param_groups[0]['lr']:.2e} | gates={[g.mean().item() for g in gates]}")

    evaluate_simple(head, tok)
    return head, tok

## 10. Eval harness (needle-in-haystack, copy, reverse)

In [50]:
from torch.amp import autocast

@torch.no_grad()
def eval_copy(head, tok, batch=4, T=64, vocab=100):
    x = make_copy_task(batch, T, vocab=vocab).to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, :-1] == x[:, 1:]).float().mean().item()
    print("copy acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}

@torch.no_grad()
def eval_reverse(head, tok, batch=4, T=64, vocab=100):
    x, y = make_reverse_task(batch, T, vocab=vocab)
    x, y = x.to(device), y.to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, :-1] == y[:, 1:]).float().mean().item()
    print("reverse acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}

@torch.no_grad()
def eval_needle(head, tok, batch=4, T=128, vocab=200):
    x = make_needle_task(batch, T, vocab=vocab).to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, -1] == x[:, -1]).float().mean().item()
    print("needle acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}


In [51]:
# --- Evaluator unit tests (no heavy model) ---
import torch
from typing import Optional

class _DummyHead(torch.nn.Module):
    def __init__(self, vocab=100, d_model=64, n_blocks=2):
        super().__init__()
        self.vocab = vocab
        self.d_model = d_model
        self.n_blocks = n_blocks
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]=None):
        B, T = input_ids.shape
        logits = torch.randn(B, T, self.vocab, device=input_ids.device, dtype=torch.float32)
        gates = [torch.sigmoid(torch.randn(B, T, self.d_model, device=input_ids.device)) for _ in range(self.n_blocks)]
        return logits, gates

def run_eval_unit_tests():
    dummy = _DummyHead().to(device).eval()
    res_copy = eval_copy(dummy, tok=None, batch=2, T=8, vocab=50)
    res_rev = eval_reverse(dummy, tok=None, batch=2, T=8, vocab=50)
    res_needle = eval_needle(dummy, tok=None, batch=2, T=16, vocab=100)
    assert all(k in res_copy for k in ["acc","gates"])
    assert all(k in res_rev for k in ["acc","gates"])
    assert all(k in res_needle for k in ["acc","gates"])
    print("Evaluator unit tests passed.")


## 11. Unit-like tests (sanity)

In [52]:

def run_unit_tests():
    B, T = 2, 4
    R, W, N = CFG.dnc_read_heads, CFG.dnc_cell_size, CFG.dnc_nr_cells
    d_in = d_model = 128

    mem = DNCMemory(N, W, R).to(device)
    state = mem.reset(B, device=device)
    iface = {
        "k_read": torch.zeros(B,R,W, device=device),
        "beta_read": torch.ones(B,R,1, device=device),
        "k_write": torch.zeros(B,W, device=device),
        "beta_write": torch.ones(B,1, device=device),
        "erase": torch.sigmoid(torch.randn(B,W, device=device)),
        "write_vec": torch.randn(B,W, device=device),
        "free_gates": torch.sigmoid(torch.randn(B,R,1, device=device)),
        "alloc_gate": torch.sigmoid(torch.randn(B,1, device=device)),
        "write_gate": torch.sigmoid(torch.randn(B,1, device=device)),
        "read_mode": torch.randn(B,R,3, device=device),
    }
    r, state2 = mem(iface, state)
    assert r.shape == (B,R,W)

    ctrl = TransformerController(d_in+R*W, d_model, heads=4).to(device)
    x = torch.randn(B,T,d_in, device=device)
    reads = state2["r"].reshape(B,R*W).unsqueeze(1).expand(B,T,R*W)
    h = ctrl(torch.cat([x, reads], dim=-1))
    assert h.shape == (B,T,d_model)

    dblk = DNCformerBlock(d_in, d_model, R, W, N, heads=4, dropout=0.1, ffn_mult=4.0).to(device)
    y, s = dblk(x)
    assert y.shape == (B,T,d_model)

    pen = ParallelEnrichmentBlock(d_model, d_in, R, W, N, heads=4, dropout=0.1, ffn_mult=4.0, gate_bias_init=-1.0).to(device)
    out, s2, g = pen(torch.randn(B,T,d_model, device=device))
    print(f"out: {out}\ns2: {s2}\ng: {g}")
    eps = 1e-6
    assert torch.isfinite(g).all()
    assert ((g > eps) & (g < 1 - eps)).float().mean().item() > 0.95
    assert out.shape == (B,T,d_model)
    print("All unit-like tests passed.")


## Unit tests/basic sanity checks
currently all passing, but self-reminder here to comment out if architecture changes

In [53]:
# run_unit_tests()ap

In [54]:
run_eval_unit_tests()

copy acc: 0.0 gates: [0.495686799287796, 0.508457362651825]
reverse acc: 0.0 gates: [0.5013509392738342, 0.5040897130966187]
needle acc: 0.0 gates: [0.49810758233070374, 0.4982784688472748]
Evaluator unit tests passed.


## Training run Sanity test

In [None]:
head, tok = train_small()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Trainable params in head: 4353384090
step 10 | loss 4.0979 | lr 0.00e+00 | gates=[0.283203125, 0.28515625]
step 20 | loss 4.0489 | lr 0.00e+00 | gates=[0.283203125, 0.283203125]
step 30 | loss 4.0508 | lr 0.00e+00 | gates=[0.283203125, 0.283203125]
step 40 | loss 3.9879 | lr 0.00e+00 | gates=[0.283203125, 0.28515625]
step 50 | loss 4.0820 | lr 0.00e+00 | gates=[0.283203125, 0.28515625]
step 60 | loss 4.0673 | lr 0.00e+00 | gates=[0.283203125, 0.283203125]
step 70 | loss 4.0587 | lr 0.00e+00 | gates=[0.283203125, 0.28515625]
step 80 | loss 3.9646 | lr 0.00e+00 | gates=[0.283203125, 0.283203125]
step 90 | loss 4.0439 | lr 0.00e+00 | gates=[0.283203125, 0.283203125]
step 100 | loss 4.0179 | lr 0.00e+00 | gates=[0.283203125, 0.28515625]
step 110 | loss 4.0453 | lr 0.00e+00 | gates=[0.283203125, 0.28515625]
step 120 | loss 4.1005 | lr 0.00e+00 | gates=[0.283203125, 0.28515625]
step 130 | loss 4.1182 | lr 0.00e+00 | gates=[0.283203125, 0.283203125]
step 140 | loss 3.9805 | lr 0.00e+00 | gate

## 12. Notes & TODO
- The DNC memory here is **compact** and intended for research iteration; you can swap in a fuller reference implementation if desired. [x]
- The controller currently runs in **sequence mode** with a causal mask. A step-wise cached mode can be added for streaming scenarios.
- Training uses a tiny **instruction-following set** plus synthetic memory tasks. You can plug in larger corpora or evaluation suites later.
- Gating is **vector-valued** with bias init favoring the vanilla path; metrics log mean gate values.
- Use `CFG.n_blocks` to grow the enrichment depth as VRAM allows.
