# Minimal Dialectical Attention Head (with constrained decoding)

Opposed streams → dual attention → tension-gated recursive update with per-token halting.
Includes a **constrained decoder** to avoid unbalanced parentheses and trailing PAD/space runs.


In [None]:
import math, random
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
class DialecticalHead(nn.Module):
    def __init__(self, d_model: int, d_head: int, max_steps: int = 6, halt_eps: float = 5e-4, tension_tau: float = 0.15, min_steps:int=1):
        super().__init__()
        self.q = nn.Linear(d_model, d_head, bias=False)
        self.k = nn.Linear(d_model, d_head, bias=False)
        self.v = nn.Linear(d_model, d_head, bias=False)
        self.v_pos = nn.Linear(d_head, d_head, bias=False)
        self.v_neg = nn.Linear(d_head, d_head, bias=False)
        self.synth = nn.Linear(5 * d_head, d_head)
        self.up = nn.Linear(d_head, d_model)
        self.u_pos = nn.Parameter(torch.zeros(d_head, 1))
        self.u_neg = nn.Parameter(torch.zeros(d_head, 1))
        self.b_pos = nn.Parameter(torch.zeros(1))
        self.b_neg = nn.Parameter(torch.zeros(1))
        self.step_gate = nn.Linear(d_head, 1)
        self.max_steps, self.min_steps = max_steps, min_steps
        self.halt_eps, self.tension_tau = halt_eps, tension_tau
        nn.init.kaiming_uniform_(self.v_pos.weight, a=math.sqrt(5))
        with torch.no_grad():
            self.v_neg.weight.copy_(-self.v_pos.weight + 0.01 * torch.randn_like(self.v_pos.weight))

    @staticmethod
    def _masks(token_mask: torch.Tensor, T: int, device):
        keep = torch.tril(torch.ones(T, T, device=device))
        causal_add = torch.where(keep > 0, torch.zeros_like(keep), torch.full_like(keep, float('-inf')))  # [T,T]
        key_add = (~token_mask).float().unsqueeze(1) * float('-inf')  # [B,1,T]
        return causal_add, key_add

    @staticmethod
    def _bmm(attn: torch.Tensor, vals: torch.Tensor) -> torch.Tensor:
        return torch.bmm(attn, vals)

    def forward(self, x, token_mask):
        B, T, _ = x.shape
        z = x
        active = token_mask.clone()
        steps_used = torch.zeros(B, T, device=x.device)
        tensions = []
        causal_add, key_add = self._masks(token_mask, T, x.device)

        for t in range(self.max_steps):
            steps_used = steps_used + active.float()
            z_eff = torch.where(active.unsqueeze(-1), z, z.detach())
            q = self.q(z_eff); k = self.k(z_eff); v = self.v(z_eff)
            vpos = self.v_pos(v); vneg = self.v_neg(v)
            logits = torch.matmul(q, k.transpose(-1, -2)) / (k.size(-1) ** 0.5)
            logits = logits + causal_add + key_add
            logits = torch.nan_to_num(logits, nan=0.0, posinf=1e9, neginf=-1e9)
            all_blocked = torch.isneginf(logits).all(dim=-1, keepdim=True)
            logits = torch.where(all_blocked, torch.zeros_like(logits), logits)
            tilt_p = torch.matmul(q, self.u_pos); tilt_n = torch.matmul(q, self.u_neg)
            attn_p = F.softmax(logits + self.b_pos + tilt_p, dim=-1)
            attn_n = F.softmax(logits + self.b_neg + tilt_n, dim=-1)
            attn_p = torch.nan_to_num(attn_p, nan=0.0); attn_n = torch.nan_to_num(attn_n, nan=0.0)
            attn_p = attn_p / attn_p.sum(-1, keepdim=True).clamp_min(1e-9)
            attn_n = attn_n / attn_n.sum(-1, keepdim=True).clamp_min(1e-9)
            up = self._bmm(attn_p, vpos); un = self._bmm(attn_n, vneg)
            cos = F.cosine_similarity(up, un, dim=-1, eps=1e-6).unsqueeze(-1)
            tension = 0.5 * (1.0 - cos)
            act = active.float().unsqueeze(-1)
            tensions.append(((tension * act).sum() / act.sum().clamp(min=1.0)).detach().item())
            step = torch.sigmoid(self.step_gate(q)) * tension.clamp(0,1)
            proposal = self.up(F.silu(self.synth(torch.cat([up, un, up-un, up+un, q], dim=-1))))
            z_new = z + step * proposal
            delta = (z_new - z).norm(dim=-1) / (z.norm(dim=-1) + 1e-6)
            allow_halt = (t + 1) >= self.min_steps
            done = allow_halt & ((delta < self.halt_eps) | (tension.squeeze(-1) < self.tension_tau)) & token_mask
            z = torch.where(done.unsqueeze(-1), z, z_new)
            active = active & (~done)
            if not active.any():
                break

        denom = token_mask.float().sum().clamp(min=1.0)
        avg_steps = (steps_used * token_mask.float()).sum() / denom
        return z, {'avg_steps': avg_steps.item(), 'tensions': tensions}

class TinyDialecticalBlock(nn.Module):
    def __init__(self, d_model=128, d_head=64, **kw):
        super().__init__()
        self.attn = DialecticalHead(d_model, d_head, **kw)
        self.out = nn.Linear(d_model, d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, 2*d_model), nn.SiLU(), nn.Linear(2*d_model, d_model)
        )
    def forward(self, x, token_mask):
        z, metrics = self.attn(self.ln1(x), token_mask)
        x = x + self.out(z)
        x = x + self.ff(x)
        return x, metrics

class TinyDialecticalLM(nn.Module):
    def __init__(self, vocab_size, pad_idx, d_model=128, d_head=64, **kw):
        super().__init__()
        self.pad_idx = pad_idx
        self.emb = nn.Embedding(vocab_size, d_model)
        nn.init.normal_(self.emb.weight, mean=0.0, std=0.02)
        self.pos = nn.Parameter(torch.randn(1, 256, d_model) * 0.01)
        self.block = TinyDialecticalBlock(d_model, d_head, **kw)
        self.lm = nn.Linear(d_model, vocab_size)
    def forward(self, x):
        B, T = x.shape
        m = (x != self.pad_idx)
        h = self.emb(x) + self.pos[:, :T, :]
        h, metrics = self.block(h, m)
        return self.lm(h), metrics


In [None]:
VOCAB = list("()abc _")  # '_' is PAD
stoi = {ch:i for i,ch in enumerate(VOCAB)}
itos = {i:ch for ch,i in stoi.items()}
PAD = stoi['_']
vocab_size = len(VOCAB)

def gen_balanced(n_pairs=4, fillers=True):
    if n_pairs == 0: return ""
    left = random.randint(0, n_pairs-1)
    right = n_pairs-1-left
    inner = gen_balanced(left, fillers)
    outer = gen_balanced(right, fillers)
    s = '(' + inner + ')' + outer
    if fillers:
        out = []
        for ch in s:
            out.append(ch)
            if random.random() < 0.2:
                out.append(random.choice('ab '))
        s = ''.join(out)
    return s

def make_sample(max_pairs=6, max_len=96):
    pairs = random.randint(1, max_pairs)
    s = gen_balanced(pairs)
    s = s[:max_len-1]
    x = torch.tensor([stoi[ch] for ch in s], dtype=torch.long)
    y = torch.tensor([stoi[ch] for ch in (s[1:] + ' ')], dtype=torch.long)
    return x, y

def batchify(B=64, max_len=96, device=device):
    xs, ys = [], []
    for _ in range(B):
        x, y = make_sample(max_len=max_len)
        xs.append(x); ys.append(y)
    T = max(x.size(0) for x in xs)
    X = torch.full((B, T), PAD, dtype=torch.long)
    Y = torch.full((B, T), PAD, dtype=torch.long)
    for i,(x,y) in enumerate(zip(xs, ys)):
        X[i, :x.size(0)] = x
        Y[i, :y.size(0)] = y
    return X.to(device), Y.to(device)


In [None]:
model = TinyDialecticalLM(vocab_size, PAD, d_model=128, d_head=64, max_steps=6, tension_tau=0.15).to(device)
optim = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
losses, avg_steps_hist, tstarts, tends = [], [], [], []
for step in range(200):
    X, Y = batchify(B=64, max_len=96)
    logits, metrics = model(X)
    ce = F.cross_entropy(logits.reshape(-1, logits.size(-1)), Y.reshape(-1), ignore_index=PAD)
    optim.zero_grad(); ce.backward(); nn.utils.clip_grad_norm_(model.parameters(), 1.0); optim.step()
    losses.append(ce.item()); avg_steps_hist.append(metrics['avg_steps'])
    if metrics['tensions']:
        tstarts.append(metrics['tensions'][0]); tends.append(metrics['tensions'][-1])
    if step % 20 == 0:
        ts = metrics['tensions']
        print(f"step {step:03d} | loss {ce.item():.3f} | avg_steps {metrics['avg_steps']:.2f} | tension start {ts[0] if ts else float('nan'):.3f} -> end {ts[-1] if ts else float('nan'):.3f}")
plt.plot(losses); plt.title('Training loss'); plt.show()
plt.plot(avg_steps_hist); plt.title('Avg steps used (non-pad tokens)'); plt.show()
if tstarts and tends:
    plt.plot(tstarts, label='start'); plt.plot(tends, label='end'); plt.legend(); plt.title('Tension start vs end per batch'); plt.show()


In [None]:
X, Y = batchify(B=32, max_len=96)
with torch.no_grad():
    logits, metrics = model(X)
print('Avg steps (approx):', metrics['avg_steps'])
plt.plot(metrics['tensions']); plt.title('Mean tension over ACTIVE non-pad tokens per recursion step'); plt.show()

In [None]:
def generate_constrained(model, seed='((', max_new=40, stop_on_closed=True, close_tail=True, bias_close=5.0):
    idx_l = stoi['(']; idx_r = stoi[')']; idx_sp = stoi[' ']; idx_pad = stoi['_']
    x = torch.tensor([[stoi.get(ch, idx_pad) for ch in seed]], device=device)
    depth = sum(1 if ch=='(' else -1 if ch==')' else 0 for ch in seed)
    with torch.no_grad():
        for t in range(max_new):
            logits, _ = model(x)
            logit = logits[0, -1].clone()
            logit[idx_pad] = -1e-9
            if depth <= 0:
                logit[idx_r] = -1e-9
            remaining = max_new - t
            if depth > 0 and remaining <= depth:
                logit[idx_r] = logit[idx_r] + bias_close
            nxt = logit.argmax().unsqueeze(0).unsqueeze(0)
            ch = itos[int(nxt)]
            if ch == '(':
                depth += 1
            elif ch == ')':
                depth -= 1
            x = torch.cat([x, nxt], dim=1)
            if stop_on_closed and depth == 0 and ch == ' ':
                break
    if close_tail and depth > 0:
        closes = torch.tensor([[idx_r] * depth], device=device)
        x = torch.cat([x, closes], dim=1)
    return ''.join(itos[int(i)] for i in x[0].tolist())

print(generate_constrained(model, '(()', max_new=40))