# Dialectical (Recursive Dual) Attention — Colab Demo

This notebook implements a minimal **recursive dual attention head** ("dialectical attention") and plugs it into a tiny Transformer.

It trains on a small **balanced-parentheses next-token prediction** task, then visualizes:
- per-token **tension** (how opposed the two summaries are),
- **steps used** by the per-token halting loop.

You can use a GPU runtime in Colab (Runtime → Change runtime type → GPU) for faster training.


In [None]:
import math, random, os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Dialectical Attention Head (single head)

Key ideas:
- One attention map, **two opposed value channels** (pos/neg).
- Compute two summaries, measure **tension** (how opposed),
- Propose a small **synthesis** update and **gate** it by tension,
- **Recurse per token** until change is small or a max step is hit.

In [None]:
class DialecticalHead(nn.Module):
    def __init__(self, d_model: int, d_head: int, max_steps: int = 3, halt_eps: float = 1e-3):
        super().__init__()
        self.q = nn.Linear(d_model, d_head, bias=False)
        self.k = nn.Linear(d_model, d_head, bias=False)
        self.v = nn.Linear(d_model, d_head, bias=False)
        # Opposed value channels (+ / -)
        self.v_pos = nn.Linear(d_head, d_head, bias=False)
        self.v_neg = nn.Linear(d_head, d_head, bias=False)
        # Tiny synthesis block
        self.synth = nn.Linear(3 * d_head, d_head)
        # Learned biases to tilt pos/neg routing
        self.b_pos = nn.Parameter(torch.zeros(1))
        self.b_neg = nn.Parameter(torch.zeros(1))
        # Gate to scale updates
        self.gate = nn.Linear(d_head, 1)
        self.max_steps = max_steps
        self.halt_eps = halt_eps

    def forward(self, x, mask=None):
        """
        x: [B, T, d_model]
        mask: [B, 1, T, T] (0 for keep, -inf for block)
        returns: [B, T, d_head], metrics
        """
        B, T, _ = x.shape
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        v_pos = self.v_pos(v)
        v_neg = self.v_neg(v)

        logits = torch.matmul(q, k.transpose(-1, -2)) / (k.size(-1) ** 0.5)
        if mask is not None:
            logits = logits + mask.squeeze(1)
        attn_pos = F.softmax(logits + self.b_pos, dim=-1)
        attn_neg = F.softmax(logits + self.b_neg, dim=-1)

        z = q.clone()
        active = torch.ones(B, T, dtype=torch.bool, device=x.device)
        steps_used = torch.zeros(B, T, device=x.device)
        mean_tensions = []

        for t in range(self.max_steps):
            up = torch.matmul(attn_pos, v_pos)
            un = torch.matmul(attn_neg, v_neg)
            # Tension proxy: negative cosine (higher when opposed)
            cos = F.cosine_similarity(up, un, dim=-1, eps=1e-6).unsqueeze(-1)
            tension = torch.sigmoid(-cos)
            mean_tensions.append(tension.mean().item())

            proposal = F.silu(self.synth(torch.cat([up, un, z], dim=-1)))
            gate = torch.sigmoid(self.gate(z)) * tension
            z_new = z + gate * proposal

            delta = (z_new - z).norm(dim=-1) / (z.norm(dim=-1) + 1e-6)
            newly_done = (delta < self.halt_eps)
            active = active & (~newly_done)
            z = z_new
            steps_used += (~newly_done).float()

            if not active.any():
                break

        metrics = {
            'avg_steps': steps_used.mean().item(),
            'tensions': mean_tensions,
        }
        return z, metrics


## Tiny Transformer Block using the Dialectical Head
We keep it minimal: one dialectical head + residual + feedforward.

In [None]:
class TinyDialecticalBlock(nn.Module):
    def __init__(self, d_model=128, d_head=64, max_steps=3, halt_eps=1e-3, ff_mult=2):
        super().__init__()
        self.attn = DialecticalHead(d_model, d_head, max_steps, halt_eps)
        self.out = nn.Linear(d_head, d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, ff_mult*d_model),
            nn.SiLU(),
            nn.Linear(ff_mult*d_model, d_model),
        )

    def forward(self, x, mask=None):
        z, metrics = self.attn(self.ln1(x), mask)
        x = x + self.out(z)
        x = x + self.ff(x)
        return x, metrics


## Synthetic dataset: Balanced parentheses next-token prediction
- Vocabulary: `"()"` plus a few filler symbols.
- Generate balanced strings (Dyck(1)) of random lengths.
- Task: next token prediction (language modeling).

In [None]:
VOCAB = list("()abc ")  # include some distractors and space
stoi = {ch:i for i,ch in enumerate(VOCAB)}
itos = {i:ch for ch,i in stoi.items()}
vocab_size = len(VOCAB)

def gen_balanced(n_pairs=4, fillers=True):
    # simple recursive generator
    if n_pairs == 0: return ""
    left = random.randint(0, n_pairs-1)
    right = n_pairs-1-left
    inner = gen_balanced(left, fillers)
    outer = gen_balanced(right, fillers)
    s = '(' + inner + ')' + outer
    if fillers:
        # sprinkle filler chars
        out = []
        for ch in s:
            out.append(ch)
            if random.random() < 0.2:
                out.append(random.choice('ab '))
        s = ''.join(out)
    return s

def make_sample(max_pairs=6, max_len=64):
    pairs = random.randint(1, max_pairs)
    s = gen_balanced(pairs)
    s = s[:max_len-1]  # leave room for next-token target
    x = torch.tensor([stoi[ch] for ch in s], dtype=torch.long)
    y = torch.tensor([stoi[ch] for ch in (s[1:] + ' ')], dtype=torch.long)  # next-token target
    return x, y

def batchify(B=32, max_len=64):
    xs, ys = [], []
    for _ in range(B):
        x, y = make_sample(max_len=max_len)
        xs.append(x)
        ys.append(y)
    T = max(x.size(0) for x in xs)
    X = torch.full((B, T), stoi[' '], dtype=torch.long)
    Y = torch.full((B, T), stoi[' '], dtype=torch.long)
    for i,(x,y) in enumerate(zip(xs, ys)):
        X[i, :x.size(0)] = x
        Y[i, :y.size(0)] = y
    return X.to(device), Y.to(device)


## Tiny model wrapper
Embedding + positional encoding, one dialectical block, LM head.

In [None]:
class TinyDialecticalLM(nn.Module):
    def __init__(self, vocab_size, d_model=128, d_head=64, max_steps=3, halt_eps=1e-3):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Parameter(torch.randn(1, 256, d_model) * 0.01)  # max length 256
        self.block = TinyDialecticalBlock(d_model, d_head, max_steps, halt_eps)
        self.lm = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        B, T = x.shape
        h = self.emb(x) + self.pos[:, :T, :]
        h, metrics = self.block(h)
        logits = self.lm(h)
        return logits, metrics

model = TinyDialecticalLM(vocab_size).to(device)
sum(p.numel() for p in model.parameters())

## Train briefly
Small run just to see loss go down and produce nontrivial tension/steps.

In [None]:
optim = torch.optim.AdamW(model.parameters(), lr=3e-3)
losses = []
for step in range(300):  # keep modest for Colab
    X, Y = batchify(B=64, max_len=96)
    logits, metrics = model(X)
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), Y.reshape(-1))
    optim.zero_grad(); loss.backward(); optim.step()
    if step % 20 == 0:
        print(f"step {step:03d} loss {loss.item():.3f} avg_steps {metrics['avg_steps']:.2f}")
    losses.append(loss.item())
plt.plot(losses); plt.title('Training loss'); plt.show()

## Inspect tension & steps on a batch
We gather per-token tension/steps and plot quick histograms.

In [None]:
X, Y = batchify(B=32, max_len=96)
with torch.no_grad():
    logits, metrics = model(X)
print('Avg steps (approx):', metrics['avg_steps'])
plt.plot(metrics['tensions']); plt.title('Mean tension per recursion step'); plt.show()


## Qualitative sample
Greedy next-token generation from a seed.

In [None]:
def generate(seed='((', max_new=40):
    x = torch.tensor([[stoi.get(ch, stoi[' ']) for ch in seed]], device=device)
    for _ in range(max_new):
        logits, _ = model(x)
        nxt = logits[0, -1].argmax().unsqueeze(0).unsqueeze(0)
        x = torch.cat([x, nxt], dim=1)
    return ''.join(itos[int(i)] for i in x[0].tolist())

print(generate('(()'))


----
**Notes:**
- To compare, you can swap `TinyDialecticalBlock` with a vanilla attention block and observe differences in tension/steps (the vanilla one will have none).
- You can also log how many tokens halt after 1, 2, or 3 steps by instrumenting `steps_used` more precisely.
- For factual tasks, you can use the *tension/steps* signal to trigger retrieval/verification (not shown here).