In [1]:
# 0) Imports & Config
import os, json, random, itertools, collections
from pathlib import Path
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Hyperparameters & Paths
DATA_DIR      = Path("data")
PRE_DIR       = Path("preprocessed")
EVID_J        = PRE_DIR/"evidence_stemmed.json"
CLAIM_J       = PRE_DIR/"claims_stemmed.json"
TRAIN_J       = DATA_DIR/"train-claims.json"
DEV_J         = DATA_DIR/"dev-claims.json"
EVID_CORPUS_J = DATA_DIR/"evidence.json"

EMB_DIM   = 100
HID_DIM   = 128
BATCH     = 128
EPOCHS    = 5
LR        = 3e-4
MARGIN    = 0.3
MIN_FREQ  = 3
TOP_K     = 5
DEVICE    = "cuda" if torch.cuda.is_available() else "cpu"

# 1) Load pre-stemmed data
with open(EVID_J, "r", encoding="utf-8") as f:
    evidence_proc = json.load(f)
with open(CLAIM_J, "r", encoding="utf-8") as f:
    claim_proc_all = json.load(f)

# 2) Build vocab
freq = collections.Counter(
    t for toks in itertools.chain(evidence_proc.values(),
                                  claim_proc_all.values())
    for t in toks
)
PAD, UNK = "<PAD>", "<UNK>"
itos = [PAD, UNK] + [t for t,c in freq.items() if c>=MIN_FREQ]
stoi = {t:i for i,t in enumerate(itos)}
def numerise(tokens):
    return [stoi.get(t, stoi[UNK]) for t in tokens]

# 3) Load labels
train_lbl = json.loads(TRAIN_J.read_text())
dev_lbl   = json.loads(DEV_J.read_text())

# 4) Triplet Dataset & DataLoader
class TripletDataset(Dataset):
    def __init__(self, labeled, evid_dict):
        items, evid_ids = [], list(evid_dict.keys())
        for cid, obj in labeled.items():
            pos = [e for e in obj["evidences"] if e in evid_dict]
            for p in pos:
                n = random.choice(evid_ids)
                while n==p: n = random.choice(evid_ids)
                items.append((cid, p, n))
        self.items = items
        self.evid  = evid_dict
    def __len__(self):
        return len(self.items)
    def __getitem__(self, idx):
        cid,p,n = self.items[idx]
        return (
          torch.tensor(numerise(claim_proc_all[cid]), dtype=torch.long),
          torch.tensor(numerise(self.evid[p]), dtype=torch.long),
          torch.tensor(numerise(self.evid[n]), dtype=torch.long),
        )

def collate_fn(batch):
    def pad(seqs):
        m = max(len(s) for s in seqs)
        return torch.tensor([s.tolist()+[0]*(m-len(s)) for s in seqs])
    c,p,n = zip(*batch)
    return pad(c), pad(p), pad(n)

train_ds = TripletDataset(train_lbl, evidence_proc)
train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
                      collate_fn=collate_fn)

# 5) BiLSTM Sentence Encoder
class BiLSTMSentenceEncoder(nn.Module):
    def __init__(self, vocab_sz, emb_dim=EMB_DIM, hid_dim=HID_DIM):
        super().__init__()
        self.emb  = nn.Embedding(vocab_sz, emb_dim, padding_idx=0)
        self.lstm = nn.LSTM(emb_dim, hid_dim, batch_first=True,
                            bidirectional=True)
    def forward(self, x):
        mask = (x!=0).float().unsqueeze(-1)
        out, _ = self.lstm(self.emb(x))
        # mean‐pool over the length dim
        out = (out * mask).sum(1) / mask.sum(1)
        return nn.functional.normalize(out, p=2, dim=-1)

In [2]:
# 6) Train the retriever
model   = BiLSTMSentenceEncoder(len(itos)).to(DEVICE)
optim   = torch.optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.MarginRankingLoss(margin=MARGIN)

for epoch in range(EPOCHS):
    model.train()
    total = 0
    for c, p, n in tqdm(train_dl, desc=f"Epoch {epoch+1}"):
        c,p,n = [t.to(DEVICE) for t in (c,p,n)]
        vc, vp, vn = model(c), model(p), model(n)
        pos_sim = (vc * vp).sum(1)
        neg_sim = (vc * vn).sum(1)
        loss    = loss_fn(pos_sim, neg_sim,
                          torch.ones_like(pos_sim, device=DEVICE))
        optim.zero_grad(); loss.backward(); optim.step()
        total += loss.item()
    print(f"  Epoch {epoch+1} avg loss = {total/len(train_dl):.4f}")

Epoch 1: 100%|██████████| 33/33 [01:40<00:00,  3.05s/it]


  Epoch 1 avg loss = 0.1079


Epoch 2: 100%|██████████| 33/33 [01:59<00:00,  3.63s/it]


  Epoch 2 avg loss = 0.0652


Epoch 3: 100%|██████████| 33/33 [01:42<00:00,  3.11s/it]


  Epoch 3 avg loss = 0.0467


Epoch 4: 100%|██████████| 33/33 [01:37<00:00,  2.94s/it]


  Epoch 4 avg loss = 0.0335


Epoch 5: 100%|██████████| 33/33 [01:37<00:00,  2.96s/it]

  Epoch 5 avg loss = 0.0240





In [3]:
# 7) Encode all evidence
# single‐worker to avoid multiprocessing issues in notebook
evidence_vecs = {}
loader = DataLoader(
    [(eid, torch.tensor(numerise(evidence_proc[eid]),
                        dtype=torch.long))
     for eid in evidence_proc],
    batch_size=512, shuffle=False,
    collate_fn=lambda batch: (
        [e[0] for e in batch],
        pad_sequence([e[1] for e in batch],
                     batch_first=True, padding_value=0)
    ),
    num_workers=0,
)
model.eval()
with torch.no_grad():
    for eids, seqs in tqdm(loader, desc="Encoding evidence"):
        vecs = model(seqs.to(DEVICE)).cpu()
        for eid, v in zip(eids, vecs):
            evidence_vecs[eid] = v

Encoding evidence: 100%|██████████| 2360/2360 [54:32<00:00,  1.39s/it] 


In [5]:
# 8) Ranking & Evaluation on dev
def rank_evidence(stems, top_k):
    idxs = numerise(stems)
    x = torch.tensor([idxs], dtype=torch.long, device=DEVICE)
    with torch.no_grad():
        vc = model(x).cpu().squeeze(0)
    sims = {eid: float(torch.dot(vc, v_e))
            for eid, v_e in evidence_vecs.items()}
    return sorted(sims, key=sims.get, reverse=True)[:top_k]

def evaluate(top_k):
    recalls, precisions, f1s = [], [], []
    for cid, obj in tqdm(dev_lbl.items(), desc="Evaluating"):
        gold      = set(obj["evidences"])
        retrieved = rank_evidence(claim_proc_all[cid], top_k)
        hits      = len(gold & set(retrieved))
        r = hits/len(gold) if gold else 0.0
        p = hits/top_k
        f = (2*r*p/(r+p)) if (r+p)>0 else 0.0
        recalls.append(r); precisions.append(p); f1s.append(f)

    print(f"\nRecall@{top_k}:    {np.mean(recalls):.2%}")
    print(f"Precision@{top_k}: {np.mean(precisions):.2%}")
    print(f"F1@{top_k}:        {np.mean(f1s):.2%}")

for k in [3, 4, 5]:
    evaluate(k)

Evaluating: 100%|██████████| 154/154 [04:50<00:00,  1.89s/it]



Recall@3:    5.01%
Precision@3: 3.90%
F1@3:        4.04%


Evaluating: 100%|██████████| 154/154 [04:44<00:00,  1.85s/it]



Recall@4:    5.90%
Precision@4: 3.57%
F1@4:        4.15%


Evaluating: 100%|██████████| 154/154 [04:49<00:00,  1.88s/it]


Recall@5:    6.24%
Precision@5: 3.12%
F1@5:        3.90%



