In [None]:

# SCRIBE: Complete Baselines


#!pip install -q sentence-transformers

import torch, torch.nn as nn, torch.nn.functional as F
import pandas as pd, numpy as np, math, time, os, gc
from collections import defaultdict
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device} ({torch.cuda.get_device_name(0) if device.type=='cuda' else 'cpu'})")


# CONFIG

MAX_SENTS = 15
BATCH_SIZE = 256
NUM_EPOCHS = 30  # Enough for convergence
LR = 1e-3
HIDDEN = 192
NUM_SLOTS = 16
NUM_HEADS = 4
READ_HOPS = 2
CONSOL_EVERY = 4
DROPOUT = 0.1
USE_AMP = True


# LOAD & ENCODE

print("\n[1/3] Loading data...")
df = pd.read_parquet("unified_reasoning_dataset.parquet")

train_pw = df[(df['split']=='train') & (df['task']=='proofwriter')].sample(n=20000, random_state=42)
train_ba = df[(df['split']=='train') & (df['task']=='babi')]
train_df = pd.concat([train_pw, train_ba.sample(n=min(20000,len(train_ba)), random_state=42)]).sample(frac=1, random_state=42).reset_index(drop=True)

test_pw = df[(df['split']=='test') & (df['task']=='proofwriter')].sample(n=5000, random_state=42)
test_ba = df[(df['split']=='test') & (df['task']=='babi')]
test_df = pd.concat([test_pw, test_ba.sample(n=min(5000,len(test_ba)), random_state=42)]).sample(frac=1, random_state=42).reset_index(drop=True)

vocab = {a:i for i,a in enumerate(df['answer'].unique())}
del df; gc.collect()

train_answers = [vocab.get(a,0) for a in train_df['answer']]
test_answers = [vocab.get(a,0) for a in test_df['answer']]
train_tasks = train_df['task'].tolist()
test_tasks = test_df['task'].tolist()
NUM_CLASSES = len(vocab)

print(f"Train: {len(train_df)} | Test: {len(test_df)} | Classes: {NUM_CLASSES}")

print("\n[2/3] Encoding...")
encoder = SentenceTransformer('all-MiniLM-L6-v2').to(device)

def encode_dataset(dataframe, encoder, max_sents=15, batch_size=512):
    N = len(dataframe)
    questions = dataframe['question'].tolist()
    contexts = dataframe['context'].tolist()
    all_sents = []; sents_per = []
    for ctx in contexts:
        s = [x.strip() for x in ctx.split('.') if len(x.strip())>3]
        if not s: s = [ctx[:100]]
        s = s[:max_sents]; sents_per.append(len(s)); all_sents.extend(s)
    with torch.no_grad():
        q = encoder.encode(questions, convert_to_tensor=True, batch_size=batch_size, show_progress_bar=True).clone().cpu()
        sf = encoder.encode(all_sents, convert_to_tensor=True, batch_size=batch_size, show_progress_bar=True).clone().cpu()
    sv = torch.zeros(N, max_sents, 384); sm = torch.zeros(N, max_sents)
    idx = 0
    for i in range(N):
        n = sents_per[i]; sv[i,:n] = sf[idx:idx+n]; sm[i,:n] = 1.0; idx += n
    del sf; gc.collect()
    return q, sv, sm, sents_per

train_q, train_s, train_m, train_ns = encode_dataset(train_df, encoder, MAX_SENTS)
test_q, test_s, test_m, test_ns = encode_dataset(test_df, encoder, MAX_SENTS)
del encoder; gc.collect(); torch.cuda.empty_cache()
print("Encoding done.\n")


# SHARED TRAIN/EVAL FUNCTIONS

def train_model(model, num_epochs=NUM_EPOCHS, lr=LR):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    N = train_q.shape[0]
    steps = num_epochs * ((N+BATCH_SIZE-1)//BATCH_SIZE)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, total_steps=steps, pct_start=0.1)
    best_acc = 0; best_results = None

    for epoch in range(num_epochs):
        model.train()
        perm = torch.randperm(N)
        nb = (N+BATCH_SIZE-1)//BATCH_SIZE
        scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

        for bi in range(nb):
            s,e = bi*BATCH_SIZE, min((bi+1)*BATCH_SIZE, N)
            ix = perm[s:e]
            q = train_q[ix].to(device); sv = train_s[ix].to(device)
            mk = train_m[ix].to(device)
            ans = torch.tensor([train_answers[i] for i in ix], dtype=torch.long, device=device)
            optimizer.zero_grad()
            with torch.amp.autocast('cuda', enabled=USE_AMP):
                logits = model(sv, q, mk)
                loss = F.cross_entropy(logits, ans)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer); scaler.update(); scheduler.step()
            del q,sv,mk,ans,logits,loss

        # Eval every 5 epochs + last epoch
        if (epoch+1) % 5 == 0 or epoch == num_epochs-1:
            results = eval_model(model)
            if results['overall'] > best_acc:
                best_acc = results['overall']
                best_results = results.copy()
                best_results['epoch'] = epoch+1

    return best_results

@torch.no_grad()
def eval_model(model):
    model.eval()
    N = test_q.shape[0]
    correct = 0; total = 0
    per_task = defaultdict(lambda:{'c':0,'t':0})
    per_depth = defaultdict(lambda:{'c':0,'t':0})

    for bi in range((N+BATCH_SIZE-1)//BATCH_SIZE):
        s,e = bi*BATCH_SIZE, min((bi+1)*BATCH_SIZE, N)
        q = test_q[s:e].to(device); sv = test_s[s:e].to(device)
        mk = test_m[s:e].to(device)
        ans = torch.tensor(test_answers[s:e], dtype=torch.long, device=device)
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            logits = model(sv, q, mk)
        match = (logits.argmax(1)==ans)
        correct += match.sum().item(); total += len(ans)
        for i in range(len(match)):
            ri = s+i
            per_task[test_tasks[ri]]['c'] += match[i].item()
            per_task[test_tasks[ri]]['t'] += 1
            ns = test_ns[ri]
            d = 'easy' if ns<=3 else 'med' if ns<=6 else 'hard'
            per_depth[d]['c'] += match[i].item(); per_depth[d]['t'] += 1
        del q,sv,mk,ans,logits

    task_acc = {t:d['c']/d['t']*100 for t,d in per_task.items()}
    depth_acc = {d:v['c']/v['t']*100 for d,v in per_depth.items() if v['t']>0}
    return {'overall': correct/total*100, **{f'task_{t}':a for t,a in task_acc.items()},
            **{f'depth_{d}':a for d,a in depth_acc.items()}}


# MODEL DEFINITIONS


# --- Shared MemoryReader ---
class MemoryReader(nn.Module):
    def __init__(self, D, heads=4, hops=2, drop=0.1):
        super().__init__()
        self.hops = hops
        self.attn = nn.ModuleList([nn.MultiheadAttention(D,heads,batch_first=True,dropout=drop) for _ in range(hops)])
        self.norm = nn.ModuleList([nn.LayerNorm(D) for _ in range(hops)])
        self.gate = nn.ModuleList([nn.Linear(D*2,D) for _ in range(hops)])
    def forward(self, query, mem, mask):
        ig=(mask==0); state=query
        for i in range(self.hops):
            q=self.norm[i](state).unsqueeze(1)
            o,_=self.attn[i](q,mem,mem,key_padding_mask=ig); o=o.squeeze(1)
            g=torch.sigmoid(self.gate[i](torch.cat([state,o],-1)))
            state=g*o+(1-g)*state
        return state

# --- BASELINE 1: MLP ---
class MLPBaseline(nn.Module):
    def __init__(self, in_d=384, D=192, n_cls=100, drop=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_d*2, D), nn.GELU(), nn.Dropout(drop),
            nn.Linear(D, D), nn.GELU(), nn.Dropout(drop),
            nn.Linear(D, n_cls),
        )
    def forward(self, sents, query, smask=None):
        # Mean pool sentences
        if smask is not None:
            mask_exp = smask.unsqueeze(-1)
            ctx = (sents * mask_exp).sum(1) / mask_exp.sum(1).clamp(min=1)
        else:
            ctx = sents.mean(1)
        combined = torch.cat([ctx, query], -1)
        return self.net(combined)

# --- BASELINE 2: LSTM ---
class LSTMBaseline(nn.Module):
    def __init__(self, in_d=384, D=192, n_cls=100, drop=0.1):
        super().__init__()
        self.proj = nn.Linear(in_d, D)
        self.lstm = nn.LSTM(D, D, num_layers=2, batch_first=True, dropout=drop, bidirectional=False)
        self.q_proj = nn.Linear(in_d, D)
        self.head = nn.Sequential(nn.LayerNorm(D*2), nn.Linear(D*2, D), nn.GELU(), nn.Dropout(drop), nn.Linear(D, n_cls))
    def forward(self, sents, query, smask=None):
        x = self.proj(sents)
        out, (h, _) = self.lstm(x)
        ctx = h[-1]  # Last hidden state
        q = self.q_proj(query)
        return self.head(torch.cat([ctx, q], -1))

# --- BASELINE 3: Flat Attention (no sequential constraint) ---
class FlatAttentionBaseline(nn.Module):
    def __init__(self, in_d=384, D=192, heads=4, hops=2, n_cls=100, drop=0.1):
        super().__init__()
        self.s_enc = nn.Sequential(nn.Linear(in_d,D),nn.GELU(),nn.Dropout(drop),nn.Linear(D,D),nn.LayerNorm(D))
        self.q_enc = nn.Sequential(nn.Linear(in_d,D),nn.GELU(),nn.Dropout(drop),nn.Linear(D,D),nn.LayerNorm(D))
        self.reader = MemoryReader(D, heads, hops, drop)
        self.head = nn.Sequential(nn.LayerNorm(D),nn.Linear(D,D),nn.GELU(),nn.Dropout(drop),nn.Linear(D,n_cls))
    def forward(self, sents, query, smask=None):
        mem = self.s_enc(sents)  # All sentences at once — no sequential processing
        if smask is None: smask = torch.ones(sents.shape[0], sents.shape[1], device=sents.device)
        q = self.q_enc(query)
        state = self.reader(q, mem, smask)
        return self.head(state)

# --- SCRIBE (Full) ---
class SCRIBE(nn.Module):
    def __init__(self, in_d=384, D=192, slots=16, heads=4, rhops=2, consol=4,
                 n_cls=100, drop=0.1, use_revision=True, use_consolidation=True,
                 use_sentinel=True, use_confidence=True):
        super().__init__()
        self.D=D; self.S=slots; self.consol=consol
        self.use_revision=use_revision; self.use_consolidation=use_consolidation
        self.use_sentinel=use_sentinel; self.use_confidence=use_confidence

        self.s_enc = nn.Sequential(nn.Linear(in_d,D),nn.GELU(),nn.Dropout(drop),nn.Linear(D,D),nn.LayerNorm(D))
        self.q_enc = nn.Sequential(nn.Linear(in_d,D),nn.GELU(),nn.Dropout(drop),nn.Linear(D,D),nn.LayerNorm(D))

        if use_sentinel:
            self.sentinel = nn.Parameter(torch.randn(D)*0.02)

        self.mem_attn = nn.MultiheadAttention(D,heads,batch_first=True,dropout=drop)
        self.mem_norm = nn.LayerNorm(D)
        self.w_gate = nn.Sequential(nn.Linear(D*2,D),nn.GELU(),nn.Linear(D,D),nn.Sigmoid())

        if use_confidence:
            self.conf = nn.Sequential(nn.Linear(D*2,1),nn.Sigmoid())

        if use_revision:
            self.rev_proj = nn.Linear(D,D)
            self.rev_gate = nn.Sequential(nn.Linear(D*2,D),nn.Sigmoid())

        if use_consolidation:
            self.consolidation = nn.TransformerEncoderLayer(D,heads,D*2,drop,batch_first=True)

        self.reader = MemoryReader(D,heads,rhops,drop)
        self.head = nn.Sequential(nn.LayerNorm(D),nn.Linear(D,D),nn.GELU(),nn.Dropout(drop),nn.Linear(D,n_cls))
        self._init()

    def _init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)

    def forward(self, sents, query, smask=None):
        B,N,_ = sents.shape; dev=sents.device; S=self.S; D=self.D
        if smask is None: smask=torch.ones(B,N,device=dev)
        mx = min(N,S-1)
        enc = self.s_enc(sents[:,:mx])

        mem = torch.zeros(B,S,D,device=dev)
        mm = torch.zeros(B,S,device=dev)

        if self.use_sentinel:
            mem[:,0,:] = self.sentinel.unsqueeze(0).expand(B,-1)
        mm[:,0] = 1.0

        wp = 1
        for t in range(mx):
            act = smask[:,t]
            if act.sum()==0: continue
            if wp>=S: break
            st = enc[:,t,:]
            q = self.mem_norm(st).unsqueeze(1)
            ctx,_ = self.mem_attn(q,mem,mem,key_padding_mask=(mm==0))
            ctx = ctx.squeeze(1)
            comb = torch.cat([st,ctx],-1)
            wg = self.w_gate(comb)
            wv = wg*st + (1-wg)*ctx

            if self.use_confidence:
                cf = self.conf(comb)
            else:
                cf = torch.ones(B,1,device=dev)

            # Revision
            if self.use_revision:
                rs = self.rev_proj(st)
                sim = torch.bmm(mem, rs.unsqueeze(-1)).squeeze(-1)/math.sqrt(D)
                rw = torch.sigmoid(sim)*mm
                re = rs.unsqueeze(1).expand_as(mem)
                ri = torch.cat([mem,re],-1)
                rg = self.rev_gate(ri)
                rst = (cf*act.unsqueeze(-1)).unsqueeze(1)
                rsc = rw.unsqueeze(-1)*rst*0.3
                mem = mem + rg*(re-mem)*rsc

            # Write
            ws = cf.squeeze(-1)*act
            sc = wv*ws.unsqueeze(-1)
            wm = torch.zeros(B,S,1,device=dev); wm[:,wp,:]=1.0
            mem = mem*(1-wm) + sc.unsqueeze(1)*wm
            mu = torch.zeros(B,S,device=dev); mu[:,wp]=ws
            mm = mm + mu
            wp += 1

            # Consolidation
            if self.use_consolidation and (t+1)%self.consol==0 and t>0:
                mem = self.consolidation(mem, src_key_padding_mask=(mm<0.1))

        if self.use_consolidation:
            mem = self.consolidation(mem, src_key_padding_mask=(mm<0.1))

        qr = self.q_enc(query)
        rs = self.reader(qr, mem, mm)
        return self.head(rs)


all_results = []

experiments = [
    # Baselines
    ("MLP Baseline",        lambda: MLPBaseline(384, HIDDEN, NUM_CLASSES, DROPOUT)),
    ("LSTM Baseline",       lambda: LSTMBaseline(384, HIDDEN, NUM_CLASSES, DROPOUT)),
    ("Flat Attention",      lambda: FlatAttentionBaseline(384, HIDDEN, NUM_HEADS, READ_HOPS, NUM_CLASSES, DROPOUT)),

    # SCRIBE Full
    ("SCRIBE (Full)",       lambda: SCRIBE(384, HIDDEN, NUM_SLOTS, NUM_HEADS, READ_HOPS, CONSOL_EVERY, NUM_CLASSES, DROPOUT,
                                           use_revision=True, use_consolidation=True, use_sentinel=True, use_confidence=True)),

    # Ablations
    ("- No Revision",       lambda: SCRIBE(384, HIDDEN, NUM_SLOTS, NUM_HEADS, READ_HOPS, CONSOL_EVERY, NUM_CLASSES, DROPOUT,
                                           use_revision=False, use_consolidation=True, use_sentinel=True, use_confidence=True)),
    ("- No Consolidation",  lambda: SCRIBE(384, HIDDEN, NUM_SLOTS, NUM_HEADS, READ_HOPS, CONSOL_EVERY, NUM_CLASSES, DROPOUT,
                                           use_revision=True, use_consolidation=False, use_sentinel=True, use_confidence=True)),
    ("- No Sentinel",       lambda: SCRIBE(384, HIDDEN, NUM_SLOTS, NUM_HEADS, READ_HOPS, CONSOL_EVERY, NUM_CLASSES, DROPOUT,
                                           use_revision=True, use_consolidation=True, use_sentinel=False, use_confidence=True)),
    ("- No Confidence",     lambda: SCRIBE(384, HIDDEN, NUM_SLOTS, NUM_HEADS, READ_HOPS, CONSOL_EVERY, NUM_CLASSES, DROPOUT,
                                           use_revision=True, use_consolidation=True, use_sentinel=True, use_confidence=False)),
]

print(f"[3/3] Running {len(experiments)} experiments ({NUM_EPOCHS} epochs each)")
print("="*90)

for name, model_fn in experiments:
    print(f"\n{'='*90}")
    print(f"  EXPERIMENT: {name}")
    print(f"{'='*90}")

    torch.cuda.empty_cache(); gc.collect()
    model = model_fn().to(device)
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"  Params: {params:,}")

    t0 = time.time()
    results = train_model(model, NUM_EPOCHS, LR)
    elapsed = time.time() - t0

    results['name'] = name
    results['params'] = params
    results['time'] = elapsed
    all_results.append(results)

    print(f"  Best (epoch {results.get('epoch','?')}): {results['overall']:.1f}%")
    print(f"  bAbI: {results.get('task_babi',0):.1f}% | PW: {results.get('task_proofwriter',0):.1f}%")
    print(f"  Easy: {results.get('depth_easy',0):.1f}% | Med: {results.get('depth_med',0):.1f}% | Hard: {results.get('depth_hard',0):.1f}%")
    print(f"  Time: {elapsed:.0f}s")

    del model; torch.cuda.empty_cache(); gc.collect()


# FINAL RESULTS TABLE

print("\n\n")
print("="*90)
print("  FINAL RESULTS TABLE — Copy this for the paper")
print("="*90)

header = f"{'Model':<22} {'Params':>8} {'Overall':>8} {'bAbI':>8} {'PW':>8} {'Easy':>8} {'Med':>8} {'Hard':>8}"
print(header)
print("-"*90)

for r in all_results:
    row = (f"{r['name']:<22} "
           f"{r['params']/1e6:>7.1f}M "
           f"{r['overall']:>7.1f}% "
           f"{r.get('task_babi',0):>7.1f}% "
           f"{r.get('task_proofwriter',0):>7.1f}% "
           f"{r.get('depth_easy',0):>7.1f}% "
           f"{r.get('depth_med',0):>7.1f}% "
           f"{r.get('depth_hard',0):>7.1f}%")
    print(row)

print("-"*90)

# Also save as CSV
results_df = pd.DataFrame(all_results)
results_df.to_csv("experiments_results.csv", index=False)
print("\nSaved to experiments_results.csv")
print("\nAll done!")

Device: cuda (Tesla T4)

[1/3] Loading data...
Train: 38013 | Test: 10000 | Classes: 59

[2/3] Encoding...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/75 [00:00<?, ?it/s]

Batches:   0%|          | 0/728 [00:00<?, ?it/s]

Batches:   0%|          | 0/20 [00:00<?, ?it/s]

Batches:   0%|          | 0/190 [00:00<?, ?it/s]

Encoding done.

[3/3] Running 8 experiments (30 epochs each)

  EXPERIMENT: MLP Baseline
  Params: 196,091
  Best (epoch 25): 70.2%
  bAbI: 48.5% | PW: 91.8%
  Easy: 72.4% | Med: 56.6% | Hard: 73.4%
  Time: 32s

  EXPERIMENT: LSTM Baseline
  Params: 826,811
  Best (epoch 20): 75.5%
  bAbI: 59.1% | PW: 91.8%
  Easy: 75.9% | Med: 66.0% | Hard: 77.9%
  Time: 47s

  EXPERIMENT: Flat Attention
  Params: 716,603
  Best (epoch 30): 79.0%
  bAbI: 64.4% | PW: 93.6%
  Easy: 79.0% | Med: 73.1% | Hard: 80.5%
  Time: 65s

  EXPERIMENT: SCRIBE (Full)
  Params: 1,384,764


  scaler.step(optimizer); scaler.update(); scheduler.step()


  Best (epoch 30): 80.3%
  bAbI: 70.5% | PW: 90.2%
  Easy: 80.8% | Med: 79.0% | Hard: 80.6%
  Time: 475s

  EXPERIMENT: - No Revision
  Params: 1,273,788
  Best (epoch 30): 80.1%
  bAbI: 69.8% | PW: 90.3%
  Easy: 81.4% | Med: 77.6% | Hard: 80.5%
  Time: 369s

  EXPERIMENT: - No Consolidation
  Params: 1,087,740
  Best (epoch 30): 85.8%
  bAbI: 78.0% | PW: 93.7%
  Easy: 82.1% | Med: 85.2% | Hard: 86.5%
  Time: 426s

  EXPERIMENT: - No Sentinel
  Params: 1,384,572
  Best (epoch 30): 80.0%
  bAbI: 70.2% | PW: 89.7%
  Easy: 80.7% | Med: 78.7% | Hard: 80.2%
  Time: 470s

  EXPERIMENT: - No Confidence
  Params: 1,384,379
  Best (epoch 25): 77.9%
  bAbI: 65.6% | PW: 90.2%
  Easy: 79.1% | Med: 73.4% | Hard: 78.9%
  Time: 438s



  FINAL RESULTS TABLE — Copy this for the paper
Model                    Params  Overall     bAbI       PW     Easy      Med     Hard
------------------------------------------------------------------------------------------
MLP Baseline               0.2M    70.2%    