In [1]:
# =========================
# Colab-ready bAbI + Transformer + (Orthogonality + MI-difference regularizers)
# =========================

# If running in Colab, uncomment:
# !pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

import os, re, tarfile, random, math
from pathlib import Path
from collections import Counter
import urllib.request

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# -------------------------
# Reproducibility
# -------------------------
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# -------------------------
# Download bAbI
# -------------------------
DATA_ROOT = Path("./data_babi")
DATA_ROOT.mkdir(parents=True, exist_ok=True)

url = "https://s3.amazonaws.com/text-datasets/babi_tasks_1-20_v1-2.tar.gz"
tgz_path = DATA_ROOT / "babi_tasks_1-20_v1-2.tar.gz"
extract_dir = DATA_ROOT

if not tgz_path.exists():
    print("Downloading bAbI tarball...")
    urllib.request.urlretrieve(url, tgz_path)

print("Extracting...")
with tarfile.open(tgz_path, "r:gz") as tar:
    tar.extractall(DATA_ROOT)

# -------------------------
# bAbI parser
# -------------------------
# bAbI format:
# "1 Mary moved to the bathroom."
# ...
# "3 Where is Mary?\tbathroom\t1"
#
# We'll create samples: (story_sentences, question, answer)
# and convert to "context tokens + [SEP] + question tokens".

def tokenize(text: str):
    # Keep simple; split on non-letters/numbers, keep punctuation as separate tokens.
    # bAbI is small; this works well.
    text = text.lower()
    # separate punctuation
    text = re.sub(r"([?.!,])", r" \1 ", text)
    tokens = [t for t in text.split() if t.strip()]
    return tokens

def load_babi_qa(path: Path, max_story_sentences=50):
    samples = []
    story = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            # line starts with an index
            idx_str, rest = line.split(" ", 1)
            idx = int(idx_str)
            if idx == 1:
                story = []  # new story
            if "\t" in rest:
                # question line
                q, a, supporting = rest.split("\t")
                story_tokens = []
                # flatten story sentences into one token list with sentence separators
                # to keep it simple, we just concatenate with a special token <ssep>
                for sent in story[-max_story_sentences:]:
                    story_tokens += tokenize(sent) + ["<ssep>"]
                q_tokens = tokenize(q)
                a_token = a.lower().strip()
                samples.append((story_tokens, q_tokens, a_token))
            else:
                # statement line
                story.append(rest)
    return samples

# Choose a task for speed. Task 1 is commonly used.
TASK_ID = 1
LANG = "en"
train_file = extract_dir / f"tasks_1-20_v1-2/{LANG}/qa{TASK_ID}_single-supporting-fact_train.txt"
test_file  = extract_dir / f"tasks_1-20_v1-2/{LANG}/qa{TASK_ID}_single-supporting-fact_test.txt"

train_samples = load_babi_qa(train_file)
test_samples = load_babi_qa(test_file)

print("Train samples:", len(train_samples), "Test samples:", len(test_samples))
print("Example:", train_samples[0][:2], "answer:", train_samples[0][2])

# -------------------------
# Vocabulary + encoding
# -------------------------
SPECIALS = ["<pad>", "<unk>", "<sep>", "<ssep>"]
PAD, UNK, SEP, SSEP = SPECIALS

def build_vocab(samples, min_freq=1):
    counter = Counter()
    answers = Counter()
    for story_tokens, q_tokens, a in samples:
        counter.update(story_tokens)
        counter.update(q_tokens)
        answers[a] += 1
    itos = list(SPECIALS)
    for tok, c in counter.items():
        if c >= min_freq and tok not in itos:
            itos.append(tok)
    stoi = {t:i for i,t in enumerate(itos)}
    ans_itos = sorted(list(answers.keys()))
    ans_stoi = {a:i for i,a in enumerate(ans_itos)}
    return stoi, itos, ans_stoi, ans_itos

stoi, itos, ans_stoi, ans_itos = build_vocab(train_samples + test_samples, min_freq=1)
vocab_size = len(itos)
num_answers = len(ans_itos)
print("Vocab size:", vocab_size, "Num answers:", num_answers)

def encode_tokens(tokens, stoi):
    return [stoi.get(t, stoi[UNK]) for t in tokens]

# Pack input as: story + [SEP] + question
MAX_LEN = 180  # enough for task 1; increase if you switch tasks
def make_input(story_tokens, q_tokens):
    tokens = story_tokens + [SEP] + q_tokens
    ids = encode_tokens(tokens, stoi)
    if len(ids) > MAX_LEN:
        ids = ids[-MAX_LEN:]  # keep last tokens (question-related)
    return ids

class BabiQADataset(Dataset):
    def __init__(self, samples):
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        story_tokens, q_tokens, ans = self.samples[idx]
        x = make_input(story_tokens, q_tokens)
        y = ans_stoi[ans]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

def collate(batch):
    xs, ys = zip(*batch)
    maxlen = max(x.size(0) for x in xs)
    x_pad = torch.full((len(xs), maxlen), fill_value=stoi[PAD], dtype=torch.long)
    attn_mask = torch.zeros((len(xs), maxlen), dtype=torch.bool)  # True for PAD positions
    for i, x in enumerate(xs):
        x_pad[i, :x.size(0)] = x
        attn_mask[i, x.size(0):] = True
    y = torch.stack(ys)
    return x_pad, attn_mask, y

train_ds = BabiQADataset(train_samples)
test_ds  = BabiQADataset(test_samples)

BATCH_SIZE = 64
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)

# -------------------------
# Model: Transformer Encoder with explicit per-head outputs
# -------------------------

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))  # (1, max_len, d_model)

    def forward(self, x):
        # x: (B, T, D)
        T = x.size(1)
        return x + self.pe[:, :T]

class MultiHeadSelfAttentionExpose(nn.Module):
    """
    Wraps nn.MultiheadAttention but returns:
    - output: (T, B, D)
    - per_head_attn: (B, H, T, T)
    - per_head_value_out: (B, H, T, Dh)  (head-specific output before mixing)
    """
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        # We'll implement QKV ourselves to easily expose head outputs.
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask=None):
        # x: (T, B, D)
        T, B, D = x.shape
        qkv = self.qkv(x)  # (T, B, 3D)
        q, k, v = qkv.chunk(3, dim=-1)  # each (T, B, D)

        # reshape to heads: (B, H, T, Dh)
        def to_heads(t):
            t = t.permute(1, 0, 2).contiguous()  # (B, T, D)
            t = t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)  # (B, H, T, Dh)
            return t

        qh = to_heads(q)
        kh = to_heads(k)
        vh = to_heads(v)

        # scaled dot-product attention
        scores = torch.matmul(qh, kh.transpose(-2, -1)) / math.sqrt(self.d_head)  # (B,H,T,T)

        if key_padding_mask is not None:
            # key_padding_mask: (B, T) True where PAD
            mask = key_padding_mask.unsqueeze(1).unsqueeze(2)  # (B,1,1,T)
            scores = scores.masked_fill(mask, float("-inf"))

        attn = F.softmax(scores, dim=-1)  # (B,H,T,T)
        attn = self.dropout(attn)

        head_out = torch.matmul(attn, vh)  # (B,H,T,Dh)
        # merge heads
        merged = head_out.transpose(1, 2).contiguous().view(B, T, D)  # (B,T,D)
        out = self.o(merged)  # (B,T,D)

        return out.permute(1, 0, 2).contiguous(), attn, head_out

class EncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttentionExpose(d_model, n_heads, dropout=dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x, key_padding_mask=None):
        # pre-norm
        h = self.ln1(x)
        attn_out, attn_w, head_out = self.attn(h, key_padding_mask=key_padding_mask)
        x = x + attn_out
        x = x + self.ff(self.ln2(x))
        return x, attn_w, head_out

class BabiTransformerQA(nn.Module):
    def __init__(self, vocab_size, num_answers, d_model=128, n_heads=8, n_layers=2, d_ff=256, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model, max_len=512)
        self.dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList([EncoderBlock(d_model, n_heads, d_ff, dropout=dropout) for _ in range(n_layers)])
        self.ln_final = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, num_answers)

    def forward(self, x, key_padding_mask):
        """
        x: (B, T)
        key_padding_mask: (B, T) True where PAD
        Returns:
          logits: (B, num_answers)
          head_reprs: (B, H, D) pooled per-head representations (for MI/ortho losses)
        """
        B, T = x.shape
        h = self.embed(x) * math.sqrt(self.d_model)  # (B,T,D)
        h = self.pos(h)
        h = self.dropout(h)
        h = h.transpose(0, 1).contiguous()  # (T,B,D)

        all_head_out = []  # list of (B,H,T,Dh) per layer
        for layer in self.layers:
            h, attn_w, head_out = layer(h, key_padding_mask=key_padding_mask)
            all_head_out.append(head_out)

        h = self.ln_final(h)  # (T,B,D)
        # Pool final token representation (last non-pad) OR simply mean-pool non-pad
        h_bt = h.transpose(0,1)  # (B,T,D)

        not_pad = (~key_padding_mask).float().unsqueeze(-1)  # (B,T,1)
        pooled = (h_bt * not_pad).sum(dim=1) / (not_pad.sum(dim=1).clamp(min=1.0))  # (B,D)
        logits = self.classifier(pooled)

        # Build per-head representations from the LAST layer head outputs:
        # head_out: (B,H,T,Dh). We'll mean-pool over tokens (non-pad), then project to D.
        last_head = all_head_out[-1]  # (B,H,T,Dh)
        mask_h = (~key_padding_mask).float().unsqueeze(1).unsqueeze(-1)  # (B,1,T,1)
        head_pooled = (last_head * mask_h).sum(dim=2) / (mask_h.sum(dim=2).clamp(min=1.0))  # (B,H,Dh)
        # Expand to D by simple linear mapping per head (shared)
        # Option: just pad to D via a learned projection:
        head_proj = F.pad(head_pooled, (0, self.d_model - head_pooled.size(-1)))  # (B,H,D) if Dh < D
        return logits, head_proj

# -------------------------
# Regularizers
# -------------------------

def orthogonality_loss(head_repr):
    """
    head_repr: (B, H, D)
    Encourage different heads to be orthogonal in representation space.
    Compute Gram matrix per sample: G = H x H via normalized dot products.
    Penalize off-diagonal energy.
    """
    B, H, D = head_repr.shape
    z = F.normalize(head_repr, dim=-1)  # (B,H,D)
    G = torch.matmul(z, z.transpose(1,2))  # (B,H,H)
    I = torch.eye(H, device=head_repr.device).unsqueeze(0)  # (1,H,H)
    off_diag = G - I
    return (off_diag ** 2).mean()

def info_nce_mi_proxy_loss(head_repr, temperature=0.2):
    """
    Approximate mutual information between pairs of heads with an InfoNCE-style objective.
    We then MINIMIZE this proxy to "maximize differences" (i.e., reduce dependence).

    For each pair of heads (i,j):
      positives: (h_i[b], h_j[b]) same sample b
      negatives: (h_i[b], h_j[b']) with shuffled b'
    We compute an InfoNCE loss that would normally maximize alignment/MI;
    minimizing it discourages shared info.

    head_repr: (B,H,D)
    """
    B, H, D = head_repr.shape
    z = F.normalize(head_repr, dim=-1)  # (B,H,D)

    total = 0.0
    count = 0
    for i in range(H):
        zi = z[:, i, :]  # (B,D)
        for j in range(i+1, H):
            zj = z[:, j, :]  # (B,D)
            # logits: (B,B) similarity between zi[b] and zj[b']
            logits = torch.matmul(zi, zj.t()) / temperature  # (B,B)
            labels = torch.arange(B, device=head_repr.device)
            # Standard InfoNCE (cross-entropy) would encourage zi[b] close to zj[b]
            # We MINIMIZE this to reduce dependence => "maximize differences"
            total = total + F.cross_entropy(logits, labels)
            count += 1
    return total / max(count, 1)

# -------------------------
# Train / Eval
# -------------------------

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, pad_mask, y in loader:
            x, pad_mask, y = x.to(device), pad_mask.to(device), y.to(device)
            logits, _ = model(x, pad_mask)
            pred = logits.argmax(dim=-1)
            correct += (pred == y).sum().item()
            total += y.numel()
    return correct / max(total, 1)

# -------------------------
# Hyperparameters (small + fast)
# -------------------------
d_model = 192
n_heads = 16
n_layers = 3
d_ff = 256
dropout = 0.1

lambda_ortho = 0.1    # weight for orthogonality constraint
lambda_mi = 0.05      # weight for MI-difference constraint (minimize MI proxy)

model = BabiTransformerQA(
    vocab_size=vocab_size,
    num_answers=num_answers,
    d_model=d_model,
    n_heads=n_heads,
    n_layers=n_layers,
    d_ff=d_ff,
    dropout=dropout
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)

EPOCHS = 8  # Task 1 converges quickly
print_every = 1

for epoch in range(1, EPOCHS + 1):
    model.train()
    running = 0.0

    for x, pad_mask, y in train_loader:
        x, pad_mask, y = x.to(device), pad_mask.to(device), y.to(device)

        optimizer.zero_grad(set_to_none=True)
        logits, head_repr = model(x, pad_mask)

        ce = F.cross_entropy(logits, y)
        ortho = orthogonality_loss(head_repr)
        mi_proxy = info_nce_mi_proxy_loss(head_repr, temperature=0.2)

        loss = ce + lambda_ortho * ortho + lambda_mi * mi_proxy
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running += loss.item()

    if epoch % print_every == 0:
        train_acc = evaluate(model, train_loader)
        test_acc = evaluate(model, test_loader)
        avg_loss = running / max(len(train_loader), 1)
        print(f"Epoch {epoch:02d} | loss {avg_loss:.4f} | train_acc {train_acc:.3f} | test_acc {test_acc:.3f}")

print("Done.")
print("Final test accuracy:", evaluate(model, test_loader))


Device: cpu
Downloading bAbI tarball...
Extracting...


  tar.extractall(DATA_ROOT)


Train samples: 1000 Test samples: 1000
Example: (['mary', 'moved', 'to', 'the', 'bathroom', '.', '<ssep>', 'john', 'went', 'to', 'the', 'hallway', '.', '<ssep>'], ['where', 'is', 'mary', '?']) answer: bathroom
Vocab size: 25 Num answers: 6
Epoch 01 | loss 2.0082 | train_acc 0.244 | test_acc 0.218
Epoch 02 | loss 1.9787 | train_acc 0.327 | test_acc 0.290
Epoch 03 | loss 1.9503 | train_acc 0.370 | test_acc 0.339
Epoch 04 | loss 1.8776 | train_acc 0.375 | test_acc 0.359
Epoch 05 | loss 1.7294 | train_acc 0.400 | test_acc 0.363
Epoch 06 | loss 1.5675 | train_acc 0.426 | test_acc 0.395
Epoch 07 | loss 1.4549 | train_acc 0.442 | test_acc 0.396
Epoch 08 | loss 1.3866 | train_acc 0.481 | test_acc 0.409
Done.
Final test accuracy: 0.409
