# Dialectical (Recursive Dual) Attention — Colab Demo v2

**What's new in v2:**
- Two value streams are initialized as **opposites** and nudged by an opposition regularizer.
- **Per-stream routing tilts** (learned) make pos/neg actually look at different neighborhoods.
- **Tension metric** is now linear in cosine: `tension = 0.5*(1 - cos)` (0=aligned, 1=opposed).
- **Per-token halting** truly freezes stable tokens between recursion steps.

The rest matches v1: tiny Transformer on a balanced-parentheses toy LM task, quick plots of loss and mean tension per recursion step.

In [None]:
import math, random, os
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Dialectical Attention Head (single head, improved)

In [None]:
class DialecticalHead(nn.Module):
    def __init__(self, d_model: int, d_head: int, max_steps: int = 4, halt_eps: float = 2e-3, opp_lambda: float = 1e-3):
        super().__init__()
        self.q = nn.Linear(d_model, d_head, bias=False)
        self.k = nn.Linear(d_model, d_head, bias=False)
        self.v = nn.Linear(d_model, d_head, bias=False)
        # Opposed value channels (+ / -)
        self.v_pos = nn.Linear(d_head, d_head, bias=False)
        self.v_neg = nn.Linear(d_head, d_head, bias=False)
        # Tiny synthesis block
        self.synth = nn.Linear(3 * d_head, d_head)
        # Learned per-stream query tilts (broadcast over keys)
        self.u_pos = nn.Parameter(torch.zeros(d_head, 1))
        self.u_neg = nn.Parameter(torch.zeros(d_head, 1))
        # Small learned biases
        self.b_pos = nn.Parameter(torch.zeros(1))
        self.b_neg = nn.Parameter(torch.zeros(1))
        # Gate to scale updates
        self.gate = nn.Linear(d_head, 1)
        self.max_steps = max_steps
        self.halt_eps = halt_eps
        self.opp_lambda = opp_lambda

        self._init_opposed_streams()

    def _init_opposed_streams(self):
        # Initialize v_pos randomly and set v_neg ≈ -v_pos for genuine opposition at start
        nn.init.kaiming_uniform_(self.v_pos.weight, a=math.sqrt(5))
        with torch.no_grad():
            self.v_neg.weight.copy_(-self.v_pos.weight + 0.01 * torch.randn_like(self.v_pos.weight))

    def forward(self, x, mask=None, return_aux=False):
        """
        x: [B, T, d_model]
        mask: [B, 1, T, T] (0 keep, -inf block) optional
        returns: [B, T, d_head], metrics, extra_loss (opposition regularizer)
        """
        B, T, _ = x.shape
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        v_pos = self.v_pos(v)
        v_neg = self.v_neg(v)

        logits = torch.matmul(q, k.transpose(-1, -2)) / (k.size(-1) ** 0.5)
        if mask is not None:
            logits = logits + mask.squeeze(1)

        # Per-stream routing tilts (query-dependent, broadcast across keys)
        tilt_pos = torch.matmul(q, self.u_pos)  # [B,T,1]
        tilt_neg = torch.matmul(q, self.u_neg)

        attn_pos = F.softmax(logits + self.b_pos + tilt_pos, dim=-1)
        attn_neg = F.softmax(logits + self.b_neg + tilt_neg, dim=-1)

        # Initial state and masks
        z = q.clone()
        active = torch.ones(B, T, dtype=torch.bool, device=x.device)
        steps_used = torch.zeros(B, T, device=x.device)
        mean_tensions = []

        for t in range(self.max_steps):
            up = torch.matmul(attn_pos, v_pos)     # [B,T,d]
            un = torch.matmul(attn_neg, v_neg)
            # Tension in [0,1]: 0 aligned, 1 opposed
            cos = F.cosine_similarity(up, un, dim=-1, eps=1e-6).unsqueeze(-1)
            tension = 0.5 * (1.0 - cos)
            mean_tensions.append(tension.mean().item())

            proposal = F.silu(self.synth(torch.cat([up, un, z], dim=-1)))
            gate = torch.sigmoid(self.gate(z)) * tension
            z_new = z + gate * proposal

            # Per-token halting: freeze tokens whose update is small
            delta = (z_new - z).norm(dim=-1) / (z.norm(dim=-1) + 1e-6)
            newly_done = (delta < self.halt_eps)
            # Update state; frozen tokens simply carry forward
            z = torch.where(newly_done.unsqueeze(-1), z, z_new)
            steps_used = steps_used + (~newly_done).float()
            active = active & (~newly_done)

            if not active.any():
                break

        metrics = {
            'avg_steps': steps_used.mean().item(),
            'tensions': mean_tensions,
        }

        # Opposition regularizer: encourage v_neg ≈ -v_pos within head subspace
        opp_reg = (self.v_pos.weight + self.v_neg.weight).pow(2).mean()
        extra_loss = self.opp_lambda * opp_reg

        if return_aux:
            return z, metrics, extra_loss, (attn_pos, attn_neg)
        return z, metrics, extra_loss


## Tiny Transformer Block and LM wrapper

In [None]:
class TinyDialecticalBlock(nn.Module):
    def __init__(self, d_model=128, d_head=64, max_steps=4, halt_eps=2e-3, opp_lambda=1e-3, ff_mult=2):
        super().__init__()
        self.attn = DialecticalHead(d_model, d_head, max_steps, halt_eps, opp_lambda)
        self.out = nn.Linear(d_head, d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, ff_mult*d_model),
            nn.SiLU(),
            nn.Linear(ff_mult*d_model, d_model),
        )

    def forward(self, x, mask=None):
        z, metrics, extra_loss = self.attn(self.ln1(x), mask)
        x = x + self.out(z)
        x = x + self.ff(x)
        return x, metrics, extra_loss

class TinyDialecticalLM(nn.Module):
    def __init__(self, vocab_size, d_model=128, d_head=64, max_steps=4, halt_eps=2e-3, opp_lambda=1e-3):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Parameter(torch.randn(1, 256, d_model) * 0.01)
        self.block = TinyDialecticalBlock(d_model, d_head, max_steps, halt_eps, opp_lambda)
        self.lm = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        B, T = x.shape
        h = self.emb(x) + self.pos[:, :T, :]
        h, metrics, extra_loss = self.block(h)
        logits = self.lm(h)
        return logits, metrics, extra_loss


## Synthetic dataset: Balanced parentheses (Dyck-1) with distractors

In [None]:
VOCAB = list("()abc ")
stoi = {ch:i for i,ch in enumerate(VOCAB)}
itos = {i:ch for ch,i in stoi.items()}
vocab_size = len(VOCAB)

def gen_balanced(n_pairs=4, fillers=True):
    if n_pairs == 0: return ""
    left = random.randint(0, n_pairs-1)
    right = n_pairs-1-left
    inner = gen_balanced(left, fillers)
    outer = gen_balanced(right, fillers)
    s = '(' + inner + ')' + outer
    if fillers:
        out = []
        for ch in s:
            out.append(ch)
            if random.random() < 0.2:
                out.append(random.choice('ab '))
        s = ''.join(out)
    return s

def make_sample(max_pairs=6, max_len=96):
    pairs = random.randint(1, max_pairs)
    s = gen_balanced(pairs)
    s = s[:max_len-1]
    x = torch.tensor([stoi[ch] for ch in s], dtype=torch.long)
    y = torch.tensor([stoi[ch] for ch in (s[1:] + ' ')], dtype=torch.long)
    return x, y

def batchify(B=64, max_len=96, device=device):
    xs, ys = [], []
    for _ in range(B):
        x, y = make_sample(max_len=max_len)
        xs.append(x); ys.append(y)
    T = max(x.size(0) for x in xs)
    X = torch.full((B, T), stoi[' '], dtype=torch.long)
    Y = torch.full((B, T), stoi[' '], dtype=torch.long)
    for i,(x,y) in enumerate(zip(xs, ys)):
        X[i, :x.size(0)] = x
        Y[i, :y.size(0)] = y
    return X.to(device), Y.to(device)


## Build model

In [None]:
model = TinyDialecticalLM(vocab_size, d_model=128, d_head=64, max_steps=4, halt_eps=2e-3, opp_lambda=1e-3).to(device)
sum(p.numel() for p in model.parameters())

## Train briefly

In [None]:
optim = torch.optim.AdamW(model.parameters(), lr=3e-3)
losses = []
avg_steps_hist = []
for step in range(320):
    X, Y = batchify(B=64, max_len=96)
    logits, metrics, extra = model(X)
    ce = F.cross_entropy(logits.reshape(-1, logits.size(-1)), Y.reshape(-1))
    loss = ce + extra
    optim.zero_grad(); loss.backward(); optim.step()
    losses.append(ce.item())
    avg_steps_hist.append(metrics['avg_steps'])
    if step % 20 == 0:
        print(f"step {step:03d} | loss {ce.item():.3f} | avg_steps {metrics['avg_steps']:.2f} | tension0 {metrics['tensions'][0]:.3f}")

plt.plot(losses); plt.title('Training loss'); plt.show()
plt.plot(avg_steps_hist); plt.title('Avg steps used'); plt.show()

## Inspect tension profile

In [None]:
X, Y = batchify(B=32, max_len=96)
with torch.no_grad():
    logits, metrics, _ = model(X)
print('Avg steps (approx):', metrics['avg_steps'])
plt.plot(metrics['tensions']); plt.title('Mean tension per recursion step'); plt.show()

## Qualitative generation
Greedy next-token generation from a seed.

In [None]:
stoi_inv = stoi; itos_inv = {i:ch for ch,i in stoi_inv.items()}

def generate(model, seed='((', max_new=40):
    model.eval()
    x = torch.tensor([[stoi_inv.get(ch, stoi_inv[' ']) for ch in seed]], device=device)
    with torch.no_grad():
        for _ in range(max_new):
            logits, _, _ = model(x)
            nxt = logits[0, -1].argmax().unsqueeze(0).unsqueeze(0)
            x = torch.cat([x, nxt], dim=1)
    return ''.join(itos_inv[int(i)] for i in x[0].tolist())

print(generate(model, '(()'))

---
**Tips:**
- Tweak `opp_lambda` (e.g., `5e-4` to `2e-3`) and `halt_eps` to adjust separation and compute.
- Increase `max_steps` to 5–6 to see more pronounced convergence behavior (cost rises for hard tokens).
- To compare, replace `TinyDialecticalBlock` with a vanilla attention block and note loss vs. interpretability (no tension/steps).