
# DNCformer: Parallel-Enrichment Transformer–DNC (Notebook Prototype)

Implements a **parallel enrichment** architecture that adds a **Transformer-style DNC block** alongside a standard Transformer block, with a learned **gating** to mix their outputs. It wires on top of a **frozen ~4B LLM** (Phi-3-mini-4k-instruct by default), and provides **lightweight train/eval** loops and **unit-like tests**.

**Hardware:** designed for a single GPU (e.g., RTX 3090 24GB) using AMP (`bf16` if available, otherwise `fp16`).  
**Structure:** Config → Utils → DNC Memory → Transformer Controller → DNCformer Block → Parallel Enrichment → Frozen Base + N Blocks → Data → Train → Eval → Tests.


## 0. Imports, config, and environment

#### 0.1 Imports

In [61]:
import sys, platform, torch
from dataclasses import dataclass
from typing import Optional, List, Tuple, Dict, Any
import os, math, random, time, numpy as np
from torch import nn, Tensor
from torch.nn import functional as F
import contextlib
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim.lr_scheduler import LambdaLR
from typing import List, Optional
import torch, gc


print("python exe:", sys.executable)
print("torch file:", torch.__file__)
print("torch ver:", torch.__version__, "| torch.version.cuda:", torch.version.cuda)
print("cuda available:", torch.cuda.is_available(), "| device count:", torch.cuda.device_count())


python exe: C:\Users\andre\miniconda3\envs\dncformer\python.exe
torch file: C:\Users\andre\miniconda3\envs\dncformer\lib\site-packages\torch\__init__.py
torch ver: 2.5.1 | torch.version.cuda: 12.6
cuda available: True | device count: 1


#### 0.2 Configuration

In [62]:
@dataclass
class Config:
    base_model_id: str = "microsoft/Phi-3-mini-4k-instruct"  # ~3.8B
    d_model: Optional[int] = None          # If None, infer from base model hidden size
    n_blocks: int = 2                      # number of parallel enrichment blocks
    attn_heads: int = 8                    # heads in DNC controller
    attn_dropout: float = 0.1              # attention layer dropout fraction
    ffn_mult: float = 4.0
    dnc_read_heads: int = 2                # number of DNC read heads
    dnc_cell_size: int = 64                # memory slot width
    dnc_nr_cells: int = 256                # number of memory slots
    gate_bias_init: float = -1.0           # bias to prefer transformer at init
    lr: float = 2e-4                       # learning rate
    weight_decay: float = 0.01
    max_seq_len: int = 1024                # training seq length
    train_steps: int = 200                 # small sanity pass
    warmup_steps: int = 20
    grad_clip: float = 1.0
    precision: str = "bf16"                # "bf16" | "fp16" | "fp32"
    use_torch_compile: bool = False
    device: str = "cuda"
    log_every: int = 10
    batch_size: int = 8
    seed: int = 42

CFG = Config()

device = torch.device(CFG.device if torch.cuda.is_available() else "cpu")
print("Device:", device, "CUDA:", torch.cuda.is_available())

amp_dtype = None
if CFG.precision == "bf16" and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
    amp_dtype = torch.bfloat16
elif CFG.precision == "fp16" and torch.cuda.is_available():
    amp_dtype = torch.float16
else:
    amp_dtype = torch.float32

def cfg_to_json(cfg=None) -> str:
    cfg = cfg or CFG
    d = {}
    for k,v in getattr(cfg, "__dict__", {}).items():
        if not k.startswith("_"):
            try:
                json.dumps(v); d[k]=v
            except TypeError:
                d[k]=str(v)
    return json.dumps(d, indent=2)

def echo_cfg(to_console: bool = True, to_tb: bool = True, tag: str = "cfg/json"):
    s = cfg_to_json(CFG)
    if to_console:
        print("=== CFG ===")
        print(s)
    if to_tb and 'tb' in globals() and getattr(tb, "writer", None):
        with contextlib.suppress(Exception):
            tb.writer.add_text(tag, s, 0)



Device: cuda CUDA: True


#### 0.3 Config patch toggles

In [63]:

try:
    CFG
except NameError:
    class _Tmp: pass
    CFG = _Tmp()
# safe defaults if not already present
if not hasattr(CFG, 'batch_size'): CFG.batch_size = 8
if not hasattr(CFG, 'gate_reg_lambda'): CFG.gate_reg_lambda = 0.0   # only applied on memory-tagged batches
if not hasattr(CFG, 'hist_every'): CFG.hist_every = 200             # histogram cadence
if not hasattr(CFG, 'force_g'): CFG.force_g = None                  # None, or 0.0 or 1.0
if not hasattr(CFG, 'gate_temp'): CFG.gate_temp = 1.0


In [64]:
# CFG extensions for E10–E14
def extend_cfg_defaults():
    # E10: multi-memory experts (shared controller)
    setattr(CFG, "mem_experts", getattr(CFG, "mem_experts", 1))                   # 1 -> current behavior
    setattr(CFG, "expert_gate_temp", getattr(CFG, "expert_gate_temp", 1.0))
    setattr(CFG, "expert_diversity_lambda", getattr(CFG, "expert_diversity_lambda", 0.0))
    setattr(CFG, "expert_W", getattr(CFG, "expert_W", getattr(CFG, "W", 64)))
    setattr(CFG, "expert_R", getattr(CFG, "expert_R", getattr(CFG, "R", 1)))
    setattr(CFG, "expert_N", getattr(CFG, "expert_N", [getattr(CFG,"N",64), getattr(CFG,"N",64)]))

    # E11: per-block memory configs (list of dicts or None -> use global N/W/R)
    setattr(CFG, "per_block_cfg", getattr(CFG, "per_block_cfg", None))
    # Optional: per-block free gate bias (+ retains less, - retains more)
    setattr(CFG, "per_block_free_bias", getattr(CFG, "per_block_free_bias", None))  # e.g., [ +0.3, -0.2 ]

    # E12: read-to-attention fusion (light)
    setattr(CFG, "fusion_enable", getattr(CFG, "fusion_enable", False))
    setattr(CFG, "fusion_hidden_mult", getattr(CFG, "fusion_hidden_mult", 2.0))
    setattr(CFG, "fusion_drop", getattr(CFG, "fusion_drop", 0.0))
    setattr(CFG, "fusion_bias_queries", getattr(CFG, "fusion_bias_queries", False))  # keep False initially

    # E13: write sparsity / overlap regs (overlap wired later when keys are exposed)
    setattr(CFG, "write_reg_lambda", getattr(CFG, "write_reg_lambda", 0.0))
    setattr(CFG, "key_overlap_lambda", getattr(CFG, "key_overlap_lambda", 0.0))
    setattr(CFG, "key_overlap_window", getattr(CFG, "key_overlap_window", 1))
    setattr(CFG, "reg_only_on_memory_batches", getattr(CFG, "reg_only_on_memory_batches", True))

extend_cfg_defaults()
print("[CFG] E10–E14 flags added with safe defaults.")

[CFG] E10–E14 flags added with safe defaults.


#### 0.4 SDPA selection
- prefer PyTorch SDPA
- avoid flash-attn

In [65]:
def sdpa_ctx():
    """Return a fresh attention-kernel selection context each time it's called.
    Uses PyTorch SDPA (math + mem-efficient) and disables flash-attn to avoid warnings.
    """
    try:
        from torch.backends.cuda import sdp_kernel  # callable context manager in PyTorch 2.x
        return sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True)
    except Exception:
        return contextlib.nullcontext()


#### 0.5 Determinism and seeds

In [66]:
import os, random, numpy as _np, torch, contextlib

def set_determinism(seed: int = 42, deterministic: bool = True, cudnn_benchmark: bool = False):
    """
    Set seeds across Python/NumPy/Torch and optionally toggle deterministic algorithms.
    deterministic=True may slow kernels; Warn_only to avoid hard errors
    """
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    _np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # cuDNN settings
    torch.backends.cudnn.benchmark = bool(cudnn_benchmark)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        # Help cublas determinism for some kernels; warn only failure mode
        os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
        with contextlib.suppress(Exception):
            torch.use_deterministic_algorithms(True, warn_only=True)

def set_seed(seed: int = 42):
    """Convenience alias."""
    set_determinism(seed=seed, deterministic=False, cudnn_benchmark=False)

# set CFG seed
with contextlib.suppress(Exception):
    if hasattr(CFG, "seed"):
        set_seed(int(CFG.seed))

## 1. Utilities

#### 1.1 Model Information and factory

In [67]:
_GLOBAL = {}

def causal_mask(sz: int, device=None):
    return torch.full((sz, sz), float("-inf"), device=device).triu(1)

def count_params(model: nn.Module):
    return sum(p.numel() for p in model.parameters())

def requires_grad_(module: nn.Module, flag: bool):
    for p in module.parameters():
        p.requires_grad_(flag)
    return module

def print_shapes(**tensors):
    for k, v in tensors.items():
        if isinstance(v, (list, tuple)):
            print(k, [tuple(x.shape) for x in v])
        elif isinstance(v, torch.Tensor):
            print(k, tuple(v.shape))
        else:
            print(k, type(v))
            
def _mean_safely(seq):
    xs = [x for x in seq if isinstance(x, (int,float)) and not (x != x)]  # drop NaN
    return sum(xs)/len(xs) if xs else float("nan")


#### 1.2 Gate metrics utils (mean, frac>0.5, entropy)

In [68]:
def _reduce_gate_tensor(g: torch.Tensor) -> torch.Tensor:
    # g: (B,T,*) -> (B,T)
    if g.dim() == 3:
        return g.mean(dim=-1)
    return g

@torch.no_grad()
def _gate_metrics(g: torch.Tensor):
    """
    Returns scalar (mean, frac>0.5, entropy)
    """
    g2 = _reduce_gate_tensor(g.detach())
    mean_val = float(g2.mean().item())
    frac = float((g2 > 0.5).float().mean().item())
    eps = 1e-6
    p = g2.clamp(eps, 1 - eps)
    # binary entropy (natural log base)
    ent = float((-(p * (p + eps).log() + (1 - p) * (1 - p + eps).log())).mean().item())
    return mean_val, frac, ent

#### 1.3 Memory Freeing/handling

In [69]:
def cuda_report(tag=""):
    if not torch.cuda.is_available():
        print("[cuda_report] CUDA not available"); return
    free, total = torch.cuda.mem_get_info()
    alloc = torch.cuda.memory_allocated()
    reserv = torch.cuda.memory_reserved()
    print(f"[{tag}] alloc={alloc/1e9:.2f} GB | reserved={reserv/1e9:.2f} GB | free={free/1e9:.2f} GB | total={total/1e9:.2f} GB")

def free_head_and_cache():
    # delete typical globals and clear allocator
    with contextlib.suppress(Exception):
        del globals()['head']
    with contextlib.suppress(Exception):
        del globals()['tok']
    gc.collect(); torch.cuda.empty_cache()
    cuda_report("after free")

#### 1.4 Environment Export

In [70]:
# --- Optional: export environment specs (run in the dncformer conda env) ---
import shutil, subprocess, sys
def _run(cmd):
    try:
        print(">", cmd); subprocess.run(cmd, shell=True, check=True)
    except Exception as e:
        print("Command failed (may be expected on some setups):", e)
        
def export_current_env():
    print("Exporting conda environment and pip freeze to current working directory...")
    _run("conda env export --from-history > environment.yml")
    _run("conda env export > environment.lock.yml")
    _run("python -m pip list --format=freeze > requirements-pip.txt")
    print("Done. If any commands failed, run them in your terminal inside the active conda env.")

#### 1.5 Checkpoint helpers

In [71]:
import json, os, contextlib, torch
from pathlib import Path

def _cfg_to_dict(cfg) -> dict:
    out = {}
    for k, v in getattr(cfg, "__dict__", {}).items():
        if not k.startswith("_"):
            try:
                json.dumps(v)  # test serializability
                out[k] = v
            except TypeError:
                out[k] = str(v)
    return out

def save_head(head: torch.nn.Module, out_dir: str, cfg=None, run_label: str = None):
    """
    Saves only the trainable head parameters (state_dict) to out_dir/{run_label or 'head'}.pt,
    plus a metadata JSON with config and light model info.
    """
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    tag = run_label or "head"
    ckpt_path = Path(out_dir) / f"{tag}.pt"
    meta_path = Path(out_dir) / f"{tag}.meta.json"

    # Only head params
    sd = head.state_dict()

    # Try to capture minimal head info
    info = {"type": head.__class__.__name__}
    with contextlib.suppress(Exception):
        info["blocks"] = len(getattr(head, "blocks", []))
    with contextlib.suppress(Exception):
        info["d_model"] = int(getattr(CFG, "d_model"))
    with contextlib.suppress(Exception):
        info["base_model_id"] = getattr(CFG, "base_model_id", None)

    meta = {"config": _cfg_to_dict(cfg or CFG), "head_info": info}
    torch.save(sd, ckpt_path)
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2)
    print(f"[save_head] wrote {ckpt_path} and {meta_path}")

def load_head(head: torch.nn.Module, in_path: str, strict: bool = False, map_location=None):
    """
    Loads a head checkpoint (state_dict) into an existing head instance.
    """
    in_path = str(in_path)
    sd = torch.load(in_path, map_location=map_location or "cpu")
    missing, unexpected = head.load_state_dict(sd, strict=strict)
    if missing or unexpected:
        print(f"[load_head] missing={missing} unexpected={unexpected}")
    else:
        print("[load_head] restored successfully.")
    return head

## 2. Models

#### 2.1: DNC Memory

In [72]:
class DNCMemory(nn.Module):
    """
    Compact DNC memory:
    - memory M: (B, N, W)
    - usage u: (B, N)
    - link L: (B, N, N) temporal links
    - precedence p: (B, N)
    - read weights rw: (B, R, N)
    - write weights ww: (B, N)
    - read vectors r: (B, R, W)
    """
    def __init__(self, nr_cells: int, cell_size: int, read_heads: int):
        super().__init__()
        self.probe = None  # optional callable to record per-step state
        self.N = nr_cells
        self.W = cell_size
        self.R = read_heads
        self.free_bias = getattr(self, "free_bias", 0.0)   # default to no bias

    def reset(self, B: int, device=None):
        device = device or next(self.parameters(), torch.empty(0, device="cpu")).device
        M = torch.zeros(B, self.N, self.W, device=device)
        u = torch.zeros(B, self.N, device=device)
        L = torch.zeros(B, self.N, self.N, device=device)
        p = torch.zeros(B, self.N, device=device)
        rw = F.one_hot(torch.zeros(B, self.R, dtype=torch.long, device=device), num_classes=self.N).float()
        r = torch.zeros(B, self.R, self.W, device=device)
        return {"M": M, "u": u, "L": L, "p": p, "rw": rw, "r": r}

    @staticmethod
    def _cosine_sim(M: torch.Tensor, k: torch.Tensor, eps=1e-6):
        # M: (B, N, W), k: (B, W) or (B, R, W)
        if k.dim() == 2:
            k = k.unsqueeze(1)  # (B,1,W)
        B, R, W = k.shape
        Mnorm = F.normalize(M, p=2, dim=-1)
        knorm = F.normalize(k, p=2, dim=-1)
        sim = torch.einsum("bnw,brw->brn", Mnorm, knorm)  # (B, R, N)
        return sim

    def _allocation(self, u: torch.Tensor):
        # u: (B, N) in [0,1]
        δ = 1e-6
        u = δ + (1 - δ) * u                 # avoid tiny values before cumprod
        B, N = u.shape
        # sort ascending usage -> free list φ
        sorted_u, phi = torch.sort(u, dim=-1, descending=False)  # (B,N)
        # exclusive cumprod of sorted_u
        ones = torch.ones(B, 1, device=u.device, dtype=u.dtype)
        prod_excl = torch.cumprod(torch.cat([ones, sorted_u], dim=1), dim=1)[:, :-1]  # (B,N)
        a_sorted = (1 - sorted_u) * prod_excl                                        # (B,N)
        # invert the sort to original order
        inv_phi = torch.argsort(phi, dim=-1)
        a = a_sorted.gather(1, inv_phi)                                              # (B,N)
        return a


    def forward(self, x_if: dict, state: dict):
        """
        x_if (interface dict) must contain:
        - k_read: (B,R,W), beta_read: (B,R,1)
        - k_write: (B,W), beta_write:(B,1)
        - erase: (B,W) in (0,1), write_vec: (B,W)
        - free_gates: (B,R,1) in (0,1), alloc_gate:(B,1), write_gate:(B,1) in (0,1)
        - read_mode: (B,R,3) softmax over {backward, content, forward}
        """
        M, u, L, p, rw, r = state["M"], state["u"], state["L"], state["p"], state["rw"], state["r"]
        B, N, W = M.shape

        # --- Usage update (faithful, stable) ---
        # previous write weights; keep as (B,1,N) for easy broadcasting
        ww_prev = state.get("ww", torch.zeros(M.size(0), 1, self.N, device=M.device, dtype=M.dtype))
        # writes increase usage
        u = u + (1 - u) * (1 - torch.prod(1 - ww_prev, dim=1))   # -> (B,N)
        # apply a clamped bias if config specifies (per-block or global)
        free_g = x_if["free_gates"]
        if getattr(self, "free_bias", 0.0) != 0.0:
            free_g = (free_g + self.free_bias).clamp(0.0, 0.1)
        # free gates release usage at read locations (per-location retention)
        psi = torch.prod(1 - free_g * rw, dim=1)     # (B,N), since free_gates:(B,R,1), rw:(B,R,N)
        u = torch.clamp(u * psi, 0, 1)


        # 2.1) Write weighting (robust broadcasting) ---
        # sim_w: (B,1,N) if k_write=(B,W); (B,R,N) if k_write=(B,R,W)
        sim_w = self._cosine_sim(M, x_if["k_write"])
        # beta_w: expect (B,1) or (B,R,1); make sure it has the trailing head axis
        beta_w = x_if["beta_write"]
        if beta_w.dim() == 2:           # (B,1) or (B,R) -> add trailing axis
            beta_w = beta_w.unsqueeze(-1)  # -> (B,1,1) or (B,R,1)
        # content weights over memory locations
        cw = F.softmax(sim_w * beta_w, dim=-1)  # (B,1,N) or (B,R,N)
        # canonical DNC: single write head; if multiple heads exist, reduce over heads
        if cw.size(1) > 1:
            cw = cw.mean(dim=1)                # -> (B,N)  (alternatives: sum or a learned reduce)
        else:
            cw = cw.squeeze(1)                 # -> (B,N)
        # allocation weights from usage (B,N)
        a = self._allocation(u)                # (B,N)
        # interpolate content vs allocation via alloc_gate, then apply write_gate
        alloc = x_if["alloc_gate"]             # (B,1)
        write_gate = x_if["write_gate"]        # (B,1)
        # Broadcast (B,1) over N
        ww = write_gate * (alloc * a + (1.0 - alloc) * cw)  # -> (B,N) via broadcasting
        state["ww"] = ww
        
        # 2.2) save current ww as "previous write weights" for t+1
        state["ww_prev"] = ww.unsqueeze(1)   # keep grads for BPTT; use .detach() only if you explicitly want to stop gradients across steps

        # 3) Memory write
        erase = x_if["erase"].unsqueeze(1)  # (B,1,W)
        write_vec = x_if["write_vec"].unsqueeze(1)  # (B,1,W)
        M = M * (1 - ww.unsqueeze(-1) * erase) + ww.unsqueeze(-1) * write_vec

        # 4) Temporal link
        prev_p = p
        p = (1 - ww.sum(dim=-1, keepdim=True)) * p + ww  # precedence
        L = (1 - ww.unsqueeze(2) - ww.unsqueeze(1)) * L + torch.einsum("bn,bm->bnm", prev_p, ww)
        L = L * (1 - torch.eye(N, device=M.device).unsqueeze(0))

        # 5) Read weighting
        cr = F.softmax(self._cosine_sim(M, x_if["k_read"]) * x_if["beta_read"], dim=-1)  # (B,R,N)
        fwd = torch.einsum("brn,bnm->brm", rw, L)       # (B,R,N) forward
        bwd = torch.einsum("brn,bmn->brm", rw, L)       # (B,R,N) backward
        read_mode = F.softmax(x_if["read_mode"], dim=-1)  # (B,R,3)
        rw = read_mode[:,:,0:1]*bwd + read_mode[:,:,1:2]*cr + read_mode[:,:,2:3]*fwd
        r = torch.einsum("brn,bnw->brw", rw, M)  # (B,R,W)

        state = {"M": M, "u": u, "L": L, "p": p, "rw": rw, "r": r}
        # aggregated lightweight stats (per-step)
        try:
            _stats = {}
            with torch.no_grad():
                _stats["u_mean"] = state["u"].mean().detach()
                try:
                    _stats["M_norm_mean"] = state["M"].norm(dim=-1).mean().detach()
                except Exception:
                    pass
                try:
                    _stats["rw_max_mean"] = rw.max(dim=-1).values.mean().detach()
                except Exception:
                    pass
                try:
                    _stats["ww_max_mean"] = ww.max(dim=-1).values.mean().detach()
                except Exception:
                    pass
            state["stats"] = _stats
        except Exception:
            pass
        
        if self.probe is not None:
            M = state["M"].detach().float().cpu()
            u = state["u"].detach().float().cpu()
            L = state["L"].detach().float().cpu()
            rw_cpu = rw.detach().float().cpu()
            ww_cpu = state.get("ww", torch.zeros_like(state["u"])).detach().float().cpu()
            self.probe(
                {
                "u": u,                         # (B,N)
                "ww": ww_cpu,                   # (B,N)
                "rw": rw_cpu,                   # (B,R,N)
                "M_norm": M.norm(dim=-1),       # (B,N)
                "L_diag_mean": torch.diagonal(L, dim1=-2, dim2=-1).mean(dim=-1), # (B,)
                }
            )
        
        try:
            # Mean write gate (scalar): shape-agnostic mean
            wg = x_if.get("write_gate", None)
            write_gate_mean = float(wg.mean().detach().item()) if wg is not None else float("nan")
            # Optional: if you have read vectors available as 'r' (B,R,W) or (B,T,W), expose their norm
            read_vecs = locals().get("r", None)  # or however your code names it
            read_norm = float(read_vecs.norm().detach().item()/max(1, read_vecs.numel())) if isinstance(read_vecs, torch.Tensor) else float("nan")
            # Stash for upstream (PEB/Head) to collect
            self._last_metrics = {"write_gate_mean": write_gate_mean, "read_vec_norm": read_norm}
        except Exception:
            self._last_metrics = {}
        
        return r, state


#### 2.2 Transformer-style Controller (sequence mode)

In [73]:
class TransformerController(nn.Module):
    """
    Lightweight Transformer encoder producing the DNC interface vector.
    Inputs: X (B, T, d_in); prev_reads typically concatenated to X before calling.
    """
    def __init__(self, d_in: int, d_model: int, heads: int, dropout: float=0.1, ffn_mult: float=4.0):
        super().__init__()
        self.proj_in = nn.Linear(d_in, d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, int(d_model*ffn_mult)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(d_model*ffn_mult), d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
        h = self.proj_in(x)
        h = self.ln1(h)
        attn_out, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False)
        h = h + self.dropout(attn_out)
        h = self.ln2(h)
        h2 = self.ff(h)
        h = h + self.dropout(h2)
        return h

#### 2.3 DNCformer Block (controller → interface → memory)

In [74]:
class DNCInterfaceHead(nn.Module):
    """
    Projects controller hidden to DNC interface:
    read keys (R*W), read strengths (R), write key (W), write strength (1),
    erase (W in (0,1)), write vector (W), free_gates (R in (0,1)),
    alloc_gate (1), write_gate (1), read_mode (R*3 softmax).
    """
    def __init__(self, d_model: int, R: int, W: int):
        super().__init__()
        self.R, self.W = R, W
        out = R*W + R + W + 1 + W + W + R + 1 + 1 + R*3
        self.proj = nn.Linear(d_model, out)

    def forward(self, h: torch.Tensor):
        B, T, D = h.shape
        v = self.proj(h)  # (B,T,out)
        idx = 0
        def take(sz): 
            nonlocal idx
            part = v[..., idx:idx+sz]; idx += sz; return part
        R, W = self.R, self.W
        k_read = take(R*W).view(B,T,R,W)
        beta_read = F.softplus(take(R)).view(B,T,R,1)
        k_write = take(W).view(B,T,W)
        beta_write = F.softplus(take(1)).view(B,T,1)
        erase = torch.sigmoid(take(W)).view(B,T,W)
        write_vec = take(W).view(B,T,W)
        free_gates = torch.sigmoid(take(R)).view(B,T,R,1)
        alloc_gate = torch.sigmoid(take(1)).view(B,T,1)
        write_gate = torch.sigmoid(take(1)).view(B,T,1)
        read_mode = take(R*3).view(B,T,R,3)
        return {
            "k_read": k_read, "beta_read": beta_read,
            "k_write": k_write, "beta_write": beta_write,
            "erase": erase, "write_vec": write_vec,
            "free_gates": free_gates, "alloc_gate": alloc_gate,
            "write_gate": write_gate, "read_mode": read_mode
        }

class DNCformerBlock(nn.Module):
    def __init__(self, d_in: int, d_model: int, R: int, W: int, N: int,
                 heads: int, dropout: float, ffn_mult: float,
                 free_bias: float = 0.0):   # NEW: free_bias default
        super().__init__()
        self.R, self.W, self.N = R, W, N
        self.ctrl = TransformerController(d_in + R*W, d_model, heads=heads, dropout=dropout, ffn_mult=ffn_mult)
        self.if_head = DNCInterfaceHead(d_model, R=R, W=W)
        self.mem = DNCMemory(nr_cells=N, cell_size=W, read_heads=R)
        # Propagate free-bias down to memory (Patch B expects this attr)
        try:
            self.mem.free_bias = float(free_bias)
        except Exception:
            setattr(self.mem, "free_bias", float(free_bias))
        self.out_proj = nn.Linear(d_model + R*W, d_model)  # fuse controller + reads

        # scratch for metrics/fusion
        self.last_metrics = {}
        self.last_read_feat = None  # (B,T,W) pooled read vectors (per time)

    def forward(self, x: torch.Tensor, state: Optional[dict]=None):
        # x: (B,T,d_in); state carries memory fields; if None -> reset
        B, T, D = x.shape
        if state is None:
            state = self.mem.reset(B, device=x.device)

        reads = state["r"].reshape(B, self.R*self.W)  # (B,RW)
        reads_seq = reads.unsqueeze(1).expand(B, T, self.R*self.W)
        ctrl_in = torch.cat([x, reads_seq], dim=-1)  # concat
        h = self.ctrl(ctrl_in, attn_mask=causal_mask(T, device=x.device))  # (B,T,d_model)

        # step over time for memory I/O
        r_list = []
        new_state = state
        iface = self.if_head(h)
        for t in range(T):
            x_if = {k: v[:,t] for k,v in iface.items()}
            r_t, new_state = self.mem(x_if, new_state)  # r_t: (B,R,W)
            r_list.append(r_t)

        Rseq = torch.stack(r_list, dim=1)  # (B,T,R,W)
        reads_flat = Rseq.reshape(B,T,self.R*self.W)
        fused = torch.cat([h, reads_flat], dim=-1)
        y = self.out_proj(fused)  # (B,T,d_model)

        # --- metrics for fusion/regs/logging ---
        try:
            # pooled read features per time step (B,T,W) by averaging heads
            self.last_read_feat = Rseq.mean(dim=2).detach()  # safe feature for fusion
            # read norm (unitless average)
            rnorm = float(self.last_read_feat.norm().item() / max(1, self.last_read_feat.numel()))
        except Exception:
            self.last_read_feat = None
            rnorm = float("nan")

        # write gate mean (if exposed by DNCMemory last step)
        wg_mean = float("nan")
        lm = getattr(self.mem, "_last_metrics", None)
        if isinstance(lm, dict):
            wg_mean = float(lm.get("write_gate_mean", float("nan")))

        self.last_metrics = {"read_vec_norm": rnorm, "write_gate_mean": wg_mean}
        return y, new_state


#### 2.4 Parallel Enrichment Block (Transformer path ‖ DNCformer path + gating)

In [75]:
class VanillaTransformerBlock(nn.Module):
    def __init__(self, d_model: int, heads: int, dropout: float=0.1, ffn_mult: float=4.0):
        super().__init__()
        self.collect_metrics = False
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, int(d_model*ffn_mult)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(d_model*ffn_mult), d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]=None, gate_override: None = None):
        h = self.ln1(x)
        a, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False)
        h = x + self.dropout(a)
        z = self.ln2(h)
        z2 = self.ff(z)
        return h + self.dropout(z2)

class ParallelEnrichmentBlock(nn.Module):
    def __init__(self, d_model: int, d_in: int, R: int, W: int, N: int,
                 heads: int = 4, dropout: float = 0.1, ffn_mult: float = 4.0,
                 block_index: int = 0,               # (existing)
                 gate_bias_init: float = -1.0):       # <<< NEW DEFAULT ADDED
        super().__init__()
        self.d_model = d_model
        self.block_index = block_index

        # Vanilla branch
        self.vanilla = VanillaTransformerBlock(d_model, heads=heads, dropout=dropout, ffn_mult=ffn_mult)

        # E11 per-block overrides
        N_, W_, R_ = N, W, R
        self.gate_temp = getattr(CFG, "gate_temp", 1.0)
        fbias = 0.0
        if isinstance(getattr(CFG, "per_block_cfg", None), (list, tuple)):
            if self.block_index < len(CFG.per_block_cfg) and CFG.per_block_cfg[self.block_index] is not None:
                blk = CFG.per_block_cfg[self.block_index]
                N_ = int(blk.get("N", N_)); W_ = int(blk.get("W", W_)); R_ = int(blk.get("R", R_))
                self.gate_temp = float(blk.get("gate_temp", self.gate_temp))
                fbias = float(blk.get("free_bias", 0.0))
        if getattr(CFG, "per_block_free_bias", None) and self.block_index < len(CFG.per_block_free_bias):
            fbias = float(CFG.per_block_free_bias[self.block_index])

        # E10 multi-experts (shared controller IO signature)
        K = int(getattr(CFG, "mem_experts", 1))
        self.mem_experts = K
        if K == 1:
            self.dncblocks = nn.ModuleList([DNCformerBlock(d_in=d_in, d_model=d_model, R=R_, W=W_, N=N_,
                                                           heads=heads, dropout=dropout, ffn_mult=ffn_mult,
                                                           free_bias=fbias)])
        else:
            Ns = getattr(CFG, "expert_N", [N_]*K)
            self.dncblocks = nn.ModuleList([
                DNCformerBlock(d_in=d_in, d_model=d_model, R=R_, W=getattr(CFG,"expert_W", W_), N=Ns[i],
                               heads=heads, dropout=dropout, ffn_mult=ffn_mult, free_bias=fbias)
                for i in range(K)
            ])

        # Gate: (vanilla + K experts)
        self.gate = nn.Linear((K+1) * d_model, (K+1))
        nn.init.constant_(self.gate.bias, float(gate_bias_init))

        # E12 fusion MLP
        self.fusion_enable = bool(getattr(CFG, "fusion_enable", False))
        if self.fusion_enable:
            fuse_in = d_model + W_   # concat [x, pooled reads]
            hidden = int(CFG.fusion_hidden_mult * d_model)
            self.fuse_ln = nn.LayerNorm(fuse_in)
            self.fuse_mlp = nn.Sequential(
                nn.Linear(fuse_in, hidden),
                nn.GELU(),
                nn.Dropout(getattr(CFG, "fusion_drop", 0.0)),
                nn.Linear(hidden, d_model),
            )

        self.dropout = nn.Dropout(dropout)
        self.pre_gate_ln = nn.LayerNorm((K+1) * d_model)
        self.last_metrics = {}

    def forward(self, x: torch.Tensor, dnc_state=None, gate_override: float = None):
        B, T, D = x.shape
        mask = causal_mask(T, device=x.device)

        # 1) vanilla path
        vt = self.vanilla(x, attn_mask=mask)

        # 2) memory experts
        states_in = dnc_state if isinstance(dnc_state, (list, tuple)) else [dnc_state]*self.mem_experts
        dts, states_out, per_mem_metrics = [], [], []
        for m, st in zip(self.dncblocks, states_in):
            dt, st2 = m(x, state=st)
            dts.append(dt); states_out.append(st2)
            per_mem_metrics.append(getattr(m, "last_metrics", {}) or {})

        # 3) fusion (real read features if available)
        if self.fusion_enable:
            r_feat = getattr(self.dncblocks[0], "last_read_feat", None)
            if r_feat is None:
                r_feat = torch.zeros(B,T,self.dncblocks[0].W, device=x.device, dtype=vt.dtype)
            fuse_in = torch.cat([x, r_feat], dim=-1)
            delta = self.fuse_mlp(self.fuse_ln(fuse_in))
            vt = vt + delta
            self.last_metrics["fusion_delta_norm"] = float(delta.norm().detach().item() / max(1, delta.numel()))

        # 4) gate mixture
        paths = [vt] + dts
        z = self.pre_gate_ln(torch.cat(paths, dim=-1))
        logits = self.gate(z) / max(1e-6, float(getattr(CFG, "expert_gate_temp", getattr(self, "gate_temp", 1.0))))
        if gate_override is not None:
            g_mem = float(gate_override)
            g_vec = torch.zeros_like(logits)
            g_vec[...,0] = 1.0 - g_mem
            if self.mem_experts > 0:
                g_vec[...,1:] = g_mem / float(self.mem_experts)
            pi = g_vec
        else:
            pi = torch.softmax(logits, dim=-1)

        g_mem_synth = pi[...,1:].sum(dim=-1, keepdim=True)
        out = sum(pi[...,i:i+1]*p for i,p in enumerate(paths))
        out = self.dropout(out)

        with torch.no_grad():
            pi_mean = pi.mean(dim=(0,1))     # (K+1,)
            H = - (pi * (pi.clamp_min(1e-9).log())).sum(dim=-1).mean().item()
            self.last_metrics.update({
                "experts_pi_mean": [float(v) for v in pi_mean.detach().cpu()],
                "experts_pi_entropy": float(H),
                "write_gate_mean": float(_mean_safely([m.get("write_gate_mean", float("nan")) for m in per_mem_metrics])),
            })
        return out, (states_out if self.mem_experts>1 else states_out[0]), g_mem_synth


#### 2.5. Frozen Base LLM + N Enrichment Blocks

In [76]:
def load_base_model(model_id: str):
    tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16) if amp_dtype!=torch.float32 else None,
        device_map=None,
        trust_remote_code=True,
        attn_implementation="sdpa",
    ).to("cuda")
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    requires_grad_(model, False)
    model.config.output_hidden_states = True
    model.config.use_cache = False
    return tok, model

class DNCFormerHead(nn.Module):
    def __init__(self, base: AutoModelForCausalLM, cfg):
        super().__init__()
        self.base = base
        d_model = base.config.hidden_size if cfg.d_model is None else cfg.d_model
        self.d_model = d_model
        self.blocks = nn.ModuleList([
            ParallelEnrichmentBlock(
                d_model=d_model, d_in=d_model,
                R=cfg.dnc_read_heads, W=cfg.dnc_cell_size, N=cfg.dnc_nr_cells,
                heads=cfg.attn_heads, dropout=cfg.attn_dropout,
                ffn_mult=cfg.ffn_mult, gate_bias_init=cfg.gate_bias_init
            ) for _ in range(cfg.n_blocks)
        ])
        self.proj_out = nn.Identity()

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]=None):
        with torch.no_grad():
            with sdpa_ctx():
                out = self.base(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False)
        h = out.hidden_states[-1]  # (B,T,d_model)
        dnc_states = [None] * len(self.blocks)
        gates = []
        for i, blk in enumerate(self.blocks):
            st_in = dnc_states[i]
            # Defensive wrap: if block has K>1 experts, ensure list state shape
            K = int(getattr(blk, "mem_experts", 1))
            if K > 1 and not isinstance(st_in, (list, tuple)):
                st_in = [st_in] * K
            h, dnc_states[i], g = blk(h, dnc_state=st_in)
            gates.append(g.detach())



    def forward_with_metrics(self, input_ids: torch.Tensor, attention_mask: "Optional[torch.Tensor]" = None,
                             gate_override: "Optional[float]" = None):
        with torch.no_grad():
            with sdpa_ctx():
                out = self.base(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False)
        h = out.hidden_states[-1].to(device)  # ensure CUDA
        dnc_states = [None] * len(self.blocks)
        gates_det, gates_raw, per_block = [], [], []

        # enable metrics collection per block (if supported)
        for blk in self.blocks:
            if hasattr(blk, "collect_metrics"):
                blk.collect_metrics = True

        for i, blk in enumerate(self.blocks):
            st_in = dnc_states[i]
            K = int(getattr(blk, "mem_experts", 1))
            if K > 1 and not isinstance(st_in, (list, tuple)):
                st_in = [st_in] * K
            h, dnc_states[i], g = blk(h, dnc_state=st_in, gate_override=gate_override)
            gates_raw.append(g)
            gates_det.append(g.detach())
            per_block.append(getattr(blk, "last_metrics", {}) or {})

            if hasattr(blk, "collect_metrics"):
                blk.collect_metrics = False

        # lm head on its own device, then back to CUDA
        lm_dev = self.base.lm_head.weight.device
        y = self.proj_out(h).to(lm_dev, dtype=self.base.lm_head.weight.dtype)
        logits = self.base.lm_head(y).to(device)

        # --- aux dict (add alias "blocks" for compatibility) ---
        aux = {"per_block": per_block, "blocks": per_block, "gates_raw": gates_raw, "gates_detached": gates_det}
        aux["g_entropy_block"] = []
        for g in gates_det:
            _, _, ent = _gate_metrics(g)
            aux["g_entropy_block"].append(ent)
        return logits, gates_det, aux



## 3. Data Generation

#### 3.1 Synthetic tasks + simple instruction-following

In [77]:
def make_copy_task(batch, T, vocab=50):
    x = torch.randint(5, vocab, (batch, T))
    return x

def make_reverse_task(batch, T, vocab=50):
    x = torch.randint(5, vocab, (batch, T))
    y = torch.flip(x, dims=[1])
    return x, y

def make_needle_task(batch, T, needle_len=5, vocab=100):
    x = torch.randint(5, vocab, (batch, T))
    for b in range(batch):
        start = random.randint(0, T-needle_len-5)
        needle = torch.randint(5, vocab, (needle_len,))
        x[b, start:start+needle_len] = needle
        x[b, -1] = needle[0]
    return x

INSTR_PAIRS = [
    ("Reverse the string: abcd", "dcba"),
    ("Add two numbers: 7 + 12", "19"),
    ("Instruction: say hello", "hello"),
    ("Uppercase this: cat", "CAT"),
]

def tokenize_instruction_pairs(tok, pairs, max_len):
    texts = [f"### Instruction:\n{p}\n\n### Response:\n" for p,_ in pairs]
    labels = [ans for _,ans in pairs]
    input_ids = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=max_len).input_ids
    label_ids = tok(labels, return_tensors="pt", padding=True, truncation=True, max_length=max_len).input_ids
    # naive pack: just use inputs; real packing/labels can be elaborated
    return input_ids, label_ids

def make_repeat_copy(batch: int, T: int, repeat_min=2, repeat_max=4, vocab=100, pad_id: int = 0, device: str = "cpu") -> torch.Tensor:
    L = max(1, T // 2)
    x = torch.randint(1, vocab, (batch, L), device=device, dtype=torch.long)
    r = torch.randint(repeat_min, repeat_max + 1, (batch,), device=device)
    out = torch.full((batch, T), pad_id, dtype=torch.long, device=device)
    for i in range(batch):
        seq = x[i].repeat_interleave(int(r[i].item()))
        out[i, :min(T, seq.numel())] = seq[:T]
    return out

def make_n_back(batch: int, T: int, n: int = 3, vocab=50) -> torch.Tensor:
    return torch.randint(1, vocab, (batch, T))

def format_instruction(tok, instr: str, resp: str, max_len=256) -> torch.Tensor:
    prompt = f"### Instruction:\n{instr}\n\n### Response:\n{resp}"
    return tok(prompt, return_tensors="pt", truncation=True, max_length=max_len).input_ids[0]


#### 3.2 HF dataset integration

In [78]:
def hf_instruction_loader(dataset_name="tatsu-lab/alpaca", split="train", text_field=("instruction","output"), max_items=5000):
    try:
        from datasets import load_dataset
    except Exception:
        print("Install 'datasets' to enable HF loading: pip install datasets -q")
        return []
    ds = load_dataset(dataset_name, split=split)
    pairs = []
    i_field, o_field = text_field
    for ex in ds:
        instr = ex.get(i_field, ""); out = ex.get(o_field, "")
        if instr and out: pairs.append((instr, out))
        if len(pairs) >= max_items: break
    random.shuffle(pairs); return pairs

def make_hf_batch(tok, pairs: List[Tuple[str,str]], batch: int, max_len=256) -> torch.Tensor:
    if not pairs:
        return torch.full((batch, max_len), tok.pad_token_id, dtype=torch.long)
    batch_ids = []
    for _ in range(batch):
        instr, out = random.choice(pairs)
        ids = format_instruction(tok, instr, out, max_len=max_len)
        batch_ids.append(ids)
    maxL = min(max(x.size(0) for x in batch_ids), max_len)
    out_ids = torch.full((batch, maxL), tok.pad_token_id, dtype=torch.long)
    for i, ids in enumerate(batch_ids):
        ids = ids[:maxL]; out_ids[i, :ids.size(0)] = ids
    return out_ids


#### 3.3 Haystack batch creation

In [79]:
# --- Haystack (needle) eval: long-span retrieval of a key's value ---
def make_haystack_batch(batch: int, T: int = 256, vocab: int = 1024, sentinel: int = 3):
    """
    Build sequences: ... K V ... K ?  (next token should be V)
    Returns: input_ids [B,T], answer_ids [B], query_pos [B] (position of '?')
    """
    assert T >= 12, "T too small for haystack layout"
    x = torch.randint(5, vocab, (batch, T), dtype=torch.long)
    K = torch.randint(5, vocab, (batch,), dtype=torch.long)
    V = torch.randint(5, vocab, (batch,), dtype=torch.long)

    # Place (K,V) in the first half
    p1 = torch.randint(low=T//8, high=T//2 - 2, size=(batch,))
    x[torch.arange(batch), p1] = K
    x[torch.arange(batch), p1 + 1] = V

    # Place (K, sentinel) in the last quarter
    p2 = torch.randint(low=3*T//4, high=T - 2, size=(batch,))
    x[torch.arange(batch), p2] = K
    x[torch.arange(batch), p2 + 1] = sentinel  # '?'
    query_pos = p2 + 1  # position of '?'; we will look at logits at this position

    return x, V, query_pos

#### 3.4 MixtureSampler

In [80]:
class MixtureSampler:
    def __init__(self, gens: List, weights: List[float], names: Optional[List[str]] = None):
        self.gens = gens
        import torch as _t
        self.weights = list(map(float, weights))
        self.p = _t.tensor(self.weights, dtype=_t.float32)  # CPU is fine for multinomial
        self.p /= (self.p.sum() + 1e-8)
        self.names = names if names is not None else [f"g{i}" for i in range(len(gens))]
        self.last_name = None

    def __call__(self, batch: int) -> torch.Tensor:
        import torch as _t
        idx = _t.multinomial(self.p, 1).item()
        self.last_name = self.names[idx]
        return self.gens[idx](batch)
    
    def set_weights(self, weights):
        """Update mixture probabilities at runtime (used by schedules)."""
        import torch as _t
        ws = list(map(float, weights))
        t = _t.tensor(ws, dtype=_t.float32, device=self.p.device)
        s = float(t.sum().item())
        if s <= 0:
            raise ValueError("Mixture weights must sum to > 0")
        self.p = t / s
        self.weights = ws
        # Optional sanity: warn if names length mismatches weights
        if hasattr(self, "names") and len(self.names) != len(ws):
            print(f"[MixtureSampler] Warning: len(names)={len(self.names)} != len(weights)={len(ws)}")

#### 3.5 Mixture Builder

In [81]:
def _build_mixer(tok, weights, hf_dataset="tatsu-lab/alpaca", hf_max_items=2000) -> MixtureSampler:
    """
    Build a mixture of generators: [HF, copy, repeat, nback].
    - If hf_dataset is Alpaca-like (has instruction/output), we use hf_instruction_loader + make_hf_batch.
    - Else we treat it as text-only (e.g., roneneldan/TinyStories), tokenizing and making windowed sequences.
    - If HF fails/empty, we drop it and renormalize over synthetics.
    """
    mx = int(getattr(CFG, "max_seq_len", 256))
    pad_id = getattr(tok, "pad_token_id", 0) or 0

    gens, wts, names = [], [], []

    hf_ok = False
    hf_reason = ""
    gen_hf = None

    if hf_dataset:
        try:
            # Heuristic: use instruction loader for alpaca-like IDs
            if "alpaca" in hf_dataset.lower():
                pairs = hf_instruction_loader(hf_dataset, "train", ("instruction", "output"),
                                              max_items=hf_max_items)
                if pairs:
                    def gen_hf(b): 
                        return make_hf_batch(tok, pairs, b, max_len=mx)
                    hf_ok = True
                else:
                    hf_reason = "hf_instruction_loader returned 0 pairs"
            else:
                # Text-only fallback (e.g., roneneldan/TinyStories)
                from datasets import load_dataset
                # Try streaming first, then non-streaming
                try:
                    ds = load_dataset(hf_dataset, split="train", streaming=True)
                except Exception:
                    ds = load_dataset(hf_dataset, split="train")

                # Find a usable text key
                text_key = None
                common_keys = ("text", "content", "story", "document", "body", "article")

                feats = getattr(ds, "features", None)
                if feats:
                    for k in common_keys:
                        if k in feats:
                            text_key = k
                            break

                if text_key is None:
                    # Probe first example (works for streaming iterable)
                    try:
                        first_ex = next(iter(ds))
                        for k in common_keys:
                            if k in first_ex:
                                text_key = k
                                break
                    except StopIteration:
                        pass

                if text_key is None:
                    hf_reason = "no usable text field (tried: %s)" % ",".join(common_keys)
                else:
                    import random as _rnd
                    import torch
                    samples = []
                    for ex in ds:
                        txt = ex.get(text_key, None)
                        if not txt:
                            continue
                        ids = tok(txt, return_tensors="pt", add_special_tokens=False).input_ids.squeeze(0)
                        if ids.numel() < 8:
                            continue
                        samples.append(ids.cpu())
                        if len(samples) >= int(hf_max_items):
                            break

                    if len(samples) == 0:
                        hf_reason = "collected 0 tokenized samples"
                    else:
                        def gen_hf(b: int) -> torch.Tensor:
                            out = []
                            for _ in range(b):
                                ids = _rnd.choice(samples)
                                n = ids.numel()
                                if n >= mx:
                                    s = _rnd.randint(0, n - mx)
                                    seq = ids[s:s+mx]
                                else:
                                    pad = torch.full((mx - n,), pad_id, dtype=torch.long)
                                    seq = torch.cat([ids, pad], dim=0)
                                out.append(seq.unsqueeze(0))
                            return torch.cat(out, dim=0)
                        hf_ok = True
        except Exception as e:
            hf_reason = f"{type(e).__name__}: {e}"

    # Assemble mixture in the agreed order (hf, copy, repeat, nback)
    if hf_ok and gen_hf is not None:
        gens.append(gen_hf); wts.append(weights[0]); names.append("hf")
        s_w = list(weights[1:])
    else:
        if hf_dataset:
            print(f"HF dataset unavailable or empty; using synthetic only. reason={hf_reason}")
        s_w = list(weights[1:])  # keep caller's relative weights for synthetics

    # Synthetic tasks (your existing closures)
    def gen_copy(b):   return make_copy_task(b, T=min(mx, 128), vocab=100)
    def gen_repeat(b): return make_repeat_copy(b, T=min(mx, 128), vocab=100, pad_id=pad_id, device="cpu")
    def gen_nback(b):  return make_n_back(b, T=min(mx, 128), n=5, vocab=50)

    gens.extend([gen_copy, gen_repeat, gen_nback])
    wts.extend(s_w)
    names.extend(["copy", "repeat", "nback"])

    return MixtureSampler(gens, wts, names=names)


## 4 Logging utilities

#### 4.1 TensorBoard Logger

In [82]:
import os, time, json, re

try:
    from torch.utils.tensorboard import SummaryWriter
    TB_AVAILABLE = True
except Exception as e:
    print("TensorBoard not available:", e)
    TB_AVAILABLE = False

class TBLogger:
    def __init__(self, logdir: Optional[str] = None, run_name: Optional[str] = None):
        self.enabled = TB_AVAILABLE
        self.writer = None
        if not self.enabled:
            return
        logdir = logdir or "./runs"
        run_name = run_name or time.strftime("dncformer-%Y%m%d-%H%M%S")
        self.path = os.path.join(logdir, run_name)
        os.makedirs(self.path, exist_ok=True)
        self.writer = SummaryWriter(self.path)

    def log_scalars(self, step: int, loss: float, lr: float, gate_means: List[float]):
        if not (self.enabled and self.writer): return
        self.writer.add_scalar("train/loss", loss, step)
        self.writer.add_scalar("train/lr", lr, step)
        for i, gm in enumerate(gate_means):
            self.writer.add_scalar(f"gates/block_{i}_mean", gm, step)

    def add_image_hw(self, tag: str, img_hw: "torch.Tensor", step: int):
        if not (self.enabled and self.writer): return
        import torch
        x = img_hw
        if x.device.type != "cpu": x = x.cpu()
        x = x.float()
        if x.numel() > 0:
            m, M = x.min(), x.max()
            x = (x - m) / (M - m + 1e-8)
        x = x.unsqueeze(0)  # [1,H,W]
        self.writer.add_image(tag, x, step, dataformats='CHW')

    def add_histogram(self, tag: str, values: "torch.Tensor", step: int, bins: int = 50):
        if not (self.enabled and self.writer): return
        import torch
        v = values
        if isinstance(v, torch.Tensor):
            v = v.detach()
            if v.device.type != "cpu": v = v.cpu()
            v = v.reshape(-1).float()
        self.writer.add_histogram(tag, v, global_step=step, bins=bins)

    def add_text(self, tag: str, text: str, step: int):
        if not (self.enabled and self.writer): return
        self.writer.add_text(tag, text, step)

    def flush(self):
        if self.enabled and self.writer: self.writer.flush()

    def close(self):
        if self.enabled and self.writer: self.writer.close()
        
        
def start_tb_run(label: str = None, logdir: str = "./runs"):
    """Close any existing TB writer and open a fresh run dir with timestamp + optional label."""
    if not TB_AVAILABLE:
        print("TensorBoard not available; skipping start_tb_run.")
        return False
    global tb
    # Close/flush a previous writer if present
    try:
        if 'tb' in globals() and isinstance(tb, TBLogger) and getattr(tb, "writer", None):
            tb.flush(); tb.close()
    except Exception:
        pass

    ts = time.strftime("dncformer-%Y%m%d-%H%M%S")
    run_name = ts
    if label:
        safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", str(label))
        run_name = f"{ts}-{safe}"

    tb = TBLogger(logdir=logdir, run_name=run_name)
    tb.add_text("run/label", str(label or "unlabeled"), 0)
    print("TB run started:", getattr(tb, "path", None))
    return True

# TODO: wire this into start_tb_run above at some stage
# try:
#     _orig_start_tb_run = start_tb_run
#     def start_tb_run(label: str = None):
#         tb_obj = _orig_start_tb_run(label)
#         echo_cfg(to_console=False, to_tb=True, tag="cfg/json")
#         return tb_obj
# except Exception:
#     pass


#### 4.2 Memory tracer & TensorBoard memory visualizer

In [83]:
from contextlib import contextmanager

class MemoryTracer:
    def __init__(self): self.frames = []
    def __call__(self, frame): self.frames.append(frame)
    def stack(self, key, bidx: int = 0):
        import torch
        return torch.stack([f[key][bidx] for f in self.frames], dim=0)  # [T, ...]

@contextmanager
def trace_memory(module):
    tracer = MemoryTracer()
    memories = [m for m in module.modules() if m.__class__.__name__ == "DNCMemory"]
    for m in memories: m.probe = tracer
    try:
        yield tracer
    finally:
        for m in memories: m.probe = None

@torch.no_grad()
def visualize_memory_tb(head, tok, writer, global_step: int, prompt="### Instruction:\nRemember A then B; later return A.\n\n### Response:\n", max_T=64):
    if writer is None: return
    enc = tok(prompt, return_tensors="pt").to(device)
    enc.input_ids = enc.input_ids[:, :max_T]
    from torch.amp import autocast
    with trace_memory(head) as tracer:
        with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
            _ = head(enc.input_ids)

    u = tracer.stack("u")           # [T, N]
    Mnorm = tracer.stack("M_norm")  # [T, N]
    rw = tracer.stack("rw")         # [T, R, N]

    # Log images
    writer.add_image("memory/u_TxN", (u.T).unsqueeze(0), global_step, dataformats='CHW')
    writer.add_image("memory/Mnorm_TxN", (Mnorm.T).unsqueeze(0), global_step, dataformats='CHW')

    top_read = rw.argmax(dim=-1).float()  # [T,R]
    top_read_img = (top_read / max(1, rw.size(-1)-1)).T  # [R,T]
    writer.add_image("memory/top_read_RxT", top_read_img.unsqueeze(0), global_step, dataformats='CHW')


## 5. Training

#### 5.1 Schedulers

In [84]:
# --- LR Scheduler: linear warmup -> cosine decay (nonzero start) ---

def make_warmup_cosine_scheduler(optimizer, warmup_steps: int, total_steps: int, min_lr_ratio: float = 0.10):
    """
    Warms up linearly from 0->1 over warmup_steps; cosine decays from 1->min_lr_ratio for the remainder.
    Uses step_idx+1 to avoid zero LR at the start.
    """
    warmup_steps = max(1, int(warmup_steps))
    total_steps = max(warmup_steps + 1, int(total_steps))

    def lr_lambda(step_idx: int):
        s = step_idx + 1
        if s <= warmup_steps:
            return s / float(warmup_steps)
        progress = (s - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return LambdaLR(optimizer, lr_lambda)


#### 5.2 Training Utilities

In [85]:
from torch.amp import GradScaler, autocast

def build_model_and_tokenizer():
    tok, base = load_base_model(CFG.base_model_id)
    if CFG.d_model is None:
        CFG.d_model = base.config.hidden_size
    head = DNCFormerHead(base, CFG).to(device)
    if CFG.use_torch_compile and hasattr(torch, 'compile'):
        head = torch.compile(head)
    #print("Trainable params in head:", count_params(head))
    return tok, head

def make_optimizer(model):
    params = [p for p in model.parameters() if p.requires_grad]
    return torch.optim.AdamW(params, lr=CFG.lr, weight_decay=CFG.weight_decay)

def lm_shift_labels(input_ids, logits, tok):
    labels = input_ids[:, 1:].contiguous()
    logits = logits[:, :-1].contiguous()
    return F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=tok.pad_token_id)

#### 5.3 Experimental training loop, stage 1 experiments

In [86]:
from torch.amp import GradScaler, autocast

def train_experiment(
    steps: int = None,
    batch_size: int = None,
    warmup_steps: int = None,
    min_lr_ratio: float = 0.1,
    mixture_weights=(0.4, 0.2, 0.2, 0.2),
    hf_dataset: str = "tatsu-lab/alpaca",
    hf_max_items: int = 5000,
    log_every: int = None,
    viz_memory_after: bool = False,
    viz_prompt: str = "### Instruction:\nSay hello in one word\n\n### Response:\n",
    viz_max_T: int = 64,
    # NEW: optional schedules (each a list of (until_step, value))
    mixture_schedule=None,         # e.g., [(100, (0.3,0.3,0.2,0.2)), (None, (0.4,0.2,0.2,0.2))]
    gate_temp_schedule=None,       # e.g., [(100, 0.8), (None, 1.0)]
    gate_reg_schedule=None,        # e.g., [(100, 3e-4), (None, 2e-4)]
):
    cfg = CFG
    steps = int(steps or cfg.train_steps)
    batch_size = int(batch_size or getattr(cfg, "batch_size", 8))
    warmup_steps = int(warmup_steps if warmup_steps is not None else getattr(cfg, "warmup_steps", max(10, steps // 20)))
    log_every = int(log_every if log_every is not None else getattr(cfg, "log_every", 10))

    # Model & tokenizer
    tok, head = build_model_and_tokenizer()
    optim = make_optimizer(head)
    scheduler = make_warmup_cosine_scheduler(optim, warmup_steps, steps, min_lr_ratio=min_lr_ratio)

    # Sampler (and allow schedules to change weights)
    mixer = _build_mixer(tok, mixture_weights, hf_dataset=hf_dataset, hf_max_items=hf_max_items)

    def _apply_schedules(step: int, mix_name_hint=None):
        # mixture schedule
        if mixture_schedule:
            for until, ws in mixture_schedule:
                if until is None or step <= int(until):
                    try:
                        mixer.set_weights(ws)
                    except Exception:
                        pass
                    break
        # gate temp schedule
        if gate_temp_schedule:
            for until, temp in gate_temp_schedule:
                if until is None or step <= int(until):
                    setattr(CFG, "gate_temp", float(temp))
                    break
        # gate reg schedule
        if gate_reg_schedule:
            for until, lam in gate_reg_schedule:
                if until is None or step <= int(until):
                    setattr(CFG, "gate_reg_lambda", float(lam))
                    break

    # TB setup
    global tb
    if TB_AVAILABLE:
        try:
            tb  # NameError if not defined
        except NameError:
            tb = TBLogger(logdir="./runs")
        if not isinstance(tb, TBLogger) or getattr(tb, "writer", None) is None:
            tb = TBLogger(logdir="./runs")
        tblog = True
        tb.add_text("run/config", json.dumps({
            "steps": steps,
            "batch_size": batch_size,
            "warmup_steps": warmup_steps,
            "mixture_weights": list(mixture_weights),
            "mixture_schedule": mixture_schedule,
            "gate_temp_schedule": gate_temp_schedule,
            "gate_reg_schedule": gate_reg_schedule,
        }, indent=2), 0)
    else:
        tblog = False

    head.train()
    scaler = torch.amp.GradScaler('cuda', enabled=(amp_dtype in (torch.float16, torch.bfloat16)))

    for step in range(1, steps + 1):
        _apply_schedules(step)

        in_ids = mixer(batch_size).to(device)
        with torch.autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype != torch.float32)):
            logits, gates, aux = head.forward_with_metrics(in_ids, gate_override=getattr(CFG, "force_g", None))
            loss = lm_shift_labels(in_ids, logits, tok)
            
            # Optional: encourage expert usage diversity (higher entropy across (vanilla + K experts))
            div_lam = float(getattr(CFG, "expert_diversity_lambda", 0.0))
            if div_lam > 0.0 and isinstance(aux, dict) and "per_block" in aux:
                try:
                    ent_vals = []
                    for m in aux["per_block"]:
                        ent = float(m.get("experts_pi_entropy", float("nan")))
                        if not math.isnan(ent):
                            ent_vals.append(ent)
                    if ent_vals:
                        loss = loss - div_lam * (sum(ent_vals) / len(ent_vals))
                except Exception as _e:
                    # keep robust, don’t crash training if a metric is missing
                    pass

            # Optional: mild gate-usage regularizer
            lam = float(getattr(CFG, "gate_reg_lambda", 0.0))
            if lam > 0 and isinstance(gates, (list, tuple)) and len(gates) > 0:
                reg = 0.0
                for g in gates:
                    g2 = _reduce_gate_tensor(g)
                    # Encourage decisive routing on memory batches only? For now, uniform
                    reg = reg + (g2.mean() * 0.0 + (g2 * (1 - g2)).mean())  # small entropy-like penalty
                loss = loss + lam * reg
            
            lam_w = float(getattr(CFG, "write_reg_lambda", 0.0))
            if lam_w > 0.0 and isinstance(aux, dict) and "blocks" in aux:
                # Only apply on memory-tagged batches if flag is on
                apply_reg = True
                if bool(getattr(CFG, "reg_only_on_memory_batches", True)):
                    # crude heuristic: if mixture sampler last batch name contains 'copy'/'repeat'/'nback'
                    bn = getattr(mixer, "last_name", "")
                    apply_reg = any(tk in bn for tk in ("copy","repeat","nback"))
                if apply_reg:
                    w_means = [b.get("write_gate_mean") for b in aux["blocks"] if isinstance(b, dict)]
                    w_means = [float(x) for x in w_means if isinstance(x, (float,int))]
                    if w_means:
                        loss = loss + lam_w * (sum(w_means)/len(w_means))

        use_scaler = (amp_dtype in (torch.float16, torch.bfloat16))
        if use_scaler:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), CFG.grad_clip)
            scaler.step(optim); scaler.update(); optim.zero_grad(set_to_none=True)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), CFG.grad_clip)
            optim.step(); optim.zero_grad(set_to_none=True)

        scheduler.step()

        # --- Logging (TB) ---
        if step % log_every == 0 and tblog:
            # global scalars
            tb.writer.add_scalar("train/loss", float(loss.item()), step)
            tb.writer.add_scalar("train/lr", float(scheduler.get_last_lr()[0]), step)

            # gate metrics per block (global)
            if isinstance(gates, (list, tuple)):
                for bi, g in enumerate(gates):
                    g_mean, g_frac, g_ent = _gate_metrics(g)
                    tb.writer.add_scalar(f"gates/block_{bi}_mean", g_mean, step)
                    tb.writer.add_scalar(f"gates/block_{bi}_frac>0.5", g_frac, step)
                    tb.writer.add_scalar(f"gates/block_{bi}_entropy", g_ent, step)

                # per-task metrics (using the sampler's last batch type)
                mix_name = getattr(mixer, "last_name", None) or "unknown"
                tb.writer.add_scalar(f"loss_by_task/{mix_name}", float(loss.item()), step)
                for bi, g in enumerate(gates):
                    g_mean, g_frac, g_ent = _gate_metrics(g)
                    tb.writer.add_scalar(f"gates_by_task/block_{bi}_mean/{mix_name}", g_mean, step)
                    tb.writer.add_scalar(f"gates_by_task/block_{bi}_frac>0.5/{mix_name}", g_frac, step)

                # optional: quartiles (block 0)
                if len(gates) > 0:
                    g0 = _reduce_gate_tensor(gates[0].detach())
                    T = g0.size(1); q = max(1, T // 4)
                    slices = [(0, q), (q, 2*q), (2*q, 3*q), (3*q, T)]
                    for qi, (s, e) in enumerate(slices, start=1):
                        tb.writer.add_scalar(f"gates/block0_q{qi}_mean/{mix_name}", float(g0[:, s:e].mean().item()), step)
                
                 # E10: experts distributions
                for bi, b in enumerate(aux.get("blocks", [])):
                    if isinstance(b, dict) and "experts_pi_mean" in b:
                        pi_mean = b["experts_pi_mean"]  # list of length K+1
                        for j, v in enumerate(pi_mean):
                            tb.writer.add_scalar(f"experts/block_{bi}/pi_mean_{j}", float(v), step)
                    if isinstance(b, dict) and "experts_pi_entropy" in b:
                        tb.writer.add_scalar(f"experts/block_{bi}/pi_entropy", float(b["experts_pi_entropy"]), step)
            
                    # E12: fusion delta norm
                    if isinstance(b, dict) and "fusion_delta_norm" in b:
                        tb.writer.add_scalar(f"fusion/block_{bi}/delta_norm", float(b["fusion_delta_norm"]), step)
            
                    # E13: write sparsity signal we used
                    if isinstance(b, dict) and "write_gate_mean" in b:
                        tb.writer.add_scalar(f"reg/block_{bi}/write_gate_mean", float(b["write_gate_mean"]), step)
            
                # (optional) print current mixture weights if schedule is active
                if "mixture_schedule" in locals() or "mixture_schedule" in globals():
                    ws = getattr(mixer, "weights", None)
                    if ws:
                        tb.writer.add_text("schedule/mixture_weights", str(ws), step)
            
                # keep console echo compact (existing print is fine); optionally add top‑1 expert
                if isinstance(aux.get("blocks"), list) and aux["blocks"]:
                    b0 = aux["blocks"][0]
                    if "experts_pi_mean" in b0:
                        top = int(max(range(len(b0["experts_pi_mean"])), key=lambda k: b0["experts_pi_mean"][k]))
                        print(f"  [experts] block0 top={top} pi={b0['experts_pi_mean']}")
                        
            # Expert routing diagnostics (if available from per_block metrics)
            try:
                if isinstance(aux, dict) and "per_block" in aux:
                    for bi, m in enumerate(aux["per_block"]):
                        ent = m.get("experts_pi_entropy", None)
                        if ent is not None:
                            tb.writer.add_scalar(f"experts/block_{bi}_pi_entropy", float(ent), step)
                        pi_mean = m.get("experts_pi_mean", None)  # list length K+1 (vanilla + experts)
                        if isinstance(pi_mean, (list, tuple)) and len(pi_mean) > 0:
                            for j, v in enumerate(pi_mean):
                                tb.writer.add_scalar(f"experts/block_{bi}_pi_mean/path_{j}", float(v), step)
                        # optional: write activity
                        if "write_gate_mean" in m:
                            tb.writer.add_scalar(f"memory/block_{bi}_write_gate_mean", float(m["write_gate_mean"]), step)
            except Exception as _e:
                pass  # keep logging robust

                
            tb.flush()

        # --- Console echo (always) ---
        if step % log_every == 0:
            g_means_print = []
            if isinstance(gates, (list, tuple)):
                for g in gates:
                    g_means_print.append(float(_reduce_gate_tensor(g).mean().item()))
            print(f"step {step} | loss {loss.item():.4f} | lr {scheduler.get_last_lr()[0]:.2e} | "
                  f"gates={g_means_print} | mix={getattr(mixer,'last_name','?')}")

    # Optional memory viz
    if viz_memory_after:
        try:
            visualize_memory_tb(head, tok, tb.writer, global_step=steps, prompt=viz_prompt, max_T=viz_max_T)
        except Exception as e:
            print("Memory TB viz skipped:", e)
    
    # flush log      
    if tblog:
        tb.flush()

    return head, tok


#### 5.4 Robust forward for ParallelEnrichmentBlock
TODO: this is a monkeypatch, non-critical but fold in properly when I have time

In [87]:
# Rebind PEB.forward with dtype-harmonization (safe override)
import torch

def _peb_forward_multi(self, x: torch.Tensor, dnc_state=None, gate_override: float = None):
    B, T, D = x.shape
    mask = causal_mask(T, device=x.device)

    # --- Decide a target dtype for the vanilla path: use ln1.weight.dtype
    dtype_v = self.vanilla.ln1.weight.dtype
    x_v = x.to(dtype_v) if x.dtype != dtype_v else x

    # 1) Vanilla path in its own dtype
    vt = self.vanilla(x_v, attn_mask=mask)  # (B,T,D), dtype=dtype_v

    # Optional back-compat alias for mem_experts==1
    if getattr(self, 'mem_experts', 1) == 1 and not hasattr(self, 'dncblock') and hasattr(self, 'dncblocks'):
        try: self.dncblock = self.dncblocks[0]
        except Exception: pass

    # 2) Memory experts: run in their native dtype, then cast to dtype_v
    K = int(getattr(self, 'mem_experts', 1))
    states_in = dnc_state if isinstance(dnc_state, (list, tuple)) else [dnc_state] * K
    dts, states_out, per_mem_metrics = [], [], []
    last_read_feat = None  # (B,T,W) if provided by DNCformerBlock

    for m, st in zip(self.dncblocks, states_in):
        dt, st2 = m(x, state=st)          # compute with original x (native path)
        if dt.dtype != dtype_v:
            dt = dt.to(dtype_v)
        dts.append(dt); states_out.append(st2)
        pm = getattr(m, "last_metrics", {}) or {}
        per_mem_metrics.append(pm)
        # try to harvest a read feature for fusion if the block exposes one
        if last_read_feat is None:
            lf = getattr(m, "last_read_feat", None)
            if lf is not None:
                last_read_feat = lf

    # 3) Optional fusion: use read-hint; align fusion input with fuse_ln dtype
    if getattr(self, 'fusion_enable', False):
        # pick a feature: provided by block or zeros as no-op
        if last_read_feat is None:
            W = getattr(self.dncblocks[0], "W", self.d_model)  # fallback
            last_read_feat = torch.zeros(B, T, W, device=x.device, dtype=dtype_v)
        else:
            last_read_feat = last_read_feat.to(dtype_v)

        # fuse LayerNorm/MLP likely float32; match their parameter dtype
        dtype_f = self.fuse_ln.weight.dtype
        fuse_in = torch.cat([x_v, last_read_feat], dim=-1)
        fuse_in = fuse_in.to(dtype_f) if fuse_in.dtype != dtype_f else fuse_in
        delta = self.fuse_mlp(self.fuse_ln(fuse_in))
        if delta.dtype != vt.dtype:
            delta = delta.to(vt.dtype)
        vt = vt + delta
        self.last_metrics["fusion_delta_norm"] = float(delta.norm().detach().item() / max(1, delta.numel()))

    # 4) Gate over (vanilla + K experts)
    paths = [vt] + dts                             # all in dtype_v
    z = torch.cat(paths, dim=-1)                   # (B,T,(K+1)*D), dtype=dtype_v

    # gate linear likely float32: feed it in its dtype and bring logits back to float32
    dtype_g = self.gate.weight.dtype
    z_g = z.to(dtype_g) if z.dtype != dtype_g else z
    temp = float(getattr(CFG, "expert_gate_temp", getattr(self, "gate_temp", 1.0)))
    logits = self.gate(z_g) / max(1e-6, temp)     # float32

    if gate_override is not None:
        g_mem = float(gate_override)
        pi = torch.zeros_like(logits)             # (B,T,K+1), float32
        pi[..., 0] = 1.0 - g_mem
        if K > 0:
            pi[..., 1:] = g_mem / float(K)
    else:
        pi = torch.softmax(logits, dim=-1)        # float32

    # Synthesize a single memory gate (back-compat); cast pi to dtype_v for the weighted sum
    pi_v = pi.to(dtype_v) if pi.dtype != dtype_v else pi
    g_mem_synth = pi_v[..., 1:].sum(dim=-1, keepdim=True)  # (B,T,1), dtype_v

    out = sum(pi_v[..., i:i+1] * p for i, p in enumerate(paths))  # dtype_v
    out = self.dropout(out)

    # 5) Metrics
    with torch.no_grad():
        pi_mean = pi.mean(dim=(0, 1))
        H = - (pi * (pi.clamp_min(1e-9).log())).sum(dim=-1).mean().item()
        def _safe_mean(vals):
            v = [float(x) for x in vals if isinstance(x, (int, float)) and x == x]
            return float(sum(v) / max(1, len(v))) if v else float('nan')
        wg = _safe_mean([m.get("write_gate_mean", float('nan')) for m in per_mem_metrics])
        self.last_metrics.update({
            "experts_pi_mean": [float(v) for v in pi_mean.detach().cpu()],
            "experts_pi_entropy": float(H),
            "write_gate_mean": wg,
        })

    return out, (states_out if K > 1 else states_out[0]), g_mem_synth

ParallelEnrichmentBlock.forward = _peb_forward_multi
print("Patched ParallelEnrichmentBlock.forward -> multi-expert aware (dncblocks) + dtype-harmonization")

Patched ParallelEnrichmentBlock.forward -> multi-expert aware (dncblocks) + dtype-harmonization


#### 5.5 Guardrails warnings (for __debug__ mode)

In [88]:
import warnings, numbers, torch

def _normalize_weights(ws):
    ws = list(map(float, ws))
    s = sum(ws)
    if s <= 0:
        warnings.warn("[guard] mixture_weights sum<=0; normalizing to uniform.")
        n = max(1, len(ws)); return [1.0/n]*n
    return [w/(s+1e-8) for w in ws]

def _check_schedule(name, sched):
    if sched is None: return True
    if not isinstance(sched, (list, tuple)):
        warnings.warn(f"[guard] {name} must be list[(until,value)]"); return False
    ok = True
    for i, it in enumerate(sched):
        if not (isinstance(it, (list, tuple)) and len(it)==2):
            warnings.warn(f"[guard] {name}[{i}] not (until,value)"); ok=False; continue
        u,v = it
        if u is not None and not isinstance(u, numbers.Number):
            warnings.warn(f"[guard] {name}[{i}].until not a number"); ok=False
    return ok

# Wrap train_experiment
try:
    if 'train_experiment' in globals() and not getattr(train_experiment, "__guarded__", False):
        _orig_train_experiment = train_experiment
        def train_experiment(*args, **kwargs):
            if __debug__:
                if "mixture_weights" in kwargs:
                    kwargs["mixture_weights"] = _normalize_weights(kwargs["mixture_weights"])
                for nm in ("mixture_schedule","gate_temp_schedule","gate_reg_schedule"):
                    _check_schedule(nm, kwargs.get(nm))
                # Optional: enforce positive gate_temp if present in CFG
                if hasattr(CFG, "gate_temp") and CFG.gate_temp <= 0:
                    warnings.warn("[guard] CFG.gate_temp <= 0; routing may degenerate.")
            return _orig_train_experiment(*args, **kwargs)
        train_experiment.__guarded__ = True
except Exception as e:
    print("[guard] train_experiment wrapper skipped:", e)

# Wrap DNCFormerHead.forward
try:
    if 'DNCFormerHead' in globals() and not getattr(DNCFormerHead, "__guarded__", False):
        _orig_forward = DNCFormerHead.forward
        def forward(self, input_ids: torch.Tensor, attention_mask=None):
            if __debug__:
                if not isinstance(input_ids, torch.Tensor) or input_ids.dtype != torch.long:
                    warnings.warn(f"[guard] input_ids dtype should be torch.long (got {getattr(input_ids,'dtype',None)})")
            return _orig_forward(self, input_ids, attention_mask=attention_mask)
        DNCFormerHead.forward = forward
        DNCFormerHead.__guarded__ = True
except Exception as e:
    print("[guard] DNCFormerHead wrapper skipped:", e)

## 6. Eval harnesses

#### 6.1 copy, reverse, needle-in-haystack

In [89]:
from torch.amp import autocast

@torch.no_grad()
def eval_copy(head, tok, batch=4, T=64, vocab=100):
    x = make_copy_task(batch, T, vocab=vocab).to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, :-1] == x[:, 1:]).float().mean().item()
    print("copy acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}

@torch.no_grad()
def eval_reverse(head, tok, batch=4, T=64, vocab=100):
    x, y = make_reverse_task(batch, T, vocab=vocab)
    x, y = x.to(device), y.to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, :-1] == y[:, 1:]).float().mean().item()
    print("reverse acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}

@torch.no_grad() # WARNING - planned for depreciation
def eval_needle(head, tok, batch=4, T=128, vocab=200):
    x = make_needle_task(batch, T, vocab=vocab).to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, -1] == x[:, -1]).float().mean().item()
    print("needle acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}

@torch.no_grad()
def evaluate_haystack(head, steps: int = 50, batch: int = 16, T: int = 256, vocab: int = 1024,
                      tb_step: int = None, fast: bool = False):
    """
    Long-span retrieval probe: ... K V ... K ? => predict V at '?'.
    fast=True uses inference_mode() and smaller defaults to speed up sweeps.
    """
    from torch.amp import autocast
    head.eval()

    # Fast path defaults (can still be overridden by explicit args)
    if fast:
        steps = min(steps, 10)
        batch = min(batch, 8)
        T = min(T, 128)

    device_ = next(head.parameters()).device
    use_amp = (amp_dtype in (torch.float16, torch.bfloat16)) and torch.cuda.is_available()

    accs, losses = [], []
    ctx = torch.inference_mode() if fast else torch.no_grad()
    with ctx:
        for _ in range(steps):
            x, V, qpos = make_haystack_batch(batch, T=T, vocab=vocab)
            x = x.to(device_); V = V.to(device_); qpos = qpos.to(device_)

            if use_amp:
                with autocast(device_type="cuda", dtype=amp_dtype):
                    logits, _g = head(x)
            else:
                logits, _g = head(x)

            idx = torch.arange(x.size(0), device=device_)
            logits_q = logits[idx, qpos, :].float()
            loss = F.cross_entropy(logits_q, V)
            pred = logits_q.argmax(dim=-1)

            accs.append((pred == V).float().mean().item())
            losses.append(loss.item())

    acc_m = float(np.mean(accs)); loss_m = float(np.mean(losses))

    if TB_AVAILABLE and ('tb' in globals()):
        tb.writer.add_scalar("eval/haystack_acc", acc_m, tb_step if tb_step is not None else 0)
        tb.writer.add_scalar("eval/haystack_loss", loss_m, tb_step if tb_step is not None else 0)
        tb.flush()

    head.train()
    print(f"[Haystack] acc={acc_m:.3f} | loss={loss_m:.3f} | fast={fast}")
    return acc_m, loss_m

## 7. Unit tests, smoke tests

#### 7.1 basic unit test, model integrity

In [90]:
def run_basic_unit_tests():
    B, T = 2, 4
    R, W, N = CFG.dnc_read_heads, CFG.dnc_cell_size, CFG.dnc_nr_cells
    d_in = d_model = 128

    mem = DNCMemory(N, W, R).to(device)
    state = mem.reset(B, device=device)
    iface = {
        "k_read": torch.zeros(B,R,W, device=device),
        "beta_read": torch.ones(B,R,1, device=device),
        "k_write": torch.zeros(B,W, device=device),
        "beta_write": torch.ones(B,1, device=device),
        "erase": torch.sigmoid(torch.randn(B,W, device=device)),
        "write_vec": torch.randn(B,W, device=device),
        "free_gates": torch.sigmoid(torch.randn(B,R,1, device=device)),
        "alloc_gate": torch.sigmoid(torch.randn(B,1, device=device)),
        "write_gate": torch.sigmoid(torch.randn(B,1, device=device)),
        "read_mode": torch.randn(B,R,3, device=device),
    }
    r, state2 = mem(iface, state)
    assert r.shape == (B,R,W)

    ctrl = TransformerController(d_in+R*W, d_model, heads=4).to(device)
    x = torch.randn(B,T,d_in, device=device)
    reads = state2["r"].reshape(B,R*W).unsqueeze(1).expand(B,T,R*W)
    h = ctrl(torch.cat([x, reads], dim=-1))
    assert h.shape == (B,T,d_model)

    dblk = DNCformerBlock(d_in, d_model, R, W, N, heads=4, dropout=0.1, ffn_mult=4.0).to(device)
    y, s = dblk(x)
    assert y.shape == (B,T,d_model)

    pen = ParallelEnrichmentBlock(d_model, d_in, R, W, N, heads=4, dropout=0.1, ffn_mult=4.0, gate_bias_init=-1.0).to(device)
    out, s2, g = pen(torch.randn(B,T,d_model, device=device))
    print(f"out: {out}\ns2: {s2}\ng: {g}")
    eps = 1e-6
    assert torch.isfinite(g).all()
    assert ((g > eps) & (g < 1 - eps)).float().mean().item() > 0.95
    assert out.shape == (B,T,d_model)
    print("All unit-like tests passed.")


In [34]:
#run_basic_unit_tests()

#### 7.2 Evaluator unit tests

In [35]:
import torch
from typing import Optional

class _DummyHead(torch.nn.Module):
    def __init__(self, vocab=100, d_model=64, n_blocks=2):
        super().__init__()
        self.vocab = vocab
        self.d_model = d_model
        self.n_blocks = n_blocks
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]=None):
        B, T = input_ids.shape
        logits = torch.randn(B, T, self.vocab, device=input_ids.device, dtype=torch.float32)
        gates = [torch.sigmoid(torch.randn(B, T, self.d_model, device=input_ids.device)) for _ in range(self.n_blocks)]
        return logits, gates

def run_eval_unit_tests():
    dummy = _DummyHead().to(device).eval()
    res_copy = eval_copy(dummy, tok=None, batch=2, T=8, vocab=50)
    res_rev = eval_reverse(dummy, tok=None, batch=2, T=8, vocab=50)
    res_needle = eval_needle(dummy, tok=None, batch=2, T=16, vocab=100)
    assert all(k in res_copy for k in ["acc","gates"])
    assert all(k in res_rev for k in ["acc","gates"])
    assert all(k in res_needle for k in ["acc","gates"])
    print("Evaluator unit tests passed.")


In [101]:
#run_eval_unit_tests()

#### 7.3 GPU VRAM diagnostics + cuda allocation smoke test

In [102]:
import torch, gc, contextlib

def list_head_refs():
    # looks for globals named 'head' and count of DNCFormerHead instances
    import gc, inspect, sys
    heads = [o for o in gc.get_objects() if o.__class__.__name__ == "DNCFormerHead"]
    print(f"[liveness] DNCFormerHead instances alive: {len(heads)}")
    if 'head' in globals():
        h = globals()['head']
        try:
            devs = sorted({p.device.type for p in h.parameters()})
        except Exception:
            devs = ["<unknown>"]
        print(f"[liveness] global 'head' present; param devices: {devs}")
    else:
        print("[liveness] no global 'head'")
        
from transformers import AutoModelForCausalLM, AutoTokenizer
gc.collect(); torch.cuda.empty_cache(); cuda_report("pre base-only")
tok = AutoTokenizer.from_pretrained(CFG.base_model_id, trust_remote_code=True)
base = AutoModelForCausalLM.from_pretrained(
    CFG.base_model_id,
    torch_dtype=(torch.bfloat16 if (amp_dtype!=torch.float32 and torch.cuda.is_bf16_supported()) else (torch.float16 if amp_dtype!=torch.float32 else None)),
    attn_implementation="sdpa",
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    use_safetensors=True,
).to("cuda")
cuda_report("after base-only")
del base, tok; gc.collect(); torch.cuda.empty_cache()


[pre base-only] alloc=9.62 GB | reserved=9.64 GB | free=14.73 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[after base-only] alloc=17.26 GB | reserved=17.29 GB | free=7.09 GB | total=25.77 GB


#### 7.4 TB logger test

In [36]:
tb = TBLogger(logdir="./runs")
try:
    tb
except NameError:
    tb = TBLogger(logdir="./runs")
print("TB active:", TB_AVAILABLE, "| logdir:", getattr(tb, "path", None))

TB active: True | logdir: ./runs\dncformer-20250823-013407


#### 7.5 Memory tracer smoke test

In [104]:
def run_memory_tracer_smoke(head, tok):
    if not TB_AVAILABLE:
        print("TB not available; skipping image smoke.")
        return
    from torch.amp import autocast
    with trace_memory(head) as tracer:
        x = torch.randint(5, (1, 16), device=device)
        with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
            _ = head(x)
    assert len(tracer.frames) > 0, "No memory frames captured"
    print("Tracer captured", len(tracer.frames), "steps")


In [105]:
#run_memory_tracer_smoke()

#### 7.6 Data generator smoke test

In [106]:

mx = int(getattr(CFG, "max_seq_len", 256))
tok, _base = tok if 'tok' in globals() else (None, None)
_pad = getattr(tok, "pad_token_id", 0) if tok is not None else 0

x = make_repeat_copy(batch=3, T=min(mx, 128), vocab=50, pad_id=_pad)
print("repeat_copy batch shape:", x.shape, "| dtype:", x.dtype)  # expect (3, <=128), long
assert x.dtype == torch.long


repeat_copy batch shape: torch.Size([3, 128]) | dtype: torch.int64


#### 7.7 Mixture sampler smoke test

In [107]:
ms = MixtureSampler(gens=[lambda b: None, lambda b: None, lambda b: None, lambda b: None],
                    weights=[0.4,0.2,0.2,0.2],
                    names=["hf","copy","repeat","nback"])
print("p0:", ms.p.tolist())    # ~[0.4,0.2,0.2,0.2]
ms.set_weights([0.3,0.3,0.25,0.15])
print("p1:", ms.p.tolist())    # ~[0.3,0.3,0.25,0.15]

p0: [0.4000000059604645, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224]
p1: [0.30000001192092896, 0.30000001192092896, 0.25, 0.15000000596046448]


#### 7.8 Build mixer smoke test

In [108]:
# 1) With Alpaca (instruction/output)
tok, _ = build_model_and_tokenizer()
m1 = _build_mixer(tok, (0.4,0.2,0.2,0.2), hf_dataset="tatsu-lab/alpaca", hf_max_items=500)
print("names:", m1.names, "| p:", m1.p.tolist())  # expect 'hf' present

# 2) With TinyStories (text)
m2 = _build_mixer(tok, (0.4,0.2,0.2,0.2), hf_dataset="roneneldan/TinyStories", hf_max_items=500)
print("names:", m2.names, "| p:", m2.p.tolist())  # expect 'hf' present (text path)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

names: ['hf', 'copy', 'repeat', 'nback'] | p: [0.4000000059604645, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224]
names: ['hf', 'copy', 'repeat', 'nback'] | p: [0.4000000059604645, 0.20000000298023224, 0.20000000298023224, 0.20000000298023224]


#### 7.9 haystack smoke test

In [109]:
# tok, head = build_model_and_tokenizer()
# acc, loss = evaluate_haystack(head, steps=2, batch=4, T=64, vocab=512, tb_step=0)
# print("Haystack smoke:", acc, loss)

#### 7.10 Training run Sanity test

In [110]:
free_head_and_cache()
cuda_report("snapshot: before train_experiment")
head, tok = train_experiment(steps=10, warmup_steps=10, mixture_weights=(0.4,0.2,0.2,0.2))
cuda_report("snapshot: after train_experiment")
#Launch TensorBoard in a terminal: tensorboard --logdir ./runs


[after free] alloc=10.61 GB | reserved=10.63 GB | free=13.74 GB | total=25.77 GB
[snapshot: before train_experiment] alloc=10.61 GB | reserved=10.63 GB | free=13.74 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  self.gen = func(*args, **kwds)


  [experts] block0 top=0 pi=[0.9081056118011475, 0.09189442545175552]
step 10 | loss 10.4021 | lr 2.00e-05 | gates=[0.09189443290233612, 0.10174798965454102] | mix=copy
[snapshot: after train_experiment] alloc=20.25 GB | reserved=33.38 GB | free=0.00 GB | total=25.77 GB


#### 7.11 Integrity panel
- shapes check
- dtyptes check
- forward check

In [111]:
import torch, contextlib, json

def count_params(m: torch.nn.Module, trainable_only=True):
    return sum(p.numel() for p in m.parameters() if (p.requires_grad or not trainable_only))

def integrity_panel(tok=None, head=None, T: int = 8):
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"[integrity] device={dev} | amp_dtype={globals().get('amp_dtype', torch.float32)}")
    print(f"[integrity] base_model_id={getattr(CFG, 'base_model_id', '?')}")
    for k in ("d_model", "N", "W", "R", "num_blocks"):
        if hasattr(CFG, k):
            print(f"[integrity] CFG.{k}={getattr(CFG,k)}")

    if head is not None:
        print(f"[integrity] head params (trainable): {count_params(head):,}")

    # SDPA context visibility
    with contextlib.suppress(Exception):
        print(f"[integrity] SDPA_CTX={type(globals().get('SDPA_CTX', None)).__name__}")

    # Tiny forward
    try:
        if tok is None or head is None:
            print("[integrity] (skip tiny forward; missing tok/head)")
            return
        tok.pad_token = tok.pad_token or tok.eos_token
        dummy = tok("hello world", return_tensors="pt")
        x = dummy.input_ids.to(dev)
        with torch.no_grad():
            with torch.autocast(device_type="cuda", dtype=globals().get("amp_dtype", torch.float32),
                                enabled=(globals().get("amp_dtype", torch.float32) != torch.float32)):
                logits, gates = head(x)
        print(f"[integrity] forward ok | logits={tuple(logits.shape)} | gates={[tuple(g.shape) for g in gates]}")
        # Basic gate sanity
        for i, g in enumerate(gates):
            g_mean = float(torch.sigmoid(g).mean().item()) if g.dtype.is_floating_point else float(g.float().mean().item())
            print(f"[integrity] gate[{i}] mean≈{g_mean:.3f}")
    except Exception as e:
        print("[integrity] forward failed:", e)

#### 7.12 Seed, build integrity, and save/load smoke tests

In [112]:
# 1) Seed determinism invoked
set_seed(getattr(CFG, "seed", 42))

In [113]:
# 2) Build & integrity
tok, head = build_model_and_tokenizer()
integrity_panel(tok, head, T=8)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[integrity] device=cuda | amp_dtype=torch.bfloat16
[integrity] base_model_id=microsoft/Phi-3-mini-4k-instruct
[integrity] CFG.d_model=3072
[integrity] head params (trainable): 494,574,238
[integrity] SDPA_CTX=NoneType
[integrity] forward failed: cannot unpack non-iterable NoneType object


In [114]:
# 3) Save/load head (no-op training resume test)
save_head(head, "./checkpoints", cfg=CFG, run_label="smoke")
_ = load_head(head, "./checkpoints/smoke.pt", strict=False)

[save_head] wrote checkpoints\smoke.pt and checkpoints\smoke.meta.json


  sd = torch.load(in_path, map_location=map_location or "cpu")


[load_head] restored successfully.


#### 7.13 Parallel / fused memory blocks

In [37]:
def _assert_close(a, b, eps=1e-5, msg=""):
    assert abs(float(a) - float(b)) <= eps, msg or f"{a} vs {b}"

@torch.no_grad()
def smoke_parallel_block_basic():
    print("[smoke] ParallelEnrichmentBlock basic (K=1, fusion=False)")
    B,T,D = 2, 5, 32
    R,W,N = 1, 8, 16
    x = torch.randn(B,T,D, device=device)
    CFG.mem_experts = 1
    CFG.fusion_enable = False
    CFG.per_block_cfg = None
    blk = ParallelEnrichmentBlock(d_model=D, d_in=D, R=R, W=W, N=N, heads=2, dropout=0.0, ffn_mult=2.0).to(device)
    y, st, g = blk(x, dnc_state=None)
    assert y.shape == (B,T,D)
    assert isinstance(st, dict) or st is None
    assert g.shape == (B,T,1)
    assert (g>=0).all() and (g<=1).all()
    print("  ok.")

@torch.no_grad()
def smoke_parallel_block_multi_experts():
    print("[smoke] Multi-experts (K=2), fusion off")
    B,T,D = 2, 5, 32
    R,W,N = 1, 8, 16
    x = torch.randn(B,T,D, device=device)
    CFG.mem_experts = 2
    CFG.expert_N = [N, N]
    CFG.expert_W = W
    CFG.expert_R = R
    CFG.fusion_enable = False
    CFG.per_block_cfg = None
    blk = ParallelEnrichmentBlock(d_model=D, d_in=D, R=R, W=W, N=N, heads=2, dropout=0.0, ffn_mult=2.0).to(device)
    y, st, g = blk(x, dnc_state=None)
    assert y.shape == (B,T,D)
    assert isinstance(st, list) and len(st)==2, "state should be list(len=K) when mem_experts>1"
    assert g.shape == (B,T,1) and (g>=0).all() and (g<=1).all()
    # experts metrics
    pm = blk.last_metrics.get("experts_pi_mean", None)
    assert isinstance(pm, list) and len(pm)==3, "experts_pi_mean should have K+1 entries (vanilla+experts)"
    print("  ok.")

@torch.no_grad()
def smoke_fusion_path():
    print("[smoke] Fusion path uses real read features if available")
    B,T,D = 2, 6, 32
    R,W,N = 1, 8, 16
    x = torch.randn(B,T,D, device=device)
    CFG.mem_experts = 1
    CFG.fusion_enable = True
    CFG.fusion_hidden_mult = 2.0
    CFG.fusion_drop = 0.0
    blk = ParallelEnrichmentBlock(d_model=D, d_in=D, R=R, W=W, N=N, heads=2, dropout=0.0, ffn_mult=2.0).to(device)
    # first pass (collects read feats inside DNCformerBlock)
    y1, st, g = blk(x, dnc_state=None)
    # second pass to check delta norm presence
    y2, st2, g2 = blk(x, dnc_state=st)
    dn = blk.last_metrics.get("fusion_delta_norm", None)
    assert dn is not None, "fusion_delta_norm should be logged"
    assert y2.shape == (B,T,D) and g2.shape==(B,T,1)
    print("  ok.")

@torch.no_grad()
def smoke_per_block_override():
    print("[smoke] per_block_cfg override (E11)")
    B,T,D = 2, 4, 32
    R,W,N = 1, 8, 16
    CFG.mem_experts = 1
    CFG.fusion_enable = False
    # Two different blocks to test overrides
    CFG.per_block_cfg = [
        {"N": 24, "W": 8,  "R": 1, "gate_temp": 1.0, "free_bias": +0.2},
        {"N": 16, "W": 16, "R": 1, "gate_temp": 0.8, "free_bias": -0.2},
    ]
    blk0 = ParallelEnrichmentBlock(d_model=D, d_in=D, R=R, W=W, N=N, heads=2, dropout=0.0, ffn_mult=2.0, block_index=0).to(device)
    blk1 = ParallelEnrichmentBlock(d_model=D, d_in=D, R=R, W=W, N=N, heads=2, dropout=0.0, ffn_mult=2.0, block_index=1).to(device)
    # check effective N/W on internal DNC
    assert blk0.dncblocks[0].N == 24 and blk0.dncblocks[0].W == 8
    assert blk1.dncblocks[0].N == 16 and blk1.dncblocks[0].W == 16
    # forward for sanity
    x = torch.randn(B,T,D, device=device)
    y0, st0, g0 = blk0(x, dnc_state=None)
    y1, st1, g1 = blk1(x, dnc_state=None)
    assert y0.shape == (B,T,D) and y1.shape==(B,T,D)
    print("  ok.")

def run_patch_smoke_tests():
    smoke_parallel_block_basic()
    smoke_parallel_block_multi_experts()
    smoke_fusion_path()
    smoke_per_block_override()
    print("[smoke] all tests passed.")

import types
import torch
import torch.nn as nn

# Fallback sdpa context if not present
try:
    sdpa_ctx  # noqa: F401
except NameError:
    import contextlib
    def sdpa_ctx():
        return contextlib.nullcontext()

# Fallback device if not present
try:
    device  # noqa: F401
except NameError:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TinyBaseLM(nn.Module):
    """
    Minimal stand-in for AutoModelForCausalLM that DNCFormerHead expects:
      - config.hidden_size
      - get_input_embeddings().num_embeddings
      - forward(..., output_hidden_states=True) -> object with .hidden_states
      - lm_head: Linear(d_model -> vocab)
    """
    def __init__(self, vocab: int = 1024, d_model: int = 32):
        super().__init__()
        self.config = types.SimpleNamespace(hidden_size=d_model, use_return_dict=True)
        self.embed = nn.Embedding(vocab, d_model)
        # a tiny 2-layer MLP "backbone"
        self.backbone = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )
        self.lm_head = nn.Linear(d_model, vocab, bias=False)

    def get_input_embeddings(self):
        return self.embed

    def forward(self, input_ids: torch.Tensor, attention_mask=None,
                position_ids=None, past_key_values=None, inputs_embeds=None,
                use_cache=False, output_attentions=False, output_hidden_states=False,
                return_dict=True):
        if inputs_embeds is None:
            x = self.embed(input_ids)
        else:
            x = inputs_embeds
        h = self.backbone(x)  # (B,T,D)
        if output_hidden_states:
            out = types.SimpleNamespace(hidden_states=[x, h])
        else:
            out = types.SimpleNamespace(last_hidden_state=h)
        return out

def _make_tiny_head(mem_experts: int = 1, n_blocks: int = 2, d_model: int = 32,
                    R: int = 1, W: int = 8, N: int = 16, heads: int = 2,
                    dropout: float = 0.0, ffn_mult: float = 2.0,
                    gate_bias_init: float = -1.0):
    """
    Build a small DNCFormerHead wired to TinyBaseLM; does not touch global training.
    """
    # mirror current CFG but keep local overrides
    local_cfg = types.SimpleNamespace(
        d_model=None,
        dnc_read_heads=R, dnc_cell_size=W, dnc_nr_cells=N,
        attn_heads=heads, attn_dropout=dropout, ffn_mult=ffn_mult,
        gate_bias_init=gate_bias_init,
        n_blocks=n_blocks,
    )
    # honor mem_experts at block construction time:
    setattr(CFG, "mem_experts", int(mem_experts))

    base = TinyBaseLM(vocab=1024, d_model=d_model).to(device)
    head_local = DNCFormerHead(base, local_cfg).to(device)
    return head_local

@torch.no_grad()
def smoke_head_mem_experts_one_step():
    print("[smoke] head forward (mem_experts=1)")
    head_local = _make_tiny_head(mem_experts=1, n_blocks=2, d_model=32, R=1, W=8, N=16, heads=2)
    B, T = 2, 8
    vocab_guess = getattr(head_local.base.get_input_embeddings(), "num_embeddings", 1024)
    dummy = torch.randint(0, min(1024, vocab_guess), (B, T), device=device)
    logits, gates = head_local(dummy)
    assert logits.shape[:2] == (B, T)
    assert len(gates) == len(head_local.blocks)
    for g in gates:
        assert g.shape[:2] == (B, T) and (g >= 0).all() and (g <= 1).all()
    print("  ok.")

@torch.no_grad()
def smoke_head_mem_experts_two_step():
    print("[smoke] head forward_with_metrics (mem_experts=2)")
    head_local = _make_tiny_head(mem_experts=2, n_blocks=2, d_model=32, R=1, W=8, N=16, heads=2)
    B, T = 2, 8
    vocab_guess = getattr(head_local.base.get_input_embeddings(), "num_embeddings", 1024)
    dummy = torch.randint(0, min(1024, vocab_guess), (B, T), device=device)
    logits, gates, aux = head_local.forward_with_metrics(dummy, gate_override=None)
    assert logits.shape[:2] == (B, T)
    assert "per_block" in aux and "blocks" in aux
    for m in aux["blocks"]:
        assert isinstance(m, dict)
    print("  ok.")

In [38]:
run_patch_smoke_tests()

[smoke] ParallelEnrichmentBlock basic (K=1, fusion=False)
  ok.
[smoke] Multi-experts (K=2), fusion off
  ok.
[smoke] Fusion path uses real read features if available
  ok.
[smoke] per_block_cfg override (E11)
  ok.
[smoke] all tests passed.


In [39]:
# Replace the body of the two smoke tests with this (only the call site changes)
@torch.no_grad()
def smoke_head_mem_experts_one_step():
    print("[smoke] head forward_with_metrics (mem_experts=1)")
    head_local = _make_tiny_head(mem_experts=1, n_blocks=2, d_model=32, R=1, W=8, N=16, heads=2)
    B, T = 2, 8
    vocab_guess = getattr(head_local.base.get_input_embeddings(), "num_embeddings", 1024)
    dummy = torch.randint(0, min(1024, vocab_guess), (B, T), device=device)
    logits, gates, aux = head_local.forward_with_metrics(dummy)
    assert logits.shape[:2] == (B, T)
    assert len(gates) == len(head_local.blocks)
    print("  ok.")

@torch.no_grad()
def smoke_head_mem_experts_two_step():
    print("[smoke] head forward_with_metrics (mem_experts=2)")
    head_local = _make_tiny_head(mem_experts=2, n_blocks=2, d_model=32, R=1, W=8, N=16, heads=2)
    B, T = 2, 8
    vocab_guess = getattr(head_local.base.get_input_embeddings(), "num_embeddings", 1024)
    dummy = torch.randint(0, min(1024, vocab_guess), (B, T), device=device)
    logits, gates, aux = head_local.forward_with_metrics(dummy)
    assert logits.shape[:2] == (B, T)
    assert "per_block" in aux and "blocks" in aux
    print("  ok.")

In [40]:
smoke_head_mem_experts_one_step()
smoke_head_mem_experts_two_step()

[smoke] head forward_with_metrics (mem_experts=1)
  ok.
[smoke] head forward_with_metrics (mem_experts=2)
  ok.


  self.gen = func(*args, **kwds)


In [119]:
blk0 = next(iter(head.blocks))
print("vanilla ln1 dtype:", blk0.vanilla.ln1.weight.dtype)
print("gate weight dtype:", blk0.gate.weight.dtype)
print("fusion ln dtype:", getattr(blk0, "fuse_ln", nn.LayerNorm(1)).weight.dtype if getattr(blk0,"fusion_enable",False) else "n/a")


vanilla ln1 dtype: torch.float32
gate weight dtype: torch.float32
fusion ln dtype: n/a


## 8. Experiments - Stage 1 - basic architecture

#### 8.1 experiment launch helpers

In [91]:
# run set parameters
EXP_STEPS   = 500       # try 1000–5000 for longer curves
EXP_WARMUP  = 10       # keep ~steps/20
BASE_MIX    = (0.4, 0.2, 0.2, 0.2)  # (hf, copy, repeat, nback)

def set_cfg(**kv):
    for k, v in kv.items():
        setattr(CFG, k, v)

def run_isolate_test(label,
            mixture_weights=BASE_MIX,
            gate_reg_lambda=None,
            gate_temp=None,
            force_g=None,
            steps=EXP_STEPS,
            warmup=EXP_WARMUP):
    """Run a single train_experiment experiment with explicit knobs."""
    print(f"\n=== {label} ===")
    # set experiment-specific knobs (others come from CFG defaults)
    if gate_reg_lambda is not None: set_cfg(gate_reg_lambda=float(gate_reg_lambda))
    if gate_temp is not None:       set_cfg(gate_temp=float(gate_temp))
    set_cfg(force_g=force_g)  # may be None, 0.0, or 1.0

    print(f"mixture={mixture_weights}, gate_reg_lambda={getattr(CFG,'gate_reg_lambda', 0.0)}, "
          f"gate_temp={getattr(CFG,'gate_temp', 1.0)}, force_g={getattr(CFG,'force_g', None)}")

    # clean slate for VRAM and allocator fragmentation between runs
    free_head_and_cache()
    cuda_report(f"before {label}")
    time.sleep(1.2)  # ensure distinct TB run dirs (timestamp granularity)

    # run
    head, tok = train_experiment(
        steps=steps,
        warmup_steps=warmup,
        mixture_weights=mixture_weights,
        viz_memory_after=False,   # keep quick; use visualize_memory_tb() ad-hoc
    )

    # post-run snapshot + cleanup
    cuda_report(f"after  {label}")
    free_head_and_cache()
    return head, tok

def run_one_labeled(label, steps, mixture_weights, seed=1234,
                    mixture_schedule=None, gate_temp_schedule=None, gate_reg_schedule=None,
                    post_haystack=False):
    print(f"\n=== {label} | seed={seed} ===")
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) 
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

    set_cfg(force_g=None)  # ensure no ablation
    # unique TB run for each experiment label
    start_tb_run(label)

    # echo run metadata
    if TB_AVAILABLE and 'tb' in globals():
        import json
        tb.add_text("run/meta", json.dumps({
            "label": label,
            "steps": steps,
            "mixture_weights": list(mixture_weights),
            "mixture_schedule": mixture_schedule,
            "gate_temp_schedule": gate_temp_schedule,
            "gate_reg_schedule": gate_reg_schedule,
        }, indent=2), 0)
    
    # print config stub
    print("CFG.gate_temp:", getattr(CFG, "gate_temp", 1.0),
          "| CFG.gate_reg_lambda:", getattr(CFG, "gate_reg_lambda", 0.0),
          "| mixture:", mixture_weights)

    free_head_and_cache()
    if 'cuda_report' in globals(): cuda_report(f"before {label}")

    head, tok = train_experiment(
        steps=steps,
        warmup_steps=max(10, steps//20),
        mixture_weights=mixture_weights,
        mixture_schedule=mixture_schedule,
        gate_temp_schedule=gate_temp_schedule,
        gate_reg_schedule=gate_reg_schedule,
        viz_memory_after=False,
    )

    if post_haystack:
        evaluate_haystack(head, steps=50, batch=16, T=256, vocab=1024, tb_step=steps, fast=True)

    if 'cuda_report' in globals(): cuda_report(f"after  {label}")
    free_head_and_cache()
    return head, tok

#### 8.2 Experiment Set 1 - Medium training sweep, testing parameter / dataset variants
 - E0: Baseline
 - E1: memory/algorithmic slanted data mix
 - E2: gate regularizer (low)
 - E3: gate regularizer (high)
 - E4: sharper routing via lower gate temperature
 - E5: ablations: disable/force DNC path
    - 5a: disable DNC path (DNC path disabled)
    - 5b: force DNC path (transformer path disabled) 

In [58]:
import time, torch, gc

# === E0 baseline: identical to your sanity run ===
run_isolate_test("E0_baseline",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E1 memory-leaning mix: more algorithmic exposure ===
run_isolate_test("E1_memory_lean",
        mixture_weights=(0.4, 0.3, 0.2, 0.1),  # HF, copy, repeat, nback
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E2 gate regularizer (low) ===
run_isolate_test("E2_gate_reg_low",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=2e-4,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E3 gate regularizer (high) ===
run_isolate_test("E3_gate_reg_high",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=5e-4,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E4 sharper routing via lower gate temperature ===
run_isolate_test("E4_gate_temp_0p7",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=0.7,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E5 ablations: disable/force DNC path ===
run_isolate_test("E5a_force_g_0",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=0.0,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

run_isolate_test("E5b_force_g_1",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=1.0,
        steps=EXP_STEPS, warmup=EXP_WARMUP)



=== E0_baseline ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0, gate_temp=1.0, force_g=None
[after free] alloc=9.79 GB | reserved=9.82 GB | free=14.54 GB | total=25.77 GB
[before E0_baseline] alloc=9.79 GB | reserved=9.82 GB | free=14.54 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 8.7638 | lr 2.00e-04 | gates=[0.283203125, 0.28125] | mix=repeat
step 20 | loss 8.1121 | lr 2.00e-04 | gates=[0.283203125, 0.28515625] | mix=nback
step 30 | loss 4.2501 | lr 1.99e-04 | gates=[0.26171875, 0.26953125] | mix=hf
step 40 | loss 4.6715 | lr 1.98e-04 | gates=[0.259765625, 0.26953125] | mix=hf
step 50 | loss 4.4469 | lr 1.97e-04 | gates=[0.25390625, 0.263671875] | mix=hf
step 60 | loss 7.0320 | lr 1.95e-04 | gates=[0.287109375, 0.310546875] | mix=nback
step 70 | loss 5.7877 | lr 1.92e-04 | gates=[0.28515625, 0.314453125] | mix=nback
step 80 | loss 1.9082 | lr 1.90e-04 | gates=[0.2353515625, 0.248046875] | mix=hf
step 90 | loss 1.8775 | lr 1.87e-04 | gates=[0.2294921875, 0.2412109375] | mix=hf
step 100 | loss 6.4503 | lr 1.83e-04 | gates=[0.291015625, 0.3203125] | mix=repeat
step 110 | loss 5.8166 | lr 1.80e-04 | gates=[0.291015625, 0.326171875] | mix=copy
step 120 | loss 1.3207 | lr 1.76e-04 | gates=[0.2197265625, 0.2412109375] | mix=hf
step 130 | loss 5.6160 | 

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 1.6762 | lr 2.00e-04 | gates=[0.27734375, 0.275390625] | mix=hf
step 20 | loss 7.2456 | lr 2.00e-04 | gates=[0.265625, 0.267578125] | mix=hf
step 30 | loss 2.8066 | lr 1.99e-04 | gates=[0.259765625, 0.26171875] | mix=hf
step 40 | loss 5.9084 | lr 1.98e-04 | gates=[0.28515625, 0.294921875] | mix=copy
step 50 | loss 6.2994 | lr 1.97e-04 | gates=[0.283203125, 0.296875] | mix=nback
step 60 | loss 2.0780 | lr 1.95e-04 | gates=[0.240234375, 0.251953125] | mix=hf
step 70 | loss 5.9646 | lr 1.92e-04 | gates=[0.2890625, 0.302734375] | mix=repeat
step 80 | loss 6.1615 | lr 1.90e-04 | gates=[0.279296875, 0.30078125] | mix=nback
step 90 | loss 1.6371 | lr 1.87e-04 | gates=[0.2255859375, 0.2412109375] | mix=hf
step 100 | loss 5.8504 | lr 1.83e-04 | gates=[0.287109375, 0.30859375] | mix=copy
step 110 | loss 6.1592 | lr 1.80e-04 | gates=[0.287109375, 0.310546875] | mix=copy
step 120 | loss 1.6537 | lr 1.76e-04 | gates=[0.2109375, 0.228515625] | mix=hf
step 130 | loss 1.3438 | lr 1.71e-

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 2.5047 | lr 2.00e-04 | gates=[0.275390625, 0.279296875] | mix=hf
step 20 | loss 2.1480 | lr 2.00e-04 | gates=[0.263671875, 0.26953125] | mix=hf
step 30 | loss 7.5692 | lr 1.99e-04 | gates=[0.279296875, 0.287109375] | mix=nback
step 40 | loss 1.5048 | lr 1.98e-04 | gates=[0.244140625, 0.25390625] | mix=hf
step 50 | loss 6.6892 | lr 1.97e-04 | gates=[0.279296875, 0.296875] | mix=copy
step 60 | loss 5.5064 | lr 1.95e-04 | gates=[0.287109375, 0.3046875] | mix=repeat
step 70 | loss 10.6165 | lr 1.92e-04 | gates=[0.28515625, 0.302734375] | mix=repeat
step 80 | loss 6.2255 | lr 1.90e-04 | gates=[0.279296875, 0.302734375] | mix=repeat
step 90 | loss 1.4385 | lr 1.87e-04 | gates=[0.2080078125, 0.2255859375] | mix=hf
step 100 | loss 1.6988 | lr 1.83e-04 | gates=[0.20703125, 0.2265625] | mix=hf
step 110 | loss 1.4504 | lr 1.80e-04 | gates=[0.2021484375, 0.2216796875] | mix=hf
step 120 | loss 6.7830 | lr 1.76e-04 | gates=[0.271484375, 0.314453125] | mix=nback
step 130 | loss 1.6999 

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 8.4827 | lr 2.00e-04 | gates=[0.28515625, 0.283203125] | mix=repeat
step 20 | loss 7.1305 | lr 2.00e-04 | gates=[0.287109375, 0.28515625] | mix=repeat
step 30 | loss 2.0686 | lr 1.99e-04 | gates=[0.26171875, 0.267578125] | mix=hf
step 40 | loss 3.8306 | lr 1.98e-04 | gates=[0.248046875, 0.255859375] | mix=hf
step 50 | loss 7.1654 | lr 1.97e-04 | gates=[0.28515625, 0.294921875] | mix=copy
step 60 | loss 7.3936 | lr 1.95e-04 | gates=[0.28125, 0.296875] | mix=nback
step 70 | loss 6.3211 | lr 1.92e-04 | gates=[0.283203125, 0.302734375] | mix=nback
step 80 | loss 5.5104 | lr 1.90e-04 | gates=[0.29296875, 0.310546875] | mix=repeat
step 90 | loss 5.6180 | lr 1.87e-04 | gates=[0.28515625, 0.3125] | mix=copy
step 100 | loss 5.1885 | lr 1.83e-04 | gates=[0.29296875, 0.31640625] | mix=repeat
step 110 | loss 5.5819 | lr 1.80e-04 | gates=[0.283203125, 0.314453125] | mix=nback
step 120 | loss 1.4206 | lr 1.76e-04 | gates=[0.216796875, 0.2353515625] | mix=hf
step 130 | loss 6.0211 | lr

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 6.7747 | lr 2.00e-04 | gates=[0.224609375, 0.224609375] | mix=repeat
step 20 | loss 10.0938 | lr 2.00e-04 | gates=[0.2294921875, 0.232421875] | mix=copy
step 30 | loss 9.4271 | lr 1.99e-04 | gates=[0.23046875, 0.23828125] | mix=nback
step 40 | loss 7.5605 | lr 1.98e-04 | gates=[0.2294921875, 0.2412109375] | mix=nback
step 50 | loss 7.0946 | lr 1.97e-04 | gates=[0.23046875, 0.2451171875] | mix=nback
step 60 | loss 5.3907 | lr 1.95e-04 | gates=[0.236328125, 0.25390625] | mix=repeat
step 70 | loss 8.6001 | lr 1.92e-04 | gates=[0.232421875, 0.255859375] | mix=copy
step 80 | loss 6.2352 | lr 1.90e-04 | gates=[0.2275390625, 0.259765625] | mix=nback
step 90 | loss 2.6570 | lr 1.87e-04 | gates=[0.166015625, 0.19140625] | mix=hf
step 100 | loss 2.9589 | lr 1.83e-04 | gates=[0.1591796875, 0.185546875] | mix=hf
step 110 | loss 5.4593 | lr 1.80e-04 | gates=[0.2353515625, 0.2734375] | mix=repeat
step 120 | loss 5.5796 | lr 1.76e-04 | gates=[0.2275390625, 0.2734375] | mix=nback
step 1

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 1.7708 | lr 2.00e-04 | gates=[0.0, 0.0] | mix=hf
step 20 | loss 8.1129 | lr 2.00e-04 | gates=[0.0, 0.0] | mix=copy
step 30 | loss 9.4176 | lr 1.99e-04 | gates=[0.0, 0.0] | mix=nback
step 40 | loss 2.0564 | lr 1.98e-04 | gates=[0.0, 0.0] | mix=hf
step 50 | loss 1.4602 | lr 1.97e-04 | gates=[0.0, 0.0] | mix=hf
step 60 | loss 6.0042 | lr 1.95e-04 | gates=[0.0, 0.0] | mix=repeat
step 70 | loss 6.5107 | lr 1.92e-04 | gates=[0.0, 0.0] | mix=copy
step 80 | loss 1.6592 | lr 1.90e-04 | gates=[0.0, 0.0] | mix=hf
step 90 | loss 6.2884 | lr 1.87e-04 | gates=[0.0, 0.0] | mix=nback
step 100 | loss 6.1378 | lr 1.83e-04 | gates=[0.0, 0.0] | mix=copy
step 110 | loss 1.2196 | lr 1.80e-04 | gates=[0.0, 0.0] | mix=hf
step 120 | loss 5.9388 | lr 1.76e-04 | gates=[0.0, 0.0] | mix=nback
step 130 | loss 6.2737 | lr 1.71e-04 | gates=[0.0, 0.0] | mix=copy
step 140 | loss 1.4693 | lr 1.67e-04 | gates=[0.0, 0.0] | mix=hf
step 150 | loss 1.4359 | lr 1.62e-04 | gates=[0.0, 0.0] | mix=hf
step 160 | lo

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 6.4470 | lr 2.00e-04 | gates=[1.0, 1.0] | mix=nback
step 20 | loss 8.4882 | lr 2.00e-04 | gates=[1.0, 1.0] | mix=hf
step 30 | loss 6.6057 | lr 1.99e-04 | gates=[1.0, 1.0] | mix=hf
step 40 | loss 6.8276 | lr 1.98e-04 | gates=[1.0, 1.0] | mix=hf
step 50 | loss 6.6664 | lr 1.97e-04 | gates=[1.0, 1.0] | mix=nback
step 60 | loss 6.3100 | lr 1.95e-04 | gates=[1.0, 1.0] | mix=hf
step 70 | loss 6.7212 | lr 1.92e-04 | gates=[1.0, 1.0] | mix=nback
step 80 | loss 9.0949 | lr 1.90e-04 | gates=[1.0, 1.0] | mix=repeat
step 90 | loss 5.5494 | lr 1.87e-04 | gates=[1.0, 1.0] | mix=hf
step 100 | loss 5.9994 | lr 1.83e-04 | gates=[1.0, 1.0] | mix=nback
step 110 | loss 5.6693 | lr 1.80e-04 | gates=[1.0, 1.0] | mix=nback
step 120 | loss 5.1639 | lr 1.76e-04 | gates=[1.0, 1.0] | mix=hf
step 130 | loss 5.7746 | lr 1.71e-04 | gates=[1.0, 1.0] | mix=copy
step 140 | loss 4.8143 | lr 1.67e-04 | gates=[1.0, 1.0] | mix=hf
step 150 | loss 5.4356 | lr 1.62e-04 | gates=[1.0, 1.0] | mix=repeat
step 160 

(DNCFormerHead(
   (base): Phi3ForCausalLM(
     (model): Phi3Model(
       (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
       (embed_dropout): Dropout(p=0.0, inplace=False)
       (layers): ModuleList(
         (0-31): 32 x Phi3DecoderLayer(
           (self_attn): Phi3Attention(
             (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
             (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
             (rotary_emb): Phi3RotaryEmbedding()
           )
           (mlp): Phi3MLP(
             (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
             (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
             (activation_fn): SiLU()
           )
           (input_layernorm): Phi3RMSNorm()
           (resid_attn_dropout): Dropout(p=0.0, inplace=False)
           (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
           (post_attention_layernorm): Phi3RMSNorm()
         )
    

In [59]:
free_head_and_cache()

[after free] alloc=19.56 GB | reserved=19.60 GB | free=3.89 GB | total=25.77 GB


#### 8.3 Experiment Set 2 - parameter sweeps, based on E2 params from set 1 above
 - E6: higher gate temp
 - E7: memory-leaning warm-start
 - E8: capacity sweep
 - E9: baseline training, haystack eval

In [60]:
import re, time

# Common settings
EXP_STEPS = 500
BASE_MIX  = (0.4, 0.2, 0.2, 0.2)
SEEDS     = [1337, 2027, 4242]

# === E6: gate_temp=0.8 (3 seeds), otherwise baseline ===
set_cfg(gate_reg_lambda=getattr(CFG, "gate_reg_lambda", 2e-4))  # low-λ default
for s in SEEDS:
    run_one_labeled(f"E6_temp0p8_seed{s}", steps=EXP_STEPS, mixture_weights=BASE_MIX,
                    seed=s,
                    gate_temp_schedule=[(None, 0.8)])

# === E7: memory-leaning warm-start (first 10% steps), then baseline; 3 seeds ===
warm_steps = max(50, EXP_STEPS // 10)
mix_warm   = (0.3, 0.3, 0.25, 0.15)  # a bit more memory-heavy than baseline
mix_main   = BASE_MIX
for s in SEEDS:
    run_one_labeled(f"E7_warmstart_seed{s}", steps=EXP_STEPS, mixture_weights=BASE_MIX,
                    seed=s,
                    mixture_schedule=[(warm_steps, mix_warm), (None, mix_main)],
                    gate_temp_schedule=[(warm_steps, 0.8), (None, 1.0)],
                    gate_reg_schedule=[(warm_steps, max(2e-4, getattr(CFG,'gate_reg_lambda', 2e-4))), (None, getattr(CFG,'gate_reg_lambda', 2e-4))])

# === E8: capacity sweep N=64 vs N=128 (1 seed each) ===
for N_val in (64, 128):
    set_cfg(N=N_val)  # assumes your DNC block reads CFG.N at construction
    run_one_labeled(f"E8_capacity_N{N_val}", steps=EXP_STEPS, mixture_weights=BASE_MIX,
                    seed=777,
                    gate_temp_schedule=[(None, getattr(CFG, 'gate_temp', 1.0))])

# Reset N if changed
set_cfg(N=getattr(CFG, 'N', 128))

# === E9: baseline with haystack eval ===
run_one_labeled("E9_baseline_haystack", steps=EXP_STEPS, mixture_weights=BASE_MIX,
                seed=31415, post_haystack=True)



=== E6_temp0p8_seed1337 | seed=1337 ===
TB run started: ./runs\dncformer-20250820-143212-E6_temp0p8_seed1337
CFG.gate_temp: 1.0 | CFG.gate_reg_lambda: 0.0 | mixture: (0.4, 0.2, 0.2, 0.2)
[after free] alloc=19.56 GB | reserved=19.60 GB | free=3.89 GB | total=25.77 GB
[before E6_temp0p8_seed1337] alloc=19.56 GB | reserved=19.60 GB | free=3.89 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 1.8139 | lr 8.80e-05 | gates=[0.2392578125, 0.2431640625] | mix=hf
step 20 | loss 5.6580 | lr 1.68e-04 | gates=[0.2451171875, 0.25390625] | mix=nback
step 30 | loss 6.5675 | lr 2.00e-04 | gates=[0.25, 0.2578125] | mix=copy
step 40 | loss 3.2836 | lr 1.99e-04 | gates=[0.2158203125, 0.224609375] | mix=hf
step 50 | loss 1.9842 | lr 1.99e-04 | gates=[0.2080078125, 0.21875] | mix=hf
step 60 | loss 2.1618 | lr 1.97e-04 | gates=[0.201171875, 0.21484375] | mix=hf
step 70 | loss 2.7611 | lr 1.95e-04 | gates=[0.1923828125, 0.2080078125] | mix=hf
step 80 | loss 6.0227 | lr 1.93e-04 | gates=[0.25390625, 0.275390625] | mix=copy
step 90 | loss 1.4240 | lr 1.91e-04 | gates=[0.1884765625, 0.2060546875] | mix=hf
step 100 | loss 1.8305 | lr 1.88e-04 | gates=[0.177734375, 0.197265625] | mix=hf
step 110 | loss 5.7472 | lr 1.84e-04 | gates=[0.2578125, 0.283203125] | mix=copy
step 120 | loss 1.5439 | lr 1.81e-04 | gates=[0.1669921875, 0.1904296875] | mix=hf
step 130 | loss 5.7418 | lr 1.76e-0

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 6.7216 | lr 8.80e-05 | gates=[0.240234375, 0.2451171875] | mix=nback
step 20 | loss 1.7473 | lr 1.68e-04 | gates=[0.2373046875, 0.2392578125] | mix=hf
step 30 | loss 7.0766 | lr 2.00e-04 | gates=[0.2470703125, 0.255859375] | mix=repeat
step 40 | loss 7.0362 | lr 1.99e-04 | gates=[0.25, 0.26171875] | mix=repeat
step 50 | loss 1.9754 | lr 1.99e-04 | gates=[0.2119140625, 0.21875] | mix=hf
step 60 | loss 6.4103 | lr 1.97e-04 | gates=[0.2470703125, 0.263671875] | mix=copy
step 70 | loss 6.1247 | lr 1.95e-04 | gates=[0.25, 0.26953125] | mix=copy
step 80 | loss 1.6452 | lr 1.93e-04 | gates=[0.1875, 0.2021484375] | mix=hf
step 90 | loss 5.8153 | lr 1.91e-04 | gates=[0.2451171875, 0.26953125] | mix=repeat
step 100 | loss 1.5125 | lr 1.88e-04 | gates=[0.1708984375, 0.189453125] | mix=hf
step 110 | loss 5.8301 | lr 1.84e-04 | gates=[0.25390625, 0.279296875] | mix=repeat
step 120 | loss 6.0745 | lr 1.81e-04 | gates=[0.2490234375, 0.279296875] | mix=repeat
step 130 | loss 6.5550 | lr

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 6.5768 | lr 8.80e-05 | gates=[0.248046875, 0.2470703125] | mix=repeat
step 20 | loss 5.6260 | lr 1.68e-04 | gates=[0.251953125, 0.255859375] | mix=nback
step 30 | loss 2.4026 | lr 2.00e-04 | gates=[0.2373046875, 0.240234375] | mix=hf
step 40 | loss 7.2257 | lr 1.99e-04 | gates=[0.2578125, 0.263671875] | mix=repeat
step 50 | loss 2.8536 | lr 1.99e-04 | gates=[0.2236328125, 0.228515625] | mix=hf
step 60 | loss 6.7658 | lr 1.97e-04 | gates=[0.25390625, 0.271484375] | mix=nback
step 70 | loss 6.8637 | lr 1.95e-04 | gates=[0.255859375, 0.275390625] | mix=copy
step 80 | loss 6.1135 | lr 1.93e-04 | gates=[0.259765625, 0.279296875] | mix=copy
step 90 | loss 1.7781 | lr 1.91e-04 | gates=[0.19921875, 0.2119140625] | mix=hf
step 100 | loss 5.3427 | lr 1.88e-04 | gates=[0.259765625, 0.28125] | mix=repeat
step 110 | loss 1.6011 | lr 1.84e-04 | gates=[0.185546875, 0.2041015625] | mix=hf
step 120 | loss 1.5415 | lr 1.81e-04 | gates=[0.1845703125, 0.2060546875] | mix=hf
step 130 | loss 

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 7.1638 | lr 8.80e-05 | gates=[0.2431640625, 0.2490234375] | mix=repeat
step 20 | loss 1.8687 | lr 1.68e-04 | gates=[0.2373046875, 0.2431640625] | mix=hf
step 30 | loss 6.0424 | lr 2.00e-04 | gates=[0.25390625, 0.259765625] | mix=nback
step 40 | loss 2.2922 | lr 1.99e-04 | gates=[0.2236328125, 0.232421875] | mix=hf
step 50 | loss 1.4852 | lr 1.99e-04 | gates=[0.21484375, 0.2255859375] | mix=hf
step 60 | loss 6.5158 | lr 1.97e-04 | gates=[0.291015625, 0.30078125] | mix=copy
step 70 | loss 5.9229 | lr 1.95e-04 | gates=[0.2890625, 0.298828125] | mix=repeat
step 80 | loss 5.7869 | lr 1.93e-04 | gates=[0.291015625, 0.3046875] | mix=copy
step 90 | loss 1.7243 | lr 1.91e-04 | gates=[0.2294921875, 0.2431640625] | mix=hf
step 100 | loss 5.8018 | lr 1.88e-04 | gates=[0.28515625, 0.306640625] | mix=nback
step 110 | loss 5.0903 | lr 1.84e-04 | gates=[0.296875, 0.3125] | mix=repeat
step 120 | loss 5.2232 | lr 1.81e-04 | gates=[0.283203125, 0.30859375] | mix=nback
step 130 | loss 5.425

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 2.4950 | lr 8.80e-05 | gates=[0.244140625, 0.24609375] | mix=hf
step 20 | loss 5.4397 | lr 1.68e-04 | gates=[0.240234375, 0.2421875] | mix=hf
step 30 | loss 5.8121 | lr 2.00e-04 | gates=[0.248046875, 0.255859375] | mix=nback
step 40 | loss 8.2393 | lr 1.99e-04 | gates=[0.251953125, 0.263671875] | mix=copy
step 50 | loss 6.9483 | lr 1.99e-04 | gates=[0.251953125, 0.265625] | mix=copy
step 60 | loss 2.1542 | lr 1.97e-04 | gates=[0.25390625, 0.26171875] | mix=hf
step 70 | loss 1.6823 | lr 1.95e-04 | gates=[0.248046875, 0.259765625] | mix=hf
step 80 | loss 6.2565 | lr 1.93e-04 | gates=[0.28515625, 0.306640625] | mix=copy
step 90 | loss 8.1904 | lr 1.91e-04 | gates=[0.28125, 0.306640625] | mix=nback
step 100 | loss 6.0802 | lr 1.88e-04 | gates=[0.291015625, 0.3125] | mix=repeat
step 110 | loss 6.4359 | lr 1.84e-04 | gates=[0.283203125, 0.310546875] | mix=copy
step 120 | loss 1.4914 | lr 1.81e-04 | gates=[0.2236328125, 0.2451171875] | mix=hf
step 130 | loss 5.9994 | lr 1.76e-0

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 6.8860 | lr 8.80e-05 | gates=[0.248046875, 0.2470703125] | mix=copy
step 20 | loss 6.7967 | lr 1.68e-04 | gates=[0.25390625, 0.255859375] | mix=nback
step 30 | loss 5.8866 | lr 2.00e-04 | gates=[0.2578125, 0.2578125] | mix=copy
step 40 | loss 5.5299 | lr 1.99e-04 | gates=[0.259765625, 0.26171875] | mix=repeat
step 50 | loss 1.4707 | lr 1.99e-04 | gates=[0.220703125, 0.228515625] | mix=hf
step 60 | loss 5.3748 | lr 1.97e-04 | gates=[0.291015625, 0.30078125] | mix=nback
step 70 | loss 8.4891 | lr 1.95e-04 | gates=[0.29296875, 0.30078125] | mix=copy
step 80 | loss 6.0569 | lr 1.93e-04 | gates=[0.296875, 0.302734375] | mix=copy
step 90 | loss 1.6448 | lr 1.91e-04 | gates=[0.228515625, 0.236328125] | mix=hf
step 100 | loss 5.4202 | lr 1.88e-04 | gates=[0.294921875, 0.3046875] | mix=repeat
step 110 | loss 1.7415 | lr 1.84e-04 | gates=[0.2138671875, 0.224609375] | mix=hf
step 120 | loss 3.8637 | lr 1.81e-04 | gates=[0.212890625, 0.2255859375] | mix=hf
step 130 | loss 1.7198 | l

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 2.6379 | lr 8.80e-05 | gates=[0.28125, 0.279296875] | mix=hf
step 20 | loss 6.7689 | lr 1.68e-04 | gates=[0.279296875, 0.287109375] | mix=copy
step 30 | loss 6.5771 | lr 2.00e-04 | gates=[0.28125, 0.291015625] | mix=nback
step 40 | loss 6.0427 | lr 1.99e-04 | gates=[0.28515625, 0.294921875] | mix=repeat
step 50 | loss 7.1813 | lr 1.99e-04 | gates=[0.28125, 0.294921875] | mix=nback
step 60 | loss 2.1033 | lr 1.97e-04 | gates=[0.240234375, 0.24609375] | mix=hf
step 70 | loss 1.5742 | lr 1.95e-04 | gates=[0.232421875, 0.240234375] | mix=hf
step 80 | loss 8.6676 | lr 1.93e-04 | gates=[0.28125, 0.302734375] | mix=nback
step 90 | loss 1.8938 | lr 1.91e-04 | gates=[0.2099609375, 0.22265625] | mix=hf
step 100 | loss 5.6929 | lr 1.88e-04 | gates=[0.275390625, 0.306640625] | mix=nback
step 110 | loss 1.7913 | lr 1.84e-04 | gates=[0.1982421875, 0.2138671875] | mix=hf
step 120 | loss 7.8458 | lr 1.81e-04 | gates=[0.287109375, 0.31640625] | mix=copy
step 130 | loss 9.1497 | lr 1.76e-

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 2.6384 | lr 8.80e-05 | gates=[0.28125, 0.279296875] | mix=hf
step 20 | loss 6.7687 | lr 1.68e-04 | gates=[0.279296875, 0.287109375] | mix=copy
step 30 | loss 6.5741 | lr 2.00e-04 | gates=[0.28125, 0.291015625] | mix=nback
step 40 | loss 6.0421 | lr 1.99e-04 | gates=[0.28515625, 0.294921875] | mix=repeat
step 50 | loss 6.5700 | lr 1.99e-04 | gates=[0.28125, 0.294921875] | mix=nback
step 60 | loss 2.1646 | lr 1.97e-04 | gates=[0.240234375, 0.24609375] | mix=hf
step 70 | loss 1.6074 | lr 1.95e-04 | gates=[0.232421875, 0.240234375] | mix=hf
step 80 | loss 5.7022 | lr 1.93e-04 | gates=[0.28125, 0.302734375] | mix=nback
step 90 | loss 1.8592 | lr 1.91e-04 | gates=[0.2099609375, 0.22265625] | mix=hf
step 100 | loss 6.3857 | lr 1.88e-04 | gates=[0.275390625, 0.3046875] | mix=nback
step 110 | loss 1.9222 | lr 1.84e-04 | gates=[0.19921875, 0.21484375] | mix=hf
step 120 | loss 5.9504 | lr 1.81e-04 | gates=[0.2890625, 0.314453125] | mix=copy
step 130 | loss 7.1080 | lr 1.76e-04 | ga

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

step 10 | loss 2.5027 | lr 8.80e-05 | gates=[0.27734375, 0.279296875] | mix=hf
step 20 | loss 2.2774 | lr 1.68e-04 | gates=[0.26953125, 0.271484375] | mix=hf
step 30 | loss 6.3625 | lr 2.00e-04 | gates=[0.28125, 0.2890625] | mix=nback
step 40 | loss 6.3541 | lr 1.99e-04 | gates=[0.28515625, 0.29296875] | mix=copy
step 50 | loss 1.6552 | lr 1.99e-04 | gates=[0.240234375, 0.2470703125] | mix=hf
step 60 | loss 1.6690 | lr 1.97e-04 | gates=[0.23828125, 0.2451171875] | mix=hf
step 70 | loss 2.0622 | lr 1.95e-04 | gates=[0.2275390625, 0.2373046875] | mix=hf
step 80 | loss 1.6111 | lr 1.93e-04 | gates=[0.220703125, 0.232421875] | mix=hf
step 90 | loss 1.6116 | lr 1.91e-04 | gates=[0.2158203125, 0.2294921875] | mix=hf
step 100 | loss 5.7697 | lr 1.88e-04 | gates=[0.28125, 0.30859375] | mix=copy
step 110 | loss 3.2152 | lr 1.84e-04 | gates=[0.2021484375, 0.2216796875] | mix=hf
step 120 | loss 2.9168 | lr 1.81e-04 | gates=[0.20703125, 0.2294921875] | mix=hf
step 130 | loss 5.2972 | lr 1.76e-04 |

(DNCFormerHead(
   (base): Phi3ForCausalLM(
     (model): Phi3Model(
       (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
       (embed_dropout): Dropout(p=0.0, inplace=False)
       (layers): ModuleList(
         (0-31): 32 x Phi3DecoderLayer(
           (self_attn): Phi3Attention(
             (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
             (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
             (rotary_emb): Phi3RotaryEmbedding()
           )
           (mlp): Phi3MLP(
             (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
             (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
             (activation_fn): SiLU()
           )
           (input_layernorm): Phi3RMSNorm()
           (resid_attn_dropout): Dropout(p=0.0, inplace=False)
           (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
           (post_attention_layernorm): Phi3RMSNorm()
         )
    

In [121]:
free_head_and_cache()

[after free] alloc=9.64 GB | reserved=10.39 GB | free=13.98 GB | total=25.77 GB


## 9. Experiments - Stage 2 - tiered/parallel memory systems

#### 9.1 E10-14

In [42]:
def set_e10_multi_experts(K=2, N_each=64, temp=0.9, lambda_div=1e-3):
    CFG.mem_experts = int(K)
    CFG.expert_N = [int(N_each)]*K
    CFG.expert_W = getattr(CFG, "W", 64)
    CFG.expert_R = getattr(CFG, "R", 1)
    CFG.expert_gate_temp = float(temp)
    CFG.expert_diversity_lambda = float(lambda_div)
    # Keep fusion off for a pure E10 test
    CFG.fusion_enable = False

def set_e11_tiered(which="a"):
    # E11a: many-narrow early, few-wide late
    if which == "a":
        CFG.per_block_cfg = [
            {"N":128, "W":32, "R":1, "gate_temp":1.0, "free_bias": +0.3},
            {"N": 64, "W":64, "R":1, "gate_temp":0.8, "free_bias": -0.2},
        ]
    # E11b: classic - longer horizon late
    else:
        CFG.per_block_cfg = [
            {"N": 64, "W":32, "R":1, "gate_temp":1.0, "free_bias": +0.3},
            {"N":128, "W":64, "R":2, "gate_temp":0.8, "free_bias": -0.2},
        ]
    CFG.mem_experts = 1   # keep single memory per block for E11
    CFG.fusion_enable = False

def set_e12_fusion(enable=True, hidden_mult=2.0, drop=0.0):
    CFG.fusion_enable = bool(enable)
    CFG.fusion_hidden_mult = float(hidden_mult)
    CFG.fusion_drop = float(drop)

def set_e13_regs(write_lambda=1e-4, overlap_lambda=0.0, on_mem_only=True):
    CFG.write_reg_lambda = float(write_lambda)
    CFG.key_overlap_lambda = float(overlap_lambda)  # currently inert
    CFG.key_overlap_window = int(getattr(CFG, "key_overlap_window", 1))
    CFG.reg_only_on_memory_batches = bool(on_mem_only)

def set_e14_curriculum(S=200):
    # extreme warm-start: 0 HF early
    CFG.mixture_schedule = [(S, (0.0, 0.34, 0.33, 0.33)), (None, (0.4, 0.2, 0.2, 0.2))]
    CFG.gate_temp_schedule = [(S, 0.8), (None, 1.0)]



In [50]:
# --- E10–E14 Sweep Driver (uses run_one_labeled + setters) ---

# Keys we may touch in per-experiment setters; snapshot/restore to avoid bleed-over
_EXP_KEYS = [
    "mem_experts", "expert_N", "expert_W", "expert_R", "expert_gate_temp", "expert_diversity_lambda",
    "per_block_cfg", "fusion_enable", "fusion_hidden_mult", "fusion_drop",
    "mixture_schedule", "gate_temp_schedule", "gate_reg_schedule",
    "write_reg_lambda", "key_overlap_lambda", "key_overlap_window", "reg_only_on_memory_batches",
    "force_g", "gate_temp", "gate_reg_lambda"
]

def _snap_exp_cfg():
    return {k: getattr(CFG, k, None) for k in _EXP_KEYS}

def _restore_exp_cfg(snap):
    for k, v in snap.items():
        setattr(CFG, k, v)

def run_e10_14_sweep(
    steps: int = None,
    seeds = (1337,),
    base_mix = (0.4, 0.2, 0.2, 0.2),
    warmup: int = None,
    post_haystack: bool = False
):
    """
    E10: multi-memory experts (K=2) with modest diversity penalty
    E11a/b: tiered memories (many/narrow early vs few/deep late)
    E12: read-to-attention fusion enabled
    E13: write-sparsity regs (if your training loop consumes them)
    E14: curriculum warm-start (HF=0 early, then baseline)
    """
    steps  = int(steps  if steps  is not None else getattr(CFG, "train_steps", 500))
    warmup = int(warmup if warmup is not None else max(10, steps // 20))

    for seed in seeds:
        # E10 ─ Multi-memory experts (K=2)
        snap = _snap_exp_cfg()
        set_e10_multi_experts(K=2, N_each=64, temp=0.9, lambda_div=1e-3)
        run_one_labeled(
            label=f"E10_multi_experts_k2_s{seed}_{int(time.time())}",
            steps=steps,
            mixture_weights=base_mix,
            seed=seed,
            mixture_schedule=None, gate_temp_schedule=None, gate_reg_schedule=None,
            post_haystack=post_haystack
        )
        _restore_exp_cfg(snap)

        # E11a ─ Tiered memories (many/narrow early, few/deep late)
        snap = _snap_exp_cfg()
        set_e11_tiered("a")
        run_one_labeled(
            label=f"E11a_tiered_many_narrow_s{seed}_{int(time.time())}",
            steps=steps,
            mixture_weights=base_mix,
            seed=seed,
            post_haystack=post_haystack
        )
        _restore_exp_cfg(snap)

        # E11b ─ Tiered memories (classic: longer horizon late, deeper W/R)
        snap = _snap_exp_cfg()
        set_e11_tiered("b")
        run_one_labeled(
            label=f"E11b_tiered_deeper_late_s{seed}_{int(time.time())}",
            steps=steps,
            mixture_weights=base_mix,
            seed=seed,
            post_haystack=post_haystack
        )
        _restore_exp_cfg(snap)

        # # E12 ─ Fusion (read‑hint → residual MLP on vanilla path)
        # snap = _snap_exp_cfg()
        # set_e12_fusion(enable=True, hidden_mult=2.0, drop=0.0)
        # run_one_labeled(
        #     label=f"E12_fusion_on_s{seed}_{int(time.time())}",
        #     steps=steps,
        #     mixture_weights=base_mix,
        #     seed=seed,
        #     post_haystack=post_haystack
        # )
        # _restore_exp_cfg(snap)

        # E13 ─ Write/overlap regs (NOTE: only effective if your train loop uses them)
        snap = _snap_exp_cfg()
        set_e13_regs(write_lambda=1e-4, overlap_lambda=0.0, on_mem_only=True)
        run_one_labeled(
            label=f"E13_write_sparsity_s{seed}_{int(time.time())}",
            steps=steps,
            mixture_weights=base_mix,
            seed=seed,
            post_haystack=post_haystack
        )
        _restore_exp_cfg(snap)

        # E14 ─ Curriculum (0 HF early, then back to baseline)
        snap = _snap_exp_cfg()
        set_e14_curriculum(S=max(50, steps // 5))   # early phase ~20% of steps
        run_one_labeled(
            label=f"E14_curriculum_warmstart_s{seed}_{int(time.time())}",
            steps=steps,
            mixture_weights=base_mix,
            seed=seed,
            mixture_schedule=getattr(CFG, "mixture_schedule", None),
            gate_temp_schedule=getattr(CFG, "gate_temp_schedule", None),
            gate_reg_schedule=getattr(CFG, "gate_reg_schedule", None),
            post_haystack=post_haystack
        )
        _restore_exp_cfg(snap)

In [51]:
# Quick launch presets (adjust STEPS/SEEDS as needed)

# Baseline knobs (kept in sync with your earlier experiments)
EXP_STEPS  = 1000
BASE_MIX   = (0.4, 0.2, 0.2, 0.2)
SEEDS      = (1337,)            # use (1337, 2027, 4242) for 3× repeats
POST_HAY   = False              # set True to run short haystack after each run

In [52]:
# Full E10–E14 sweep
run_e10_14_sweep(steps=EXP_STEPS, seeds=SEEDS, base_mix=BASE_MIX, post_haystack=POST_HAY)


=== E10_multi_experts_k2_s1337_1755938361 | seed=1337 ===
TB run started: ./runs\dncformer-20250823-013921-E10_multi_experts_k2_s1337_1755938361
CFG.gate_temp: 1.0 | CFG.gate_reg_lambda: 0.0 | mixture: (0.4, 0.2, 0.2, 0.2)
[after free] alloc=0.01 GB | reserved=0.02 GB | free=24.36 GB | total=25.77 GB
[before E10_multi_experts_k2_s1337_1755938361] alloc=0.01 GB | reserved=0.02 GB | free=24.36 GB | total=25.77 GB


`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  self.gen = func(*args, **kwds)
You are not running the flash-attention implementation, expect numerical differences.


  [experts] block0 top=0 pi=[0.8813047409057617, 0.059696048498153687, 0.05899922549724579]
step 10 | loss 9.9201 | lr 4.40e-05 | gates=[0.11869527399539948, 0.29810744524002075] | mix=copy
  [experts] block0 top=0 pi=[0.8175254464149475, 0.08705279231071472, 0.09542176127433777]
step 20 | loss 9.8345 | lr 8.40e-05 | gates=[0.1824745535850525, 0.30169153213500977] | mix=copy
  [experts] block0 top=0 pi=[0.9532410502433777, 0.01264512725174427, 0.0341138020157814]
step 30 | loss 1.9990 | lr 1.24e-04 | gates=[0.046758927404880524, 0.3054656982421875] | mix=hf
  [experts] block0 top=0 pi=[0.8043654561042786, 0.10710994899272919, 0.08852458745241165]
step 40 | loss 7.5039 | lr 1.64e-04 | gates=[0.19563452899456024, 0.3201567530632019] | mix=copy
  [experts] block0 top=0 pi=[0.9803692698478699, 0.01403859630227089, 0.0055921003222465515]
step 50 | loss 6.8181 | lr 2.00e-04 | gates=[0.01963069662451744, 0.9948534965515137] | mix=copy
  [experts] block0 top=0 pi=[0.9995763301849365, 0.0003516

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.827828586101532, 0.17217141389846802]
step 10 | loss 9.2175 | lr 4.40e-05 | gates=[0.17217141389846802, 0.41098305583000183] | mix=copy
  [experts] block0 top=0 pi=[0.9850215315818787, 0.014978446997702122]
step 20 | loss 1.8815 | lr 8.40e-05 | gates=[0.014978446066379547, 0.11323383450508118] | mix=hf
  [experts] block0 top=1 pi=[0.12222029268741608, 0.8777797222137451]
step 30 | loss 5.9493 | lr 1.24e-04 | gates=[0.8777797222137451, 0.2587236762046814] | mix=nback
  [experts] block0 top=0 pi=[0.8507078886032104, 0.14929214119911194]
step 40 | loss 6.1915 | lr 1.64e-04 | gates=[0.14929214119911194, 0.9159359931945801] | mix=repeat
  [experts] block0 top=0 pi=[0.8856296539306641, 0.11437037587165833]
step 50 | loss 6.2415 | lr 2.00e-04 | gates=[0.11437037587165833, 0.5090433955192566] | mix=copy
  [experts] block0 top=0 pi=[0.9974364638328552, 0.0025634909979999065]
step 60 | loss 2.5038 | lr 2.00e-04 | gates=[0.00256349123083055, 0.5222287774085999] | mi

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.8276891708374023, 0.17231079936027527]
step 10 | loss 9.2188 | lr 4.40e-05 | gates=[0.17231079936027527, 0.410535991191864] | mix=copy
  [experts] block0 top=0 pi=[0.9848330020904541, 0.015167055651545525]
step 20 | loss 1.8858 | lr 8.40e-05 | gates=[0.015167055651545525, 0.11637142300605774] | mix=hf
  [experts] block0 top=1 pi=[0.1882839798927307, 0.8117159605026245]
step 30 | loss 5.8030 | lr 1.24e-04 | gates=[0.8117160201072693, 0.3568115234375] | mix=nback
  [experts] block0 top=0 pi=[0.700255274772644, 0.29974478483200073]
step 40 | loss 6.1163 | lr 1.64e-04 | gates=[0.29974478483200073, 0.8269035816192627] | mix=repeat
  [experts] block0 top=0 pi=[0.883404016494751, 0.11659602075815201]
step 50 | loss 5.7706 | lr 2.00e-04 | gates=[0.11659602075815201, 0.9920541048049927] | mix=copy
  [experts] block0 top=0 pi=[0.9125295281410217, 0.08747048676013947]
step 60 | loss 4.3830 | lr 2.00e-04 | gates=[0.08747047930955887, 0.13603438436985016] | mix=hf
  [

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.8813318610191345, 0.05967889726161957, 0.05898924544453621]
step 10 | loss 9.9197 | lr 4.40e-05 | gates=[0.11866813898086548, 0.29818278551101685] | mix=copy
  [experts] block0 top=0 pi=[0.8176863193511963, 0.08694931119680405, 0.09536437690258026]
step 20 | loss 9.8345 | lr 8.40e-05 | gates=[0.1823136955499649, 0.3018989861011505] | mix=copy
  [experts] block0 top=0 pi=[0.9536228775978088, 0.012637387029826641, 0.03373980149626732]
step 30 | loss 1.9897 | lr 1.24e-04 | gates=[0.04637718200683594, 0.3020365238189697] | mix=hf
  [experts] block0 top=0 pi=[0.7972662448883057, 0.11007903516292572, 0.09265470504760742]
step 40 | loss 7.4993 | lr 1.64e-04 | gates=[0.20273372530937195, 0.3123947083950043] | mix=copy
  [experts] block0 top=0 pi=[0.974845826625824, 0.018635254353284836, 0.006518871523439884]
step 50 | loss 6.7290 | lr 2.00e-04 | gates=[0.025154123082756996, 0.9936141967773438] | mix=copy
  [experts] block0 top=0 pi=[0.9994534254074097, 0.00045155

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.842902660369873, 0.08819584548473358, 0.06890152394771576]
step 10 | loss 7.2841 | lr 4.40e-05 | gates=[0.15709735453128815, 0.687245786190033] | mix=nback
  [experts] block0 top=1 pi=[0.4024929404258728, 0.4208603501319885, 0.17664667963981628]
step 20 | loss 7.0734 | lr 8.40e-05 | gates=[0.5975069999694824, 0.6532856225967407] | mix=copy
  [experts] block0 top=0 pi=[0.9398799538612366, 0.024742722511291504, 0.035377345979213715]
step 30 | loss 6.2738 | lr 1.24e-04 | gates=[0.06012007221579552, 0.7696507573127747] | mix=repeat
  [experts] block0 top=0 pi=[0.9718606472015381, 0.010401194915175438, 0.01773819327354431]
step 40 | loss 5.5504 | lr 1.64e-04 | gates=[0.0281393900513649, 0.7790657877922058] | mix=repeat
  [experts] block0 top=0 pi=[0.984597384929657, 0.005642917938530445, 0.009759693406522274]
step 50 | loss 5.9330 | lr 2.00e-04 | gates=[0.015402611345052719, 0.7496401071548462] | mix=copy
  [experts] block0 top=0 pi=[0.9900622963905334, 0.0028

#### 9.2 E15 a/b

In [54]:
# E15 configuration helpers (on top of the E11b tiered baseline)

def set_e11b_baseline():
    """
    E11b: tiered memories (shallower early, deeper late).
    Leaves base model frozen and keeps 1 memory expert (K=1), fusion disabled.
    """
    # Per-block memory shapes + mild gate temp biasing; free_bias nudges write-gate
    CFG.per_block_cfg = [
        {"N": 64,  "W": 32, "R": 1, "gate_temp": 1.0, "free_bias": +0.30},  # block 0 (shallower)
        {"N": 128, "W": 64, "R": 2, "gate_temp": 0.9, "free_bias": -0.20},  # block 1 (deeper)
    ]
    # Single memory per block
    CFG.mem_experts = 1
    # No fusion (we found it heavy/unstable on 24GB)
    CFG.fusion_enable = False

    # Keep other knobs sane but unobtrusive
    CFG.gate_reg_lambda = getattr(CFG, "gate_reg_lambda", 2e-4)   # mild entropy-ish gate reg
    CFG.write_reg_lambda = 0.0                                    # no write sparsity in baseline
    CFG.key_overlap_lambda = getattr(CFG, "key_overlap_lambda", 0.0)

    # Clear expert overrides so blocks use their own N/W
    for k in ("expert_N", "expert_W", "expert_gate_temp"):
        if hasattr(CFG, k):
            try: delattr(CFG, k)
            except Exception: pass

    # No forced gating and no mixture/temperature schedules for the baseline
    for k in ("force_g", "mixture_schedule", "gate_temp_schedule", "gate_reg_schedule"):
        if hasattr(CFG, k):
            try: delattr(CFG, k)
            except Exception: pass


def set_e15a_write_sparse_light(lambda_write: float = 5e-5):
    """
    E15a = E11b + light write-sparsity penalty (kept small to preserve HF quality).
    """
    set_e11b_baseline()
    CFG.write_reg_lambda = float(lambda_write)  # train loop should already include this term from E13 work


def set_e15b_two_experts_smallW(K: int = 2, expert_W: int = 32, expert_gate_temp: float = 1.0):
    """
    E15b = E11b + K=2 memory experts per block with small W to keep VRAM in check.
    We let each block keep its own N from per_block_cfg; only override W for experts.
    """
    set_e11b_baseline()
    CFG.mem_experts = int(K)
    # Use each block's N (per_block_cfg) by not setting expert_N at all.
    CFG.expert_W = int(expert_W)                   # experts use narrower read/write width
    CFG.expert_gate_temp = float(expert_gate_temp) # softer, avoids hard saturation

    # Keep write reg off here to isolate effect of experts
    CFG.write_reg_lambda = 0.0


In [57]:
import time, json, random, numpy as np, torch

def run_e15_one(label: str,
                setup_fn,
                steps: int = 1000,
                seed: int = 1337,
                mixture_weights=(0.4, 0.2, 0.2, 0.2),
                warmup_steps: int = None,
                post_haystack: bool = False):
    """
    Single E15 run wrapper: sets seed, applies setup_fn (E15a/b), starts dedicated TB run,
    and calls train_experiment with your current training loop.
    """
    print(f"\n=== {label} | seed={seed} ===")

    # Repro
    try:
        set_seed(seed)
    except NameError:
        random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
        if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

    # Apply config
    setup_fn()

    # Unique TB run for this label/seed
    try:
        start_tb_run(label)
    except Exception as e:
        print("[warn] start_tb_run failed or unavailable:", e)

    # Echo config snapshot to TB if available
    if 'tb' in globals():
        try:
            tb.add_text("run/config", json.dumps({
                "label": label,
                "seed": seed,
                "steps": steps,
                "mixture": list(mixture_weights),
                "per_block_cfg": getattr(CFG, "per_block_cfg", None),
                "mem_experts": getattr(CFG, "mem_experts", 1),
                "expert_W": getattr(CFG, "expert_W", None),
                "expert_gate_temp": getattr(CFG, "expert_gate_temp", None),
                "gate_reg_lambda": getattr(CFG, "gate_reg_lambda", None),
                "write_reg_lambda": getattr(CFG, "write_reg_lambda", None),
            }, indent=2), 0)
        except Exception as e:
            print("[warn] TB add_text skipped:", e)

    # Clean VRAM between runs
    try: free_head_and_cache()
    except Exception: pass
    if 'cuda_report' in globals(): cuda_report(f"before {label}")
    time.sleep(1.2)  # ensure distinct TB run dirs (timestamp granularity)

    # Train
    head, tok = train_experiment(
        steps=steps,
        warmup_steps=(warmup_steps if warmup_steps is not None else max(10, steps//20)),
        mixture_weights=mixture_weights,
        mixture_schedule=getattr(CFG, "mixture_schedule", None),
        gate_temp_schedule=getattr(CFG, "gate_temp_schedule", None),
        gate_reg_schedule=getattr(CFG, "gate_reg_schedule", None),
        viz_memory_after=False,
    )

    # Optional post‑eval
    if post_haystack:
        try:
            evaluate_haystack(head, steps=50, batch=16, T=256, vocab=1024, tb_step=steps, fast=True)
        except Exception as e:
            print("[warn] haystack eval skipped:", e)

    if 'cuda_report' in globals(): cuda_report(f"after  {label}")
    try: free_head_and_cache()
    except Exception: pass

    return head, tok


def run_e15_suite(steps: int = 1000, seeds=(1337, 2027, 4242),
                  mixture_weights=(0.4, 0.2, 0.2, 0.2),
                  warmup_steps: int = None,
                  include_haystack: bool = True):
    """
    Launch E15a (write-sparse‑light) and E15b (two experts with small W) for the given seeds.
    """
    results = []
    for s in seeds:
        # E15a
        label_a = f"E15a_write_sparse_light_s{s}_{int(time.time())}"
        run_e15_one(label_a, set_e15a_write_sparse_light,
                    steps=steps, seed=s,
                    mixture_weights=mixture_weights,
                    warmup_steps=warmup_steps,
                    post_haystack=include_haystack)

        # E15b
        label_b = f"E15b_two_experts_smallW_s{s}_{int(time.time())}"
        run_e15_one(label_b, set_e15b_two_experts_smallW,
                    steps=steps, seed=s,
                    mixture_weights=mixture_weights,
                    warmup_steps=warmup_steps,
                    post_haystack=include_haystack)

        results.append((label_a, label_b))
        # small spacer to avoid TB dir collisions on fast filesystems
        time.sleep(1.2)
    return results

In [59]:
run_e15_suite(steps=1000, seeds=(1337,2027,4242), include_haystack=False)


=== E15a_write_sparse_light_s1337_1755981882 | seed=1337 ===
TB run started: ./runs\dncformer-20250823-134442-E15a_write_sparse_light_s1337_1755981882
[after free] alloc=0.02 GB | reserved=0.03 GB | free=24.33 GB | total=25.77 GB
[before E15a_write_sparse_light_s1337_1755981882] alloc=0.02 GB | reserved=0.03 GB | free=24.33 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  self.gen = func(*args, **kwds)


  [experts] block0 top=0 pi=[0.7963446378707886, 0.20365534722805023]
step 10 | loss 9.2285 | lr 4.40e-05 | gates=[0.20365534722805023, 0.36654311418533325] | mix=copy
  [experts] block0 top=0 pi=[0.97617107629776, 0.023828942328691483]
step 20 | loss 1.8761 | lr 8.40e-05 | gates=[0.023828942328691483, 0.1664227694272995] | mix=hf
  [experts] block0 top=1 pi=[0.07069626450538635, 0.929303765296936]
step 30 | loss 6.2480 | lr 1.24e-04 | gates=[0.9293037056922913, 0.12432555109262466] | mix=nback
  [experts] block0 top=0 pi=[0.9162425994873047, 0.08375738561153412]
step 40 | loss 6.0591 | lr 1.64e-04 | gates=[0.08375739306211472, 0.7469349503517151] | mix=repeat
  [experts] block0 top=0 pi=[0.544292151927948, 0.4557078778743744]
step 50 | loss 6.6627 | lr 2.00e-04 | gates=[0.455707848072052, 0.15452229976654053] | mix=copy
  [experts] block0 top=0 pi=[0.7967431545257568, 0.20325683057308197]
step 60 | loss 2.4098 | lr 2.00e-04 | gates=[0.20325683057308197, 0.001003516255877912] | mix=hf


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.9076195955276489, 0.0313478522002697, 0.061032485216856]
step 10 | loss 2.2204 | lr 4.40e-05 | gates=[0.0923803374171257, 0.1458486020565033] | mix=hf
  [experts] block0 top=0 pi=[0.9883792400360107, 0.0071257054805755615, 0.004495005123317242]
step 20 | loss 7.1180 | lr 8.40e-05 | gates=[0.011620711535215378, 0.7565804123878479] | mix=nback
  [experts] block0 top=0 pi=[0.9990176558494568, 0.0004991888999938965, 0.00048318770132027566]
step 30 | loss 1.9076 | lr 1.24e-04 | gates=[0.0009823766304180026, 0.28592449426651] | mix=hf
  [experts] block0 top=0 pi=[0.9561795592308044, 0.03500876575708389, 0.008811662904918194]
step 40 | loss 7.9232 | lr 1.64e-04 | gates=[0.04382042586803436, 0.31809091567993164] | mix=copy
  [experts] block0 top=0 pi=[0.989941418170929, 0.008588459342718124, 0.0014700754545629025]
step 50 | loss 6.1138 | lr 2.00e-04 | gates=[0.010058535262942314, 0.7343466281890869] | mix=nback
  [experts] block0 top=0 pi=[0.9955825805664062, 0.0

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.6549537181854248, 0.3450463116168976]
step 10 | loss 9.0408 | lr 4.40e-05 | gates=[0.3450462818145752, 0.3071179986000061] | mix=repeat
  [experts] block0 top=0 pi=[0.9978355169296265, 0.0021645643282681704]
step 20 | loss 4.5686 | lr 8.40e-05 | gates=[0.002164564561098814, 0.7058046460151672] | mix=hf
  [experts] block0 top=1 pi=[0.415398508310318, 0.5846015214920044]
step 30 | loss 6.5673 | lr 1.24e-04 | gates=[0.5846015214920044, 0.32852187752723694] | mix=repeat
  [experts] block0 top=0 pi=[0.6869228482246399, 0.3130771517753601]
step 40 | loss 6.1414 | lr 1.64e-04 | gates=[0.3130771517753601, 0.4550093710422516] | mix=repeat
  [experts] block0 top=1 pi=[0.02594377100467682, 0.9740562438964844]
step 50 | loss 5.9816 | lr 2.00e-04 | gates=[0.9740562438964844, 0.2442379593849182] | mix=copy
  [experts] block0 top=1 pi=[0.013701247051358223, 0.9862987399101257]
step 60 | loss 6.8492 | lr 2.00e-04 | gates=[0.9862987399101257, 0.8319540023803711] | mix=rep

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.8296784162521362, 0.08364704251289368, 0.08667456358671188]
step 10 | loss 2.5194 | lr 4.40e-05 | gates=[0.17032159864902496, 0.2086634337902069] | mix=hf
  [experts] block0 top=0 pi=[0.9660220146179199, 0.02369740419089794, 0.01028058584779501]
step 20 | loss 7.3184 | lr 8.40e-05 | gates=[0.03397798910737038, 0.7729206085205078] | mix=nback
  [experts] block0 top=0 pi=[0.9991491436958313, 0.0005307519459165633, 0.00032010592985898256]
step 30 | loss 1.6272 | lr 1.24e-04 | gates=[0.0008508579339832067, 0.009282042272388935] | mix=hf
  [experts] block0 top=0 pi=[0.9862990379333496, 0.008131793700158596, 0.005569151137024164]
step 40 | loss 5.8134 | lr 1.64e-04 | gates=[0.013700945302844048, 0.8488316535949707] | mix=nback
  [experts] block0 top=0 pi=[0.9999624490737915, 2.5839719455689192e-05, 1.1692439329635818e-05]
step 50 | loss 2.3644 | lr 2.00e-04 | gates=[3.753215787583031e-05, 0.05808749049901962] | mix=hf
  [experts] block0 top=0 pi=[0.999071598052

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.6655751466751099, 0.33442485332489014]
step 10 | loss 9.1353 | lr 4.40e-05 | gates=[0.33442485332489014, 0.2681189179420471] | mix=repeat
  [experts] block0 top=1 pi=[0.08732841908931732, 0.9126715660095215]
step 20 | loss 6.9223 | lr 8.40e-05 | gates=[0.9126715660095215, 0.03181570768356323] | mix=nback
  [experts] block0 top=1 pi=[0.10383504629135132, 0.8961649537086487]
step 30 | loss 5.7412 | lr 1.24e-04 | gates=[0.8961649537086487, 0.14861273765563965] | mix=nback
  [experts] block0 top=1 pi=[0.06477998197078705, 0.9352200031280518]
step 40 | loss 5.9549 | lr 1.64e-04 | gates=[0.9352200031280518, 0.024931583553552628] | mix=nback
  [experts] block0 top=1 pi=[0.2397720366716385, 0.7602279782295227]
step 50 | loss 9.4520 | lr 2.00e-04 | gates=[0.7602279186248779, 0.028303522616624832] | mix=copy
  [experts] block0 top=0 pi=[0.986657977104187, 0.013342034071683884]
step 60 | loss 6.2780 | lr 2.00e-04 | gates=[0.013342034071683884, 0.7727603316307068] | 

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.6075003743171692, 0.276397705078125, 0.11610183864831924]
step 10 | loss 3.3222 | lr 4.40e-05 | gates=[0.39249953627586365, 0.39013808965682983] | mix=hf
  [experts] block0 top=0 pi=[0.9980084300041199, 0.000596475088968873, 0.00139511632733047]
step 20 | loss 1.8696 | lr 8.40e-05 | gates=[0.001991591416299343, 0.005576163996011019] | mix=hf
  [experts] block0 top=0 pi=[0.9942731857299805, 0.004270132631063461, 0.0014566569589078426]
step 30 | loss 8.0889 | lr 1.24e-04 | gates=[0.005726790055632591, 0.9509265422821045] | mix=copy
  [experts] block0 top=0 pi=[0.9545009136199951, 0.028429334983229637, 0.0170697383582592]
step 40 | loss 6.8222 | lr 1.64e-04 | gates=[0.04549907520413399, 0.9673513770103455] | mix=repeat
  [experts] block0 top=0 pi=[0.9915063381195068, 0.005374231841415167, 0.00311942002736032]
step 50 | loss 6.3531 | lr 2.00e-04 | gates=[0.008493651635944843, 0.9913922548294067] | mix=repeat
  [experts] block0 top=0 pi=[0.9997721910476685, 5.

[('E15a_write_sparse_light_s1337_1755981882',
  'E15b_two_experts_smallW_s1337_1755983525'),
 ('E15a_write_sparse_light_s2027_1755987068',
  'E15b_two_experts_smallW_s2027_1755988655'),
 ('E15a_write_sparse_light_s4242_1755992098',
  'E15b_two_experts_smallW_s4242_1755994049')]

In [96]:
free_head_and_cache()

[after free] alloc=21.38 GB | reserved=21.43 GB | free=2.74 GB | total=25.77 GB


In [92]:
# --- E16: composite configs (multi-memory experts + tiered per-block + regs + curriculum) ---

def set_e16a(
    steps:int,
    K:int=2,
    expert_N_each:int=64,
    expert_W:int=None,     # None -> use CFG.dnc_cell_size
    expert_R:int=1,
    gate_temp:float=0.9,
    diversity_lambda:float=1e-3,
    warm_stage:int=None,   # None -> auto = min(steps//4, 250)
):
    """
    E16a: early block = many/narrow, later = fewer/wider.
    + two memory experts per block, diversity reg, and curriculum warm-start (memory-only early).
    """
    # Defaults derived from current CFG to avoid surprises
    W_default = getattr(CFG, "dnc_cell_size", 64)
    expert_W  = int(expert_W or W_default)
    warm_S    = int(warm_stage if warm_stage is not None else min(steps // 4, 250))

    # core knobs
    CFG.mem_experts = int(K)
    CFG.expert_N    = [int(expert_N_each)] * K
    CFG.expert_W    = int(expert_W)
    CFG.expert_R    = int(expert_R)
    CFG.expert_gate_temp       = float(gate_temp)
    CFG.expert_diversity_lambda = float(diversity_lambda)

    # per-block tiering (2 blocks assumed; adjust if CFG.n_blocks != 2)
    CFG.per_block_cfg = [
        {"N":128, "W":32, "R":1, "gate_temp":1.0, "free_bias": +0.30},  # block 0: many & shallow
        {"N": 64, "W":64, "R":1, "gate_temp":0.8, "free_bias": -0.20},  # block 1: fewer & deeper
    ]

    # curriculum: force memory exploration first, then revert to baseline mix
    CFG.mixture_schedule    = [(warm_S, (0.0, 0.34, 0.33, 0.33)), (None, (0.4, 0.2, 0.2, 0.2))]
    CFG.gate_temp_schedule  = [(warm_S, 0.8), (None, 1.0)]
    # keep your existing gate_reg_lambda from CFG; or optionally:
    # CFG.gate_reg_schedule = [(warm_S, getattr(CFG, "gate_reg_lambda", 2e-4)), (None, getattr(CFG, "gate_reg_lambda", 2e-4))]

    # keep fusion OFF in E16 composites unless explicitly tested
    CFG.fusion_enable = False


def set_e16b(
    steps:int,
    K:int=2,
    expert_N_each:int=64,
    expert_W:int=None,
    expert_R:int=1,
    gate_temp:float=0.8,
    diversity_lambda:float=1e-3,
    warm_stage:int=None,
):
    """
    E16b: reversed tiering—early block fewer/wider, later block many/narrow.
    Other ingredients same as E16a.
    """
    W_default = getattr(CFG, "dnc_cell_size", 64)
    expert_W  = int(expert_W or W_default)
    warm_S    = int(warm_stage if warm_stage is not None else min(steps // 4, 250))

    CFG.mem_experts = int(K)
    CFG.expert_N    = [int(expert_N_each)] * K
    CFG.expert_W    = int(expert_W)
    CFG.expert_R    = int(expert_R)
    CFG.expert_gate_temp       = float(gate_temp)
    CFG.expert_diversity_lambda = float(diversity_lambda)

    CFG.per_block_cfg = [
        {"N": 64, "W":64, "R":1, "gate_temp":1.0, "free_bias": -0.20},  # block 0: fewer & deeper
        {"N":128, "W":32, "R":1, "gate_temp":0.9, "free_bias": +0.30},  # block 1: many & shallow
    ]

    CFG.mixture_schedule    = [(warm_S, (0.0, 0.34, 0.33, 0.33)), (None, (0.4, 0.2, 0.2, 0.2))]
    CFG.gate_temp_schedule  = [(warm_S, 0.8), (None, 1.0)]
    CFG.fusion_enable = False

In [93]:
# --- E16 runners: convenience wrappers ---

def _set_seed_all(seed:int):
    import random, numpy as _np, torch as _t
    random.seed(seed); _np.random.seed(seed); _t.manual_seed(seed)
    if _t.cuda.is_available():
        _t.cuda.manual_seed_all(seed)

def run_e16_once(label:str, steps:int, mode:str="a", seed:int=1337):
    """
    mode='a' -> set_e16a; mode='b' -> set_e16b
    Creates a unique TB run for each labeled call.
    """
    _set_seed_all(int(seed))
    if mode.lower().startswith("a"):
        set_e16a(steps=steps)
    else:
        set_e16b(steps=steps)

    # housekeeping for clean VRAM
    free_head_and_cache()
    if 'cuda_report' in globals(): cuda_report(f"before {label}-s{seed}")
    time.sleep(1.2)  # keep distinct TB dirs

    # start TB run; your start_tb_run(label) helper sets tb writer accordingly
    start_tb_run(f"{label}-s{seed}")

    # optional: echo cfg for reproducibility
    if TB_AVAILABLE and 'tb' in globals():
        tb.add_text("run/config/E16", json.dumps({
            "label": label, "mode": mode, "steps": steps, "seed": seed,
            "mem_experts": getattr(CFG, "mem_experts", 1),
            "per_block_cfg": getattr(CFG, "per_block_cfg", None),
            "mixture_schedule": getattr(CFG, "mixture_schedule", None),
            "gate_temp_schedule": getattr(CFG, "gate_temp_schedule", None),
            "diversity_lambda": getattr(CFG, "expert_diversity_lambda", 0.0),
        }, indent=2), 0)

    head, tok = train_experiment(
        steps=steps,
        warmup_steps=max(10, steps//20),
        mixture_weights=(0.4, 0.2, 0.2, 0.2),      # baseline mix; E16 sets schedules on CFG
        mixture_schedule=getattr(CFG, "mixture_schedule", None),
        gate_temp_schedule=getattr(CFG, "gate_temp_schedule", None),
        gate_reg_schedule=getattr(CFG, "gate_reg_schedule", None),
        viz_memory_after=False,
    )

    if 'cuda_report' in globals(): cuda_report(f"after  {label}-s{seed}")
    free_head_and_cache()
    return head, tok

def run_e16_sweep(steps:int=1000, seeds=(1337, 2027, 4242)):
    """
    Launch E16a and E16b across seeds. Adjust steps to taste for overnight runs.
    """
    results = {}
    for s in seeds:
        print(f"\n=== E16a (tier A) | seed={s} ===")
        results[(f"E16a", s)] = run_e16_once(label="E16a", steps=steps, mode="a", seed=s)
    for s in seeds:
        print(f"\n=== E16b (tier B) | seed={s} ===")
        results[(f"E16b", s)] = run_e16_once(label="E16b", steps=steps, mode="b", seed=s)
    return results


In [94]:
# --- E16 smoke: build config + single forward on tiny batch ---
@torch.no_grad()
def smoke_e16_forward(mode="a"):
    steps = 20
    if mode.lower().startswith("a"):
        set_e16a(steps=steps, K=2, expert_N_each=32)  # smaller N for a very quick pass
    else:
        set_e16b(steps=steps, K=2, expert_N_each=32)

    tok, base = load_base_model(CFG.base_model_id)
    requires_grad_(base, False)
    head = DNCFormerHead(base, CFG).to(device)
    head.eval()

    B, T = 2, 8
    vocab_guess = getattr(base.get_input_embeddings(), "num_embeddings", 1024)
    x = torch.randint(0, min(1024, vocab_guess), (B, T), device=device)
    logits, gates, aux = head.forward_with_metrics(x)

    assert logits.shape[:2] == (B, T)
    assert isinstance(gates, (list, tuple)) and len(gates) == getattr(CFG, "n_blocks", len(head.blocks))
    # sanity: ensure expert metrics are present
    for i, blk in enumerate(head.blocks):
        if hasattr(blk, "last_metrics") and blk.last_metrics:
            _ = blk.last_metrics.get("experts_pi_entropy", None)
    print(f"[smoke_e16_forward] mode={mode} ok.")

# quick check:
smoke_e16_forward("a")
smoke_e16_forward("b")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  self.gen = func(*args, **kwds)


[smoke_e16_forward] mode=a ok.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[smoke_e16_forward] mode=b ok.


In [95]:
#_ = run_e16_sweep(steps=10, seeds=(1337,))


=== E16a (tier A) | seed=1337 ===
[after free] alloc=0.02 GB | reserved=0.03 GB | free=24.33 GB | total=25.77 GB
[before E16a-s1337] alloc=0.02 GB | reserved=0.03 GB | free=24.33 GB | total=25.77 GB
TB run started: ./runs\dncformer-20250823-220057-E16a-s1337


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  self.gen = func(*args, **kwds)


  [experts] block0 top=0 pi=[0.998247504234314, 0.0003349798498675227, 0.0014175600372254848]
step 10 | loss 11.0822 | lr 2.00e-05 | gates=[0.0017525398870930076, 0.009336728602647781] | mix=copy
[after  E16a-s1337] alloc=10.70 GB | reserved=28.40 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=10.70 GB | reserved=10.72 GB | free=13.64 GB | total=25.77 GB

=== E16b (tier B) | seed=1337 ===
[after free] alloc=10.70 GB | reserved=10.72 GB | free=13.64 GB | total=25.77 GB
[before E16b-s1337] alloc=10.70 GB | reserved=10.72 GB | free=13.64 GB | total=25.77 GB
TB run started: ./runs\dncformer-20250823-222516-E16b-s1337


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.9990187287330627, 0.00027480811695568264, 0.0007064775563776493]
step 10 | loss 9.0837 | lr 2.00e-05 | gates=[0.0009812857024371624, 0.7871806621551514] | mix=copy
[after  E16b-s1337] alloc=21.38 GB | reserved=39.21 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=21.38 GB | reserved=21.43 GB | free=2.89 GB | total=25.77 GB


In [97]:
_ = run_e16_sweep(steps=1000, seeds=(1337, 2027, 4242))


=== E16a (tier A) | seed=1337 ===
[after free] alloc=21.38 GB | reserved=21.43 GB | free=2.73 GB | total=25.77 GB
[before E16a-s1337] alloc=21.38 GB | reserved=21.43 GB | free=2.73 GB | total=25.77 GB
TB run started: ./runs\dncformer-20250823-223458-E16a-s1337


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  self.gen = func(*args, **kwds)


  [experts] block0 top=0 pi=[0.8429850339889526, 0.0881362333893776, 0.06887868046760559]
step 10 | loss 7.2834 | lr 4.40e-05 | gates=[0.1570149064064026, 0.6872451901435852] | mix=nback
  [experts] block0 top=1 pi=[0.4031856060028076, 0.4203474521636963, 0.1764669418334961]
step 20 | loss 7.0721 | lr 8.40e-05 | gates=[0.5968143939971924, 0.6538575887680054] | mix=copy
  [experts] block0 top=0 pi=[0.9398620128631592, 0.024761836975812912, 0.03537613898515701]
step 30 | loss 6.2750 | lr 1.24e-04 | gates=[0.060137972235679626, 0.7713171243667603] | mix=repeat
  [experts] block0 top=0 pi=[0.9718441367149353, 0.010407068766653538, 0.01774880662560463]
step 40 | loss 5.5585 | lr 1.64e-04 | gates=[0.028155872598290443, 0.7868714332580566] | mix=repeat
  [experts] block0 top=0 pi=[0.9845647811889648, 0.0056892093271017075, 0.009745968505740166]
step 50 | loss 5.8682 | lr 2.00e-04 | gates=[0.015435177832841873, 0.814529299736023] | mix=copy
  [experts] block0 top=0 pi=[0.9901917576789856, 0.00

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.8297238945960999, 0.08063918352127075, 0.0896369218826294]
step 10 | loss 8.6882 | lr 4.40e-05 | gates=[0.17027610540390015, 0.2963612675666809] | mix=repeat
  [experts] block0 top=2 pi=[0.39833226799964905, 0.13602319359779358, 0.4656445384025574]
step 20 | loss 6.4253 | lr 8.40e-05 | gates=[0.6016677618026733, 0.2832692563533783] | mix=nback
  [experts] block0 top=0 pi=[0.6478135585784912, 0.1297212541103363, 0.22246518731117249]
step 30 | loss 5.6363 | lr 1.24e-04 | gates=[0.3521864414215088, 0.8580359220504761] | mix=nback
  [experts] block0 top=0 pi=[0.8499336838722229, 0.10212591290473938, 0.04794043302536011]
step 40 | loss 6.1248 | lr 1.64e-04 | gates=[0.1500663459300995, 0.7026690244674683] | mix=copy
  [experts] block0 top=0 pi=[0.5152212977409363, 0.3792222738265991, 0.1055564135313034]
step 50 | loss 5.9980 | lr 2.00e-04 | gates=[0.4847787022590637, 0.961904764175415] | mix=copy
  [experts] block0 top=0 pi=[0.6330140233039856, 0.03021349012851

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  [experts] block0 top=0 pi=[0.7936932444572449, 0.06786225736141205, 0.13844448328018188]
step 10 | loss 8.3066 | lr 4.40e-05 | gates=[0.20630674064159393, 0.41940373182296753] | mix=repeat
  [experts] block0 top=2 pi=[0.4397725462913513, 0.06371527910232544, 0.49651217460632324]
step 20 | loss 7.2165 | lr 8.40e-05 | gates=[0.5602273941040039, 0.5967902541160583] | mix=copy
  [experts] block0 top=2 pi=[0.29083216190338135, 0.10324284434318542, 0.6059249639511108]
step 30 | loss 5.4364 | lr 1.24e-04 | gates=[0.7091678380966187, 0.18113945424556732] | mix=nback
  [experts] block0 top=0 pi=[0.5081151127815247, 0.013232313096523285, 0.4786525368690491]
step 40 | loss 5.9170 | lr 1.64e-04 | gates=[0.49188485741615295, 0.125884547829628] | mix=nback
  [experts] block0 top=0 pi=[0.6729072332382202, 0.008411949500441551, 0.3186808228492737]
step 50 | loss 6.1366 | lr 2.00e-04 | gates=[0.3270927667617798, 0.17755326628684998] | mix=nback
  [experts] block0 top=2 pi=[0.4645317792892456, 0.02894

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


## 00. Misc.

#### Tensorboard log dump

In [98]:
# ===== TensorBoard event analyzer for DNCFormer runs (robust) =====
# Discovers TB runs, merges multiple event files per run, summarizes, and exports granular CSVs.

import os, re, math, json, time, glob
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from IPython.display import display

# --- TensorBoard imports (version-agnostic handling) ---
try:
    import tensorboard as _tb
    print("TensorBoard version:", getattr(_tb, "__version__", "unknown"))
except Exception as _e:
    print("TensorBoard import note:", _e)

try:
    from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
    from tensorboard.backend.event_processing import event_accumulator as ea_mod
except Exception as e:
    raise RuntimeError("TensorBoard not installed. Install via: pip install tensorboard") from e


# ---------------- helpers ----------------
def _size_guidance_version_safe():
    """Build size_guidance dict usable across TB versions."""
    sg = {}
    keys = ["SCALARS", "HISTOGRAMS", "IMAGES", "COMPRESSED_HISTOGRAMS", "AUDIO", "TENSORS"]
    for k in keys:
        v = getattr(ea_mod, k, None)
        if v is not None:
            sg[v] = 0
        else:
            sg[k.lower()] = 0
    return sg


def _infer_label_from_run_dir(run_dir: Path) -> str:
    """
    Expect 'dncformer-YYYYMMDD-HHMMSS-<LABEL>' or just 'dncformer-YYYYMMDD-HHMMSS'.
    Returns <LABEL> if present, else the directory name.
    """
    name = run_dir.name
    m = re.match(r".*-\d{8}-\d{6}-(.+)$", name)
    if m:
        return m.group(1)
    return name


def _load_scalars_from_event_file(ev_path: str) -> dict:
    """Load scalars from a single event file: tag -> list[(step, value)]."""
    acc = EventAccumulator(ev_path, size_guidance=_size_guidance_version_safe())
    acc.Reload()
    tags = acc.Tags().get('scalars', []) or []
    out = {}
    for tag in tags:
        vals = acc.Scalars(tag)
        out[tag] = [(int(x.step), float(x.value)) for x in vals]
    return out


def _merge_scalar_dicts(list_of_scalar_dicts):
    """
    Merge multiple event files for the same run.
    For each tag: keep the last value seen per (step), then return sorted by step.
    """
    merged = defaultdict(dict)  # tag -> {step: value}
    for scal in list_of_scalar_dicts:
        for tag, series in scal.items():
            d = merged[tag]
            for step, val in series:
                d[step] = val  # 'last wins' is fine; event files are append-only per run
    # Convert to tag -> sorted list[(step, value)]
    out = {}
    for tag, d in merged.items():
        steps_sorted = sorted(d.keys())
        out[tag] = [(s, d[s]) for s in steps_sorted]
    return out


def _detect_tasks_and_blocks(scalars: dict):
    """Detect task names and block ids from present tags. Returns (sorted_tasks, sorted_blocks)."""
    tasks = set()
    blocks = set()

    # tasks from "loss_by_task/<task>"
    for tag in scalars.keys():
        if tag.startswith("loss_by_task/"):
            tasks.add(tag.split("/", 1)[1])

    # blocks from "gates_by_task/block_<b>_mean/<task>" or "gates/block_<b>_mean"
    for tag in scalars.keys():
        m = re.match(r"gates_by_task/block_(\d+)_", tag)
        if m:
            blocks.add(int(m.group(1)))
        m2 = re.match(r"gates/block_(\d+)_mean$", tag)
        if m2:
            blocks.add(int(m2.group(1)))

    # sensible defaults if nothing is found
    if not tasks:
        tasks = {"hf", "copy", "repeat", "nback"}
    if not blocks:
        blocks = {0, 1}

    return sorted(tasks), sorted(blocks)


def s_last(vals, k=None):
    if not vals: return np.nan
    arr = np.array([v for _, v in vals], dtype=float)
    if k is None or k >= len(arr): return float(arr[-1])
    return float(np.nanmean(arr[-k:]))


def s_first(vals, k=None):
    if not vals: return np.nan
    arr = np.array([v for _, v in vals], dtype=float)
    if k is None or k >= len(arr): return float(arr[0])
    return float(np.nanmean(arr[:k]))


def s_mean(vals):
    if not vals: return np.nan
    return float(np.nanmean([v for _, v in vals]))


def s_count(vals):
    return len(vals) if vals else 0


# ---------------- discover runs ----------------
# Option 1: auto-discover all event files and group by their parent directory
FOUND_EVENTS = sorted(glob.glob("runs/**/events.out.tfevents.*", recursive=True))

# Option 2: pin a subset manually
FOUND_EVENTS = [
    r"runs/dncformer-20250823-223458-E16a-s1337/events.out.tfevents.1756013698.Persephone.25868.16",
    r"runs/dncformer-20250824-002307-E16a-s2027/events.out.tfevents.1756020187.Persephone.25868.17",
    r"runs/dncformer-20250824-042729-E16a-s4242/events.out.tfevents.1756034849.Persephone.25868.18",
]

assert FOUND_EVENTS, "No event files found under ./runs. Have you executed any experiments?"

# Group event files by run directory
run_groups = defaultdict(list)   # run_dir_path -> [event_file_paths]
for p in FOUND_EVENTS:
    run_groups[str(Path(p).parent)].append(p)

print(f"Discovered {len(run_groups)} run(s).")


# ---------------- summarize each run ----------------
run_summaries = []
per_run_scalars = {}   # run_dir_name -> merged scalars dict

for run_dir, files in sorted(run_groups.items()):
    run_dir_path = Path(run_dir)
    label = _infer_label_from_run_dir(run_dir_path)
    run_id = run_dir_path.name

    try:
        # load and merge all files for this run
        scal_dicts = [_load_scalars_from_event_file(f) for f in sorted(files)]
        scal = _merge_scalar_dicts(scal_dicts)
        per_run_scalars[run_dir_path.name] = scal

        # detect tasks and blocks present
        TASKS, BLOCKS = _detect_tasks_and_blocks(scal)

        # basics
        loss_series = scal.get("train/loss", [])
        lr_series   = scal.get("train/lr", [])
        steps_logged = max([s for s, _ in loss_series], default=np.nan) if loss_series else np.nan
        loss0   = s_first(loss_series, k=5)
        lossT   = s_last(loss_series,  k=10)
        ldelta  = (loss0 - lossT) if not any(map(math.isnan, [loss0, lossT])) else np.nan
        lr_last = s_last(lr_series, k=1)

        # gates (global)
        g_means, g_entropy, g_frac_avg = {}, {}, {}
        for b in BLOCKS:
            g_means[b]   = s_last(scal.get(f"gates/block_{b}_mean", []), k=10)
            g_entropy[b] = s_last(scal.get(f"gates/block_{b}_entropy", []), k=10)
            fracs = []
            for t in TASKS:
                tag = f"gates_by_task/block_{b}_frac>0.5/{t}"
                if tag in scal:
                    fracs.append(s_last(scal[tag], k=10))
            g_frac_avg[b] = float(np.nanmean(fracs)) if fracs else np.nan

        # per-task losses (mean of last quarter of points for that task)
        task_loss_last, task_counts = {}, {}
        for t in TASKS:
            ts = scal.get(f"loss_by_task/{t}", [])
            task_counts[t] = s_count(ts)
            if ts:
                k = max(1, len(ts)//4)
                task_loss_last[t] = s_last(ts, k=k)
            else:
                task_loss_last[t] = np.nan

        # per-task gate means (avg last quarter)
        task_gmeans = {t: {} for t in TASKS}
        for t in TASKS:
            for b in BLOCKS:
                ts = scal.get(f"gates_by_task/block_{b}_mean/{t}", [])
                if ts:
                    k = max(1, len(ts)//4)
                    task_gmeans[t][b] = s_last(ts, k=k)
                else:
                    task_gmeans[t][b] = np.nan

        # quartiles (block0) if present
        q_means = {}
        for qi in range(1, 5):
            vals = []
            # May be logged globally or by task; check both
            tag_global = f"gates/block0_q{qi}_mean"
            if tag_global in scal:
                vals.append(s_last(scal[tag_global], k=10))
            else:
                for t in TASKS:
                    tag_task = f"gates/block0_q{qi}_mean/{t}"
                    if tag_task in scal:
                        vals.append(s_last(scal[tag_task], k=10))
            q_means[qi] = float(np.nanmean(vals)) if vals else np.nan

        # haystack eval if present
        hay_acc  = s_last(scal.get("eval/haystack_acc",  []), k=1)
        hay_loss = s_last(scal.get("eval/haystack_loss", []), k=1)

        # forced-g guess heuristic
        forced_guess = None
        gm_all = [g_means[b] for b in g_means if not math.isnan(g_means[b])]
        if gm_all:
            m = float(np.nanmean(gm_all))
            if m < 0.02:   forced_guess = "force_g=0"
            elif m > 0.98: forced_guess = "force_g=1"

        summary = {
            "label": label,
            "run_id": run_id,
            "run_dir": str(run_dir_path),
            "n_event_files": len(files),
            "steps_logged": steps_logged,
            "loss_start~5": loss0,
            "loss_end~10":  lossT,
            "loss_delta":   ldelta,
            "lr_last":      lr_last,
            "haystack_acc_last":  hay_acc,
            "haystack_loss_last": hay_loss,
            "forced_guess": forced_guess,
        }

        # flatten gate summaries
        for b in BLOCKS:
            summary[f"g_mean_b{b}"]     = g_means.get(b, np.nan)
            summary[f"g_entropy_b{b}"]  = g_entropy.get(b, np.nan)
            summary[f"g_frac>0.5_b{b}"] = g_frac_avg.get(b, np.nan)

        # flatten per-task last losses and per-task mean gates
        for t in TASKS:
            summary[f"loss_{t}_last"] = task_loss_last[t]
            for b in BLOCKS:
                summary[f"gmean_b{b}_{t}"] = task_gmeans[t][b]

        # quartiles
        for qi in range(1, 5):
            summary[f"g_b0_Q{qi}_mean"] = q_means[qi]

        run_summaries.append(summary)

    except Exception as e:
        print(f"[analyzer] Skipped run {run_dir}: {e}")

# -------- assemble run-level summary --------
df_runs = pd.DataFrame(run_summaries)
if df_runs.empty:
    print("No runs summarized.")
else:
    # Sort by timestamp embedded in run_id if possible
    def _ts_key(name: str):
        m = re.search(r"(\d{8})-(\d{6})", name or "")
        return (m.group(1), m.group(2)) if m else ("", "")
    df_runs = df_runs.sort_values(by=["run_id"], key=lambda s: s.map(_ts_key), ignore_index=True)

    # Arrange prominent columns
    front_cols = [c for c in [
        "label","run_id","n_event_files","steps_logged",
        "loss_start~5","loss_end~10","loss_delta","lr_last",
        "haystack_acc_last","haystack_loss_last","forced_guess",
        "g_mean_b0","g_mean_b1","g_entropy_b0","g_entropy_b1",
        "g_frac>0.5_b0","g_frac>0.5_b1",
        "loss_hf_last","loss_copy_last","loss_repeat_last","loss_nback_last",
        "gmean_b0_hf","gmean_b1_hf","gmean_b0_copy","gmean_b1_copy",
        "gmean_b0_repeat","gmean_b1_repeat","gmean_b0_nback","gmean_b1_nback",
        "g_b0_Q1_mean","g_b0_Q2_mean","g_b0_Q3_mean","g_b0_Q4_mean",
        "run_dir"
    ] if c in df_runs.columns]
    df_runs = df_runs[[*front_cols, *[c for c in df_runs.columns if c not in front_cols]]]

    display(df_runs)
    out_dir = Path("./analysis"); out_dir.mkdir(parents=True, exist_ok=True)
    df_runs.to_csv(out_dir / "run_level_summary.csv", index=False)
    print("Saved:", out_dir / "run_level_summary.csv")


# -------- granular per-task time series export --------
rows = []
for run_dir_name, scal in per_run_scalars.items():
    label = _infer_label_from_run_dir(Path(run_dir_name))
    # detect tasks/blocks present in this run
    TASKS, BLOCKS = _detect_tasks_and_blocks(scal)

    # pick up LR series for convenient join
    lr_by_step = {int(s): float(v) for s, v in scal.get("train/lr", [])}

    # per-task loss and per-task gate metrics
    for t in TASKS:
        loss_series = scal.get(f"loss_by_task/{t}", [])
        if not loss_series:
            continue

        gmean_by_step = {b: {int(s): float(v) for s, v in scal.get(f"gates_by_task/block_{b}_mean/{t}", [])}
                         for b in BLOCKS}
        gfrac_by_step = {b: {int(s): float(v) for s, v in scal.get(f"gates_by_task/block_{b}_frac>0.5/{t}", [])}
                         for b in BLOCKS}

        # Optional quartile logs (block 0)
        q_by_step = {qi: {int(s): float(v) for s, v in scal.get(f"gates/block0_q{qi}_mean/{t}", [])}
                     for qi in (1,2,3,4)}

        for step, loss_val in loss_series:
            step = int(step); loss_val = float(loss_val)
            for b in BLOCKS:
                row = {
                    "label": label,
                    "run_id": run_dir_name,
                    "task": t,
                    "block": b,
                    "step": step,
                    "loss": loss_val,
                    "g_mean": gmean_by_step[b].get(step, np.nan),
                    "g_frac>0.5": gfrac_by_step[b].get(step, np.nan),
                    "lr": lr_by_step.get(step, np.nan),
                }
                if b == 0:
                    for qi in (1,2,3,4):
                        row[f"g_b0_Q{qi}"] = q_by_step[qi].get(step, np.nan)
                rows.append(row)

df_task_ts = pd.DataFrame(rows)
if df_task_ts.empty:
    print("No per-task series found.")
else:
    df_task_ts = df_task_ts.sort_values(["label","task","step","block"], ignore_index=True)
    display(df_task_ts.head(20))
    out_dir = Path("./analysis"); out_dir.mkdir(parents=True, exist_ok=True)
    df_task_ts.to_csv(out_dir / "per_task_metrics.csv", index=False)
    print("Saved:", out_dir / "per_task_metrics.csv")


TensorBoard version: 2.20.0
Discovered 3 run(s).


Unnamed: 0,label,run_id,n_event_files,steps_logged,loss_start~5,loss_end~10,loss_delta,lr_last,haystack_acc_last,haystack_loss_last,...,gmean_b1_copy,gmean_b0_repeat,gmean_b1_repeat,gmean_b0_nback,gmean_b1_nback,g_b0_Q1_mean,g_b0_Q2_mean,g_b0_Q3_mean,g_b0_Q4_mean,run_dir
0,E16a-s1337,dncformer-20250823-223458-E16a-s1337,1,1000,6.411446,3.316509,3.094937,2e-05,,,...,0.999302,0.035101,0.998633,0.003567,0.999494,0.040112,0.003146,0.001473,0.000781,runs\dncformer-20250823-223458-E16a-s1337
1,E16a-s2027,dncformer-20250824-002307-E16a-s2027,1,1000,6.574523,3.27903,3.295493,2e-05,,,...,0.907705,0.034682,0.865977,0.002666,0.924079,0.05018,0.019344,0.022893,0.025952,runs\dncformer-20250824-002307-E16a-s2027
2,E16a-s4242,dncformer-20250824-042729-E16a-s4242,1,340,6.60265,4.87786,1.724789,0.000157,,,...,0.000307,0.7971,0.000102,0.779932,0.000211,0.632134,0.648412,0.652287,0.650466,runs\dncformer-20250824-042729-E16a-s4242


Saved: analysis\run_level_summary.csv


Unnamed: 0,label,run_id,task,block,step,loss,g_mean,g_frac>0.5,lr,g_b0_Q1,g_b0_Q2,g_b0_Q3,g_b0_Q4
0,E16a-s1337,dncformer-20250823-223458-E16a-s1337,copy,0,20,7.072079,0.596814,0.953125,8.4e-05,0.633502,0.588711,0.582837,0.582207
1,E16a-s1337,dncformer-20250823-223458-E16a-s1337,copy,1,20,7.072079,0.653858,0.99707,8.4e-05,,,,
2,E16a-s1337,dncformer-20250823-223458-E16a-s1337,copy,0,50,5.868248,0.015435,0.014648,0.0002,0.059879,0.000742,0.000558,0.000561
3,E16a-s1337,dncformer-20250823-223458-E16a-s1337,copy,1,50,5.868248,0.814529,1.0,0.0002,,,,
4,E16a-s1337,dncformer-20250823-223458-E16a-s1337,copy,0,110,5.652671,0.026108,0.023438,0.000198,0.101993,0.000887,0.000785,0.000766
5,E16a-s1337,dncformer-20250823-223458-E16a-s1337,copy,1,110,5.652671,0.999999,1.0,0.000198,,,,
6,E16a-s1337,dncformer-20250823-223458-E16a-s1337,copy,0,130,5.628609,0.022214,0.022461,0.000196,0.085761,0.001084,0.001011,0.001001
7,E16a-s1337,dncformer-20250823-223458-E16a-s1337,copy,1,130,5.628609,0.999998,1.0,0.000196,,,,
8,E16a-s1337,dncformer-20250823-223458-E16a-s1337,copy,0,140,6.160719,0.020526,0.019531,0.000196,0.077643,0.001594,0.001462,0.001406
9,E16a-s1337,dncformer-20250823-223458-E16a-s1337,copy,1,140,6.160719,1.0,1.0,0.000196,,,,


Saved: analysis\per_task_metrics.csv


In [63]:
# --- Rename ./runs/* so directory names include the experiment label (from TB text tags) ---
import os, re, glob, time, shutil
from pathlib import Path
from typing import Optional

# Try to close any active writer so files aren't locked on Windows
try:
    if 'tb' in globals() and getattr(tb, "writer", None) is not None:
        tb.flush(); tb.close()
        print("[rename] Closed active TB writer before renaming.")
except Exception as _e:
    print("[rename] Writer close note:", _e)

# TensorBoard event loading (version-agnostic size_guidance)
try:
    from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
    from tensorboard.backend.event_processing import event_accumulator as ea_mod
    try:
        from tensorboard.util import tensor_util as _tb_tensor_util
    except Exception:
        print("failed to load tensorboard utilities, but you'll install tensorflow on this system over my cold dead digital body."
              "\n\nfix your tensorboard installation and try again")
        _tb_tensor_util = None
except Exception as e:
    raise RuntimeError("TensorBoard not installed. `pip install tensorboard`") from e

def _size_guidance_version_safe():
    sg = {}
    for k in ["SCALARS","HISTOGRAMS","IMAGES","COMPRESSED_HISTOGRAMS","AUDIO","TENSORS"]:
        v = getattr(ea_mod, k, None)
        if v is not None:
            sg[v] = 0
        else:
            sg[k.lower()] = 0
    return sg

def _decode_text_tensor(tensor_proto) -> Optional[str]:
    """Decode TB text summary payload from a TensorEvent.tensor_proto."""
    try:
        if _tb_tensor_util is not None:
            arr = _tb_tensor_util.make_ndarray(tensor_proto)
        else:
            # very old stub fallback
            arr = _tb_make_ndarray(tensor_proto)
        val = arr.item() if arr.size == 1 else arr
        if isinstance(val, bytes):
            return val.decode("utf-8", "replace")
        if isinstance(val, str):
            return val
        # Some TB builds wrap a bytes array inside a 2D array
        if hasattr(val, "dtype") and str(val.dtype).startswith("|S"):
            return val.tobytes().decode("utf-8", "replace")
    except Exception:
        pass
    return None

def _infer_label_from_text_tags(run_dir: Path) -> Optional[str]:
    """Try run/label first, then run/meta (JSON with a 'label' field), then None."""
    ev_files = sorted(glob.glob(str(run_dir / "events.out.tfevents.*")))
    for ev in reversed(ev_files):
        try:
            acc = EventAccumulator(ev, size_guidance=_size_guidance_version_safe())
            acc.Reload()
            tags_t = acc.Tags().get("tensors", []) or []

            # 1) Preferred: run/label
            if "run/label" in tags_t:
                tens = acc.Tensors("run/label")
                for e in reversed(tens):
                    txt = _decode_text_tensor(e.tensor_proto)
                    if txt:
                        return txt.strip()

            # 2) Fallback: run/meta (JSON with label)
            if "run/meta" in tags_t:
                tens = acc.Tensors("run/meta")
                for e in reversed(tens):
                    txt = _decode_text_tensor(e.tensor_proto)
                    if txt:
                        txt = txt.strip()
                        # Sometimes add_text wraps in small HTML; tolerate raw JSON and simple strings
                        m = re.search(r"\{.*\}", txt, flags=re.S)
                        if m:
                            import json
                            try:
                                meta = json.loads(m.group(0))
                                if isinstance(meta, dict) and "label" in meta and meta["label"]:
                                    return str(meta["label"]).strip()
                            except Exception:
                                pass
                        # If it's just a string, return it
                        if txt and txt[0] not in "{<":
                            return txt
        except Exception:
            continue
    return None

def _infer_label_from_dirname(run_dir: Path) -> Optional[str]:
    """If the dir already has a '-<label>' suffix after timestamp, return that label."""
    m = re.match(r".*-\d{8}-\d{6}-(.+)$", run_dir.name)
    return m.group(1) if m else None

def _base_prefix(run_dir: Path) -> str:
    """Return 'dncformer-YYYYMMDD-HHMMSS' part if present, else the full name."""
    m = re.match(r"(.*-\d{8}-\d{6})(?:-.+)?$", run_dir.name)
    return m.group(1) if m else run_dir.name

def _slugify(s: str) -> str:
    return re.sub(r"[^A-Za-z0-9_.-]+", "_", s.strip())[:80] or "unlabeled"

def _is_active_run(run_dir: Path, seconds: int = 60) -> bool:
    """Heuristic: if any event file mtime is within the last `seconds`."""
    now = time.time()
    for ev in glob.glob(str(run_dir / "events.out.tfevents.*")):
        try:
            if now - os.path.getmtime(ev) < seconds:
                return True
        except Exception:
            pass
    return False

def rename_runs_by_label(log_root: str = "./runs", dry_run: bool = True,
                         skip_active_secs: int = 60) -> list[tuple[str, str]]:
    """
    Rename run directories under `log_root` to include label suffix (from TB text tag).
    Returns list of (old_path, new_path) actually renamed.
    """
    log_root = Path(log_root)
    renamed = []

    for run_dir in sorted([p for p in log_root.iterdir() if p.is_dir()]):
        # Already labeled?
        existing_label = _infer_label_from_dirname(run_dir)
        # Attempt tag-based label
        tag_label = _infer_label_from_text_tags(run_dir)

        label = tag_label or existing_label or "unlabeled"
        label_slug = _slugify(label)

        base = _base_prefix(run_dir)
        target = log_root / f"{base}-{label_slug}"

        # Skip if it's already the desired name
        if run_dir == target:
            print(f"[rename] OK (already labeled): {run_dir.name}")
            continue

        # Skip if run appears active
        if skip_active_secs and _is_active_run(run_dir, skip_active_secs):
            print(f"[rename] SKIP active (mtime<{skip_active_secs}s): {run_dir.name}")
            continue

        # Avoid collisions: add -v2, -v3, ...
        cand = target
        k = 2
        while cand.exists():
            cand = log_root / f"{base}-{label_slug}-v{k}"
            k += 1

        print(f"[rename] {run_dir.name}  ->  {cand.name}")
        if not dry_run:
            try:
                run_dir.replace(cand)
                renamed.append((str(run_dir), str(cand)))
            except Exception as e:
                print(f"[rename] FAILED: {run_dir} -> {cand}: {e}")

    return renamed

# --- Usage examples ---
# 1) Dry run (see planned changes)
_ = rename_runs_by_label("./runs", dry_run=True, skip_active_secs=60)

# 2) Execute renames
# _ = rename_runs_by_label("./runs", dry_run=False, skip_active_secs=60)
# print("Renamed:", _)


[rename] Closed active TB writer before renaming.
[rename] dncformer-20250817-184608  ->  dncformer-20250817-184608-unlabeled
[rename] dncformer-20250817-192155  ->  dncformer-20250817-192155-unlabeled
[rename] dncformer-20250817-193701  ->  dncformer-20250817-193701-unlabeled
[rename] dncformer-20250817-195856  ->  dncformer-20250817-195856-unlabeled
[rename] dncformer-20250817-201426  ->  dncformer-20250817-201426-unlabeled
[rename] dncformer-20250817-203016  ->  dncformer-20250817-203016-unlabeled
[rename] dncformer-20250820-093245  ->  dncformer-20250820-093245-unlabeled
[rename] OK (already labeled): dncformer-20250820-143212-E6_temp0p8_seed1337
[rename] OK (already labeled): dncformer-20250820-151404-E6_temp0p8_seed2027
[rename] OK (already labeled): dncformer-20250820-160508-E6_temp0p8_seed4242
[rename] OK (already labeled): dncformer-20250820-164415-E7_warmstart_seed1337
[rename] OK (already labeled): dncformer-20250820-174016-E7_warmstart_seed2027
[rename] OK (already labeled)

#### Notes & TODO
- The DNC memory here is **compact** and intended for research iteration; you can swap in a fuller reference implementation if desired. [x]
- The controller currently runs in **sequence mode** with a causal mask
- Training uses a tiny **instruction-following set** plus synthetic memory tasks.
- Gating is **vector-valued** with bias init favoring the vanilla path; metrics log mean gate values.
- Use `CFG.n_blocks` to grow the enrichment depth as VRAM allows.
