<a href="https://colab.research.google.com/github/OneFineStarstuff/Pinn/blob/main/dnc_copy_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# dnc_copy.py
# Trainable Differentiable Neural Computer (DNC) with:
# - Content-based read/write
# - Usage-based allocation
# - Temporal link matrix (forward/backward traversal)
# Includes a copy task trainer for verification.

import os
import math
import argparse
from typing import Tuple, NamedTuple

import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------------------------
# Utilities
# ---------------------------

def set_seed(seed: int = 42, deterministic: bool = True):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def softplus(x, beta=1.0, threshold=20.0):
    return F.softplus(x, beta=beta, threshold=threshold)


def cosine_similarity(M: torch.Tensor, k: torch.Tensor, eps: float = 1e-8):
    """
    M: (B, N, W) memory
    k: (B, H, W) keys (H can be 1 or num_heads)
    returns: (B, H, N)
    """
    B, N, W = M.shape
    H = k.size(1)
    Mk = (M / (M.norm(dim=2, keepdim=True) + eps))  # (B, N, W)
    kk = (k / (k.norm(dim=2, keepdim=True) + eps))  # (B, H, W)
    # (B, H, N) = (B, H, W) @ (B, W, N)
    return torch.matmul(kk, Mk.transpose(1, 2))


def batched_outer(a: torch.Tensor, b: torch.Tensor):
    """
    a: (B, N) ; b: (B, N)
    returns: (B, N, N) with outer products per batch
    """
    return a.unsqueeze(2) * b.unsqueeze(1)


# ---------------------------
# DNC core
# ---------------------------

class DNCState(NamedTuple):
    controller_h: torch.Tensor       # (L, B, H)
    controller_c: torch.Tensor       # (L, B, H)
    memory: torch.Tensor             # (B, N, W)
    usage: torch.Tensor              # (B, N)
    link: torch.Tensor               # (B, N, N)
    precedence: torch.Tensor         # (B, N)
    read_weights: torch.Tensor       # (B, R, N)
    read_vectors: torch.Tensor       # (B, R, W)
    write_weights: torch.Tensor      # (B, N)


class DNC(nn.Module):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        controller_hidden: int = 256,
        controller_layers: int = 1,
        mem_n: int = 64,
        mem_w: int = 32,
        read_heads: int = 1
    ):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.H = controller_hidden
        self.L = controller_layers
        self.N = mem_n
        self.W = mem_w
        self.R = read_heads

        # Controller sees input plus previous read vectors
        ctrl_in = input_size + read_heads * mem_w
        self.controller = nn.LSTM(ctrl_in, controller_hidden, num_layers=controller_layers)

        # Parameter head from controller to interface vector
        # Write: k_w(W), beta_w(1), e(W), v(W), g_a(1), g_w(1)
        # Read:  k_r(R*W), beta_r(R), f(R), pi(R*3)
        interface_size = (3 * self.W + 2) + self.R * (self.W + 1 + 1 + 3)
        self.interface_linear = nn.Linear(controller_hidden, interface_size)

        # Output layer sees controller output + read vectors
        self.output_layer = nn.Linear(controller_hidden + read_heads * mem_w, output_size)

        # Initialization helpers
        self.reset_parameters()

    def reset_parameters(self):
        for name, p in self.named_parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        # Slightly bias strengths positive
        # No special-case biasing here; stability is handled by activations

    # -------- Addressing subroutines --------

    def _content_address(self, M, k, beta):
        """
        M: (B, N, W)
        k: (B, H, W)
        beta: (B, H, 1) >= 0
        returns: (B, H, N)
        """
        sim = cosine_similarity(M, k)  # (B, H, N)
        return F.softmax(beta * sim, dim=-1)

    def _update_usage(self, usage, read_weights, write_weights, free_gates):
        """
        usage: (B, N)
        read_weights: (B, R, N)
        write_weights: (B, N)
        free_gates: (B, R, 1) in [0,1]
        returns: (B, N)
        """
        # After reads with free gates
        psi = torch.prod(1 - free_gates * read_weights.unsqueeze(2), dim=1).squeeze(-1)  # (B, N)
        usage = (usage + write_weights - usage * write_weights)  # u + w - u∘w
        usage = usage * psi
        return torch.clamp(usage, 0.0, 1.0)

    def _allocation_weighting(self, usage):
        """
        usage: (B, N) in [0,1]
        returns allocation weights a: (B, N), favoring least used slots.
        """
        B, N = usage.shape
        # sort ascending by usage
        sorted_usage, indices = torch.sort(usage, dim=1)
        # cumulative product of sorted usage
        cumprod = torch.cumprod(sorted_usage, dim=1)
        # a_j = (1 - u_j) * prod_{i<j} u_i; with prod_{i<0}=1
        # Shift cumprod right and insert ones at start
        one = torch.ones(B, 1, device=usage.device, dtype=usage.dtype)
        prod_prev = torch.cat([one, cumprod[:, :-1]], dim=1)
        a_sorted = (1 - sorted_usage) * prod_prev
        # unsort back
        a = torch.zeros_like(a_sorted).scatter(1, indices, a_sorted)
        # normalize just in case of zeros everywhere
        denom = a.sum(dim=1, keepdim=True) + 1e-8
        return a / denom

    def _write(self, M, usage, link, precedence, read_weights, interface):
        """
        M: (B, N, W)
        usage: (B, N)
        link: (B, N, N)
        precedence: (B, N)
        read_weights: (B, R, N)
        interface: dict of write params
        returns updated (M, usage, link, precedence, write_weights)
        """
        B, N, W = M.shape
        # Unpack write interface
        k_w = interface["k_w"]               # (B, 1, W)
        beta_w = interface["beta_w"]         # (B, 1, 1)
        e = interface["e"]                   # (B, 1, W) in [0,1]
        v = interface["v"]                   # (B, 1, W)
        g_a = interface["g_a"]               # (B, 1, 1) in [0,1]
        g_w = interface["g_w"]               # (B, 1, 1) in [0,1]
        free_gates = interface["free_gates"] # (B, R, 1)

        # Update usage with current read weights and previous usage (before new write)
        usage = self._update_usage(usage, read_weights, torch.zeros(B, N, device=M.device), free_gates)

        # Content weights for write key
        c_w = self._content_address(M, k_w, beta_w).squeeze(1)  # (B, N)

        # Allocation weights from updated usage
        a = self._allocation_weighting(usage)                   # (B, N)

        # Interpolate allocation vs content, then gate by write gate
        w_w = g_w.squeeze(-1) * (g_a.squeeze(-1) * a + (1 - g_a.squeeze(-1)) * c_w)  # (B, N)

        # Erase and add
        # M = M * (1 - w_w e^T) + w_w v^T
        erase_term = 1 - w_w.unsqueeze(-1) * e  # (B, N, W)
        add_term = w_w.unsqueeze(-1) * v        # (B, N, W)
        M = M * erase_term + add_term

        # Update usage after the write
        usage = self._update_usage(usage, read_weights, w_w, free_gates)

        # Update temporal links
        w_col = w_w.unsqueeze(2)  # (B, N, 1)
        w_row = w_w.unsqueeze(1)  # (B, 1, N)

        link = (1 - torch.eye(N, device=M.device).unsqueeze(0)) * link  # zero diag
        link = link + batched_outer(w_w, precedence)  # add new links
        link = link * (1 - torch.eye(N, device=M.device).unsqueeze(0))  # ensure zero diag

        precedence = (1 - w_w.sum(dim=1, keepdim=True)) * precedence + w_w  # (B, N)

        return M, usage, link, precedence, w_w

    def _read(self, M, link, read_weights_prev, interface):
        """
        M: (B, N, W)
        link: (B, N, N)
        read_weights_prev: (B, R, N)
        interface: dict with read keys, strengths, modes
        returns (read_weights, read_vectors)
        """
        B, N, W = M.shape

        k_r = interface["k_r"]           # (B, R, W)
        beta_r = interface["beta_r"]     # (B, R, 1)
        read_modes = interface["read_modes"]  # (B, R, 3) rows sum to 1

        # Content weights
        c_r = self._content_address(M, k_r, beta_r)             # (B, R, N)

        # Forward and backward traversals via link matrix
        fwd = torch.matmul(read_weights_prev, link)              # (B, R, N)
        bwd = torch.matmul(read_weights_prev, link.transpose(1, 2))  # (B, R, N)

        # Combine by read modes [backward, content, forward]
        pi_b = read_modes[..., 0:1]
        pi_c = read_modes[..., 1:2]
        pi_f = read_modes[..., 2:3]
        read_weights = pi_b * bwd + pi_c * c_r + pi_f * fwd     # (B, R, N)
        read_weights = read_weights + 1e-8
        read_weights = read_weights / read_weights.sum(dim=-1, keepdim=True)

        read_vectors = torch.matmul(read_weights, M)            # (B, R, W)

        return read_weights, read_vectors

    # -------- Parsing interface vector --------

    def _parse_interface(self, iface: torch.Tensor, B: int):
        """
        iface: (T, B, interface_size) or (B, interface_size) for single step
        We handle single step here (B, D).
        Returns dict of shaped tensors.
        """
        D = iface.size(-1)
        ofs = 0

        def take(n):
            nonlocal ofs
            out = iface[:, ofs:ofs+n]
            ofs += n
            return out

        # Write
        k_w = take(self.W).view(B, 1, self.W)
        beta_w = softplus(take(1)).view(B, 1, 1) + 1e-6
        e = torch.sigmoid(take(self.W)).view(B, 1, self.W)
        v = torch.tanh(take(self.W)).view(B, 1, self.W)
        g_a = torch.sigmoid(take(1)).view(B, 1, 1)
        g_w = torch.sigmoid(take(1)).view(B, 1, 1)

        # Read
        k_r = take(self.R * self.W).view(B, self.R, self.W)
        beta_r = softplus(take(self.R)).view(B, self.R, 1) + 1e-6
        free_gates = torch.sigmoid(take(self.R)).view(B, self.R, 1)
        read_modes = take(self.R * 3).view(B, self.R, 3)
        read_modes = F.softmax(read_modes, dim=-1)

        return {
            "k_w": k_w,
            "beta_w": beta_w,
            "e": e,
            "v": v,
            "g_a": g_a,
            "g_w": g_w,
            "k_r": k_r,
            "beta_r": beta_r,
            "free_gates": free_gates,
            "read_modes": read_modes,
        }

    # -------- Step and forward --------

    def initial_state(self, batch_size: int, device: str):
        controller_h = torch.zeros(self.L, batch_size, self.H, device=device)
        controller_c = torch.zeros(self.L, batch_size, self.H, device=device)
        memory = torch.zeros(batch_size, self.N, self.W, device=device)
        usage = torch.zeros(batch_size, self.N, device=device)
        link = torch.zeros(batch_size, self.N, self.N, device=device)
        precedence = torch.zeros(batch_size, self.N, device=device)
        read_weights = F.softmax(torch.zeros(batch_size, self.R, self.N, device=device), dim=-1)
        read_vectors = torch.zeros(batch_size, self.R, self.W, device=device)
        write_weights = torch.zeros(batch_size, self.N, device=device)
        return DNCState(controller_h, controller_c, memory, usage, link, precedence, read_weights, read_vectors, write_weights)

    def step(self, x_t: torch.Tensor, state: DNCState) -> Tuple[torch.Tensor, DNCState]:
        """
        x_t: (B, input_size)
        returns y_t: (B, output_size)
        """
        B = x_t.size(0)
        # Controller input: x_t + previous read vectors
        ctrl_in = torch.cat([x_t, state.read_vectors.view(B, -1)], dim=-1).unsqueeze(0)  # (1, B, D)
        ctrl_out, (h, c) = self.controller(ctrl_in, (state.controller_h, state.controller_c))  # ctrl_out: (1,B,H)
        ctrl_out = ctrl_out.squeeze(0)  # (B, H)

        # Interface parsing
        iface = self.interface_linear(ctrl_out)  # (B, D)
        iface_parsed = self._parse_interface(iface, B)

        # Write to memory
        M, usage, link, precedence, w_w = self._write(
            state.memory, state.usage, state.link, state.precedence, state.read_weights, iface_parsed
        )

        # Read from memory
        read_weights, read_vectors = self._read(
            M, link, state.read_weights, iface_parsed
        )

        # Output
        out_input = torch.cat([ctrl_out, read_vectors.view(B, -1)], dim=-1)
        y_t = self.output_layer(out_input)

        new_state = DNCState(h, c, M, usage, link, precedence, read_weights, read_vectors, w_w)
        return y_t, new_state

    def forward(self, x: torch.Tensor, state: DNCState = None):
        """
        x: (T, B, input_size)
        returns y: (T, B, output_size), final_state
        """
        T, B, _ = x.shape
        device = x.device
        if state is None:
            state = self.initial_state(B, device)

        outputs = []
        for t in range(T):
            y_t, state = self.step(x[t], state)
            outputs.append(y_t.unsqueeze(0))
        y = torch.cat(outputs, dim=0)
        return y, state


# ---------------------------
# Copy task dataset
# ---------------------------

def generate_copy_batch(batch_size: int, seq_len: int, bits: int, device: str):
    """
    Inputs: T = seq_len + 1 + seq_len
    - Phase 1 (write): seq_len timesteps of random bit vectors (bits) with a delimiter channel OFF
    - Delimiter timestep: one-hot delimiter channel ON (all bits zero)
    - Phase 2 (read): seq_len timesteps of zeros; model must output the original bit vectors
    Input size = bits + 1 (delimiter)
    Output size = bits (no delimiter in output)
    """
    # Random sequence: (B, seq_len, bits)
    seq = torch.bernoulli(0.5 * torch.ones(batch_size, seq_len, bits, device=device))
    # Assemble input
    T = seq_len + 1 + seq_len
    inp = torch.zeros(T, batch_size, bits + 1, device=device)
    out = torch.zeros(T, batch_size, bits, device=device)

    # Write phase
    inp[:seq_len, :, :bits] = seq
    # Delimiter
    inp[seq_len, :, bits] = 1.0
    # Read phase: zeros input

    # Desired output during read phase
    out[seq_len + 1:, :, :] = seq

    return inp, out


# ---------------------------
# Training
# ---------------------------

def train_copy(
    device="cuda" if torch.cuda.is_available() else "cpu",
    batch_size=16,
    seq_len=10,
    bits=8,
    mem_n=64,
    mem_w=32,
    read_heads=2,
    controller_hidden=256,
    controller_layers=1,
    lr=1e-3,
    epochs=20000,
    grad_clip=10.0,
    seed=42,
    checkpoint="./checkpoints/dnc_copy.pt",
    log_every=100
):
    set_seed(seed)
    os.makedirs(os.path.dirname(checkpoint), exist_ok=True)

    input_size = bits + 1
    output_size = bits

    model = DNC(
        input_size=input_size,
        output_size=output_size,
        controller_hidden=controller_hidden,
        controller_layers=controller_layers,
        mem_n=mem_n,
        mem_w=mem_w,
        read_heads=read_heads
    ).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr)

    best_acc = 0.0
    for step in range(1, epochs + 1):
        model.train()
        x, y_true = generate_copy_batch(batch_size, seq_len, bits, device)

        y_pred, _ = model(x)
        # We only supervise read phase outputs
        y_pred_read = y_pred[seq_len + 1:]         # (seq_len, B, bits)
        y_true_read = y_true[seq_len + 1:]         # (seq_len, B, bits)

        loss = F.binary_cross_entropy_with_logits(y_pred_read, y_true_read)

        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        opt.step()

        if step % log_every == 0 or step == 1:
            with torch.no_grad():
                preds = (torch.sigmoid(y_pred_read) > 0.5).float()
                acc = (preds == y_true_read).float().mean().item()
            if acc > best_acc:
                best_acc = acc
                torch.save({"step": step, "state_dict": model.state_dict()}, checkpoint)
            print(f"[{step:05d}] loss={loss.item():.4f} acc={acc:.3f} best_acc={best_acc:.3f}")

    # Show a sample
    with torch.no_grad():
        x, y_true = generate_copy_batch(1, seq_len, bits, device)
        y_pred, _ = model(x)
        y_logits = y_pred[seq_len + 1:].squeeze(1)    # (seq_len, bits)
        y_hat = (torch.sigmoid(y_logits) > 0.5).float()
        print("\nSample (first 4 timesteps of read phase):")
        print("target:", y_true[seq_len + 1: seq_len + 5, 0].cpu().numpy())
        print("pred  :", y_hat[:4].cpu().numpy())


# ---------------------------
# CLI
# ---------------------------

def parse_args():
    p = argparse.ArgumentParser(description="Train a DNC on the copy task.")
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--batch-size", type=int, default=16)
    p.add_argument("--seq-len", type=int, default=10)
    p.add_argument("--bits", type=int, default=8)
    p.add_argument("--mem-n", type=int, default=64)
    p.add_argument("--mem-w", type=int, default=32)
    p.add_argument("--read-heads", type=int, default=2)
    p.add_argument("--controller-hidden", type=int, default=256)
    p.add_argument("--controller-layers", type=int, default=1)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--epochs", type=int, default=20000)
    p.add_argument("--grad-clip", type=float, default=10.0)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--checkpoint", type=str, default="./checkpoints/dnc_copy.pt")
    p.add_argument("--log-every", type=int, default=100)
    return p.parse_args()


if __name__ == "__main__":
    args = parse_args()
    train_copy(
        device=args.device,
        batch_size=args.batch_size,
        seq_len=args.seq_len,
        bits=args.bits,
        mem_n=args.mem_n,
        mem_w=args.mem_w,
        read_heads=args.read_heads,
        controller_hidden=args.controller_hidden,
        controller_layers=args.controller_layers,
        lr=args.lr,
        epochs=args.epochs,
        grad_clip=args.grad_clip,
        seed=args.seed,
        checkpoint=args.checkpoint,
        log_every=args.log_every
    )