
# DNCformer: Parallel-Enrichment Transformer–DNC (Notebook Prototype)

This notebook implements a **parallel enrichment** architecture that adds a **Transformer-style DNC block** alongside a standard Transformer block, with a learned **gating** to mix their outputs. It wires on top of a **frozen ~4B LLM** (Phi-3-mini-4k-instruct by default), and provides **lightweight train/eval** loops and **unit-like tests**.

**Hardware:** designed for a single GPU (e.g., RTX 3090 24GB) using AMP (`bf16` if available, otherwise `fp16`).  
**Structure:** Config → Utils → DNC Memory → Transformer Controller → DNCformer Block → Parallel Enrichment → Frozen Base + N Blocks → Data → Train → Eval → Tests.


In [1]:
import sys, platform, torch
print("python exe:", sys.executable)
print("torch file:", torch.__file__)
print("torch ver:", torch.__version__, "| torch.version.cuda:", torch.version.cuda)
print("cuda available:", torch.cuda.is_available(), "| device count:", torch.cuda.device_count())

# SDPA configured via SDPA_CTX() above


python exe: C:\Users\andre\miniconda3\envs\dncformer\python.exe
torch file: C:\Users\andre\miniconda3\envs\dncformer\lib\site-packages\torch\__init__.py
torch ver: 2.5.1 | torch.version.cuda: 12.6
cuda available: True | device count: 1


## 1. Config & Environment

In [2]:

# --- Configuration ---
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any

@dataclass
class Config:
    base_model_id: str = "microsoft/Phi-3-mini-4k-instruct"  # ~3.8B
    d_model: Optional[int] = None          # If None, infer from base model hidden size
    n_blocks: int = 2                      # number of parallel enrichment blocks
    attn_heads: int = 8                    # heads in DNC controller
    attn_dropout: float = 0.1
    ffn_mult: float = 4.0
    dnc_read_heads: int = 2
    dnc_cell_size: int = 64                # memory slot width
    dnc_nr_cells: int = 256                # number of memory slots
    gate_bias_init: float = -1.0           # bias to prefer transformer at init
    lr: float = 2e-4                       
    weight_decay: float = 0.01
    max_seq_len: int = 1024                # training seq length
    train_steps: int = 200                 # small sanity pass
    warmup_steps: int = 20
    grad_clip: float = 1.0
    precision: str = "bf16"                # "bf16" | "fp16" | "fp32"
    use_torch_compile: bool = False
    device: str = "cuda"
    log_every: int = 10
    batch_size: int = 8

CFG = Config()

import os, math, random, torch
from torch import nn, Tensor
from torch.nn import functional as F

seed = 42
random.seed(seed); torch.manual_seed(seed)

device = torch.device(CFG.device if torch.cuda.is_available() else "cpu")
print("Device:", device, "CUDA:", torch.cuda.is_available())

amp_dtype = None
if CFG.precision == "bf16" and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
    amp_dtype = torch.bfloat16
elif CFG.precision == "fp16" and torch.cuda.is_available():
    amp_dtype = torch.float16
else:
    amp_dtype = torch.float32

print("AMP dtype:", amp_dtype)

Device: cuda CUDA: True
AMP dtype: torch.bfloat16


In [3]:
# --- Config patch toggles ---
try:
    CFG
except NameError:
    class _Tmp: pass
    CFG = _Tmp()
# safe defaults if not already present
if not hasattr(CFG, 'batch_size'): CFG.batch_size = 8
if not hasattr(CFG, 'gate_reg_lambda'): CFG.gate_reg_lambda = 0.0   # only applied on memory-tagged batches
if not hasattr(CFG, 'hist_every'): CFG.hist_every = 200             # histogram cadence
if not hasattr(CFG, 'force_g'): CFG.force_g = None                  # None, or 0.0 or 1.0
if not hasattr(CFG, 'gate_temp'): CFG.gate_temp = 1.0


In [4]:
# --- SDPA selection (prefer PyTorch SDPA; avoid flash-attn) ---
import contextlib
def sdpa_ctx():
    """Return a fresh attention-kernel selection context each time it's called.
    Uses PyTorch SDPA (math + mem-efficient) and disables flash-attn to avoid warnings.
    """
    try:
        from torch.backends.cuda import sdp_kernel  # callable context manager in PyTorch 2.x
        return sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True)
    except Exception:
        return contextlib.nullcontext()


## 2. Utilities

In [5]:
_GLOBAL = {}

def causal_mask(sz: int, device=None):
    return torch.full((sz, sz), float("-inf"), device=device).triu(1)

def count_params(model: nn.Module):
    return sum(p.numel() for p in model.parameters())

def requires_grad_(module: nn.Module, flag: bool):
    for p in module.parameters():
        p.requires_grad_(flag)
    return module

def print_shapes(**tensors):
    for k, v in tensors.items():
        if isinstance(v, (list, tuple)):
            print(k, [tuple(x.shape) for x in v])
        elif isinstance(v, torch.Tensor):
            print(k, tuple(v.shape))
        else:
            print(k, type(v))

def get_or_build_model_and_tokenizer():
    mid = CFG.base_model_id
    if _GLOBAL.get("head") is not None and _GLOBAL.get("mid") == mid:
        # Reuse the already-loaded GPU model
        return _GLOBAL["tok"], _GLOBAL["head"]
    # Otherwise build fresh
    tok, base = load_base_model(mid)             # your existing GPU-only loader
    head = DNCFormerHead(base, CFG).to(device)   # enrichment head on CUDA
    _GLOBAL.update({"tok": tok, "head": head, "mid": mid})
    return tok, head


## 3. DNC Memory (compact reference implementation)

In [6]:

class DNCMemory(nn.Module):
    """
    Compact DNC memory:
    - memory M: (B, N, W)
    - usage u: (B, N)
    - link L: (B, N, N) temporal links
    - precedence p: (B, N)
    - read weights rw: (B, R, N)
    - write weights ww: (B, N)
    - read vectors r: (B, R, W)
    """
    def __init__(self, nr_cells: int, cell_size: int, read_heads: int):
        super().__init__()
        self.probe = None  # optional callable to record per-step state
        self.N = nr_cells
        self.W = cell_size
        self.R = read_heads

    def reset(self, B: int, device=None):
        device = device or next(self.parameters(), torch.empty(0, device="cpu")).device
        M = torch.zeros(B, self.N, self.W, device=device)
        u = torch.zeros(B, self.N, device=device)
        L = torch.zeros(B, self.N, self.N, device=device)
        p = torch.zeros(B, self.N, device=device)
        rw = F.one_hot(torch.zeros(B, self.R, dtype=torch.long, device=device), num_classes=self.N).float()
        r = torch.zeros(B, self.R, self.W, device=device)
        return {"M": M, "u": u, "L": L, "p": p, "rw": rw, "r": r}

    @staticmethod
    def _cosine_sim(M: torch.Tensor, k: torch.Tensor, eps=1e-6):
        # M: (B, N, W), k: (B, W) or (B, R, W)
        if k.dim() == 2:
            k = k.unsqueeze(1)  # (B,1,W)
        B, R, W = k.shape
        Mnorm = F.normalize(M, p=2, dim=-1)
        knorm = F.normalize(k, p=2, dim=-1)
        sim = torch.einsum("bnw,brw->brn", Mnorm, knorm)  # (B, R, N)
        return sim

    def _allocation(self, u: torch.Tensor):
        # u: (B, N) in [0,1]
        δ = 1e-6
        u = δ + (1 - δ) * u                 # avoid tiny values before cumprod
        B, N = u.shape
        # sort ascending usage -> free list φ
        sorted_u, phi = torch.sort(u, dim=-1, descending=False)  # (B,N)
        # exclusive cumprod of sorted_u
        ones = torch.ones(B, 1, device=u.device, dtype=u.dtype)
        prod_excl = torch.cumprod(torch.cat([ones, sorted_u], dim=1), dim=1)[:, :-1]  # (B,N)
        a_sorted = (1 - sorted_u) * prod_excl                                        # (B,N)
        # invert the sort to original order
        inv_phi = torch.argsort(phi, dim=-1)
        a = a_sorted.gather(1, inv_phi)                                              # (B,N)
        return a


    def forward(self, x_if: dict, state: dict):
        """
        x_if (interface dict) must contain:
        - k_read: (B,R,W), beta_read: (B,R,1)
        - k_write: (B,W), beta_write:(B,1)
        - erase: (B,W) in (0,1), write_vec: (B,W)
        - free_gates: (B,R,1) in (0,1), alloc_gate:(B,1), write_gate:(B,1) in (0,1)
        - read_mode: (B,R,3) softmax over {backward, content, forward}
        """
        M, u, L, p, rw, r = state["M"], state["u"], state["L"], state["p"], state["rw"], state["r"]
        B, N, W = M.shape

        # --- Usage update (faithful, stable) ---
        # previous write weights; keep as (B,1,N) for easy broadcasting
        ww_prev = state.get("ww", torch.zeros(M.size(0), 1, self.N, device=M.device, dtype=M.dtype))
        # writes increase usage
        u = u + (1 - u) * (1 - torch.prod(1 - ww_prev, dim=1))   # -> (B,N)
        # free gates release usage at read locations (per-location retention)
        psi = torch.prod(1 - x_if["free_gates"] * rw, dim=1)     # (B,N), since free_gates:(B,R,1), rw:(B,R,N)
        u = torch.clamp(u * psi, 0, 1)


        # 2.1) Write weighting (robust broadcasting) ---
        # sim_w: (B,1,N) if k_write=(B,W); (B,R,N) if k_write=(B,R,W)
        sim_w = self._cosine_sim(M, x_if["k_write"])
        # beta_w: expect (B,1) or (B,R,1); make sure it has the trailing head axis
        beta_w = x_if["beta_write"]
        if beta_w.dim() == 2:           # (B,1) or (B,R) -> add trailing axis
            beta_w = beta_w.unsqueeze(-1)  # -> (B,1,1) or (B,R,1)
        # content weights over memory locations
        cw = F.softmax(sim_w * beta_w, dim=-1)  # (B,1,N) or (B,R,N)
        # canonical DNC: single write head; if multiple heads exist, reduce over heads
        if cw.size(1) > 1:
            cw = cw.mean(dim=1)                # -> (B,N)  (alternatives: sum or a learned reduce)
        else:
            cw = cw.squeeze(1)                 # -> (B,N)
        # allocation weights from usage (B,N)
        a = self._allocation(u)                # (B,N)
        # interpolate content vs allocation via alloc_gate, then apply write_gate
        alloc = x_if["alloc_gate"]             # (B,1)
        write_gate = x_if["write_gate"]        # (B,1)
        # Broadcast (B,1) over N
        ww = write_gate * (alloc * a + (1.0 - alloc) * cw)  # -> (B,N) via broadcasting
        state["ww"] = ww
        
        # 2.2) save current ww as "previous write weights" for t+1
        state["ww_prev"] = ww.unsqueeze(1)   # keep grads for BPTT; use .detach() only if you explicitly want to stop gradients across steps

        # 3) Memory write
        erase = x_if["erase"].unsqueeze(1)  # (B,1,W)
        write_vec = x_if["write_vec"].unsqueeze(1)  # (B,1,W)
        M = M * (1 - ww.unsqueeze(-1) * erase) + ww.unsqueeze(-1) * write_vec

        # 4) Temporal link
        prev_p = p
        p = (1 - ww.sum(dim=-1, keepdim=True)) * p + ww  # precedence
        L = (1 - ww.unsqueeze(2) - ww.unsqueeze(1)) * L + torch.einsum("bn,bm->bnm", prev_p, ww)
        L = L * (1 - torch.eye(N, device=M.device).unsqueeze(0))

        # 5) Read weighting
        cr = F.softmax(self._cosine_sim(M, x_if["k_read"]) * x_if["beta_read"], dim=-1)  # (B,R,N)
        fwd = torch.einsum("brn,bnm->brm", rw, L)       # (B,R,N) forward
        bwd = torch.einsum("brn,bmn->brm", rw, L)       # (B,R,N) backward
        read_mode = F.softmax(x_if["read_mode"], dim=-1)  # (B,R,3)
        rw = read_mode[:,:,0:1]*bwd + read_mode[:,:,1:2]*cr + read_mode[:,:,2:3]*fwd
        r = torch.einsum("brn,bnw->brw", rw, M)  # (B,R,W)

        state = {"M": M, "u": u, "L": L, "p": p, "rw": rw, "r": r}
        # aggregated lightweight stats (per-step)
        try:
            _stats = {}
            with torch.no_grad():
                _stats["u_mean"] = state["u"].mean().detach()
                try:
                    _stats["M_norm_mean"] = state["M"].norm(dim=-1).mean().detach()
                except Exception:
                    pass
                try:
                    _stats["rw_max_mean"] = rw.max(dim=-1).values.mean().detach()
                except Exception:
                    pass
                try:
                    _stats["ww_max_mean"] = ww.max(dim=-1).values.mean().detach()
                except Exception:
                    pass
            state["stats"] = _stats
        except Exception:
            pass
        
        if self.probe is not None:
            M = state["M"].detach().float().cpu()
            u = state["u"].detach().float().cpu()
            L = state["L"].detach().float().cpu()
            rw_cpu = rw.detach().float().cpu()
            ww_cpu = state.get("ww", torch.zeros_like(state["u"])).detach().float().cpu()
            self.probe(
                {
                "u": u,                         # (B,N)
                "ww": ww_cpu,                   # (B,N)
                "rw": rw_cpu,                   # (B,R,N)
                "M_norm": M.norm(dim=-1),       # (B,N)
                "L_diag_mean": torch.diagonal(L, dim1=-2, dim2=-1).mean(dim=-1), # (B,)
                }
            )
        return r, state


## 4. Transformer-style Controller (sequence mode)

In [7]:

class TransformerController(nn.Module):
    """
    Lightweight Transformer encoder producing the DNC interface vector.
    Inputs: X (B, T, d_in); prev_reads typically concatenated to X before calling.
    """
    def __init__(self, d_in: int, d_model: int, heads: int, dropout: float=0.1, ffn_mult: float=4.0):
        super().__init__()
        self.proj_in = nn.Linear(d_in, d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, int(d_model*ffn_mult)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(d_model*ffn_mult), d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
        h = self.proj_in(x)
        h = self.ln1(h)
        attn_out, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False)
        h = h + self.dropout(attn_out)
        h = self.ln2(h)
        h2 = self.ff(h)
        h = h + self.dropout(h2)
        return h


## 5. DNCformer Block (controller → interface → memory)

In [8]:

class DNCInterfaceHead(nn.Module):
    """
    Projects controller hidden to DNC interface:
    read keys (R*W), read strengths (R), write key (W), write strength (1),
    erase (W in (0,1)), write vector (W), free_gates (R in (0,1)),
    alloc_gate (1), write_gate (1), read_mode (R*3 softmax).
    """
    def __init__(self, d_model: int, R: int, W: int):
        super().__init__()
        self.R, self.W = R, W
        out = R*W + R + W + 1 + W + W + R + 1 + 1 + R*3
        self.proj = nn.Linear(d_model, out)

    def forward(self, h: torch.Tensor):
        B, T, D = h.shape
        v = self.proj(h)  # (B,T,out)
        idx = 0
        def take(sz): 
            nonlocal idx
            part = v[..., idx:idx+sz]; idx += sz; return part
        R, W = self.R, self.W
        k_read = take(R*W).view(B,T,R,W)
        beta_read = F.softplus(take(R)).view(B,T,R,1)
        k_write = take(W).view(B,T,W)
        beta_write = F.softplus(take(1)).view(B,T,1)
        erase = torch.sigmoid(take(W)).view(B,T,W)
        write_vec = take(W).view(B,T,W)
        free_gates = torch.sigmoid(take(R)).view(B,T,R,1)
        alloc_gate = torch.sigmoid(take(1)).view(B,T,1)
        write_gate = torch.sigmoid(take(1)).view(B,T,1)
        read_mode = take(R*3).view(B,T,R,3)
        return {
            "k_read": k_read, "beta_read": beta_read,
            "k_write": k_write, "beta_write": beta_write,
            "erase": erase, "write_vec": write_vec,
            "free_gates": free_gates, "alloc_gate": alloc_gate,
            "write_gate": write_gate, "read_mode": read_mode
        }

class DNCformerBlock(nn.Module):
    def __init__(self, d_in: int, d_model: int, R: int, W: int, N: int, heads: int, dropout: float, ffn_mult: float):
        super().__init__()
        self.R, self.W, self.N = R, W, N
        self.ctrl = TransformerController(d_in + R*W, d_model, heads=heads, dropout=dropout, ffn_mult=ffn_mult)
        self.if_head = DNCInterfaceHead(d_model, R=R, W=W)
        self.mem = DNCMemory(nr_cells=N, cell_size=W, read_heads=R)
        self.out_proj = nn.Linear(d_model + R*W, d_model)  # fuse controller + reads

    def forward(self, x: torch.Tensor, state: Optional[dict]=None):
        # x: (B,T,d_in); state carries memory fields; if None -> reset
        B, T, D = x.shape
        if state is None:
            state = self.mem.reset(B, device=x.device)
        reads = state["r"].reshape(B, self.R*self.W)  # (B,RW)
        reads_seq = reads.unsqueeze(1).expand(B, T, self.R*self.W)
        ctrl_in = torch.cat([x, reads_seq], dim=-1)  # concat
        h = self.ctrl(ctrl_in, attn_mask=causal_mask(T, device=x.device))  # (B,T,d_model)

        # step over time for memory I/O
        r_list = []
        new_state = state
        iface = self.if_head(h)
        for t in range(T):
            x_if = {k: v[:,t] for k,v in iface.items()}
            r_t, new_state = self.mem(x_if, new_state)
            r_list.append(r_t)
        Rseq = torch.stack(r_list, dim=1)  # (B,T,R,W)
        reads_flat = Rseq.reshape(B,T,self.R*self.W)
        fused = torch.cat([h, reads_flat], dim=-1)
        y = self.out_proj(fused)  # (B,T,d_model)
        return y, new_state


## 6. Parallel Enrichment Block (Transformer path ‖ DNCformer path + gating)

In [9]:

class VanillaTransformerBlock(nn.Module):
    def __init__(self, d_model: int, heads: int, dropout: float=0.1, ffn_mult: float=4.0):
        super().__init__()
        self.collect_metrics = False
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, int(d_model*ffn_mult)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(d_model*ffn_mult), d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]=None, gate_override: None = None):
        h = self.ln1(x)
        a, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False)
        h = x + self.dropout(a)
        z = self.ln2(h)
        z2 = self.ff(z)
        return h + self.dropout(z2)

class ParallelEnrichmentBlock(nn.Module):
    def __init__(self, d_model: int, d_in: int, R: int, W: int, N: int, heads: int, dropout: float, ffn_mult: float, gate_bias_init: float):
        super().__init__()
        self.collect_metrics = False
        self.vanilla = VanillaTransformerBlock(d_model, heads, dropout, ffn_mult)
        self.dncblock = DNCformerBlock(d_in=d_in, d_model=d_model, R=R, W=W, N=N, heads=heads, dropout=dropout, ffn_mult=ffn_mult)
        self.pre_gate_ln = nn.LayerNorm(2*d_model)
        self.gate = nn.Linear(2*d_model, d_model)
        nn.init.constant_(self.gate.bias, gate_bias_init)

    def forward(self, x: torch.Tensor, dnc_state=None, gate_override: None = None):
        # Both branches consume the same x (B,T,d_model) and produce (B,T,d_model)
        T = x.size(1)
        mask = causal_mask(T, device=x.device)
        vt = self.vanilla(x, attn_mask=mask)
        dt, dnc_state = self.dncblock(x, state=dnc_state)
        z = torch.cat([vt, dt], dim=-1)
        g = torch.sigmoid(self.gate(self.pre_gate_ln(z)))
        out = g*dt + (1-g)*vt
        
        # collect per-block metrics if requested
        if self.collect_metrics:
            try:
                import math, torch as _t
                eps = 1e-6
                g_clamp = g.clamp(min=eps, max=1-eps)
                g_entropy = (-(g_clamp*_t.log(g_clamp) + (1-g_clamp)*_t.log(1-g_clamp))).mean()
                vt_norm = vt.norm(dim=-1).mean()
                dt_norm = dt.norm(dim=-1).mean()
                _stats = dnc_state.get("stats", {}) if isinstance(dnc_state, dict) else {}
                # ensure CPU-detached tiny tensors
                self.last_metrics = {
                    "g_mean": g.mean().detach(),
                    "g_entropy": g_entropy.detach(),
                    "vt_norm": vt_norm.detach(),
                    "dt_norm": dt_norm.detach(),
                    **_stats
                }
            except Exception:
                self.last_metrics = None
        else:
            self.last_metrics = None
            return out, dnc_state, g 


## 7. Frozen Base LLM + N Enrichment Blocks

In [10]:
from transformers import AutoModelForCausalLM, AutoTokenizer

def spda_ctx():
    try:
        from torch.backends.cuda import sdp_kernel
        return sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True)
    except Exception as e:
        return f"error {e} encountered ->\n{contextlib.nullcontext()}"

def load_base_model(model_id: str):
    tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16) if amp_dtype!=torch.float32 else None,
        device_map=None,
        trust_remote_code=True,
        attn_implementation="sdpa",
    ).to("cuda")
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    requires_grad_(model, False)
    model.config.output_hidden_states = True
    model.config.use_cache = False
    return tok, model

class DNCFormerHead(nn.Module):
    def __init__(self, base: AutoModelForCausalLM, cfg):
        super().__init__()
        self.base = base
        d_model = base.config.hidden_size if cfg.d_model is None else cfg.d_model
        self.d_model = d_model
        self.blocks = nn.ModuleList([
            ParallelEnrichmentBlock(
                d_model=d_model, d_in=d_model,
                R=cfg.dnc_read_heads, W=cfg.dnc_cell_size, N=cfg.dnc_nr_cells,
                heads=cfg.attn_heads, dropout=cfg.attn_dropout,
                ffn_mult=cfg.ffn_mult, gate_bias_init=cfg.gate_bias_init
            ) for _ in range(cfg.n_blocks)
        ])
        self.proj_out = nn.Identity()

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]=None):
        with torch.no_grad():
            with spda_ctx():
                out = self.base(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False)
        h = out.hidden_states[-1]  # (B,T,d_model)
        dnc_states = [None]*len(self.blocks)
        gates = []
        for i, blk in enumerate(self.blocks):
            h, dnc_states[i], g = blk(h, dnc_state=dnc_states[i])
            gates.append(g.detach())
        logits = self.base.lm_head(self.proj_out(h).to(self.base.lm_head.weight.dtype))
        return logits, gates



    def forward_with_metrics(self, input_ids: torch.Tensor, attention_mask: "Optional[torch.Tensor]" = None,
                             gate_override: "Optional[float]" = None):
        with torch.no_grad():
            with sdpa_ctx():
                out = self.base(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, use_cache=False)
        h = out.hidden_states[-1].to(device)  # ensure CUDA
        dnc_states = [None]*len(self.blocks)
        gates_det = []
        gates_raw = []
        per_block = []
        # enable metrics collection per block
        for blk in self.blocks:
            if hasattr(blk, "collect_metrics"):
                blk.collect_metrics = True
        for i, blk in enumerate(self.blocks):
            h, dnc_states[i], g = blk(h, dnc_state=dnc_states[i], gate_override=gate_override)
            gates_raw.append(g)
            gates_det.append(g.detach())
            if hasattr(blk, "last_metrics") and blk.last_metrics is not None:
                per_block.append(blk.last_metrics)
            else:
                per_block.append({})
            # reset flag to avoid overhead elsewhere
            if hasattr(blk, "collect_metrics"):
                blk.collect_metrics = False
        # lm head on its own device, then back to CUDA
        lm_dev = self.base.lm_head.weight.device
        y = self.proj_out(h).to(lm_dev, dtype=self.base.lm_head.weight.dtype)
        logits = self.base.lm_head(y).to(device)
        aux = {"per_block": per_block, "gates_raw": gates_raw, "gates_detached": gates_det}
        return logits, gates_det, aux


## 8. Data: Synthetic tasks + simple instruction-following

In [11]:

def make_copy_task(batch, T, vocab=50):
    x = torch.randint(5, vocab, (batch, T))
    return x

def make_reverse_task(batch, T, vocab=50):
    x = torch.randint(5, vocab, (batch, T))
    y = torch.flip(x, dims=[1])
    return x, y

def make_needle_task(batch, T, needle_len=5, vocab=100):
    x = torch.randint(5, vocab, (batch, T))
    for b in range(batch):
        start = random.randint(0, T-needle_len-5)
        needle = torch.randint(5, vocab, (needle_len,))
        x[b, start:start+needle_len] = needle
        x[b, -1] = needle[0]
    return x

INSTR_PAIRS = [
    ("Reverse the string: abcd", "dcba"),
    ("Add two numbers: 7 + 12", "19"),
    ("Instruction: say hello", "hello"),
    ("Uppercase this: cat", "CAT"),
]

def tokenize_instruction_pairs(tok, pairs, max_len):
    texts = [f"### Instruction:\n{p}\n\n### Response:\n" for p,_ in pairs]
    labels = [ans for _,ans in pairs]
    input_ids = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=max_len).input_ids
    label_ids = tok(labels, return_tensors="pt", padding=True, truncation=True, max_length=max_len).input_ids
    # naive pack: just use inputs; real packing/labels can be elaborated
    return input_ids, label_ids


## 9. Training loop (lightweight)

In [12]:

# --- Synthetic generators + HF dataset integration + MixtureSampler (with tagging) ---
import random, torch
from typing import Optional, List, Tuple

def make_repeat_copy(batch: int, T: int, repeat_min=2, repeat_max=4, vocab=100, pad_id: int = 0, device: str = "cpu") -> torch.Tensor:
    L = max(1, T // 2)
    x = torch.randint(1, vocab, (batch, L), device=device, dtype=torch.long)
    r = torch.randint(repeat_min, repeat_max + 1, (batch,), device=device)
    out = torch.full((batch, T), pad_id, dtype=torch.long, device=device)
    for i in range(batch):
        seq = x[i].repeat_interleave(int(r[i].item()))
        out[i, :min(T, seq.numel())] = seq[:T]
    return out

def make_n_back(batch: int, T: int, n: int = 3, vocab=50) -> torch.Tensor:
    return torch.randint(1, vocab, (batch, T))

def format_instruction(tok, instr: str, resp: str, max_len=256) -> torch.Tensor:
    prompt = f"### Instruction:\n{instr}\n\n### Response:\n{resp}"
    return tok(prompt, return_tensors="pt", truncation=True, max_length=max_len).input_ids[0]

def hf_instruction_loader(dataset_name="tatsu-lab/alpaca", split="train", text_field=("instruction","output"), max_items=5000):
    try:
        from datasets import load_dataset
    except Exception:
        print("Install 'datasets' to enable HF loading: pip install datasets -q")
        return []
    ds = load_dataset(dataset_name, split=split)
    pairs = []
    i_field, o_field = text_field
    for ex in ds:
        instr = ex.get(i_field, ""); out = ex.get(o_field, "")
        if instr and out: pairs.append((instr, out))
        if len(pairs) >= max_items: break
    random.shuffle(pairs); return pairs

def make_hf_batch(tok, pairs: List[Tuple[str,str]], batch: int, max_len=256) -> torch.Tensor:
    if not pairs:
        return torch.full((batch, max_len), tok.pad_token_id, dtype=torch.long)
    batch_ids = []
    for _ in range(batch):
        instr, out = random.choice(pairs)
        ids = format_instruction(tok, instr, out, max_len=max_len)
        batch_ids.append(ids)
    maxL = min(max(x.size(0) for x in batch_ids), max_len)
    out_ids = torch.full((batch, maxL), tok.pad_token_id, dtype=torch.long)
    for i, ids in enumerate(batch_ids):
        ids = ids[:maxL]; out_ids[i, :ids.size(0)] = ids
    return out_ids

class MixtureSampler:
    def __init__(self, gens: List, weights: List[float], names: Optional[List[str]] = None):
        self.gens = gens
        import torch as _t
        self.p = _t.tensor(weights, dtype=_t.float)
        self.p /= self.p.sum()
        self.names = names if names is not None else [f"g{i}" for i in range(len(gens))]
        self.last_name = None

    def __call__(self, batch: int) -> torch.Tensor:
        import torch as _t
        idx = _t.multinomial(self.p, 1).item()
        self.last_name = self.names[idx]
        return self.gens[idx](batch)


In [13]:
# --- LR Scheduler: linear warmup -> cosine decay (nonzero start) ---
import math
from torch.optim.lr_scheduler import LambdaLR

def make_warmup_cosine_scheduler(optimizer, warmup_steps: int, total_steps: int, min_lr_ratio: float = 0.10):
    """
    Warms up linearly from 0->1 over warmup_steps; cosine decays from 1->min_lr_ratio for the remainder.
    Uses step_idx+1 to avoid zero LR at the start.
    """
    warmup_steps = max(1, int(warmup_steps))
    total_steps = max(warmup_steps + 1, int(total_steps))

    def lr_lambda(step_idx: int):
        s = step_idx + 1
        if s <= warmup_steps:
            return s / float(warmup_steps)
        progress = (s - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return LambdaLR(optimizer, lr_lambda)


In [14]:
# --- TensorBoard Logger (lightweight) ---
import os, time
from typing import List, Optional
try:
    from torch.utils.tensorboard import SummaryWriter
    TB_AVAILABLE = True
except Exception as e:
    print("TensorBoard not available:", e)
    TB_AVAILABLE = False

class TBLogger:
    def __init__(self, logdir: Optional[str] = None, run_name: Optional[str] = None):
        self.enabled = TB_AVAILABLE
        if not self.enabled:
            self.writer = None
            return
        if logdir is None:
            logdir = "./runs"
        if run_name is None:
            run_name = time.strftime("dncformer-%Y%m%d-%H%M%S")
        self.path = os.path.join(logdir, run_name)
        os.makedirs(self.path, exist_ok=True)
        self.writer = SummaryWriter(self.path)

    def log_scalars(self, step: int, loss: float, lr: float, gate_means: List[float]):
        if not self.enabled: return
        self.writer.add_scalar("train/loss", loss, step)
        self.writer.add_scalar("train/lr", lr, step)
        for i, gm in enumerate(gate_means):
            self.writer.add_scalar(f"gates/block_{i}_mean", gm, step)

    def add_image_hw(self, tag: str, img_hw: "torch.Tensor", step: int):
        if not self.enabled: return
        import torch
        x = img_hw
        if x.device.type != "cpu": x = x.cpu()
        x = x.float()
        if x.numel() > 0:
            m, M = x.min(), x.max()
            x = (x - m) / (M - m + 1e-8)
        x = x.unsqueeze(0)  # [1,H,W]
        self.writer.add_image(tag, x, step, dataformats='CHW')

    def flush(self):
        if self.enabled and self.writer: self.writer.flush()

    def close(self):
        if self.enabled and self.writer: self.writer.close()


def add_histogram(self, tag: str, values: "torch.Tensor", step: int, bins: int = 50):
    if not self.enabled: return
    import torch
    v = values
    if isinstance(v, torch.Tensor):
        v = v.detach()
        if v.device.type != "cpu":
            v = v.cpu()
        v = v.reshape(-1).float()
    self.writer.add_histogram(tag, v, global_step=step, bins=bins)


In [15]:
from torch.amp import GradScaler, autocast

def build_model_and_tokenizer():
    tok, base = load_base_model(CFG.base_model_id)
    if CFG.d_model is None:
        CFG.d_model = base.config.hidden_size
    head = DNCFormerHead(base, CFG).to(device)
    if CFG.use_torch_compile and hasattr(torch, 'compile'):
        head = torch.compile(head)
    #print("Trainable params in head:", count_params(head))
    return tok, head

def make_optimizer(model):
    params = [p for p in model.parameters() if p.requires_grad]
    return torch.optim.AdamW(params, lr=CFG.lr, weight_decay=CFG.weight_decay)

def lm_shift_labels(input_ids, logits, tok):
    labels = input_ids[:, 1:].contiguous()
    logits = logits[:, :-1].contiguous()
    return F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=tok.pad_token_id)

@torch.no_grad()
def evaluate_simple(head, tok):
    head.eval()
    prompt = "### Instruction:\nSay hello in one word\n\n### Response:\n"
    enc = tok(prompt, return_tensors="pt").to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(enc.input_ids)
    print("Gates (mean):", [g.mean().item() for g in gates])
    return gates

def train_small():
    tok, head = get_or_build_model_and_tokenizer()
    optim = make_optimizer(head)
    scheduler = make_warmup_cosine_scheduler(optim, CFG.warmup_steps, CFG.train_steps, min_lr_ratio=0.10)
    head.train()

    tb = TBLogger(logdir="./runs")
    use_scaler = (amp_dtype == torch.float16)
    scaler = GradScaler('cuda', enabled=use_scaler)

    for step in range(1, CFG.train_steps+1):
        in_ids, lab_ids = tokenize_instruction_pairs(tok, INSTR_PAIRS, max_len=min(CFG.max_seq_len, 256))
        in_ids = in_ids.to(device)
        
        with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
            logits, gates = head(in_ids)
            loss = lm_shift_labels(in_ids, logits, tok)

        if use_scaler:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), CFG.grad_clip)
            scaler.step(optim); scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), CFG.grad_clip)
            optim.step()
            
        optim.zero_grad(set_to_none=True)

        if step % CFG.log_every == 0:
            gm = [g.mean().item() for g in gates]
            lr = optim.param_groups[0]['lr']
            print(f"step {step} | loss {loss.item():.4f} | lr {lr:.2e} | gates={gm}")
            tb.log_scalars(step, float(loss.item()), float(lr), gm)
            
    evaluate_simple(head, tok)
    return head, tok

## 10. Eval harness (needle-in-haystack, copy, reverse)

In [16]:
# --- Memory tracer & TensorBoard visualizer ---
from contextlib import contextmanager

class MemoryTracer:
    def __init__(self): self.frames = []
    def __call__(self, frame): self.frames.append(frame)
    def stack(self, key, bidx: int = 0):
        import torch
        return torch.stack([f[key][bidx] for f in self.frames], dim=0)  # [T, ...]

@contextmanager
def trace_memory(module):
    tracer = MemoryTracer()
    memories = [m for m in module.modules() if m.__class__.__name__ == "DNCMemory"]
    for m in memories: m.probe = tracer
    try:
        yield tracer
    finally:
        for m in memories: m.probe = None

@torch.no_grad()
def visualize_memory_tb(head, tok, writer, global_step: int, prompt="### Instruction:\nRemember A then B; later return A.\n\n### Response:\n", max_T=64):
    if writer is None: return
    enc = tok(prompt, return_tensors="pt").to(device)
    enc.input_ids = enc.input_ids[:, :max_T]
    from torch.amp import autocast
    with trace_memory(head) as tracer:
        with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
            _ = head(enc.input_ids)

    u = tracer.stack("u")           # [T, N]
    Mnorm = tracer.stack("M_norm")  # [T, N]
    rw = tracer.stack("rw")         # [T, R, N]

    # Log images
    writer.add_image("memory/u_TxN", (u.T).unsqueeze(0), global_step, dataformats='CHW')
    writer.add_image("memory/Mnorm_TxN", (Mnorm.T).unsqueeze(0), global_step, dataformats='CHW')

    top_read = rw.argmax(dim=-1).float()  # [T,R]
    top_read_img = (top_read / max(1, rw.size(-1)-1)).T  # [R,T]
    writer.add_image("memory/top_read_RxT", top_read_img.unsqueeze(0), global_step, dataformats='CHW')


## GPU allocation sanity test

In [17]:
# --- GPU VRAM diagnostics ---
import torch, gc, contextlib

def cuda_report(tag=""):
    if not torch.cuda.is_available():
        print("[cuda_report] CUDA not available"); return
    free, total = torch.cuda.mem_get_info()
    alloc = torch.cuda.memory_allocated()
    reserv = torch.cuda.memory_reserved()
    print(f"[{tag}] alloc={alloc/1e9:.2f} GB | reserved={reserv/1e9:.2f} GB | free={free/1e9:.2f} GB | total={total/1e9:.2f} GB")

def list_head_refs():
    # looks for globals named 'head' and count of DNCFormerHead instances
    import gc, inspect, sys
    heads = [o for o in gc.get_objects() if o.__class__.__name__ == "DNCFormerHead"]
    print(f"[liveness] DNCFormerHead instances alive: {len(heads)}")
    if 'head' in globals():
        h = globals()['head']
        try:
            devs = sorted({p.device.type for p in h.parameters()})
        except Exception:
            devs = ["<unknown>"]
        print(f"[liveness] global 'head' present; param devices: {devs}")
    else:
        print("[liveness] no global 'head'")

def free_head_and_cache():
    # delete typical globals and clear allocator
    with contextlib.suppress(Exception):
        del globals()['head']
    with contextlib.suppress(Exception):
        del globals()['tok']
    gc.collect(); torch.cuda.empty_cache()
    cuda_report("after free")


In [18]:
cuda_report("before")
list_head_refs()

[before] alloc=0.00 GB | reserved=0.00 GB | free=24.44 GB | total=25.77 GB
[liveness] DNCFormerHead instances alive: 0
[liveness] no global 'head'


  heads = [o for o in gc.get_objects() if o.__class__.__name__ == "DNCFormerHead"]


In [19]:
# generator smoke test
mx = int(getattr(CFG, "max_seq_len", 256))
tok, _base = tok if 'tok' in globals() else (None, None)
_pad = getattr(tok, "pad_token_id", 0) if tok is not None else 0

x = make_repeat_copy(batch=3, T=min(mx, 128), vocab=50, pad_id=_pad)
print("repeat_copy batch shape:", x.shape, "| dtype:", x.dtype)  # expect (3, <=128), long
assert x.dtype == torch.long


repeat_copy batch shape: torch.Size([3, 128]) | dtype: torch.int64


In [20]:
from transformers import AutoModelForCausalLM, AutoTokenizer
gc.collect(); torch.cuda.empty_cache(); cuda_report("pre base-only")
tok = AutoTokenizer.from_pretrained(CFG.base_model_id, trust_remote_code=True)
base = AutoModelForCausalLM.from_pretrained(
    CFG.base_model_id,
    torch_dtype=(torch.bfloat16 if (amp_dtype!=torch.float32 and torch.cuda.is_bf16_supported()) else (torch.float16 if amp_dtype!=torch.float32 else None)),
    attn_implementation="sdpa",
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    use_safetensors=True,
).to("cuda")
cuda_report("after base-only")
del base, tok; gc.collect(); torch.cuda.empty_cache()


[pre base-only] alloc=0.00 GB | reserved=0.00 GB | free=24.44 GB | total=25.77 GB


`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[after base-only] alloc=7.64 GB | reserved=7.64 GB | free=16.79 GB | total=25.77 GB


## Medium-size training loop test

In [21]:
from typing import Tuple

def _build_mixer(tok, weights, hf_dataset="tatsu-lab/alpaca", hf_max_items=2000) -> MixtureSampler:
    mx = int(getattr(CFG, "max_seq_len", 256))
    pad_id = getattr(tok, "pad_token_id", 0) or 0

    pairs = hf_instruction_loader(hf_dataset, "train", ("instruction","output"), max_items=hf_max_items)
    gens = []; wts = []; names = []

    if pairs:
        def gen_hf(b): return make_hf_batch(tok, pairs, b, max_len=mx)
        gens.append(gen_hf); wts.append(weights[0]); names.append("hf")
        s_w = list(weights[1:])
    else:
        print("HF dataset unavailable or empty; using synthetic only.")
        s_w = [1.0, 1.0, 1.0]

    def gen_copy(b):   return make_copy_task(b, T=min(mx, 128), vocab=100)
    def gen_repeat(b): return make_repeat_copy(b, T=min(mx, 128), vocab=100, pad_id=pad_id, device="cpu")
    def gen_nback(b):  return make_n_back(b, T=min(mx, 128), n=5, vocab=50)
    gens.extend([gen_copy, gen_repeat, gen_nback]); wts.extend(s_w); names.extend(["copy","repeat","nback"])
    return MixtureSampler(gens, wts, names=names)


In [22]:
free_head_and_cache()

[after free] alloc=0.00 GB | reserved=0.00 GB | free=24.44 GB | total=25.77 GB


In [23]:
from torch.amp import autocast

@torch.no_grad()
def eval_copy(head, tok, batch=4, T=64, vocab=100):
    x = make_copy_task(batch, T, vocab=vocab).to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, :-1] == x[:, 1:]).float().mean().item()
    print("copy acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}

@torch.no_grad()
def eval_reverse(head, tok, batch=4, T=64, vocab=100):
    x, y = make_reverse_task(batch, T, vocab=vocab)
    x, y = x.to(device), y.to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, :-1] == y[:, 1:]).float().mean().item()
    print("reverse acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}

@torch.no_grad()
def eval_needle(head, tok, batch=4, T=128, vocab=200):
    x = make_needle_task(batch, T, vocab=vocab).to(device)
    with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
        logits, gates = head(x)
    preds = logits.argmax(dim=-1)
    acc = (preds[:, -1] == x[:, -1]).float().mean().item()
    print("needle acc:", acc, "gates:", [g.mean().item() for g in gates])
    return {"acc": acc, "gates": [g.mean().item() for g in gates]}


In [24]:
from torch.amp import GradScaler, autocast

def train_medium(
    steps: int = None,
    batch_size: int = None,
    warmup_steps: int = None,
    min_lr_ratio: float = 0.10,
    mixture_weights=(0.4, 0.2, 0.2, 0.2),
    hf_dataset: str = "tatsu-lab/alpaca",
    hf_max_items: int = 2000,
    log_every: int = None,
    viz_memory_after: bool = False,   # keep off by default; use visualize_memory_tb() ad‑hoc if desired
    viz_prompt: str = "### Instruction:\nRemember A then B; later return A.\n\n### Response:\n",
    viz_max_T: int = 64,
):
    # --- resolve config / defaults robustly ---
    cfg = CFG
    steps = int(steps if steps is not None else getattr(cfg, "train_steps", 200))
    batch_size = int(batch_size if batch_size is not None else getattr(cfg, "batch_size", 8))
    warmup_steps = int(warmup_steps if warmup_steps is not None else getattr(cfg, "warmup_steps", max(10, steps // 20)))
    log_every = int(log_every if log_every is not None else getattr(cfg, "log_every", 10))
    grad_clip = float(getattr(cfg, "grad_clip", 1.0))
    hist_every = int(getattr(cfg, "hist_every", 200))
    force_g = getattr(cfg, "force_g", None)
    gate_reg_lambda = float(getattr(cfg, "gate_reg_lambda", 0.0))

    # --- build model, optimizer, scheduler ---
    tok, head = build_model_and_tokenizer()
    optim = make_optimizer(head)
    scheduler = make_warmup_cosine_scheduler(optim, warmup_steps, steps, min_lr_ratio=min_lr_ratio)
    head.train()

    tb = TBLogger(logdir="./runs")
    use_scaler = (amp_dtype == torch.float16)
    scaler = GradScaler('cuda', enabled=use_scaler)

    mixer = _build_mixer(tok, mixture_weights, hf_dataset=hf_dataset, hf_max_items=hf_max_items)

    for step in range(1, steps + 1):
        in_ids = mixer(batch_size).to(device)

        with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype != torch.float32)):
            # forward with metrics + optional ablation override of gate
            logits, gates, aux = head.forward_with_metrics(in_ids, gate_override=force_g)
            loss = lm_shift_labels(in_ids, logits, tok)

            # optional: encourage memory usage on memory-tagged batches only
            if gate_reg_lambda > 0.0 and mixer.last_name in ("copy", "repeat", "nback"):
                try:
                    g_mean_all = torch.stack([g.mean() for g in aux["gates_raw"]]).mean()
                    loss = loss + gate_reg_lambda * (1.0 - g_mean_all)
                except Exception:
                    pass

        if use_scaler:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), grad_clip)
            scaler.step(optim); scaler.update(); scheduler.step()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(head.parameters(), grad_clip)
            optim.step(); scheduler.step()
        optim.zero_grad(set_to_none=True)

        if step % log_every == 0:
            gm = [g.mean().item() for g in gates]
            lr = optim.param_groups[0]["lr"]
            print(f"step {step} | loss {loss.item():.4f} | lr {lr:.2e} | gates={gm} | mix={mixer.last_name}")
            tb.log_scalars(step, float(loss.item()), float(lr), gm)
            
            # Tag current batch
            mix_name = mixer.last_name or "unknown"
            # Break down + log loss per task for per-task curve views
            tb.writer.add_scalar(f"loss_by_task/{mix_name}", float(loss.item()), step)
            
            # Per-block gate means, "token fractions with g>0.5" per task, and mean gate by temporal quartile for block 0
            try:
                for bi, g in enumerate(aux["gates detached"]):
                    g_mean_task = float(g.mean().item())
                    g_frac = float((g >0.5).float().mean().item())
                    tb.writer.add_scalar(f"gates_by_task/block_{bi}_mean/{mix_name}", g_mean_task, step)
                    tb.writer.add_scalar(f"gates_by_task/block_{bi}_frac>0.5/{mix_name}", g_frac, step)
                    
                    g0 = aux["gates_detached"][0]
                    if g0.dim() == 3:
                        g0 = g0.mean(dim=-1)
                    T = g0.size(1)
                    q = T // 4 if T >= 4 else 1
                    for qi, s1 in enumerate([(0,q), (q,2*q), (2*q,3*q), (3*q,T)]):
                        s, e = s1
                        tb.writer.add_scalar(f"gates/block0_q{qi+1}_mean/{mix_name}", float(g0[:, s:e].mean().item()), step)
                    
            except Exception as exception:
                print(f"Exception '{exception}' encountered while reporting pre-block gate means and g>0.5 token fracions \n This is a non-breaking exception") # extra context, may not be informative
                pass
                
                    
            # extra per‑block scalars (guarded)
            try:
                for bi, pm in enumerate(aux.get("per_block", [])):
                    if "g_entropy"   in pm: tb.writer.add_scalar(f"gates/block_{bi}_entropy",   float(pm["g_entropy"].item()), step)
                    if "u_mean"      in pm: tb.writer.add_scalar(f"memory/block_{bi}_u_mean",    float(pm["u_mean"].item()), step)
                    if "rw_max_mean" in pm: tb.writer.add_scalar(f"memory/block_{bi}_rw_max_mean",float(pm["rw_max_mean"].item()), step)
                    if "ww_max_mean" in pm: tb.writer.add_scalar(f"memory/block_{bi}_ww_max_mean",float(pm["ww_max_mean"].item()), step)
                    if "M_norm_mean" in pm: tb.writer.add_scalar(f"memory/block_{bi}_Mnorm_mean", float(pm["M_norm_mean"].item()), step)
                    if "vt_norm"     in pm: tb.writer.add_scalar(f"paths/block_{bi}_vt_norm",     float(pm["vt_norm"].item()), step)
                    if "dt_norm"     in pm: tb.writer.add_scalar(f"paths/block_{bi}_dt_norm",     float(pm["dt_norm"].item()), step)
            except Exception:
                pass

            # gate histograms (cadence)
            try:
                if hist_every > 0 and (step % hist_every == 0):
                    for bi, g in enumerate(aux["gates_detached"]):
                        tb.add_histogram(f"gates/block_{bi}_hist", g, step)
            except Exception:
                pass

    # optional memory snapshot into the same TB run (off by default)
    if viz_memory_after and TB_AVAILABLE:
        try:
            visualize_memory_tb(head, tok, tb.writer, global_step=steps, prompt=viz_prompt, max_T=viz_max_T)
        except Exception as e:
            print("Memory TB viz skipped:", e)

    tb.flush(); tb.close()
    return head, tok


In [25]:
# # --- Patch: robust forward for ParallelEnrichmentBlock ---
def _peb_forward(self, x, dnc_state=None, gate_override=None):
    """
    Returns: (out, dnc_state, g)
      out: (B,T,D), mixed from vanilla (vt) and DNC (dt)
      dnc_state: updated state dict from DNC path
      g: (B,T,D) or (B,T,1) gate values in [0,1]
    """
    # 1) Vanilla transformer path
    T = x.size(1)
    mask = causal_mask(T, device=x.device)  # (T,T) causal attn mask
    vt = self.vanilla(x, attn_mask=mask)    # (B,T,D)

    # 2) DNC path (controller+memory)
    dt, dnc_state = self.dncblock(x, state=dnc_state)  # dt: (B,T,D)

    # 3) Gate computation (optionally LN before gate)
    import torch as _t
    z = _t.cat([vt, dt], dim=-1)  # (B,T,2D)
    h = self.pre_gate_ln(z) if hasattr(self, "pre_gate_ln") and self.pre_gate_ln is not None else z
    g_pre = self.gate(h)                 # typically (B,T,D) or (B,T,1)
    tau = float(getattr(CFG, 'gate_temp', 1.0))
    g = torch.sigmoid(g_pre / max(tau, 1e-6))

    # Optional ablation override: force g to 0.0 (vanilla) or 1.0 (DNC)
    if gate_override is not None:
        g = _t.full_like(g, float(gate_override))

    # 4) Blend paths
    out = g * dt + (1.0 - g) * vt        # (B,T,D)

    # 5) Lightweight metrics (only if requested)
    self.last_metrics = None
    if getattr(self, "collect_metrics", False):
        try:
            eps = 1e-6
            gc = g.clamp(min=eps, max=1.0 - eps)
            g_entropy = (-(gc * gc.log() + (1 - gc) * (1 - gc).log())).mean()
            vt_norm = vt.norm(dim=-1).mean()
            dt_norm = dt.norm(dim=-1).mean()
            mstats = {}
            # Pull aggregated memory stats if present
            if isinstance(dnc_state, dict) and isinstance(dnc_state.get("stats", None), dict):
                for k in ("u_mean", "rw_max_mean", "ww_max_mean", "M_norm_mean"):
                    if k in dnc_state["stats"]:
                        mstats[k] = dnc_state["stats"][k]
            self.last_metrics = {
                "g_mean": g.mean().detach(),
                "g_entropy": g_entropy.detach(),
                "vt_norm": vt_norm.detach(),
                "dt_norm": dt_norm.detach(),
                **mstats,
            }
        except Exception:
            self.last_metrics = None

    return out, dnc_state, g

# Apply the patch
ParallelEnrichmentBlock.forward = _peb_forward


In [26]:
# --- Evaluator unit tests (no heavy model) ---
import torch
from typing import Optional

class _DummyHead(torch.nn.Module):
    def __init__(self, vocab=100, d_model=64, n_blocks=2):
        super().__init__()
        self.vocab = vocab
        self.d_model = d_model
        self.n_blocks = n_blocks
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]=None):
        B, T = input_ids.shape
        logits = torch.randn(B, T, self.vocab, device=input_ids.device, dtype=torch.float32)
        gates = [torch.sigmoid(torch.randn(B, T, self.d_model, device=input_ids.device)) for _ in range(self.n_blocks)]
        return logits, gates

def run_eval_unit_tests():
    dummy = _DummyHead().to(device).eval()
    res_copy = eval_copy(dummy, tok=None, batch=2, T=8, vocab=50)
    res_rev = eval_reverse(dummy, tok=None, batch=2, T=8, vocab=50)
    res_needle = eval_needle(dummy, tok=None, batch=2, T=16, vocab=100)
    assert all(k in res_copy for k in ["acc","gates"])
    assert all(k in res_rev for k in ["acc","gates"])
    assert all(k in res_needle for k in ["acc","gates"])
    print("Evaluator unit tests passed.")


## 11. Unit-like tests (sanity)

In [27]:

def run_unit_tests():
    B, T = 2, 4
    R, W, N = CFG.dnc_read_heads, CFG.dnc_cell_size, CFG.dnc_nr_cells
    d_in = d_model = 128

    mem = DNCMemory(N, W, R).to(device)
    state = mem.reset(B, device=device)
    iface = {
        "k_read": torch.zeros(B,R,W, device=device),
        "beta_read": torch.ones(B,R,1, device=device),
        "k_write": torch.zeros(B,W, device=device),
        "beta_write": torch.ones(B,1, device=device),
        "erase": torch.sigmoid(torch.randn(B,W, device=device)),
        "write_vec": torch.randn(B,W, device=device),
        "free_gates": torch.sigmoid(torch.randn(B,R,1, device=device)),
        "alloc_gate": torch.sigmoid(torch.randn(B,1, device=device)),
        "write_gate": torch.sigmoid(torch.randn(B,1, device=device)),
        "read_mode": torch.randn(B,R,3, device=device),
    }
    r, state2 = mem(iface, state)
    assert r.shape == (B,R,W)

    ctrl = TransformerController(d_in+R*W, d_model, heads=4).to(device)
    x = torch.randn(B,T,d_in, device=device)
    reads = state2["r"].reshape(B,R*W).unsqueeze(1).expand(B,T,R*W)
    h = ctrl(torch.cat([x, reads], dim=-1))
    assert h.shape == (B,T,d_model)

    dblk = DNCformerBlock(d_in, d_model, R, W, N, heads=4, dropout=0.1, ffn_mult=4.0).to(device)
    y, s = dblk(x)
    assert y.shape == (B,T,d_model)

    pen = ParallelEnrichmentBlock(d_model, d_in, R, W, N, heads=4, dropout=0.1, ffn_mult=4.0, gate_bias_init=-1.0).to(device)
    out, s2, g = pen(torch.randn(B,T,d_model, device=device))
    print(f"out: {out}\ns2: {s2}\ng: {g}")
    eps = 1e-6
    assert torch.isfinite(g).all()
    assert ((g > eps) & (g < 1 - eps)).float().mean().item() > 0.95
    assert out.shape == (B,T,d_model)
    print("All unit-like tests passed.")


## Unit tests/basic sanity checks
currently all passing, but self-reminder here to comment out if architecture changes

In [28]:
# --- Smoke tests ---
def run_memory_tracer_smoke(head, tok):
    if not TB_AVAILABLE:
        print("TB not available; skipping image smoke.")
        return
    from torch.amp import autocast
    with trace_memory(head) as tracer:
        x = torch.randint(5, (1, 16), device=device)
        with autocast('cuda', dtype=amp_dtype, enabled=(amp_dtype!=torch.float32)):
            _ = head(x)
    assert len(tracer.frames) > 0, "No memory frames captured"
    print("Tracer captured", len(tracer.frames), "steps")


In [29]:
#run_unit_tests()

out: tensor([[[ 0.2946,  1.3153,  0.4913,  ..., -0.5015,  0.0606,  1.3821],
         [ 1.0736,  0.3288, -0.0846,  ..., -0.3198, -1.2993, -0.6523],
         [ 0.9915,  0.1268, -1.9989,  ..., -0.1103, -1.3248, -1.1391],
         [-0.0569,  0.3488,  1.7448,  ...,  0.0448,  1.1685, -0.8049]],

        [[ 0.4288, -0.4819,  0.4153,  ...,  1.4628,  0.6916,  0.0974],
         [ 0.4180, -1.5419,  0.4681,  ...,  0.1858,  1.7110, -0.3419],
         [ 0.3899, -0.2150, -0.0096,  ..., -1.7715,  0.0029,  0.5766],
         [-0.6167,  0.2484,  0.8239,  ...,  0.8809,  0.4929, -0.6685]]],
       device='cuda:0', grad_fn=<AddBackward0>)
s2: {'M': tensor([[[ 2.2475e-03,  9.5308e-02,  1.3153e-01,  ...,  9.7984e-02,
           4.0287e-01, -3.2455e-02],
         [-1.7167e-03,  1.3582e-05,  2.2045e-03,  ..., -8.8032e-04,
           1.8157e-03, -1.3368e-03],
         [-1.7168e-03,  1.3498e-05,  2.2043e-03,  ..., -8.8044e-04,
           1.8153e-03, -1.3367e-03],
         ...,
         [-1.7168e-03,  1.3498e-05, 

In [30]:
#run_eval_unit_tests()

copy acc: 0.0 gates: [0.49580350518226624, 0.50152587890625]
reverse acc: 0.0 gates: [0.49599286913871765, 0.5039244890213013]
needle acc: 0.0 gates: [0.489719033241272, 0.4984525442123413]
Evaluator unit tests passed.


In [31]:
#run_memory_tracer_smoke()

## Training run Sanity test

In [32]:
# small training test
#head, tok = train_small()

## Medium size training sweeps

In [33]:
# single medium training test/sanity check
# free_head_and_cache()
# cuda_report("snapshot: before train_medium")
# head, tok = train_medium(steps=100, warmup_steps=10, mixture_weights=(0.4,0.2,0.2,0.2))
# cuda_report("snapshot: after train_medium")
# Launch TensorBoard in a terminal: tensorboard --logdir ./runs


### Medium training sweep, testing parameter / dataset variants

In [34]:
# --- Compact experiment driver: E0..E5 ---
import time, torch, gc

# === user-tunable defaults for quick smoke; bump steps for deeper runs ===
EXP_STEPS   = 10       # try 1000–5000 for longer curves
EXP_WARMUP  = 10       # keep ~steps/20
BASE_MIX    = (0.4, 0.2, 0.2, 0.2)  # (hf, copy, repeat, nback)

def set_cfg(**kv):
    for k, v in kv.items():
        setattr(CFG, k, v)

def run_isolate_test(label,
            mixture_weights=BASE_MIX,
            gate_reg_lambda=None,
            gate_temp=None,
            force_g=None,
            steps=EXP_STEPS,
            warmup=EXP_WARMUP):
    """Run a single train_medium experiment with explicit knobs."""
    print(f"\n=== {label} ===")
    # set experiment-specific knobs (others come from CFG defaults)
    if gate_reg_lambda is not None: set_cfg(gate_reg_lambda=float(gate_reg_lambda))
    if gate_temp is not None:       set_cfg(gate_temp=float(gate_temp))
    set_cfg(force_g=force_g)  # may be None, 0.0, or 1.0

    print(f"mixture={mixture_weights}, gate_reg_lambda={getattr(CFG,'gate_reg_lambda', 0.0)}, "
          f"gate_temp={getattr(CFG,'gate_temp', 1.0)}, force_g={getattr(CFG,'force_g', None)}")

    # clean slate for VRAM and allocator fragmentation between runs
    free_head_and_cache()
    cuda_report(f"before {label}")
    time.sleep(1.2)  # ensure distinct TB run dirs (timestamp granularity)

    # run
    head, tok = train_medium(
        steps=steps,
        warmup_steps=warmup,
        mixture_weights=mixture_weights,
        viz_memory_after=False,   # keep quick; use visualize_memory_tb() ad-hoc
    )

    # post-run snapshot + cleanup
    cuda_report(f"after  {label}")
    free_head_and_cache()
    return head, tok

# === E0 baseline: identical to your sanity run ===
run_isolate_test("E0_baseline",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E1 memory-leaning mix: more algorithmic exposure ===
run_isolate_test("E1_memory_lean",
        mixture_weights=(0.4, 0.3, 0.2, 0.1),  # HF, copy, repeat, nback
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E2 gate regularizer (low) ===
run_isolate_test("E2_gate_reg_low",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=2e-4,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E3 gate regularizer (high) ===
run_isolate_test("E3_gate_reg_high",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=5e-4,
        gate_temp=1.0,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E4 sharper routing via lower gate temperature ===
run_isolate_test("E4_gate_temp_0p7",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=0.7,
        force_g=None,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

# === E5 ablations: disable/force DNC path ===
run_isolate_test("E5a_force_g_0",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=0.0,
        steps=EXP_STEPS, warmup=EXP_WARMUP)

run_isolate_test("E5b_force_g_1",
        mixture_weights=(0.4, 0.2, 0.2, 0.2),
        gate_reg_lambda=0.0,
        gate_temp=1.0,
        force_g=1.0,
        steps=EXP_STEPS, warmup=EXP_WARMUP)



=== E0_baseline ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0, gate_temp=1.0, force_g=None
[after free] alloc=0.01 GB | reserved=0.02 GB | free=24.36 GB | total=25.77 GB
[before E0_baseline] alloc=0.01 GB | reserved=0.02 GB | free=24.36 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Trainable params in head: 4353384090


  self.gen = func(*args, **kwds)
You are not running the flash-attention implementation, expect numerical differences.


step 10 | loss 10.5249 | lr 2.00e-05 | gates=[0.259765625, 0.2890625] | mix=hf
Exception ''gates detached'' encountered while reporting pre-block gate means and g>0.5 token fracions 
 This is a non-breaking exception
[after  E0_baseline] alloc=9.79 GB | reserved=26.01 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=9.79 GB | reserved=9.90 GB | free=14.47 GB | total=25.77 GB

=== E1_memory_lean ===
mixture=(0.4, 0.3, 0.2, 0.1), gate_reg_lambda=0.0, gate_temp=1.0, force_g=None
[after free] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB
[before E1_memory_lean] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Trainable params in head: 4353384090
step 10 | loss 10.6371 | lr 2.00e-05 | gates=[0.2734375, 0.294921875] | mix=hf
Exception ''gates detached'' encountered while reporting pre-block gate means and g>0.5 token fracions 
 This is a non-breaking exception
[after  E1_memory_lean] alloc=9.79 GB | reserved=22.99 GB | free=0.21 GB | total=25.77 GB
[after free] alloc=9.79 GB | reserved=9.82 GB | free=14.54 GB | total=25.77 GB

=== E2_gate_reg_low ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0002, gate_temp=1.0, force_g=None
[after free] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB
[before E2_gate_reg_low] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Trainable params in head: 4353384090
step 10 | loss 19.1560 | lr 2.00e-05 | gates=[0.306640625, 0.31640625] | mix=copy
Exception ''gates detached'' encountered while reporting pre-block gate means and g>0.5 token fracions 
 This is a non-breaking exception
[after  E2_gate_reg_low] alloc=9.79 GB | reserved=23.13 GB | free=0.08 GB | total=25.77 GB
[after free] alloc=9.79 GB | reserved=9.83 GB | free=14.54 GB | total=25.77 GB

=== E3_gate_reg_high ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0005, gate_temp=1.0, force_g=None
[after free] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB
[before E3_gate_reg_high] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Trainable params in head: 4353384090
step 10 | loss 62.0292 | lr 2.00e-05 | gates=[0.3125, 0.30859375] | mix=nback
Exception ''gates detached'' encountered while reporting pre-block gate means and g>0.5 token fracions 
 This is a non-breaking exception
[after  E3_gate_reg_high] alloc=9.79 GB | reserved=27.25 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=9.79 GB | reserved=9.83 GB | free=14.54 GB | total=25.77 GB

=== E4_gate_temp_0p7 ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0, gate_temp=0.7, force_g=None
[after free] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB
[before E4_gate_temp_0p7] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Trainable params in head: 4353384090
step 10 | loss 8.8115 | lr 2.00e-05 | gates=[0.1806640625, 0.220703125] | mix=hf
Exception ''gates detached'' encountered while reporting pre-block gate means and g>0.5 token fracions 
 This is a non-breaking exception
[after  E4_gate_temp_0p7] alloc=9.79 GB | reserved=22.67 GB | free=0.54 GB | total=25.77 GB
[after free] alloc=9.79 GB | reserved=9.83 GB | free=14.54 GB | total=25.77 GB

=== E5a_force_g_0 ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0, gate_temp=1.0, force_g=0.0
[after free] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB
[before E5a_force_g_0] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Trainable params in head: 4353384090
step 10 | loss 26.0708 | lr 2.00e-05 | gates=[0.0, 0.0] | mix=nback
Exception ''gates detached'' encountered while reporting pre-block gate means and g>0.5 token fracions 
 This is a non-breaking exception
[after  E5a_force_g_0] alloc=9.79 GB | reserved=22.38 GB | free=0.82 GB | total=25.77 GB
[after free] alloc=9.79 GB | reserved=9.83 GB | free=14.54 GB | total=25.77 GB

=== E5b_force_g_1 ===
mixture=(0.4, 0.2, 0.2, 0.2), gate_reg_lambda=0.0, gate_temp=1.0, force_g=1.0
[after free] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB
[before E5b_force_g_1] alloc=0.02 GB | reserved=0.12 GB | free=24.24 GB | total=25.77 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Trainable params in head: 4353384090
step 10 | loss 19.9843 | lr 2.00e-05 | gates=[1.0, 1.0] | mix=hf
Exception ''gates detached'' encountered while reporting pre-block gate means and g>0.5 token fracions 
 This is a non-breaking exception
[after  E5b_force_g_1] alloc=9.79 GB | reserved=24.63 GB | free=0.00 GB | total=25.77 GB
[after free] alloc=9.79 GB | reserved=9.83 GB | free=14.54 GB | total=25.77 GB


(DNCFormerHead(
   (base): Phi3ForCausalLM(
     (model): Phi3Model(
       (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
       (embed_dropout): Dropout(p=0.0, inplace=False)
       (layers): ModuleList(
         (0-31): 32 x Phi3DecoderLayer(
           (self_attn): Phi3Attention(
             (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
             (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
             (rotary_emb): Phi3RotaryEmbedding()
           )
           (mlp): Phi3MLP(
             (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
             (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
             (activation_fn): SiLU()
           )
           (input_layernorm): Phi3RMSNorm()
           (resid_attn_dropout): Dropout(p=0.0, inplace=False)
           (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
           (post_attention_layernorm): Phi3RMSNorm()
         )
    

In [39]:
gc.collect(); torch.cuda.empty_cache()

[after free] alloc=10.22 GB | reserved=10.33 GB | free=14.03 GB | total=25.77 GB


## 12. Notes & TODO
- The DNC memory here is **compact** and intended for research iteration; you can swap in a fuller reference implementation if desired. [x]
- The controller currently runs in **sequence mode** with a causal mask. A step-wise cached mode can be added for streaming scenarios.
- Training uses a tiny **instruction-following set** plus synthetic memory tasks. You can plug in larger corpora or evaluation suites later.
- Gating is **vector-valued** with bias init favoring the vanilla path; metrics log mean gate values.
- Use `CFG.n_blocks` to grow the enrichment depth as VRAM allows.


In [None]:
# --- Optional: export environment specs (run in the dncformer conda env) ---
import shutil, subprocess, sys
def _run(cmd):
    try:
        print(">", cmd); subprocess.run(cmd, shell=True, check=True)
    except Exception as e:
        print("Command failed (may be expected on some setups):", e)

print("Exporting conda environment and pip freeze to current working directory...")
_run("conda env export --from-history > environment.yml")
_run("conda env export > environment.lock.yml")
_run("python -m pip list --format=freeze > requirements-pip.txt")
print("Done. If any commands failed, run them in your terminal inside the active conda env.")
