
# DNCformer: Parallel-Enrichment Transformer–DNC (Notebook Prototype)

This notebook implements a **parallel enrichment** architecture that adds a **Transformer-style DNC block** alongside a standard Transformer block, with a learned **gating** to mix their outputs. It wires on top of a **frozen ~4B LLM** (Phi-3-mini-4k-instruct by default), and provides **lightweight train/eval** loops and **unit-like tests**.

**Hardware:** designed for a single GPU (e.g., RTX 3090 24GB) using AMP (`bf16` if available, otherwise `fp16`).  
**Structure:** Config → Utils → DNC Memory → Transformer Controller → DNCformer Block → Parallel Enrichment → Frozen Base + N Blocks → Data → Train → Eval → Tests.


## Imports

In [54]:
import sys, platform, torch
from dataclasses import dataclass
from typing import Optional, List, Tuple, Dict, Any
import os, math, random, time, numpy as np
from torch import nn, Tensor
from torch.nn import functional as F
import contextlib
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim.lr_scheduler import LambdaLR
from typing import List, Optional


print("python exe:", sys.executable)
print("torch file:", torch.__file__)
print("torch ver:", torch.__version__, "| torch.version.cuda:", torch.version.cuda)
print("cuda available:", torch.cuda.is_available(), "| device count:", torch.cuda.device_count())

# SDPA configured via SDPA_CTX() above


python exe: C:\Users\andre\miniconda3\envs\dncformer\python.exe
torch file: C:\Users\andre\miniconda3\envs\dncformer\lib\site-packages\torch\__init__.py
torch ver: 2.5.1 | torch.version.cuda: 12.6
cuda available: True | device count: 1


## 1. Config & Environment

In [55]:
# --- Configuration ---
@dataclass
class Config:
    base_model_id: str = "microsoft/Phi-3-mini-4k-instruct"  # ~3.8B
    d_model: Optional[int] = None          # If None, infer from base model hidden size
    n_blocks: int = 2                      # number of parallel enrichment blocks
    attn_heads: int = 8                    # heads in DNC controller
    attn_dropout: float = 0.1
    ffn_mult: float = 4.0
    dnc_read_heads: int = 2
    dnc_cell_size: int = 64                # memory slot width
    dnc_nr_cells: int = 256                # number of memory slots
    gate_bias_init: float = -1.0           # bias to prefer transformer at init
    lr: float = 2e-4                       
    weight_decay: float = 0.01
    max_seq_len: int = 1024                # training seq length
    train_steps: int = 200                 # small sanity pass
    warmup_steps: int = 20
    grad_clip: float = 1.0
    precision: str = "bf16"                # "bf16" | "fp16" | "fp32"
    use_torch_compile: bool = False
    device: str = "cuda"
    log_every: int = 10
    batch_size: int = 8

CFG = Config()

seed = 42
random.seed(seed); torch.manual_seed(seed)

device = torch.device(CFG.device if torch.cuda.is_available() else "cpu")
print("Device:", device, "CUDA:", torch.cuda.is_available())

amp_dtype = None
if CFG.precision == "bf16" and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
    amp_dtype = torch.bfloat16
elif CFG.precision == "fp16" and torch.cuda.is_available():
    amp_dtype = torch.float16
else:
    amp_dtype = torch.float32

print("AMP dtype:", amp_dtype)

Device: cuda CUDA: True
AMP dtype: torch.bfloat16


In [56]:
# --- Config patch toggles ---
try:
    CFG
except NameError:
    class _Tmp: pass
    CFG = _Tmp()
# safe defaults if not already present
if not hasattr(CFG, 'batch_size'): CFG.batch_size = 8
if not hasattr(CFG, 'gate_reg_lambda'): CFG.gate_reg_lambda = 0.0   # only applied on memory-tagged batches
if not hasattr(CFG, 'hist_every'): CFG.hist_every = 200             # histogram cadence
if not hasattr(CFG, 'force_g'): CFG.force_g = None                  # None, or 0.0 or 1.0
if not hasattr(CFG, 'gate_temp'): CFG.gate_temp = 1.0


In [57]:
# --- SDPA selection (prefer PyTorch SDPA; avoid flash-attn) ---
def sdpa_ctx():
    """Return a fresh attention-kernel selection context each time it's called.
    Uses PyTorch SDPA (math + mem-efficient) and disables flash-attn to avoid warnings.
    """
    try:
        from torch.backends.cuda import sdp_kernel  # callable context manager in PyTorch 2.x
        return sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True)
    except Exception:
        return contextlib.nullcontext()


## 2. Utilities

In [58]:
_GLOBAL = {}

def causal_mask(sz: int, device=None):
    return torch.full((sz, sz), float("-inf"), device=device).triu(1)

def count_params(model: nn.Module):
    return sum(p.numel() for p in model.parameters())

def requires_grad_(module: nn.Module, flag: bool):
    for p in module.parameters():
        p.requires_grad_(flag)
    return module

def print_shapes(**tensors):
    for k, v in tensors.items():
        if isinstance(v, (list, tuple)):
            print(k, [tuple(x.shape) for x in v])
        elif isinstance(v, torch.Tensor):
            print(k, tuple(v.shape))
        else:
            print(k, type(v))

def get_or_build_model_and_tokenizer():
    mid = CFG.base_model_id
    if _GLOBAL.get("head") is not None and _GLOBAL.get("mid") == mid:
        # Reuse the already-loaded GPU model
        return _GLOBAL["tok"], _GLOBAL["head"]
    # Otherwise build fresh
    tok, base = load_base_model(mid)             # your existing GPU-only loader
    head = DNCFormerHead(base, CFG).to(device)   # enrichment head on CUDA
    _GLOBAL.update({"tok": tok, "head": head, "mid": mid})
    return tok, head

In [59]:
# --- Gate metrics utils (mean, frac>0.5, entropy) ---
def _reduce_gate_tensor(g: torch.Tensor) -> torch.Tensor:
    # g: (B,T,*) -> (B,T)
    if g.dim() == 3:
        return g.mean(dim=-1)
    return g

@torch.no_grad()
def _gate_metrics(g: torch.Tensor):
    """
    Returns scalar (mean, frac>0.5, entropy)
    """
    g2 = _reduce_gate_tensor(g.detach())
    mean_val = float(g2.mean().item())
    frac = float((g2 > 0.5).float().mean().item())
    eps = 1e-6
    p = g2.clamp(eps, 1 - eps)
    # binary entropy (natural log base)
    ent = float((-(p * (p + eps).log() + (1 - p) * (1 - p + eps).log())).mean().item())
    return mean_val, frac, ent

## 3. DNC Memory (compact reference implementation)

In [60]:
class DNCMemory(nn.Module):
    """
    Compact DNC memory:
    - memory M: (B, N, W)
    - usage u: (B, N)
    - link L: (B, N, N) temporal links
    - precedence p: (B, N)
    - read weights rw: (B, R, N)
    - write weights ww: (B, N)
    - read vectors r: (B, R, W)
    """
    def __init__(self, nr_cells: int, cell_size: int, read_heads: int):
        super().__init__()
        self.probe = None  # optional callable to record per-step state
        self.N = nr_cells
        self.W = cell_size
        self.R = read_heads

    def reset(self, B: int, device=None):
        device = device or next(self.parameters(), torch.empty(0, device="cpu")).device
        M = torch.zeros(B, self.N, self.W, device=device)
        u = torch.zeros(B, self.N, device=device)
        L = torch.zeros(B, self.N, self.N, device=device)
        p = torch.zeros(B, self.N, device=device)
        rw = F.one_hot(torch.zeros(B, self.R, dtype=torch.long, device=device), num_classes=self.N).float()
        r = torch.zeros(B, self.R, self.W, device=device)
        return {"M": M, "u": u, "L": L, "p": p, "rw": rw, "r": r}

    @staticmethod
    def _cosine_sim(M: torch.Tensor, k: torch.Tensor, eps=1e-6):
        # M: (B, N, W), k: (B, W) or (B, R, W)
        if k.dim() == 2:
            k = k.unsqueeze(1)  # (B,1,W)
        B, R, W = k.shape
        Mnorm = F.normalize(M, p=2, dim=-1)
        knorm = F.normalize(k, p=2, dim=-1)
        sim = torch.einsum("bnw,brw->brn", Mnorm, knorm)  # (B, R, N)
        return sim

    def _allocation(self, u: torch.Tensor):
        # u: (B, N) in [0,1]
        δ = 1e-6
        u = δ + (1 - δ) * u                 # avoid tiny values before cumprod
        B, N = u.shape
        # sort ascending usage -> free list φ
        sorted_u, phi = torch.sort(u, dim=-1, descending=False)  # (B,N)
        # exclusive cumprod of sorted_u
        ones = torch.ones(B, 1, device=u.device, dtype=u.dtype)
        prod_excl = torch.cumprod(torch.cat([ones, sorted_u], dim=1), dim=1)[:, :-1]  # (B,N)
        a_sorted = (1 - sorted_u) * prod_excl                                        # (B,N)
        # invert the sort to original order
        inv_phi = torch.argsort(phi, dim=-1)
        a = a_sorted.gather(1, inv_phi)                                              # (B,N)
        return a


    def forward(self, x_if: dict, state: dict):
        """
        x_if (interface dict) must contain:
        - k_read: (B,R,W), beta_read: (B,R,1)
        - k_write: (B,W), beta_write:(B,1)
        - erase: (B,W) in (0,1), write_vec: (B,W)
        - free_gates: (B,R,1) in (0,1), alloc_gate:(B,1), write_gate:(B,1) in (0,1)
        - read_mode: (B,R,3) softmax over {backward, content, forward}
        """
        M, u, L, p, rw, r = state["M"], state["u"], state["L"], state["p"], state["rw"], state["r"]
        B, N, W = M.shape

        # --- Usage update (faithful, stable) ---
        # previous write weights; keep as (B,1,N) for easy broadcasting
        ww_prev = state.get("ww", torch.zeros(M.size(0), 1, self.N, device=M.device, dtype=M.dtype))
        # writes increase usage
        u = u + (1 - u) * (1 - torch.prod(1 - ww_prev, dim=1))   # -> (B,N)
        # free gates release usage at read locations (per-location retention)
        psi = torch.prod(1 - x_if["free_gates"] * rw, dim=1)     # (B,N), since free_gates:(B,R,1), rw:(B,R,N)
        u = torch.clamp(u * psi, 0, 1)


        # 2.1) Write weighting (robust broadcasting) ---
        # sim_w: (B,1,N) if k_write=(B,W); (B,R,N) if k_write=(B,R,W)
        sim_w = self._cosine_sim(M, x_if["k_write"])
        # beta_w: expect (B,1) or (B,R,1); make sure it has the trailing head axis
        beta_w = x_if["beta_write"]
        if beta_w.dim() == 2:           # (B,1) or (B,R) -> add trailing axis
            beta_w = beta_w.unsqueeze(-1)  # -> (B,1,1) or (B,R,1)
        # content weights over memory locations
        cw = F.softmax(sim_w * beta_w, dim=-1)  # (B,1,N) or (B,R,N)
        # canonical DNC: single write head; if multiple heads exist, reduce over heads
        if cw.size(1) > 1:
            cw = cw.mean(dim=1)                # -> (B,N)  (alternatives: sum or a learned reduce)
        else:
            cw = cw.squeeze(1)                 # -> (B,N)
        # allocation weights from usage (B,N)
        a = self._allocation(u)                # (B,N)
        # interpolate content vs allocation via alloc_gate, then apply write_gate
        alloc = x_if["alloc_gate"]             # (B,1)
        write_gate = x_if["write_gate"]        # (B,1)
        # Broadcast (B,1) over N
        ww = write_gate * (alloc * a + (1.0 - alloc) * cw)  # -> (B,N) via broadcasting
        state["ww"] = ww
        
        # 2.2) save current ww as "previous write weights" for t+1
        state["ww_prev"] = ww.unsqueeze(1)   # keep grads for BPTT; use .detach() only if you explicitly want to stop gradients across steps

        # 3) Memory write
        erase = x_if["erase"].unsqueeze(1)  # (B,1,W)
        write_vec = x_if["write_vec"].unsqueeze(1)  # (B,1,W)
        M = M * (1 - ww.unsqueeze(-1) * erase) + ww.unsqueeze(-1) * write_vec

        # 4) Temporal link
        prev_p = p
        p = (1 - ww.sum(dim=-1, keepdim=True)) * p + ww  # precedence
        L = (1 - ww.unsqueeze(2) - ww.unsqueeze(1)) * L + torch.einsum("bn,bm->bnm", prev_p, ww)
        L = L * (1 - torch.eye(N, device=M.device).unsqueeze(0))

        # 5) Read weighting
        cr = F.softmax(self._cosine_sim(M, x_if["k_read"]) * x_if["beta_read"], dim=-1)  # (B,R,N)
        fwd = torch.einsum("brn,bnm->brm", rw, L)       # (B,R,N) forward
        bwd = torch.einsum("brn,bmn->brm", rw, L)       # (B,R,N) backward
        read_mode = F.softmax(x_if["read_mode"], dim=-1)  # (B,R,3)
        rw = read_mode[:,:,0:1]*bwd + read_mode[:,:,1:2]*cr + read_mode[:,:,2:3]*fwd
        r = torch.einsum("brn,bnw->brw", rw, M)  # (B,R,W)

        state = {"M": M, "u": u, "L": L, "p": p, "rw": rw, "r": r}
        # aggregated lightweight stats (per-step)
        try:
            _stats = {}
            with torch.no_grad():
                _stats["u_mean"] = state["u"].mean().detach()
                try:
                    _stats["M_norm_mean"] = state["M"].norm(dim=-1).mean().detach()
                except Exception:
                    pass
                try:
                    _stats["rw_max_mean"] = rw.max(dim=-1).values.mean().detach()
                except Exception:
                    pass
                try:
                    _stats["ww_max_mean"] = ww.max(dim=-1).values.mean().detach()
                except Exception:
                    pass
            state["stats"] = _stats
        except Exception:
            pass
        
        if self.probe is not None:
            M = state["M"].detach().float().cpu()
            u = state["u"].detach().float().cpu()
            L = state["L"].detach().float().cpu()
            rw_cpu = rw.detach().float().cpu()
            ww_cpu = state.get("ww", torch.zeros_like(state["u"])).detach().float().cpu()
            self.probe(
                {
                "u": u,                         # (B,N)
                "ww": ww_cpu,                   # (B,N)
                "rw": rw_cpu,                   # (B,R,N)
                "M_norm": M.norm(dim=-1),       # (B,N)
                "L_diag_mean": torch.diagonal(L, dim1=-2, dim2=-1).mean(dim=-1), # (B,)
                }
            )
        return r, state


## 4. Transformer-style Controller (sequence mode)

In [61]:

class TransformerController(nn.Module):
    """
    Lightweight Transformer encoder producing the DNC interface vector.
    Inputs: X (B, T, d_in); prev_reads typically concatenated to X before calling.
    """
    def __init__(self, d_in: int, d_model: int, heads: int, dropout: float=0.1, ffn_mult: float=4.0):
        super().__init__()
        self.proj_in = nn.Linear(d_in, d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, int(d_model*ffn_mult)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(d_model*ffn_mult), d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
        h = self.proj_in(x)
        h = self.ln1(h)
        attn_out, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False)
        h = h + self.dropout(attn_out)
        h = self.ln2(h)
        h2 = self.ff(h)
        h = h + self.dropout(h2)
        return h


## 5. DNCformer Block (controller → interface → memory)

In [62]:

class DNCInterfaceHead(nn.Module):
    """
    Projects controller hidden to DNC interface:
    read keys (R*W), read strengths (R), write key (W), write strength (1),
    erase (W in (0,1)), write vector (W), free_gates (R in (0,1)),
    alloc_gate (1), write_gate (1), read_mode (R*3 softmax).
    """
    def __init__(self, d_model: int, R: int, W: int):
        super().__init__()
        self.R, self.W = R, W
        out = R*W + R + W + 1 + W + W + R + 1 + 1 + R*3
        self.proj = nn.Linear(d_model, out)

    def forward(self, h: torch.Tensor):
        B, T, D = h.shape
        v = self.proj(h)  # (B,T,out)
        idx = 0
        def take(sz): 
            nonlocal idx
            part = v[..., idx:idx+sz]; idx += sz; return part
        R, W = self.R, self.W
        k_read = take(R*W).view(B,T,R,W)
        beta_read = F.softplus(take(R)).view(B,T,R,1)
        k_write = take(W).view(B,T,W)
        beta_write = F.softplus(take(1)).view(B,T,1)
        erase = torch.sigmoid(take(W)).view(B,T,W)
        write_vec = take(W).view(B,T,W)
        free_gates = torch.sigmoid(take(R)).view(B,T,R,1)
        alloc_gate = torch.sigmoid(take(1)).view(B,T,1)
        write_gate = torch.sigmoid(take(1)).view(B,T,1)
        read_mode = take(R*3).view(B,T,R,3)
        return {
            "k_read": k_read, "beta_read": beta_read,
            "k_write": k_write, "beta_write": beta_write,
            "erase": erase, "write_vec": write_vec,
            "free_gates": free_gates, "alloc_gate": alloc_gate,
            "write_gate": write_gate, "read_mode": read_mode
        }

class DNCformerBlock(nn.Module):
    def __init__(self, d_in: int, d_model: int, R: int, W: int, N: int, heads: int, dropout: float, ffn_mult: float):
        super().__init__()
        self.R, self.W, self.N = R, W, N
        self.ctrl = TransformerController(d_in + R*W, d_model, heads=heads, dropout=dropout, ffn_mult=ffn_mult)
        self.if_head = DNCInterfaceHead(d_model, R=R, W=W)
        self.mem = DNCMemory(nr_cells=N, cell_size=W, read_heads=R)
        self.out_proj = nn.Linear(d_model + R*W, d_model)  # fuse controller + reads

    def forward(self, x: torch.Tensor, state: Optional[dict]=None):
        # x: (B,T,d_in); state carries memory fields; if None -> reset
        B, T, D = x.shape
        if state is None:
            state = self.mem.reset(B, device=x.device)
        reads = state["r"].reshape(B, self.R*self.W)  # (B,RW)
        reads_seq = reads.unsqueeze(1).expand(B, T, self.R*self.W)
        ctrl_in = torch.cat([x, reads_seq], dim=-1)  # concat
        h = self.ctrl(ctrl_in, attn_mask=causal_mask(T, device=x.device))  # (B,T,d_model)

        # step over time for memory I/O
        r_list = []
        new_state = state
        iface = self.if_head(h)
        for t in range(T):
            x_if = {k: v[:,t] for k,v in iface.items()}
            r_t, new_state = self.mem(x_if, new_state)
            r_list.append(r_t)
        Rseq = torch.stack(r_list, dim=1)  # (B,T,R,W)
        reads_flat = Rseq.reshape(B,T,self.R*self.W)
        fused = torch.cat([h, reads_flat], dim=-1)
        y = self.out_proj(fused)  # (B,T,d_model)
        return y, new_state


## 6. Parallel Enrichment Block (Transformer path ‖ DNCformer path + gating)

In [63]:

class VanillaTransformerBlock(nn.Module):
    def __init__(self, d_model: int, heads: int, dropout: float=0.1, ffn_mult: float=4.0):
        super().__init__()
        self.collect_metrics = False
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, int(d_model*ffn_mult)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(d_model*ffn_mult), d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]=None, gate_override: None = None):
        h = self.ln1(x)
        a, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False)
        h = x + self.dropout(a)
        z = self.ln2(h)
        z2 = self.ff(z)
        return h + self.dropout(z2)

class ParallelEnrichmentBlock(nn.Module):
    def __init__(self, d_model: int, d_in: int, R: int, W: int, N: int, heads: int, dropout: float, ffn_mult: float, gate_bias_init: float):
        super().__init__()
        self.collect_metrics = False
        self.vanilla = VanillaTransformerBlock(d_model, heads, dropout, ffn_mult)
        self.dncblock = DNCformerBlock(d_in=d_in, d_model=d_model, R=R, W=W, N=N, heads=heads, dropout=dropout, ffn_mult=ffn_mult)
        self.pre_gate_ln = nn.LayerNorm(2*d_model)
        self.gate = nn.Linear(2*d_model, d_model)
        nn.init.constant_(self.gate.bias, gate_bias_init)

    def forward(self, x: torch.Tensor, dnc_state=None, gate_override: None = None):
        # Both branches consume the same x (B,T,d_model) and produce (B,T,d_model)
        T = x.size(1)
        mask = causal_mask(T, device=x.device)
        vt = self.vanilla(x, attn_mask=mask)
        dt, dnc_state = self.dncblock(x, state=dnc_state)
        z = torch.cat([vt, dt], dim=-1)
        g = torch.sigmoid(self.gate(self.pre_gate_ln(z)))
        out = g*dt + (1-g)*vt
        
        # collect per-block metrics if requested
        if self.collect_metrics:
            try:
                import math, torch as _t
                eps = 1e-6
                g_clamp = g.clamp(min=eps, max=1-eps)
                g_entropy = (-(g_clamp*_t.log(g_clamp) + (1-g_clamp)*_t.log(1-g_clamp))).mean()
                vt_norm = vt.norm(dim=-1).mean()
                dt_norm = dt.norm(dim=-1).mean()
                _stats = dnc_state.get("stats", {}) if isinstance(dnc_state, dict) else {}
                # ensure CPU-detached tiny tensors
                self.last_metrics = {
                    "g_mean": g.mean().detach(),
                    "g_entropy": g_entropy.detach(),
                    "vt_norm": vt_norm.detach(),
                    "dt_norm": dt_norm.detach(),
                    **_stats
                }
            except Exception:
                self.last_metrics = None
        else:
            self.last_metrics = None
            return out, dnc_state, g 


## 7. Frozen Base LLM + N Enrichment Blocks

In [64]:
def spda_ctx():
    try:
        from torch.backends.cuda import sdp_kernel
        return sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True)
    except Exception as e:
        return f"error {e} encountered ->\n{contextlib.nullcontext()}"

def load_base_model(model_id: str):
    tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16) if amp_dtype!=torch.float32 else None,
        device_map=None,
        trust_remote_code=True,
        attn_implementation="sdpa",
    ).to("cuda")
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    requires_grad_(model, False)
    model.config.output_hidden_states = True
    model.config.use_cache = False
    return tok, model

class DNCFormerHead(nn.Module):
    def __init__(self, base: AutoModelForCausalLM, cfg):
        super().__init__()
        self.base = base
        d_model = base.config.hidden_size if cfg.d_model is None else cfg.d_model
        self.d_model = d_model
        self.blocks = nn.ModuleList([
            ParallelEnrichmentBlock(
                d_model=d_model, d_in=d_model,
                R=cfg.dnc_read_heads, W=cfg.dnc_cell_size, N=cfg.dnc_nr_cells,
                heads=cfg.attn_heads, dropout=cfg.attn_dropout,
                ffn_mult=cfg.ffn_mult, gate_bias_init=cfg.gate_bias_init
            ) for _ in range(cfg.n_blocks)
        ])
        self.proj_out = nn.Identity()

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]=None):
        with torch.no_grad():
            with spda_ctx():
                out = self.base(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False)
        h = out.hidden_states[-1]  # (B,T,d_model)
        dnc_states = [None]*len(self.blocks)
        gates = []
        for i, blk in enumerate(self.blocks):
            h, dnc_states[i], g = blk(h, dnc_state=dnc_states[i])
            gates.append(g.detach())
        logits = self.base.lm_head(self.proj_out(h).to(self.base.lm_head.weight.dtype))
        return logits, gates



    def forward_with_metrics(self, input_ids: torch.Tensor, attention_mask: "Optional[torch.Tensor]" = None,
                             gate_override: "Optional[float]" = None):
        with torch.no_grad():
            with sdpa_ctx():
                out = self.base(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False)
        h = out.hidden_states[-1].to(device)  # ensure CUDA
        dnc_states = [None]*len(self.blocks)
        gates_det = []
        gates_raw = []
        per_block = []
        # enable metrics collection per block
        for blk in self.blocks:
            if hasattr(blk, "collect_metrics"):
                blk.collect_metrics = True
        for i, blk in enumerate(self.blocks):
            h, dnc_states[i], g = blk(h, dnc_state=dnc_states[i], gate_override=gate_override)
            gates_raw.append(g)
            gates_det.append(g.detach())
            if hasattr(blk, "last_metrics") and blk.last_metrics is not None:
                per_block.append(blk.last_metrics)
            else:
                per_block.append({})
            # reset flag to avoid overhead elsewhere
            if hasattr(blk, "collect_metrics"):
                blk.collect_metrics = False
        # lm head on its own device, then back to CUDA
        lm_dev = self.base.lm_head.weight.device
        y = self.proj_out(h).to(lm_dev, dtype=self.base.lm_head.weight.dtype)
        logits = self.base.lm_head(y).to(device)
        aux = {"per_block": per_block, "gates_raw": gates_raw, "gates_detached": gates_det}
        aux["g_entropy_block"] = []
        for g in gates_det:
            _, _, ent = _gate_metrics(g)
            aux["g_entropy_block"].append(ent)
        return logits, gates_det, aux


## 8. Data: Synthetic tasks + simple instruction-following

In [65]:

def make_copy_task(batch, T, vocab=50):
    x = torch.randint(5, vocab, (batch, T))
    return x

def make_reverse_task(batch, T, vocab=50):
    x = torch.randint(5, vocab, (batch, T))
    y = torch.flip(x, dims=[1])
    return x, y

def make_needle_task(batch, T, needle_len=5, vocab=100):
    x = torch.randint(5, vocab, (batch, T))
    for b in range(batch):
        start = random.randint(0, T-needle_len-5)
        needle = torch.randint(5, vocab, (needle_len,))
        x[b, start:start+needle_len] = needle
        x[b, -1] = needle[0]
    return x

INSTR_PAIRS = [
    ("Reverse the string: abcd", "dcba"),
    ("Add two numbers: 7 + 12", "19"),
    ("Instruction: say hello", "hello"),
    ("Uppercase this: cat", "CAT"),
]

def tokenize_instruction_pairs(tok, pairs, max_len):
    texts = [f"### Instruction:\n{p}\n\n### Response:\n" for p,_ in pairs]
    labels = [ans for _,ans in pairs]
    input_ids = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=max_len).input_ids
    label_ids = tok(labels, return_tensors="pt", padding=True, truncation=True, max_length=max_len).input_ids
    # naive pack: just use inputs; real packing/labels can be elaborated
    return input_ids, label_ids


## 9. Training loop (lightweight)

In [66]:
# --- Synthetic generators + HF dataset integration + MixtureSampler (with tagging) ---

def make_repeat_copy(batch: int, T: int, repeat_min=2, repeat_max=4, vocab=100, pad_id: int = 0, device: str = "cpu") -> torch.Tensor:
    L = max(1, T // 2)
    x = torch.randint(1, vocab, (batch, L), device=device, dtype=torch.long)
    r = torch.randint(repeat_min, repeat_max + 1, (batch,), device=device)
    out = torch.full((batch, T), pad_id, dtype=torch.long, device=device)
    for i in range(batch):
        seq = x[i].repeat_interleave(int(r[i].item()))
        out[i, :min(T, seq.numel())] = seq[:T]
    return out

def make_n_back(batch: int, T: int, n: int = 3, vocab=50) -> torch.Tensor:
    return torch.randint(1, vocab, (batch, T))

def format_instruction(tok, instr: str, resp: str, max_len=256) -> torch.Tensor:
    prompt = f"### Instruction:\n{instr}\n\n### Response:\n{resp}"
    return tok(prompt, return_tensors="pt", truncation=True, max_length=max_len).input_ids[0]

def hf_instruction_loader(dataset_name="tatsu-lab/alpaca", split="train", text_field=("instruction","output"), max_items=5000):
    try:
        from datasets import load_dataset
    except Exception:
        print("Install 'datasets' to enable HF loading: pip install datasets -q")
        return []
    ds = load_dataset(dataset_name, split=split)
    pairs = []
    i_field, o_field = text_field
    for ex in ds:
        instr = ex.get(i_field, ""); out = ex.get(o_field, "")
        if instr and out: pairs.append((instr, out))
        if len(pairs) >= max_items: break
    random.shuffle(pairs); return pairs

def make_hf_batch(tok, pairs: List[Tuple[str,str]], batch: int, max_len=256) -> torch.Tensor:
    if not pairs:
        return torch.full((batch, max_len), tok.pad_token_id, dtype=torch.long)
    batch_ids = []
    for _ in range(batch):
        instr, out = random.choice(pairs)
        ids = format_instruction(tok, instr, out, max_len=max_len)
        batch_ids.append(ids)
    maxL = min(max(x.size(0) for x in batch_ids), max_len)
    out_ids = torch.full((batch, maxL), tok.pad_token_id, dtype=torch.long)
    for i, ids in enumerate(batch_ids):
        ids = ids[:maxL]; out_ids[i, :ids.size(0)] = ids
    return out_ids

class MixtureSampler:
    def __init__(self, gens: List, weights: List[float], names: Optional[List[str]] = None):
        self.gens = gens
        import torch as _t
        self.weights = list(map(float, weights))
        self.p = _t.tensor(self.weights, dtype=_t.float32)  # CPU is fine for multinomial
        self.p /= (self.p.sum() + 1e-8)
        self.names = names if names is not None else [f"g{i}" for i in range(len(gens))]
        self.last_name = None

    def __call__(self, batch: int) -> torch.Tensor:
        import torch as _t
        idx = _t.multinomial(self.p, 1).item()
        self.last_name = self.names[idx]
        return self.gens[idx](batch)
    
    def set_weights(self, weights):
        """Update mixture probabilities at runtime (used by schedules)."""
        import torch as _t
        ws = list(map(float, weights))
        t = _t.tensor(ws, dtype=_t.float32, device=self.p.device)
        s = float(t.sum().item())
        if s <= 0:
            raise ValueError("Mixture weights must sum to > 0")
        self.p = t / s
        self.weights = ws
        # Optional sanity: warn if names length mismatches weights
        if hasattr(self, "names") and len(self.names) != len(ws):
            print(f"[MixtureSampler] Warning: len(names)={len(self.names)} != len(weights)={len(ws)}")


In [67]:
# mixture sampler smoke test
ms = MixtureSampler(gens=[lambda b: None, lambda b: None, lambda b: None, lambda b: None],
                    weights=[0.4,0.2,0.2,0.2],
                    names=["hf","copy","repeat","nback"])
print("p0:", ms.p.tolist())    # ~[0.4,0.2,0.2,0.2]
ms.set_weights([0.3,0.3,0.25,0.15])
print("p1:", ms.p.tolist())    # ~[0.3,0.3,0.25,0.15]


p0: [0.4000000059604645, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224]
p1: [0.30000001192092896, 0.30000001192092896, 0.25, 0.15000000596046448]


In [68]:
# --- LR Scheduler: linear warmup -> cosine decay (nonzero start) ---

def make_warmup_cosine_scheduler(optimizer, warmup_steps: int, total_steps: int, min_lr_ratio: float = 0.10):
    """
    Warms up linearly from 0->1 over warmup_steps; cosine decays from 1->min_lr_ratio for the remainder.
    Uses step_idx+1 to avoid zero LR at the start.
    """
    warmup_steps = max(1, int(warmup_steps))
    total_steps = max(warmup_steps + 1, int(total_steps))

    def lr_lambda(step_idx: int):
        s = step_idx + 1
        if s <= warmup_steps:
            return s / float(warmup_steps)
        progress = (s - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return LambdaLR(optimizer, lr_lambda)


In [69]:
# --- TensorBoard Logger (lightweight) ---
import os, time, json
try:
    from torch.utils.tensorboard import SummaryWriter
    TB_AVAILABLE = True
except Exception as e:
    print("TensorBoard not available:", e)
    TB_AVAILABLE = False

class TBLogger:
    def __init__(self, logdir: Optional[str] = None, run_name: Optional[str] = None):
        self.enabled = TB_AVAILABLE
        self.writer = None
        if not self.enabled:
            return
        logdir = logdir or "./runs"
        run_name = run_name or time.strftime("dncformer-%Y%m%d-%H%M%S")
        self.path = os.path.join(logdir, run_name)
        os.makedirs(self.path, exist_ok=True)
        self.writer = SummaryWriter(self.path)

    def log_scalars(self, step: int, loss: float, lr: float, gate_means: List[float]):
        if not (self.enabled and self.writer): return
        self.writer.add_scalar("train/loss", loss, step)
        self.writer.add_scalar("train/lr", lr, step)
        for i, gm in enumerate(gate_means):
            self.writer.add_scalar(f"gates/block_{i}_mean", gm, step)

    def add_image_hw(self, tag: str, img_hw: "torch.Tensor", step: int):
        if not (self.enabled and self.writer): return
        import torch
        x = img_hw
        if x.device.type != "cpu": x = x.cpu()
        x = x.float()
        if x.numel() > 0:
            m, M = x.min(), x.max()
            x = (x - m) / (M - m + 1e-8)
        x = x.unsqueeze(0)  # [1,H,W]
        self.writer.add_image(tag, x, step, dataformats='CHW')

    def add_histogram(self, tag: str, values: "torch.Tensor", step: int, bins: int = 50):
        if not (self.enabled and self.writer): return
        import torch
        v = values
        if isinstance(v, torch.Tensor):
            v = v.detach()
            if v.device.type != "cpu": v = v.cpu()
            v = v.reshape(-1).float()
        self.writer.add_histogram(tag, v, global_step=step, bins=bins)

    def add_text(self, tag: str, text: str, step: int):
        if not (self.enabled and self.writer): return
        self.writer.add_text(tag, text, step)

    def flush(self):
        if self.enabled and self.writer: self.writer.flush()

    def close(self):
        if self.enabled and self.writer: self.writer.close()


In [70]:
tb = TBLogger(logdir="./runs")
try:
    tb
except NameError:
    tb = TBLogger(logdir="./runs")
print("TB active:", TB_AVAILABLE, "| logdir:", getattr(tb, "path", None))

TB active: True | logdir: ./runs\dncformer-20250819-122754


In [71]:
from torch.amp import GradScaler, autocast

def build_model_and_tokenizer():
    tok, base = load_base_model(CFG.base_model_id)
    if CFG.d_model is None:
        CFG.d_model = base.config.hidden_size
    head = DNCFormerHead(base, CFG).to(device)
    if CFG.use_torch_compile and hasattr(torch, 'compile'):
        head = torch.compile(head)
    #print("Trainable params in head:", count_params(head))
    return tok, head

def make_optimizer(model):
    params = [p for p in model.parameters() if p.requires_grad]
    return torch.optim.AdamW(params, lr=CFG.lr, weight_decay=CFG.weight_decay)

def lm_shift_labels(input_ids, logits, tok):
    labels = input_ids[:, 1:].contiguous()
    logits = logits[:, :-1].contiguous()
    return F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=tok.pad_token_id)

@torch.no_grad()
def evaluate_simple(head, tok):
    head.eval()
    prompt = "### Instruction:\nSay hello in one word\n\n### Response:\n"
    enc = tok(prompt, return_tensors="pt").to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(enc.input_ids)
    print("Gates (mean):", [g.mean().item() for g in gates])
    return gates

def train_small():
    tok, head = get_or_build_model_and_tokenizer()
    optim = make_optimizer(head)
    scheduler = make_warmup_cosine_scheduler(optim, CFG.warmup_steps, CFG.train_steps, min_lr_ratio=0.10)
    head.train()

    tb = TBLogger(logdir="./runs")
    use_scaler = (amp_dtype == torch.float16)
    scaler = GradScaler('cuda', enabled=use_scaler)

    for step in range(1, CFG.train_steps+1):
        in_ids, lab_ids = tokenize_instruction_pairs(tok, INSTR_PAIRS, max_len=min(CFG.max_seq_len, 256))
        in_ids = in_ids.to(device)
        
        with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
            logits, gates = head(in_ids)
            loss = lm_shift_labels(in_ids, logits, tok)

        if use_scaler:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), CFG.grad_clip)
            scaler.step(optim); scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), CFG.grad_clip)
            optim.step()
            
        optim.zero_grad(set_to_none=True)

        if step % CFG.log_every == 0:
            gm = [g.mean().item() for g in gates]
            lr = optim.param_groups[0]['lr']
            print(f"step {step} | loss {loss.item():.4f} | lr {lr:.2e} | gates={gm}")
            tb.log_scalars(step, float(loss.item()), float(lr), gm)
            
    evaluate_simple(head, tok)
    return head, tok

## 10. Eval harness (needle-in-haystack, copy, reverse)

In [72]:
# --- Haystack (needle) eval: long-span retrieval of a key's value ---
def make_haystack_batch(batch: int, T: int = 256, vocab: int = 1024, sentinel: int = 3):
    """
    Build sequences: ... K V ... K ?  (next token should be V)
    Returns: input_ids [B,T], answer_ids [B], query_pos [B] (position of '?')
    """
    assert T >= 12, "T too small for haystack layout"
    x = torch.randint(5, vocab, (batch, T), dtype=torch.long)
    K = torch.randint(5, vocab, (batch,), dtype=torch.long)
    V = torch.randint(5, vocab, (batch,), dtype=torch.long)

    # Place (K,V) in the first half
    p1 = torch.randint(low=T//8, high=T//2 - 2, size=(batch,))
    x[torch.arange(batch), p1] = K
    x[torch.arange(batch), p1 + 1] = V

    # Place (K, sentinel) in the last quarter
    p2 = torch.randint(low=3*T//4, high=T - 2, size=(batch,))
    x[torch.arange(batch), p2] = K
    x[torch.arange(batch), p2 + 1] = sentinel  # '?'
    query_pos = p2 + 1  # position of '?'; we will look at logits at this position

    return x, V, query_pos

@torch.no_grad()
def evaluate_haystack(head, steps: int = 50, batch: int = 16, T: int = 256, vocab: int = 1024,
                      tb_step: int = None, fast: bool = False):
    """
    Long-span retrieval probe: ... K V ... K ? => predict V at '?'.
    fast=True uses inference_mode() and smaller defaults to speed up sweeps.
    """
    from torch.amp import autocast
    head.eval()

    # Fast path defaults (can still be overridden by explicit args)
    if fast:
        steps = min(steps, 10)
        batch = min(batch, 8)
        T = min(T, 128)

    device_ = next(head.parameters()).device
    use_amp = (amp_dtype in (torch.float16, torch.bfloat16)) and torch.cuda.is_available()

    accs, losses = [], []
    ctx = torch.inference_mode() if fast else torch.no_grad()
    with ctx:
        for _ in range(steps):
            x, V, qpos = make_haystack_batch(batch, T=T, vocab=vocab)
            x = x.to(device_); V = V.to(device_); qpos = qpos.to(device_)

            if use_amp:
                with autocast(device_type="cuda", dtype=amp_dtype):
                    logits, _g = head(x)
            else:
                logits, _g = head(x)

            idx = torch.arange(x.size(0), device=device_)
            logits_q = logits[idx, qpos, :].float()
            loss = F.cross_entropy(logits_q, V)
            pred = logits_q.argmax(dim=-1)

            accs.append((pred == V).float().mean().item())
            losses.append(loss.item())

    acc_m = float(np.mean(accs)); loss_m = float(np.mean(losses))

    if TB_AVAILABLE and ('tb' in globals()):
        tb.writer.add_scalar("eval/haystack_acc", acc_m, tb_step if tb_step is not None else 0)
        tb.writer.add_scalar("eval/haystack_loss", loss_m, tb_step if tb_step is not None else 0)
        tb.flush()

    head.train()
    print(f"[Haystack] acc={acc_m:.3f} | loss={loss_m:.3f} | fast={fast}")
    return acc_m, loss_m

In [73]:
# --- Memory tracer & TensorBoard visualizer ---
from contextlib import contextmanager

class MemoryTracer:
    def __init__(self): self.frames = []
    def __call__(self, frame): self.frames.append(frame)
    def stack(self, key, bidx: int = 0):
        import torch
        return torch.stack([f[key][bidx] for f in self.frames], dim=0)  # [T, ...]

@contextmanager
def trace_memory(module):
    tracer = MemoryTracer()
    memories = [m for m in module.modules() if m.__class__.__name__ == "DNCMemory"]
    for m in memories: m.probe = tracer
    try:
        yield tracer
    finally:
        for m in memories: m.probe = None

@torch.no_grad()
def visualize_memory_tb(head, tok, writer, global_step: int, prompt="### Instruction:\nRemember A then B; later return A.\n\n### Response:\n", max_T=64):
    if writer is None: return
    enc = tok(prompt, return_tensors="pt").to(device)
    enc.input_ids = enc.input_ids[:, :max_T]
    from torch.amp import autocast
    with trace_memory(head) as tracer:
        with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
            _ = head(enc.input_ids)

    u = tracer.stack("u")           # [T, N]
    Mnorm = tracer.stack("M_norm")  # [T, N]
    rw = tracer.stack("rw")         # [T, R, N]

    # Log images
    writer.add_image("memory/u_TxN", (u.T).unsqueeze(0), global_step, dataformats='CHW')
    writer.add_image("memory/Mnorm_TxN", (Mnorm.T).unsqueeze(0), global_step, dataformats='CHW')

    top_read = rw.argmax(dim=-1).float()  # [T,R]
    top_read_img = (top_read / max(1, rw.size(-1)-1)).T  # [R,T]
    writer.add_image("memory/top_read_RxT", top_read_img.unsqueeze(0), global_step, dataformats='CHW')


In [74]:
#cuda_report("before")
#list_head_refs()

## Medium-size training loop test

In [75]:
def _build_mixer(tok, weights, hf_dataset="tatsu-lab/alpaca", hf_max_items=2000) -> MixtureSampler:
    """
    Build a mixture of generators: [HF, copy, repeat, nback].
    - If hf_dataset is Alpaca-like (has instruction/output), we use hf_instruction_loader + make_hf_batch.
    - Else we treat it as text-only (e.g., roneneldan/TinyStories), tokenizing and making windowed sequences.
    - If HF fails/empty, we drop it and renormalize over synthetics.
    """
    mx = int(getattr(CFG, "max_seq_len", 256))
    pad_id = getattr(tok, "pad_token_id", 0) or 0

    gens, wts, names = [], [], []

    hf_ok = False
    hf_reason = ""
    gen_hf = None

    if hf_dataset:
        try:
            # Heuristic: use instruction loader for alpaca-like IDs
            if "alpaca" in hf_dataset.lower():
                pairs = hf_instruction_loader(hf_dataset, "train", ("instruction", "output"),
                                              max_items=hf_max_items)
                if pairs:
                    def gen_hf(b): 
                        return make_hf_batch(tok, pairs, b, max_len=mx)
                    hf_ok = True
                else:
                    hf_reason = "hf_instruction_loader returned 0 pairs"
            else:
                # Text-only fallback (e.g., roneneldan/TinyStories)
                from datasets import load_dataset
                # Try streaming first, then non-streaming
                try:
                    ds = load_dataset(hf_dataset, split="train", streaming=True)
                except Exception:
                    ds = load_dataset(hf_dataset, split="train")

                # Find a usable text key
                text_key = None
                common_keys = ("text", "content", "story", "document", "body", "article")

                feats = getattr(ds, "features", None)
                if feats:
                    for k in common_keys:
                        if k in feats:
                            text_key = k
                            break

                if text_key is None:
                    # Probe first example (works for streaming iterable)
                    try:
                        first_ex = next(iter(ds))
                        for k in common_keys:
                            if k in first_ex:
                                text_key = k
                                break
                    except StopIteration:
                        pass

                if text_key is None:
                    hf_reason = "no usable text field (tried: %s)" % ",".join(common_keys)
                else:
                    import random as _rnd
                    import torch
                    samples = []
                    for ex in ds:
                        txt = ex.get(text_key, None)
                        if not txt:
                            continue
                        ids = tok(txt, return_tensors="pt", add_special_tokens=False).input_ids.squeeze(0)
                        if ids.numel() < 8:
                            continue
                        samples.append(ids.cpu())
                        if len(samples) >= int(hf_max_items):
                            break

                    if len(samples) == 0:
                        hf_reason = "collected 0 tokenized samples"
                    else:
                        def gen_hf(b: int) -> torch.Tensor:
                            out = []
                            for _ in range(b):
                                ids = _rnd.choice(samples)
                                n = ids.numel()
                                if n >= mx:
                                    s = _rnd.randint(0, n - mx)
                                    seq = ids[s:s+mx]
                                else:
                                    pad = torch.full((mx - n,), pad_id, dtype=torch.long)
                                    seq = torch.cat([ids, pad], dim=0)
                                out.append(seq.unsqueeze(0))
                            return torch.cat(out, dim=0)
                        hf_ok = True
        except Exception as e:
            hf_reason = f"{type(e).__name__}: {e}"

    # Assemble mixture in the agreed order (hf, copy, repeat, nback)
    if hf_ok and gen_hf is not None:
        gens.append(gen_hf); wts.append(weights[0]); names.append("hf")
        s_w = list(weights[1:])
    else:
        if hf_dataset:
            print(f"HF dataset unavailable or empty; using synthetic only. reason={hf_reason}")
        s_w = list(weights[1:])  # keep caller's relative weights for synthetics

    # Synthetic tasks (your existing closures)
    def gen_copy(b):   return make_copy_task(b, T=min(mx, 128), vocab=100)
    def gen_repeat(b): return make_repeat_copy(b, T=min(mx, 128), vocab=100, pad_id=pad_id, device="cpu")
    def gen_nback(b):  return make_n_back(b, T=min(mx, 128), n=5, vocab=50)

    gens.extend([gen_copy, gen_repeat, gen_nback])
    wts.extend(s_w)
    names.extend(["copy", "repeat", "nback"])

    return MixtureSampler(gens, wts, names=names)


In [76]:
# GPU VRAM diagnostics/test
import torch, gc, contextlib

def cuda_report(tag=""):
    if not torch.cuda.is_available():
        print("[cuda_report] CUDA not available"); return
    free, total = torch.cuda.mem_get_info()
    alloc = torch.cuda.memory_allocated()
    reserv = torch.cuda.memory_reserved()
    print(f"[{tag}] alloc={alloc/1e9:.2f} GB | reserved={reserv/1e9:.2f} GB | free={free/1e9:.2f} GB | total={total/1e9:.2f} GB")

def list_head_refs():
    # looks for globals named 'head' and count of DNCFormerHead instances
    import gc, inspect, sys
    heads = [o for o in gc.get_objects() if o.__class__.__name__ == "DNCFormerHead"]
    print(f"[liveness] DNCFormerHead instances alive: {len(heads)}")
    if 'head' in globals():
        h = globals()['head']
        try:
            devs = sorted({p.device.type for p in h.parameters()})
        except Exception:
            devs = ["<unknown>"]
        print(f"[liveness] global 'head' present; param devices: {devs}")
    else:
        print("[liveness] no global 'head'")

def free_head_and_cache():
    # delete typical globals and clear allocator
    with contextlib.suppress(Exception):
        del globals()['head']
    with contextlib.suppress(Exception):
        del globals()['tok']
    gc.collect(); torch.cuda.empty_cache()
    cuda_report("after free")


In [91]:
free_head_and_cache()

[after free] alloc=29.34 GB | reserved=29.44 GB | free=0.00 GB | total=25.77 GB


In [78]:
from torch.amp import autocast

@torch.no_grad()
def eval_copy(head, tok, batch=4, T=64, vocab=100):
    x = make_copy_task(batch, T, vocab=vocab).to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, :-1] == x[:, 1:]).float().mean().item()
    print("copy acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}

@torch.no_grad()
def eval_reverse(head, tok, batch=4, T=64, vocab=100):
    x, y = make_reverse_task(batch, T, vocab=vocab)
    x, y = x.to(device), y.to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, :-1] == y[:, 1:]).float().mean().item()
    print("reverse acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}

@torch.no_grad()
def eval_needle(head, tok, batch=4, T=128, vocab=200):
    x = make_needle_task(batch, T, vocab=vocab).to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, -1] == x[:, -1]).float().mean().item()
    print("needle acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}


In [79]:
from torch.amp import GradScaler, autocast

def train_medium(
    steps: int = None,
    batch_size: int = None,
    warmup_steps: int = None,
    min_lr_ratio: float = 0.1,
    mixture_weights=(0.4, 0.2, 0.2, 0.2),
    hf_dataset: str = "tatsu-lab/alpaca",
    hf_max_items: int = 5000,
    log_every: int = None,
    viz_memory_after: bool = False,
    viz_prompt: str = "### Instruction:\nSay hello in one word\n\n### Response:\n",
    viz_max_T: int = 64,
    # NEW: optional schedules (each a list of (until_step, value))
    mixture_schedule=None,         # e.g., [(100, (0.3,0.3,0.2,0.2)), (None, (0.4,0.2,0.2,0.2))]
    gate_temp_schedule=None,       # e.g., [(100, 0.8), (None, 1.0)]
    gate_reg_schedule=None,        # e.g., [(100, 3e-4), (None, 2e-4)]
):
    cfg = CFG
    steps = int(steps or cfg.train_steps)
    batch_size = int(batch_size or getattr(cfg, "batch_size", 8))
    warmup_steps = int(warmup_steps if warmup_steps is not None else getattr(cfg, "warmup_steps", max(10, steps // 20)))
    log_every = int(log_every if log_every is not None else getattr(cfg, "log_every", 10))

    # Model & tokenizer
    tok, head = build_model_and_tokenizer()
    optim = make_optimizer(head)
    scheduler = make_warmup_cosine_scheduler(optim, warmup_steps, steps, min_lr_ratio=min_lr_ratio)

    # Sampler (and allow schedules to change weights)
    mixer = _build_mixer(tok, mixture_weights, hf_dataset=hf_dataset, hf_max_items=hf_max_items)

    def _apply_schedules(step: int, mix_name_hint=None):
        # mixture schedule
        if mixture_schedule:
            for until, ws in mixture_schedule:
                if until is None or step <= int(until):
                    try:
                        mixer.set_weights(ws)
                    except Exception:
                        pass
                    break
        # gate temp schedule
        if gate_temp_schedule:
            for until, temp in gate_temp_schedule:
                if until is None or step <= int(until):
                    setattr(CFG, "gate_temp", float(temp))
                    break
        # gate reg schedule
        if gate_reg_schedule:
            for until, lam in gate_reg_schedule:
                if until is None or step <= int(until):
                    setattr(CFG, "gate_reg_lambda", float(lam))
                    break

    # TB setup
    global tb
    if TB_AVAILABLE:
        try:
            tb  # NameError if not defined
        except NameError:
            tb = TBLogger(logdir="./runs")
        if not isinstance(tb, TBLogger) or getattr(tb, "writer", None) is None:
            tb = TBLogger(logdir="./runs")
        tblog = True
        tb.add_text("run/config", json.dumps({
            "steps": steps,
            "batch_size": batch_size,
            "warmup_steps": warmup_steps,
            "mixture_weights": list(mixture_weights),
            "mixture_schedule": mixture_schedule,
            "gate_temp_schedule": gate_temp_schedule,
            "gate_reg_schedule": gate_reg_schedule,
        }, indent=2), 0)
    else:
        tblog = False

    head.train()
    scaler = torch.amp.GradScaler('cuda', enabled=(amp_dtype in (torch.float16, torch.bfloat16)))

    for step in range(1, steps + 1):
        _apply_schedules(step)

        in_ids = mixer(batch_size).to(device)
        with torch.autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype != torch.float32)):
            logits, gates, aux = head.forward_with_metrics(in_ids, gate_override=getattr(CFG, "force_g", None))
            loss = lm_shift_labels(in_ids, logits, tok)

            # Optional: mild gate-usage regularizer (same as before)
            lam = float(getattr(CFG, "gate_reg_lambda", 0.0))
            if lam > 0 and isinstance(gates, (list, tuple)) and len(gates) > 0:
                reg = 0.0
                for g in gates:
                    g2 = _reduce_gate_tensor(g)
                    # Encourage decisive routing on memory batches only? For now, uniform
                    reg = reg + (g2.mean() * 0.0 + (g2 * (1 - g2)).mean())  # small entropy-like penalty
                loss = loss + lam * reg

        use_scaler = (amp_dtype in (torch.float16, torch.bfloat16))
        if use_scaler:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), CFG.grad_clip)
            scaler.step(optim); scaler.update(); optim.zero_grad(set_to_none=True)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), CFG.grad_clip)
            optim.step(); optim.zero_grad(set_to_none=True)

        scheduler.step()

        # --- Logging (TB) ---
        if step % log_every == 0 and tblog:
            # global scalars
            tb.writer.add_scalar("train/loss", float(loss.item()), step)
            tb.writer.add_scalar("train/lr", float(scheduler.get_last_lr()[0]), step)

            # gate metrics per block (global)
            if isinstance(gates, (list, tuple)):
                for bi, g in enumerate(gates):
                    g_mean, g_frac, g_ent = _gate_metrics(g)
                    tb.writer.add_scalar(f"gates/block_{bi}_mean", g_mean, step)
                    tb.writer.add_scalar(f"gates/block_{bi}_frac>0.5", g_frac, step)
                    tb.writer.add_scalar(f"gates/block_{bi}_entropy", g_ent, step)

                # per-task metrics (using the sampler's last batch type)
                mix_name = getattr(mixer, "last_name", None) or "unknown"
                tb.writer.add_scalar(f"loss_by_task/{mix_name}", float(loss.item()), step)
                for bi, g in enumerate(gates):
                    g_mean, g_frac, g_ent = _gate_metrics(g)
                    tb.writer.add_scalar(f"gates_by_task/block_{bi}_mean/{mix_name}", g_mean, step)
                    tb.writer.add_scalar(f"gates_by_task/block_{bi}_frac>0.5/{mix_name}", g_frac, step)

                # optional: quartiles (block 0)
                if len(gates) > 0:
                    g0 = _reduce_gate_tensor(gates[0].detach())
                    T = g0.size(1); q = max(1, T // 4)
                    slices = [(0, q), (q, 2*q), (2*q, 3*q), (3*q, T)]
                    for qi, (s, e) in enumerate(slices, start=1):
                        tb.writer.add_scalar(f"gates/block0_q{qi}_mean/{mix_name}", float(g0[:, s:e].mean().item()), step)

            tb.flush()

        # --- Console echo (always) ---
        if step % log_every == 0:
            g_means_print = []
            if isinstance(gates, (list, tuple)):
                for g in gates:
                    g_means_print.append(float(_reduce_gate_tensor(g).mean().item()))
            print(f"step {step} | loss {loss.item():.4f} | lr {scheduler.get_last_lr()[0]:.2e} | "
                  f"gates={g_means_print} | mix={getattr(mixer,'last_name','?')}")

    # Optional memory viz
    if viz_memory_after:
        try:
            visualize_memory_tb(head, tok, tb.writer, global_step=steps, prompt=viz_prompt, max_T=viz_max_T)
        except Exception as e:
            print("Memory TB viz skipped:", e)
    
    # flush log      
    if tblog:
        tb.flush()

    return head, tok


In [80]:
# # --- Patch: robust forward for ParallelEnrichmentBlock ---
def _peb_forward(self, x, dnc_state=None, gate_override=None):
    """
    Returns: (out, dnc_state, g)
      out: (B,T,D), mixed from vanilla (vt) and DNC (dt)
      dnc_state: updated state dict from DNC path
      g: (B,T,D) or (B,T,1) gate values in [0,1]
    """
    # 1) Vanilla transformer path
    T = x.size(1)
    mask = causal_mask(T, device=x.device)  # (T,T) causal attn mask
    x_cast = x.to(self.vanilla.ln1.weight.dtype)
    vt = self.vanilla(x_cast, attn_mask=mask)    # (B,T,D)

    # 2) DNC path (controller+memory)
    dt, dnc_state = self.dncblock(x, state=dnc_state)  # dt: (B,T,D)

    # 3) Gate computation (optionally LN before gate)
    import torch as _t
    z = _t.cat([vt, dt], dim=-1)  # (B,T,2D)
    h = self.pre_gate_ln(z) if hasattr(self, "pre_gate_ln") and self.pre_gate_ln is not None else z
    g_pre = self.gate(h)                 # typically (B,T,D) or (B,T,1)
    tau = float(getattr(CFG, 'gate_temp', 1.0))
    g = torch.sigmoid(g_pre / max(tau, 1e-6))

    # Optional ablation override: force g to 0.0 (vanilla) or 1.0 (DNC)
    if gate_override is not None:
        g = _t.full_like(g, float(gate_override))

    # 4) Blend paths
    out = g * dt + (1.0 - g) * vt        # (B,T,D)

    # 5) Lightweight metrics (only if requested)
    self.last_metrics = None
    if getattr(self, "collect_metrics", False):
        try:
            eps = 1e-6
            gc = g.clamp(min=eps, max=1.0 - eps)
            g_entropy = (-(gc * gc.log() + (1 - gc) * (1 - gc).log())).mean()
            vt_norm = vt.norm(dim=-1).mean()
            dt_norm = dt.norm(dim=-1).mean()
            mstats = {}
            # Pull aggregated memory stats if present
            if isinstance(dnc_state, dict) and isinstance(dnc_state.get("stats", None), dict):
                for k in ("u_mean", "rw_max_mean", "ww_max_mean", "M_norm_mean"):
                    if k in dnc_state["stats"]:
                        mstats[k] = dnc_state["stats"][k]
            self.last_metrics = {
                "g_mean": g.mean().detach(),
                "g_entropy": g_entropy.detach(),
                "vt_norm": vt_norm.detach(),
                "dt_norm": dt_norm.detach(),
                **mstats,
            }
        except Exception:
            self.last_metrics = None

    return out, dnc_state, g

# Apply the patch
ParallelEnrichmentBlock.forward = _peb_forward


## 11. Unit-like tests (sanity)

In [81]:

def run_unit_tests():
    B, T = 2, 4
    R, W, N = CFG.dnc_read_heads, CFG.dnc_cell_size, CFG.dnc_nr_cells
    d_in = d_model = 128

    mem = DNCMemory(N, W, R).to(device)
    state = mem.reset(B, device=device)
    iface = {
        "k_read": torch.zeros(B,R,W, device=device),
        "beta_read": torch.ones(B,R,1, device=device),
        "k_write": torch.zeros(B,W, device=device),
        "beta_write": torch.ones(B,1, device=device),
        "erase": torch.sigmoid(torch.randn(B,W, device=device)),
        "write_vec": torch.randn(B,W, device=device),
        "free_gates": torch.sigmoid(torch.randn(B,R,1, device=device)),
        "alloc_gate": torch.sigmoid(torch.randn(B,1, device=device)),
        "write_gate": torch.sigmoid(torch.randn(B,1, device=device)),
        "read_mode": torch.randn(B,R,3, device=device),
    }
    r, state2 = mem(iface, state)
    assert r.shape == (B,R,W)

    ctrl = TransformerController(d_in+R*W, d_model, heads=4).to(device)
    x = torch.randn(B,T,d_in, device=device)
    reads = state2["r"].reshape(B,R*W).unsqueeze(1).expand(B,T,R*W)
    h = ctrl(torch.cat([x, reads], dim=-1))
    assert h.shape == (B,T,d_model)

    dblk = DNCformerBlock(d_in, d_model, R, W, N, heads=4, dropout=0.1, ffn_mult=4.0).to(device)
    y, s = dblk(x)
    assert y.shape == (B,T,d_model)

    pen = ParallelEnrichmentBlock(d_model, d_in, R, W, N, heads=4, dropout=0.1, ffn_mult=4.0, gate_bias_init=-1.0).to(device)
    out, s2, g = pen(torch.randn(B,T,d_model, device=device))
    print(f"out: {out}\ns2: {s2}\ng: {g}")
    eps = 1e-6
    assert torch.isfinite(g).all()
    assert ((g > eps) & (g < 1 - eps)).float().mean().item() > 0.95
    assert out.shape == (B,T,d_model)
    print("All unit-like tests passed.")


In [82]:
# --- Evaluator unit tests (no heavy model) ---
import torch
from typing import Optional

class _DummyHead(torch.nn.Module):
    def __init__(self, vocab=100, d_model=64, n_blocks=2):
        super().__init__()
        self.vocab = vocab
        self.d_model = d_model
        self.n_blocks = n_blocks
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]=None):
        B, T = input_ids.shape
        logits = torch.randn(B, T, self.vocab, device=input_ids.device, dtype=torch.float32)
        gates = [torch.sigmoid(torch.randn(B, T, self.d_model, device=input_ids.device)) for _ in range(self.n_blocks)]
        return logits, gates

def run_eval_unit_tests():
    dummy = _DummyHead().to(device).eval()
    res_copy = eval_copy(dummy, tok=None, batch=2, T=8, vocab=50)
    res_rev = eval_reverse(dummy, tok=None, batch=2, T=8, vocab=50)
    res_needle = eval_needle(dummy, tok=None, batch=2, T=16, vocab=100)
    assert all(k in res_copy for k in ["acc","gates"])
    assert all(k in res_rev for k in ["acc","gates"])
    assert all(k in res_needle for k in ["acc","gates"])
    print("Evaluator unit tests passed.")


## Unit tests/basic sanity checks
currently all passing, but self-reminder here to comment out if architecture changes

In [83]:
# memory tracer smoke test
def run_memory_tracer_smoke(head, tok):
    if not TB_AVAILABLE:
        print("TB not available; skipping image smoke.")
        return
    from torch.amp import autocast
    with trace_memory(head) as tracer:
        x = torch.randint(5, (1, 16), device=device)
        with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
            _ = head(x)
    assert len(tracer.frames) > 0, "No memory frames captured"
    print("Tracer captured", len(tracer.frames), "steps")


In [84]:
# generator smoke test
mx = int(getattr(CFG, "max_seq_len", 256))
tok, _base = tok if 'tok' in globals() else (None, None)
_pad = getattr(tok, "pad_token_id", 0) if tok is not None else 0

x = make_repeat_copy(batch=3, T=min(mx, 128), vocab=50, pad_id=_pad)
print("repeat_copy batch shape:", x.shape, "| dtype:", x.dtype)  # expect (3, <=128), long
assert x.dtype == torch.long


repeat_copy batch shape: torch.Size([3, 128]) | dtype: torch.int64


In [85]:
# smoke test for _build_mixer
# 1) With Alpaca (instruction/output)
tok, _ = build_model_and_tokenizer()
m1 = _build_mixer(tok, (0.4,0.2,0.2,0.2), hf_dataset="tatsu-lab/alpaca", hf_max_items=500)
print("names:", m1.names, "| p:", m1.p.tolist())  # expect 'hf' present

# 2) With TinyStories (text)
m2 = _build_mixer(tok, (0.4,0.2,0.2,0.2), hf_dataset="roneneldan/TinyStories", hf_max_items=500)
print("names:", m2.names, "| p:", m2.p.tolist())  # expect 'hf' present (text path)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

names: ['hf', 'copy', 'repeat', 'nback'] | p: [0.4000000059604645, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224]
names: ['hf', 'copy', 'repeat', 'nback'] | p: [0.4000000059604645, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224]


In [24]:
# cuda allocation smoke test
from transformers import AutoModelForCausalLM, AutoTokenizer
gc.collect(); torch.cuda.empty_cache(); cuda_report("pre base-only")
tok = AutoTokenizer.from_pretrained(CFG.base_model_id, trust_remote_code=True)
base = AutoModelForCausalLM.from_pretrained(
    CFG.base_model_id,
    torch_dtype=(torch.bfloat16 if (amp_dtype!=torch.float32 and torch.cuda.is_bf16_supported()) else (torch.float16 if amp_dtype!=torch.float32 else None)),
    attn_implementation="sdpa",
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    use_safetensors=True,
).to("cuda")
cuda_report("after base-only")
del base, tok; gc.collect(); torch.cuda.empty_cache()


[pre base-only] alloc=0.00 GB | reserved=0.00 GB | free=24.44 GB | total=25.77 GB


`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[after base-only] alloc=7.64 GB | reserved=7.64 GB | free=16.79 GB | total=25.77 GB


In [86]:
# haystack smoke test
tok, head = build_model_and_tokenizer()
acc, loss = evaluate_haystack(head, steps=2, batch=4, T=64, vocab=512, tb_step=0)
print("Haystack smoke:", acc, loss)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[Haystack] acc=0.000 | loss=10.112
Haystack smoke: 0.0 10.111836910247803


In [34]:
#run_unit_tests()

In [35]:
#run_eval_unit_tests()

In [36]:
#run_memory_tracer_smoke()

## Training run Sanity test

In [37]:
# small training test
#head, tok = train_small()

## Medium size training sweeps

In [50]:
# single medium training test/sanity check
free_head_and_cache()
cuda_report("snapshot: before train_medium")
head, tok = train_medium(steps=10, warmup_steps=10, mixture_weights=(0.4,0.2,0.2,0.2))
cuda_report("snapshot: after train_medium")
#Launch TensorBoard in a terminal: tensorboard --logdir ./runs


[after free] alloc=0.00 GB | reserved=0.00 GB | free=24.44 GB | total=25.77 GB
[snapshot: before train_medium] alloc=0.00 GB | reserved=0.00 GB | free=24.44 GB | total=25.77 GB


`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  self.gen = func(*args, **kwds)
You are not running the flash-attention implementation, expect numerical differences.


step 10 | loss 3.7740 | lr 2.00e-05 | gates=[0.275390625, 0.271484375] | mix=hf
[snapshot: after train_medium] alloc=9.79 GB | reserved=23.34 GB | free=0.16 GB | total=25.77 GB


### Medium training sweep, testing parameter / dataset variants

In [39]:
# --- Compact experiment driver: E0..E5 ---
import time, torch, gc

# run set parameters
EXP_STEPS   = 500       # try 1000–5000 for longer curves
EXP_WARMUP  = 10       # keep ~steps/20
BASE_MIX    = (0.4, 0.2, 0.2, 0.2)  # (hf, copy, repeat, nback)

def set_cfg(**kv):
    for k, v in kv.items():
        setattr(CFG, k, v)

def run_isolate_test(label,
            mixture_weights=BASE_MIX,
            gate_reg_lambda=None,
            gate_temp=None,
            force_g=None,
            steps=EXP_STEPS,
            warmup=EXP_WARMUP):
    """Run a single train_medium experiment with explicit knobs."""
    print(f"\n=== {label} ===")
    # set experiment-specific knobs (others come from CFG defaults)
    if gate_reg_lambda is not None: set_cfg(gate_reg_lambda=float(gate_reg_lambda))
    if gate_temp is not None:       set_cfg(gate_temp=float(gate_temp))
    set_cfg(force_g=force_g)  # may be None, 0.0, or 1.0

    print(f"mixture={mixture_weights}, gate_reg_lambda={getattr(CFG,'gate_reg_lambda', 0.0)}, "
          f"gate_temp={getattr(CFG,'gate_temp', 1.0)}, force_g={getattr(CFG,'force_g', None)}")

    # clean slate for VRAM and allocator fragmentation between runs
    free_head_and_cache()
    cuda_report(f"before {label}")
    time.sleep(1.2)  # ensure distinct TB run dirs (timestamp granularity)

    # run
    head, tok = train_medium(
        steps=steps,
        warmup_steps=warmup,
        mixture_weights=mixture_weights,
        viz_memory_after=False,   # keep quick; use visualize_memory_tb() ad-hoc
    )

    # post-run snapshot + cleanup
    cuda_report(f"after  {label}")
    free_head_and_cache()
    return head, tok

# === E0 baseline: identical to your sanity run ===
run_isolate_test("E0_baseline",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E1 memory-leaning mix: more algorithmic exposure ===
run_isolate_test("E1_memory_lean",
        mixture_weights=(0.4, 0.3, 0.2, 0.1),  # HF, copy, repeat, nback
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E2 gate regularizer (low) ===
run_isolate_test("E2_gate_reg_low",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=2e-4,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E3 gate regularizer (high) ===
run_isolate_test("E3_gate_reg_high",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=5e-4,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E4 sharper routing via lower gate temperature ===
run_isolate_test("E4_gate_temp_0p7",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=0.7,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E5 ablations: disable/force DNC path ===
run_isolate_test("E5a_force_g_0",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=0.0,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

run_isolate_test("E5b_force_g_1",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=1.0,
        steps=EXP_STEPS, warmup=EXP_WARMUP)



=== E0_baseline ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0, gate_temp=1.0, force_g=None
[after free] alloc=9.77 GB | reserved=9.79 GB | free=14.65 GB | total=25.77 GB
[before E0_baseline] alloc=9.77 GB | reserved=9.79 GB | free=14.65 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  self.gen = func(*args, **kwds)
You are not running the flash-attention implementation, expect numerical differences.


[after  E0_baseline] alloc=19.56 GB | reserved=42.15 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=19.56 GB | reserved=19.59 GB | free=3.96 GB | total=25.77 GB

=== E1_memory_lean ===
mixture=(0.4, 0.3, 0.2, 0.1), gate_reg_lambda=0.0, gate_temp=1.0, force_g=None
[after free] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB
[before E1_memory_lean] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[after  E1_memory_lean] alloc=19.56 GB | reserved=40.45 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=19.56 GB | reserved=19.60 GB | free=4.77 GB | total=25.77 GB

=== E2_gate_reg_low ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0002, gate_temp=1.0, force_g=None
[after free] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB
[before E2_gate_reg_low] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[after  E2_gate_reg_low] alloc=19.56 GB | reserved=40.85 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=19.56 GB | reserved=19.59 GB | free=3.96 GB | total=25.77 GB

=== E3_gate_reg_high ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0005, gate_temp=1.0, force_g=None
[after free] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB
[before E3_gate_reg_high] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[after  E3_gate_reg_high] alloc=19.56 GB | reserved=42.10 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=19.56 GB | reserved=19.60 GB | free=4.77 GB | total=25.77 GB

=== E4_gate_temp_0p7 ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0, gate_temp=0.7, force_g=None
[after free] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB
[before E4_gate_temp_0p7] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[after  E4_gate_temp_0p7] alloc=19.56 GB | reserved=40.85 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=19.56 GB | reserved=19.60 GB | free=4.01 GB | total=25.77 GB

=== E5a_force_g_0 ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0, gate_temp=1.0, force_g=0.0
[after free] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB
[before E5a_force_g_0] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[after  E5a_force_g_0] alloc=19.56 GB | reserved=40.14 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=19.56 GB | reserved=19.60 GB | free=4.77 GB | total=25.77 GB

=== E5b_force_g_1 ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0, gate_temp=1.0, force_g=1.0
[after free] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB
[before E5b_force_g_1] alloc=9.79 GB | reserved=9.82 GB | free=14.55 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[after  E5b_force_g_1] alloc=19.56 GB | reserved=41.73 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=19.56 GB | reserved=19.60 GB | free=4.02 GB | total=25.77 GB


(DNCFormerHead(
   (base): Phi3ForCausalLM(
     (model): Phi3Model(
       (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
       (embed_dropout): Dropout(p=0.0, inplace=False)
       (layers): ModuleList(
         (0-31): 32 x Phi3DecoderLayer(
           (self_attn): Phi3Attention(
             (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
             (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
             (rotary_emb): Phi3RotaryEmbedding()
           )
           (mlp): Phi3MLP(
             (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
             (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
             (activation_fn): SiLU()
           )
           (input_layernorm): Phi3RMSNorm()
           (resid_attn_dropout): Dropout(p=0.0, inplace=False)
           (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
           (post_attention_layernorm): Phi3RMSNorm()
         )
    

In [51]:
gc.collect()
torch.cuda.empty_cache()

In [88]:
# --- Experiment driver: E6..E9 (E2 architecture) ---

def set_cfg(**kv):
    for k, v in kv.items():
        setattr(CFG, k, v)

def run_one_labeled(label, steps, mixture_weights, seed=1234,
                    mixture_schedule=None, gate_temp_schedule=None, gate_reg_schedule=None,
                    post_haystack=False):
    print(f"\n=== {label} | seed={seed} ===")
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) 
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

    set_cfg(force_g=None)  # ensure no ablation
    # optional: print short config
    print("CFG.gate_temp:", getattr(CFG, "gate_temp", 1.0),
          "| CFG.gate_reg_lambda:", getattr(CFG, "gate_reg_lambda", 0.0),
          "| mixture:", mixture_weights)

    free_head_and_cache()
    if 'cuda_report' in globals(): cuda_report(f"before {label}")

    head, tok = train_medium(
        steps=steps,
        warmup_steps=max(10, steps//20),
        mixture_weights=mixture_weights,
        mixture_schedule=mixture_schedule,
        gate_temp_schedule=gate_temp_schedule,
        gate_reg_schedule=gate_reg_schedule,
        viz_memory_after=False,
    )

    if post_haystack:
        evaluate_haystack(head, steps=50, batch=16, T=256, vocab=1024, tb_step=steps, fast=True)

    if 'cuda_report' in globals(): cuda_report(f"after  {label}")
    free_head_and_cache()
    return head, tok

# Common settings
EXP_STEPS = 100
BASE_MIX  = (0.4, 0.2, 0.2, 0.2)
SEEDS     = [1337, 2027, 4242]

# === E6: gate_temp=0.8 (3 seeds), otherwise baseline ===
set_cfg(gate_reg_lambda=getattr(CFG, "gate_reg_lambda", 2e-4))  # low-λ default
for s in SEEDS:
    run_one_labeled(f"E6_temp0p8_seed{s}", steps=EXP_STEPS, mixture_weights=BASE_MIX,
                    seed=s,
                    gate_temp_schedule=[(None, 0.8)])

# === E7: memory-leaning warm-start (first 10% steps), then baseline; 3 seeds ===
warm_steps = max(50, EXP_STEPS // 10)
mix_warm   = (0.3, 0.3, 0.25, 0.15)  # a bit more memory-heavy than baseline
mix_main   = BASE_MIX
for s in SEEDS:
    run_one_labeled(f"E7_warmstart_seed{s}", steps=EXP_STEPS, mixture_weights=BASE_MIX,
                    seed=s,
                    mixture_schedule=[(warm_steps, mix_warm), (None, mix_main)],
                    gate_temp_schedule=[(warm_steps, 0.8), (None, 1.0)],
                    gate_reg_schedule=[(warm_steps, max(2e-4, getattr(CFG,'gate_reg_lambda', 2e-4))), (None, getattr(CFG,'gate_reg_lambda', 2e-4))])

# === E8: capacity sweep N=64 vs N=128 (1 seed each) ===
for N_val in (64, 128):
    set_cfg(N=N_val)  # assumes your DNC block reads CFG.N at construction
    run_one_labeled(f"E8_capacity_N{N_val}", steps=EXP_STEPS, mixture_weights=BASE_MIX,
                    seed=777,
                    gate_temp_schedule=[(None, getattr(CFG, 'gate_temp', 1.0))])

# Reset N if changed
set_cfg(N=getattr(CFG, 'N', 128))

# === E9: baseline with haystack eval ===
run_one_labeled("E9_baseline_haystack", steps=EXP_STEPS, mixture_weights=BASE_MIX,
                seed=31415, post_haystack=True)



=== E6_temp0p8_seed1337 | seed=1337 ===
CFG.gate_temp: 0.8 | CFG.gate_reg_lambda: 0.0002 | mixture: (0.4, 0.2, 0.2, 0.2)
[after free] alloc=30.70 GB | reserved=30.93 GB | free=0.00 GB | total=25.77 GB
[before E6_temp0p8_seed1337] alloc=30.70 GB | reserved=30.93 GB | free=0.00 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 2.4975 | lr 2.00e-04 | gates=[0.2353515625, 0.240234375] | mix=hf
step 20 | loss 5.7329 | lr 1.93e-04 | gates=[0.25, 0.2578125] | mix=nback
step 30 | loss 6.3398 | lr 1.74e-04 | gates=[0.255859375, 0.26171875] | mix=copy
step 40 | loss 3.6917 | lr 1.47e-04 | gates=[0.2158203125, 0.2275390625] | mix=hf
step 50 | loss 2.1587 | lr 1.14e-04 | gates=[0.208984375, 0.22265625] | mix=hf
step 60 | loss 1.9526 | lr 7.92e-05 | gates=[0.2060546875, 0.220703125] | mix=hf
step 70 | loss 2.0035 | lr 4.70e-05 | gates=[0.201171875, 0.2158203125] | mix=hf
step 80 | loss 5.5659 | lr 2.12e-05 | gates=[0.255859375, 0.26953125] | mix=copy
step 90 | loss 1.3354 | lr 2.00e-05 | gates=[0.203125, 0.216796875] | mix=hf
step 100 | loss 1.5504 | lr 2.00e-05 | gates=[0.19921875, 0.2138671875] | mix=hf
[after  E6_temp0p8_seed1337] alloc=40.47 GB | reserved=58.81 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=40.47 GB | reserved=40.71 GB | free=0.00 GB | total=25.77 GB

=== E6_temp0p8_seed2027 |

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 7.0602 | lr 2.00e-04 | gates=[0.2412109375, 0.2470703125] | mix=nback
step 20 | loss 1.9751 | lr 1.93e-04 | gates=[0.2333984375, 0.236328125] | mix=hf
step 30 | loss 6.5971 | lr 1.74e-04 | gates=[0.2451171875, 0.255859375] | mix=repeat
step 40 | loss 6.4307 | lr 1.47e-04 | gates=[0.248046875, 0.263671875] | mix=repeat
step 50 | loss 2.2391 | lr 1.14e-04 | gates=[0.2109375, 0.220703125] | mix=hf
step 60 | loss 6.0388 | lr 7.92e-05 | gates=[0.24609375, 0.263671875] | mix=copy
step 70 | loss 5.7137 | lr 4.70e-05 | gates=[0.248046875, 0.267578125] | mix=copy
step 80 | loss 1.5096 | lr 2.12e-05 | gates=[0.2041015625, 0.216796875] | mix=hf
step 90 | loss 5.3578 | lr 2.00e-05 | gates=[0.244140625, 0.263671875] | mix=repeat
step 100 | loss 1.5022 | lr 2.00e-05 | gates=[0.2001953125, 0.212890625] | mix=hf
[after  E6_temp0p8_seed2027] alloc=40.47 GB | reserved=57.77 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=40.47 GB | reserved=40.70 GB | free=0.00 GB | total=25.77 GB



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 8.1877 | lr 2.00e-04 | gates=[0.25, 0.251953125] | mix=repeat
step 20 | loss 6.7026 | lr 1.93e-04 | gates=[0.255859375, 0.259765625] | mix=nback
step 30 | loss 3.0415 | lr 1.74e-04 | gates=[0.2353515625, 0.2421875] | mix=hf
step 40 | loss 5.4770 | lr 1.47e-04 | gates=[0.259765625, 0.267578125] | mix=repeat
step 50 | loss 1.3972 | lr 1.14e-04 | gates=[0.2255859375, 0.234375] | mix=hf
step 60 | loss 5.4101 | lr 7.92e-05 | gates=[0.2578125, 0.2734375] | mix=nback
step 70 | loss 5.6242 | lr 4.70e-05 | gates=[0.26171875, 0.2734375] | mix=copy
step 80 | loss 5.5566 | lr 2.12e-05 | gates=[0.263671875, 0.2734375] | mix=copy
step 90 | loss 1.5476 | lr 2.00e-05 | gates=[0.216796875, 0.2265625] | mix=hf
step 100 | loss 5.1131 | lr 2.00e-05 | gates=[0.263671875, 0.2734375] | mix=repeat
[after  E6_temp0p8_seed4242] alloc=40.47 GB | reserved=57.04 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=40.47 GB | reserved=40.70 GB | free=0.00 GB | total=25.77 GB

=== E7_warmstart_seed13

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 6.4303 | lr 2.00e-04 | gates=[0.24609375, 0.251953125] | mix=repeat
step 20 | loss 2.1397 | lr 1.93e-04 | gates=[0.236328125, 0.2451171875] | mix=hf
step 30 | loss 6.6853 | lr 1.74e-04 | gates=[0.251953125, 0.26171875] | mix=nback
step 40 | loss 1.8675 | lr 1.47e-04 | gates=[0.22265625, 0.2333984375] | mix=hf
step 50 | loss 1.5155 | lr 1.14e-04 | gates=[0.216796875, 0.2275390625] | mix=hf
step 60 | loss 6.3036 | lr 7.92e-05 | gates=[0.2890625, 0.298828125] | mix=copy
step 70 | loss 5.6683 | lr 4.70e-05 | gates=[0.287109375, 0.296875] | mix=repeat
step 80 | loss 5.6513 | lr 2.12e-05 | gates=[0.2890625, 0.30078125] | mix=copy
step 90 | loss 1.6827 | lr 2.00e-05 | gates=[0.2470703125, 0.255859375] | mix=hf
step 100 | loss 5.1239 | lr 2.00e-05 | gates=[0.287109375, 0.30078125] | mix=nback
[after  E7_warmstart_seed1337] alloc=40.47 GB | reserved=58.81 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=40.47 GB | reserved=40.71 GB | free=0.00 GB | total=25.77 GB

=== E7_war

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 2.3192 | lr 2.00e-04 | gates=[0.244140625, 0.24609375] | mix=hf
step 20 | loss 9.0200 | lr 1.93e-04 | gates=[0.2373046875, 0.2412109375] | mix=hf
step 30 | loss 7.7515 | lr 1.74e-04 | gates=[0.251953125, 0.26171875] | mix=nback
step 40 | loss 6.8784 | lr 1.47e-04 | gates=[0.251953125, 0.265625] | mix=copy
step 50 | loss 6.0626 | lr 1.14e-04 | gates=[0.25390625, 0.26953125] | mix=copy
step 60 | loss 2.4782 | lr 7.92e-05 | gates=[0.255859375, 0.26953125] | mix=hf
step 70 | loss 1.7376 | lr 4.70e-05 | gates=[0.25390625, 0.26953125] | mix=hf
step 80 | loss 5.5908 | lr 2.12e-05 | gates=[0.2890625, 0.306640625] | mix=copy
step 90 | loss 5.0990 | lr 2.00e-05 | gates=[0.287109375, 0.306640625] | mix=nback
step 100 | loss 4.9676 | lr 2.00e-05 | gates=[0.29296875, 0.30859375] | mix=repeat
[after  E7_warmstart_seed2027] alloc=40.47 GB | reserved=57.77 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=40.47 GB | reserved=40.71 GB | free=0.00 GB | total=25.77 GB

=== E7_warmstart

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 9.4094 | lr 2.00e-04 | gates=[0.25, 0.251953125] | mix=copy
step 20 | loss 6.5907 | lr 1.93e-04 | gates=[0.255859375, 0.259765625] | mix=nback
step 30 | loss 6.2954 | lr 1.74e-04 | gates=[0.2578125, 0.26171875] | mix=copy
step 40 | loss 5.3130 | lr 1.47e-04 | gates=[0.259765625, 0.265625] | mix=repeat
step 50 | loss 1.4202 | lr 1.14e-04 | gates=[0.2216796875, 0.23046875] | mix=hf
step 60 | loss 5.3073 | lr 7.92e-05 | gates=[0.29296875, 0.302734375] | mix=nback
step 70 | loss 5.5675 | lr 4.70e-05 | gates=[0.294921875, 0.3046875] | mix=copy
step 80 | loss 5.5075 | lr 2.12e-05 | gates=[0.296875, 0.3046875] | mix=copy
step 90 | loss 1.6206 | lr 2.00e-05 | gates=[0.251953125, 0.259765625] | mix=hf
step 100 | loss 5.0807 | lr 2.00e-05 | gates=[0.296875, 0.3046875] | mix=repeat
[after  E7_warmstart_seed4242] alloc=40.47 GB | reserved=57.04 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=40.47 GB | reserved=40.70 GB | free=0.00 GB | total=25.77 GB

=== E8_capacity_N64 | se

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 2.3589 | lr 2.00e-04 | gates=[0.279296875, 0.279296875] | mix=hf
step 20 | loss 7.0277 | lr 1.93e-04 | gates=[0.279296875, 0.2890625] | mix=copy
step 30 | loss 6.8002 | lr 1.74e-04 | gates=[0.279296875, 0.29296875] | mix=nback
step 40 | loss 6.3232 | lr 1.47e-04 | gates=[0.283203125, 0.296875] | mix=repeat
step 50 | loss 6.3752 | lr 1.14e-04 | gates=[0.279296875, 0.296875] | mix=nback
step 60 | loss 1.7769 | lr 7.92e-05 | gates=[0.2392578125, 0.2421875] | mix=hf
step 70 | loss 1.3798 | lr 4.70e-05 | gates=[0.2353515625, 0.2392578125] | mix=hf
step 80 | loss 5.0608 | lr 2.12e-05 | gates=[0.279296875, 0.30078125] | mix=nback
step 90 | loss 1.8050 | lr 2.00e-05 | gates=[0.228515625, 0.232421875] | mix=hf
step 100 | loss 5.2268 | lr 2.00e-05 | gates=[0.27734375, 0.298828125] | mix=nback
[after  E8_capacity_N64] alloc=40.47 GB | reserved=59.13 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=40.47 GB | reserved=40.71 GB | free=0.00 GB | total=25.77 GB

=== E8_capacity_N1

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 2.3579 | lr 2.00e-04 | gates=[0.279296875, 0.279296875] | mix=hf
step 20 | loss 7.0285 | lr 1.93e-04 | gates=[0.279296875, 0.2890625] | mix=copy
step 30 | loss 6.8010 | lr 1.74e-04 | gates=[0.279296875, 0.29296875] | mix=nback
step 40 | loss 6.3245 | lr 1.47e-04 | gates=[0.283203125, 0.296875] | mix=repeat
step 50 | loss 6.3738 | lr 1.14e-04 | gates=[0.279296875, 0.296875] | mix=nback
step 60 | loss 1.7869 | lr 7.92e-05 | gates=[0.2392578125, 0.2421875] | mix=hf
step 70 | loss 1.3772 | lr 4.70e-05 | gates=[0.2353515625, 0.2392578125] | mix=hf
step 80 | loss 5.0625 | lr 2.12e-05 | gates=[0.279296875, 0.30078125] | mix=nback
step 90 | loss 1.8055 | lr 2.00e-05 | gates=[0.228515625, 0.232421875] | mix=hf
step 100 | loss 5.2289 | lr 2.00e-05 | gates=[0.27734375, 0.298828125] | mix=nback
[after  E8_capacity_N128] alloc=40.47 GB | reserved=59.13 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=40.47 GB | reserved=40.71 GB | free=0.00 GB | total=25.77 GB

=== E9_baseline_h

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 4.9431 | lr 2.00e-04 | gates=[0.2734375, 0.275390625] | mix=hf
step 20 | loss 1.8983 | lr 1.93e-04 | gates=[0.26171875, 0.265625] | mix=hf
step 30 | loss 5.5588 | lr 1.74e-04 | gates=[0.279296875, 0.287109375] | mix=nback
step 40 | loss 6.8865 | lr 1.47e-04 | gates=[0.283203125, 0.291015625] | mix=copy
step 50 | loss 2.0390 | lr 1.14e-04 | gates=[0.2314453125, 0.2412109375] | mix=hf
step 60 | loss 2.1616 | lr 7.92e-05 | gates=[0.234375, 0.2431640625] | mix=hf
step 70 | loss 2.0309 | lr 4.70e-05 | gates=[0.228515625, 0.23828125] | mix=hf
step 80 | loss 1.5388 | lr 2.12e-05 | gates=[0.228515625, 0.2392578125] | mix=hf
step 90 | loss 1.3621 | lr 2.00e-05 | gates=[0.2294921875, 0.2392578125] | mix=hf
step 100 | loss 5.5914 | lr 2.00e-05 | gates=[0.283203125, 0.294921875] | mix=copy
[Haystack] acc=0.001 | loss=14.843
[after  E9_baseline_haystack] alloc=40.47 GB | reserved=56.68 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=40.47 GB | reserved=40.69 GB | free=0.00 GB |

(DNCFormerHead(
   (base): Phi3ForCausalLM(
     (model): Phi3Model(
       (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
       (embed_dropout): Dropout(p=0.0, inplace=False)
       (layers): ModuleList(
         (0-31): 32 x Phi3DecoderLayer(
           (self_attn): Phi3Attention(
             (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
             (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
             (rotary_emb): Phi3RotaryEmbedding()
           )
           (mlp): Phi3MLP(
             (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
             (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
             (activation_fn): SiLU()
           )
           (input_layernorm): Phi3RMSNorm()
           (resid_attn_dropout): Dropout(p=0.0, inplace=False)
           (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
           (post_attention_layernorm): Phi3RMSNorm()
         )
    

In [90]:
gc.collect()
torch.cuda.empty_cache()

## Tensorboard log dump

In [89]:
# ===== TensorBoard event analyzer for DNCFormer runs (robust) =====
# Parses .tfevents files -> summary tables + CSVs
import os, math, json, time
from pathlib import Path

import numpy as np
import pandas as pd
from IPython.display import display

# Optional: print TB version for debugging
try:
    import tensorboard as _tb
    print("TensorBoard version:", getattr(_tb, "__version__", "unknown"))
except Exception as _e:
    print("TensorBoard import note:", _e)

# EventAccumulator + module-level constants (version-agnostic handling)
try:
    from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
    from tensorboard.backend.event_processing import event_accumulator as ea_mod
except Exception as e:
    raise RuntimeError("TensorBoard not installed in this environment. "
                       "Install via: pip install tensorboard") from e

# --- Point explicitly to log paths ---
EVENT_FILES = [
    r"runs/dncformer-20250817-184608/events.out.tfevents.1755631674.Persephone.194944.2",
]
# If you prefer to auto-discover:
if not EVENT_FILES:
    EVENT_FILES = sorted([str(p) for p in Path(".").glob("**/events.out.tfevents.*")])

assert EVENT_FILES, "No event files found. Fill EVENT_FILES or adjust the glob."

# --------- helpers ---------
def _size_guidance_version_safe():
    """
    Build size_guidance dict that works across TB versions:
    older/newer TB put constants at module level, not on the class.
    """
    sg = {}
    # Try module-level numeric constants first; otherwise fall back to lowercase keys
    keys = ["SCALARS", "HISTOGRAMS", "IMAGES", "COMPRESSED_HISTOGRAMS", "AUDIO"]
    for k in keys:
        v = getattr(ea_mod, k, None)
        if v is not None:
            sg[v] = 0
        else:
            # event_accumulator also accepts string keys like "scalars"
            sg[k.lower()] = 0
    return sg

def load_scalars(ev_path: str):
    acc = EventAccumulator(ev_path, size_guidance=_size_guidance_version_safe())
    acc.Reload()
    # acc.Tags() returns dict with lowercase keys in modern TB
    tags = acc.Tags()
    scalar_tags = tags.get('scalars', [])
    out = {}
    for tag in scalar_tags:
        vals = acc.Scalars(tag)
        out[tag] = [(x.step, float(x.value)) for x in vals]
    return out

def s_last(vals, k=None):
    if not vals: return np.nan
    arr = np.array([v for _, v in vals], dtype=float)
    if k is None or k >= len(arr): return float(arr[-1])
    return float(np.nanmean(arr[-k:]))

def s_first(vals, k=None):
    if not vals: return np.nan
    arr = np.array([v for _, v in vals], dtype=float)
    if k is None or k >= len(arr): return float(arr[0])
    return float(np.nanmean(arr[:k]))

def s_mean(vals):
    if not vals: return np.nan
    return float(np.nanmean([v for _, v in vals]))

def s_count(vals):
    return len(vals) if vals else 0

TASKS  = ["hf", "copy", "repeat", "nback"]
BLOCKS = [0, 1]

def summarize_run(ev_path: str):
    scal = load_scalars(ev_path)

    # basics
    loss = scal.get("train/loss", [])
    lr   = scal.get("train/lr", [])
    steps_logged = max([s for s, _ in loss], default=np.nan) if loss else np.nan
    loss0 = s_first(loss, k=5)
    lossT = s_last(loss,  k=10)
    loss_delta = (loss0 - lossT) if not any(map(math.isnan, [loss0, lossT])) else np.nan
    lr_last = s_last(lr, k=1)

    # gates (global)
    g_means, g_entropy = {}, {}
    g_frac_avg = {}
    for b in BLOCKS:
        g_means[b]   = s_last(scal.get(f"gates/block_{b}_mean", []), k=10)
        g_entropy[b] = s_last(scal.get(f"gates/block_{b}_entropy", []), k=10)
        fracs = []
        for t in TASKS:
            tag = f"gates_by_task/block_{b}_frac>0.5/{t}"
            if tag in scal:
                fracs.append(s_last(scal[tag], k=10))
        g_frac_avg[b] = float(np.nanmean(fracs)) if fracs else np.nan

    # per-task losses (mean of last quarter)
    task_loss_last, task_counts = {}, {}
    for t in TASKS:
        ts = scal.get(f"loss_by_task/{t}", [])
        task_counts[t] = s_count(ts)
        if ts:
            k = max(1, len(ts)//4)
            task_loss_last[t] = s_last(ts, k=k)
        else:
            task_loss_last[t] = np.nan

    # per-task gate means (avg last quarter)
    task_gmeans = {t: {} for t in TASKS}
    for t in TASKS:
        for b in BLOCKS:
            ts = scal.get(f"gates_by_task/block_{b}_mean/{t}", [])
            if ts:
                k = max(1, len(ts)//4)
                task_gmeans[t][b] = s_last(ts, k=k)
            else:
                task_gmeans[t][b] = np.nan

    # quartiles (block0 Q1..Q4) averaged across tasks if present
    q_means = {}
    for qi in range(1,5):
        vals = []
        for t in TASKS:
            tagt = f"gates/block0_q{qi}_mean/{t}"
            if tagt in scal:
                vals.append(s_last(scal[tagt], k=10))
        q_means[qi] = float(np.nanmean(vals)) if vals else np.nan

    # Try to infer forced gating
    forced_guess = None
    gm_all = [g_means[b] for b in BLOCKS if not math.isnan(g_means[b])]
    if gm_all:
        m = float(np.nanmean(gm_all))
        if m < 0.02:  forced_guess = "force_g=0"
        elif m > 0.98: forced_guess = "force_g=1"

    # run_id default (may be overridden in assembly)
    suffix = Path(ev_path).suffix.lstrip(".")
    run_id = suffix if suffix.isdigit() else Path(ev_path).name

    summary = {
        "run_file": Path(ev_path).name,
        "run_id": run_id,
        "steps_logged": steps_logged,
        "loss_start~5": loss0,
        "loss_end~10": lossT,
        "loss_delta": loss_delta,
        "lr_last": lr_last,
        "g_mean_b0": g_means.get(0, np.nan),
        "g_mean_b1": g_means.get(1, np.nan),
        "g_frac>0.5_avg": float(np.nanmean([g_frac_avg.get(0, np.nan), g_frac_avg.get(1, np.nan)])),
        "g_entropy_b0": g_entropy.get(0, np.nan),
        "g_entropy_b1": g_entropy.get(1, np.nan),
        "forced_guess": forced_guess,
        **{f"loss_{t}_last": task_loss_last[t] for t in TASKS},
        **{f"gmean_b0_{t}": task_gmeans[t][0] for t in TASKS},
        **{f"gmean_b1_{t}": task_gmeans[t][1] for t in TASKS},
        **{f"g_b0_Q{qi}_mean": q_means[qi] for qi in range(1,5)},
        **{f"count_{t}": task_counts[t] for t in TASKS},
    }
    return summary, scal

def parse_suffix_as_int(ev_path: str):
    suf = Path(ev_path).suffix.lstrip(".")
    return int(suf) if suf.isdigit() else None

# Map numeric suffix -> experiment label
RUN_LABELS = {
    0: "E0_baseline",
    1: "E1_memory_lean",
    2: "E2_gate_reg_low",
    3: "E3_gate_reg_high",
    4: "E4_gate_temp_0p7",
    5: "E5a_force_g_0",
    6: "E5b_force_g_1",
}

# --------- robust assembly & labeling ----------
summaries = []
all_scalars = {}
for p in EVENT_FILES:
    fn = Path(p).name
    try:
        s, scal = summarize_run(p)
        if "run_id" not in s or (s.get("run_id") in (None, "", np.nan)):
            suf_i = parse_suffix_as_int(p)
            s["run_id"] = suf_i if suf_i is not None else fn
        suf_i = parse_suffix_as_int(p)
        s["label"] = RUN_LABELS.get(suf_i, s["run_id"])
        s["path"] = str(p)
        summaries.append(s)
        all_scalars[fn] = scal
    except Exception as e:
        suf_i = parse_suffix_as_int(p)
        summaries.append({
            "run_file": fn,
            "run_id": suf_i if suf_i is not None else fn,
            "label": RUN_LABELS.get(suf_i, f"unparsed_{fn}"),
            "error": str(e),
            "path": str(p),
        })
        all_scalars[fn] = {}

df_runs = pd.DataFrame(summaries)

# Order by planned experiment sequence
if "label" in df_runs.columns:
    order = ["E0_baseline","E1_memory_lean","E2_gate_reg_low","E3_gate_reg_high",
             "E4_gate_temp_0p7","E5a_force_g_0","E5b_force_g_1"]
    df_runs["label_cat"] = pd.Categorical(df_runs["label"], categories=order, ordered=True)
    df_runs = df_runs.sort_values(["label_cat","run_id"], ignore_index=True)
elif "run_id" in df_runs.columns:
    df_runs = df_runs.sort_values("run_id", key=lambda s: s.astype(str), ignore_index=True)

# Arrange columns with most relevant up front
front_cols = [c for c in [
    "label", "run_id", "run_file", "steps_logged",
    "loss_start~5","loss_end~10","loss_delta","lr_last",
    "g_mean_b0","g_mean_b1","g_frac>0.5_avg","g_entropy_b0","g_entropy_b1",
    "loss_hf_last","loss_copy_last","loss_repeat_last","loss_nback_last",
    "gmean_b0_hf","gmean_b1_hf","gmean_b0_copy","gmean_b1_copy",
    "gmean_b0_repeat","gmean_b1_repeat","gmean_b0_nback","gmean_b1_nback",
    "g_b0_Q1_mean","g_b0_Q2_mean","g_b0_Q3_mean","g_b0_Q4_mean",
    "forced_guess","error","path"
] if c in df_runs.columns]
df_runs = df_runs[[*front_cols, *[c for c in df_runs.columns if c not in front_cols]]]

display(df_runs)

# Save CSV
out_dir = Path("./analysis")
out_dir.mkdir(parents=True, exist_ok=True)
runs_csv = out_dir / "run_level_summary.csv"
df_runs.to_csv(runs_csv, index=False)
print("Saved:", runs_csv)

# --------- per‑task time series (granular export; replaces tidy summary block) ----------
rows = []
for ev, scal in all_scalars.items():
    # derive run label from suffix (0..6) using our RUN_LABELS map
    suf_i = parse_suffix_as_int(ev)
    run_label = RUN_LABELS.get(suf_i, ev)
    run_id_str = RUN_LABELS.get(suf_i, str(suf_i) if suf_i is not None else ev)

    # join LR by global training step (helps analyze scheduler effects)
    lr_dict = {int(s): float(v) for s, v in scal.get("train/lr", [])}

    for t in TASKS:
        # per-task loss series: list[(step, value)], logged at your log_every cadence
        loss_series = scal.get(f"loss_by_task/{t}", [])
        if not loss_series:
            continue

        # per-task gate metrics per block, keyed by step
        gmean_dict = {
            b: {int(s): float(v) for s, v in scal.get(f"gates_by_task/block_{b}_mean/{t}", [])}
            for b in BLOCKS
        }
        gfrac_dict = {
            b: {int(s): float(v) for s, v in scal.get(f"gates_by_task/block_{b}_frac>0.5/{t}", [])}
            for b in BLOCKS
        }

        # optional quartile gate means for block 0 (if logged)
        q_dict = {
            qi: {int(s): float(v) for s, v in scal.get(f"gates/block0_q{qi}_mean/{t}", [])}
            for qi in (1, 2, 3, 4)
        }

        for step, loss_val in loss_series:
            step = int(step)
            loss_val = float(loss_val)

            for b in BLOCKS:
                g_mean = gmean_dict[b].get(step, np.nan)
                g_frac = gfrac_dict[b].get(step, np.nan)
                row = {
                    "label": run_label,
                    "run_id": run_id_str,
                    "step": step,
                    "task": t,
                    "block": b,
                    "loss": loss_val,
                    "g_mean": g_mean,
                    "g_frac>0.5": g_frac,
                    "lr": lr_dict.get(step, np.nan),
                }
                # add quartile gating only for block 0; keep NaN for block 1
                if b == 0:
                    for qi in (1, 2, 3, 4):
                        row[f"g_b0_Q{qi}"] = q_dict[qi].get(step, np.nan)
                else:
                    for qi in (1, 2, 3, 4):
                        row[f"g_b0_Q{qi}"] = np.nan

                rows.append(row)

# Build DataFrame and save
df_task_ts = pd.DataFrame(rows).sort_values(
    ["label", "task", "step", "block"], ignore_index=True
)
display(df_task_ts.head(20))

# Backwards-compat alias (if you referenced df_task elsewhere)
df_task = df_task_ts.copy()

tasks_csv = out_dir / "per_task_metrics.csv"
df_task_ts.to_csv(tasks_csv, index=False)
print("Saved (granular):", tasks_csv)


df_task = pd.DataFrame(rows)
if not df_task.empty:
    df_task = df_task.sort_values(["label","task","block"], ignore_index=True)
display(df_task)

tasks_csv = out_dir / "per_task_metrics.csv"
df_task.to_csv(tasks_csv, index=False)
print("Saved:", tasks_csv)

print("Analyzed files:")
for p in EVENT_FILES:
    print(" -", p)


TensorBoard version: 2.20.0


Unnamed: 0,label,run_id,run_file,error,path,label_cat
0,E2_gate_reg_low,2,events.out.tfevents.1755631674.Persephone.1949...,b'runs/dncformer-20250817-184608/events.out.tf...,runs/dncformer-20250817-184608/events.out.tfev...,E2_gate_reg_low


Saved: analysis\run_level_summary.csv


KeyError: 'label'

## 12. Notes & TODO
- The DNC memory here is **compact** and intended for research iteration; you can swap in a fuller reference implementation if desired. [x]
- The controller currently runs in **sequence mode** with a causal mask. A step-wise cached mode can be added for streaming scenarios.
- Training uses a tiny **instruction-following set** plus synthetic memory tasks. You can plug in larger corpora or evaluation suites later.
- Gating is **vector-valued** with bias init favoring the vanilla path; metrics log mean gate values.
- Use `CFG.n_blocks` to grow the enrichment depth as VRAM allows.


In [None]:
# --- Optional: export environment specs (run in the dncformer conda env) ---
import shutil, subprocess, sys
def _run(cmd):
    try:
        print(">", cmd); subprocess.run(cmd, shell=True, check=True)
    except Exception as e:
        print("Command failed (may be expected on some setups):", e)

print("Exporting conda environment and pip freeze to current working directory...")
_run("conda env export --from-history > environment.yml")
_run("conda env export > environment.lock.yml")
_run("python -m pip list --format=freeze > requirements-pip.txt")
print("Done. If any commands failed, run them in your terminal inside the active conda env.")
