In [2]:
# ============================================================
# Vanilla bAbI (Task 1) + Transformer Encoder (NO constraints)
# Colab-ready: download + parse bAbI, train, evaluate
# ============================================================

import os, re, tarfile, random, math
from pathlib import Path
from collections import Counter
import urllib.request

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# -------------------------
# Reproducibility
# -------------------------
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# -------------------------
# Download bAbI
# -------------------------
DATA_ROOT = Path("./data_babi")
DATA_ROOT.mkdir(parents=True, exist_ok=True)

url = "https://s3.amazonaws.com/text-datasets/babi_tasks_1-20_v1-2.tar.gz"
tgz_path = DATA_ROOT / "babi_tasks_1-20_v1-2.tar.gz"
extract_dir = DATA_ROOT

if not tgz_path.exists():
    print("Downloading bAbI...")
    urllib.request.urlretrieve(url, tgz_path)


print("Extracting bAbI...")
with tarfile.open(tgz_path, "r:gz") as tar:
    tar.extractall(DATA_ROOT)

# -------------------------
# bAbI parser (Task 1 default)
# -------------------------
def tokenize(text: str):
    text = text.lower()
    text = re.sub(r"([?.!,])", r" \1 ", text)
    return [t for t in text.split() if t.strip()]

def load_babi_qa(path: Path, max_story_sentences=50):
    samples = []
    story = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            idx_str, rest = line.split(" ", 1)
            idx = int(idx_str)
            if idx == 1:
                story = []
            if "\t" in rest:
                q, a, supporting = rest.split("\t")
                story_tokens = []
                for sent in story[-max_story_sentences:]:
                    story_tokens += tokenize(sent) + ["<ssep>"]
                q_tokens = tokenize(q)
                a_token = a.lower().strip()
                samples.append((story_tokens, q_tokens, a_token))
            else:
                story.append(rest)
    return samples

TASK_ID = 1
LANG = "en"
train_file = extract_dir / f"tasks_1-20_v1-2/{LANG}/qa{TASK_ID}_single-supporting-fact_train.txt"
test_file  = extract_dir / f"tasks_1-20_v1-2/{LANG}/qa{TASK_ID}_single-supporting-fact_test.txt"

train_samples = load_babi_qa(train_file)
test_samples  = load_babi_qa(test_file)

print("Train:", len(train_samples), "Test:", len(test_samples))
print("Example answer:", train_samples[0][2])

# -------------------------
# Vocabulary + encoding
# -------------------------
SPECIALS = ["<pad>", "<unk>", "<sep>", "<ssep>"]
PAD, UNK, SEP, SSEP = SPECIALS

def build_vocab(samples, min_freq=1):
    counter = Counter()
    answers = Counter()
    for story_tokens, q_tokens, a in samples:
        counter.update(story_tokens)
        counter.update(q_tokens)
        answers[a] += 1
    itos = list(SPECIALS)
    for tok, c in counter.items():
        if c >= min_freq and tok not in itos:
            itos.append(tok)
    stoi = {t: i for i, t in enumerate(itos)}
    ans_itos = sorted(list(answers.keys()))
    ans_stoi = {a: i for i, a in enumerate(ans_itos)}
    return stoi, itos, ans_stoi, ans_itos

stoi, itos, ans_stoi, ans_itos = build_vocab(train_samples + test_samples)
vocab_size = len(itos)
num_answers = len(ans_itos)
print("Vocab:", vocab_size, "Answers:", num_answers)

def encode_tokens(tokens):
    return [stoi.get(t, stoi[UNK]) for t in tokens]

MAX_LEN = 180
def make_input(story_tokens, q_tokens):
    tokens = story_tokens + [SEP] + q_tokens
    ids = encode_tokens(tokens)
    if len(ids) > MAX_LEN:
        ids = ids[-MAX_LEN:]
    return ids

class BabiQADataset(Dataset):
    def __init__(self, samples):
        self.samples = samples
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        story_tokens, q_tokens, ans = self.samples[idx]
        x = make_input(story_tokens, q_tokens)
        y = ans_stoi[ans]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

def collate(batch):
    xs, ys = zip(*batch)
    maxlen = max(x.size(0) for x in xs)
    x_pad = torch.full((len(xs), maxlen), fill_value=stoi[PAD], dtype=torch.long)
    pad_mask = torch.zeros((len(xs), maxlen), dtype=torch.bool)  # True where PAD
    for i, x in enumerate(xs):
        x_pad[i, :x.size(0)] = x
        pad_mask[i, x.size(0):] = True
    y = torch.stack(ys)
    return x_pad, pad_mask, y

BATCH_SIZE = 64
train_loader = DataLoader(BabiQADataset(train_samples), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)
test_loader  = DataLoader(BabiQADataset(test_samples),  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)

# -------------------------
# Vanilla Transformer Encoder (no head exposure, no constraints)
# -------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))  # (1,max_len,D)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class EncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=False)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
        self.drop = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask=None):
        # x: (T,B,D)
        h = self.ln1(x)
        attn_out, _ = self.attn(h, h, h, key_padding_mask=key_padding_mask, need_weights=False)
        x = x + self.drop(attn_out)
        x = x + self.ff(self.ln2(x))
        return x

class VanillaBabiTransformerQA(nn.Module):
    def __init__(self, vocab_size, num_answers, d_model=128, n_heads=8, n_layers=2, d_ff=256, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model, max_len=512)
        self.drop = nn.Dropout(dropout)
        self.layers = nn.ModuleList([EncoderBlock(d_model, n_heads, d_ff, dropout=dropout) for _ in range(n_layers)])
        self.ln_final = nn.LayerNorm(d_model)
        self.cls = nn.Linear(d_model, num_answers)

    def forward(self, x, key_padding_mask):
        # x: (B,T), key_padding_mask: (B,T) True where PAD
        h = self.embed(x) * math.sqrt(self.d_model)  # (B,T,D)
        h = self.pos(h)
        h = self.drop(h)
        h = h.transpose(0, 1).contiguous()  # (T,B,D)

        for layer in self.layers:
            h = layer(h, key_padding_mask=key_padding_mask)

        h = self.ln_final(h)               # (T,B,D)
        h_bt = h.transpose(0, 1)           # (B,T,D)

        not_pad = (~key_padding_mask).float().unsqueeze(-1)  # (B,T,1)
        pooled = (h_bt * not_pad).sum(dim=1) / not_pad.sum(dim=1).clamp(min=1.0)  # (B,D)
        logits = self.cls(pooled)
        return logits

# -------------------------
# Train / eval
# -------------------------
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    for x, pad_mask, y in loader:
        x, pad_mask, y = x.to(device), pad_mask.to(device), y.to(device)
        logits = model(x, pad_mask)
        pred = logits.argmax(dim=-1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

def train(model, train_loader, test_loader, epochs=8, lr=3e-4):
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
    for ep in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        for x, pad_mask, y in train_loader:
            x, pad_mask, y = x.to(device), pad_mask.to(device), y.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(x, pad_mask)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            total_loss += float(loss.detach().cpu())
        if ep == 1 or ep == epochs:
            acc = evaluate(model, test_loader)
            print(f"Epoch {ep:02d}/{epochs} | avg_loss={total_loss/len(train_loader):.4f} | test_acc={acc:.3f}")

# -------------------------
# Run
# -------------------------
model = VanillaBabiTransformerQA(vocab_size=vocab_size, num_answers=num_answers,
                                d_model=128, n_heads=8, n_layers=2, d_ff=256, dropout=0.1).to(device)

train(model, train_loader, test_loader, epochs=8, lr=3e-4)
print("Final test accuracy:", evaluate(model, test_loader))

Device: cpu
Extracting bAbI...


  tar.extractall(DATA_ROOT)


Train: 1000 Test: 1000
Example answer: bathroom
Vocab: 25 Answers: 6
Epoch 01/8 | avg_loss=1.8048 | test_acc=0.170
Epoch 08/8 | avg_loss=1.3714 | test_acc=0.414
Final test accuracy: 0.414
