In [1]:
"""
Stemming-based preprocessing for the retrieval task.
✓ Multi-process spaCy
✓ Porter stemming
✓ On-disk cache (evidence_stemmed.json / claims_stemmed.json)
"""

import json, statistics, collections, time, multiprocessing as mp
from pathlib import Path
from tqdm import tqdm

import spacy
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords

# ------------------------------------------------------------------
# CONFIGURATION
# ------------------------------------------------------------------
DATA_DIR      = Path("data")
OUT_EVID      = Path("preprocessed/evidence_stemmed.json")
OUT_CLAIM     = Path("preprocessed/claims_stemmed.json")
FORCE_REBUILD = True                 # True → ignore cache, rebuild
BATCH_SIZE    = 1_000                 # spaCy batch size
NUM_PROC      = max(mp.cpu_count() - 1, 1)   # use all but 1 core

# ------------------------------------------------------------------
# INITIALISE SPACY & STEMMER
# ------------------------------------------------------------------
nlp = spacy.load("en_core_web_sm", disable=["ner", "parser"])
stemmer = PorterStemmer()
stop_set = set(stopwords.words("english"))

def stem_doc(doc):
    out = []
    for tok in doc:
        lemma = tok.lemma_.lower()
        if lemma.isalpha() and lemma not in stop_set:
            out.append(stemmer.stem(lemma))
    return out

def jload(path: Path):
    with path.open(encoding="utf-8") as f:
        return json.load(f)

def jdump(obj, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False)

# ------------------------------------------------------------------
# 0) LOAD CACHE (IF PRESENT)
# ------------------------------------------------------------------
if OUT_EVID.exists() and OUT_CLAIM.exists() and not FORCE_REBUILD:
    t0 = time.time()
    evidence_proc  = jload(OUT_EVID)
    claim_proc_all = jload(OUT_CLAIM)
    print(f"Cached data loaded in {time.time() - t0:.2f} s – ready to use.")
    exit(0)

print("No valid cache – preprocessing will start …")
t_start = time.time()

# ------------------------------------------------------------------
# 1) PRE-PROCESS EVIDENCE (PARALLEL)
# ------------------------------------------------------------------
evidence_raw = jload(DATA_DIR / "evidence.json")
evid_ids     = list(evidence_raw.keys())
evid_texts   = list(evidence_raw.values())

evidence_proc = {}
lengths = []

print(f"Tokenising {len(evid_ids):,} evidence passages "
      f"with {NUM_PROC} CPU process(es)…")

for evid_id, doc in tqdm(
        zip(evid_ids,
            nlp.pipe(evid_texts,
                     batch_size=BATCH_SIZE,
                     n_process=NUM_PROC)),
        total=len(evid_ids),
        desc="Stemming evidence",
        unit="doc"
):
    stems = stem_doc(doc)
    if stems:
        evidence_proc[evid_id] = stems
        lengths.append(len(stems))

jdump(evidence_proc, OUT_EVID)
print(f"Evidence saved → {OUT_EVID.resolve()}")

# ------------------------------------------------------------------
# 2) PRE-PROCESS CLAIMS (PARALLEL, PER SPLIT)
# ------------------------------------------------------------------
claim_files = [
    "train-claims.json",
    "dev-claims.json",
    "test-claims-unlabelled.json",
]
claim_proc_all = {}

for fname in claim_files:
    raw_claims = jload(DATA_DIR / fname)
    cids  = list(raw_claims.keys())
    texts = [raw_claims[cid]["claim_text"] for cid in cids]

    for cid, doc in tqdm(
            zip(cids,
                nlp.pipe(texts,
                         batch_size=BATCH_SIZE,
                         n_process=NUM_PROC)),
            total=len(cids),
            desc=f"Stemming {fname}",
            unit="doc"
    ):
        stems = stem_doc(doc)
        if stems:
            claim_proc_all[cid] = stems

jdump(claim_proc_all, OUT_CLAIM)
print(f"Claims saved → {OUT_CLAIM.resolve()}")

# ------------------------------------------------------------------
# 3) QUICK CORPUS STATISTICS
# ------------------------------------------------------------------
print("\n=== Evidence after stemming ===")
print(f"Total passages        : {len(evidence_proc):,}")
print(f"Stem length (min/max) : {min(lengths)} / {max(lengths)}")
print(f"Stem length (mean)    : {statistics.mean(lengths):.1f}")

vocab = {s for toks in evidence_proc.values() for s in toks}
print(f"Vocabulary size       : {len(vocab):,}")

counter = collections.Counter(s for toks in evidence_proc.values() for s in toks)
print("Top-20 stems          :", counter.most_common(20))

print(f"\nFinished in {time.time() - t_start:.1f} s – "
      f"results cached for future runs.")

No valid cache – preprocessing will start …
Tokenising 1,208,827 evidence passages with 9 CPU process(es)…


Stemming evidence: 100%|██████████| 1208827/1208827 [07:24<00:00, 2717.06doc/s]


Evidence saved → /Users/felikskong/Desktop/NLP/NLP_Ass3/preprocessed/evidence_stemmed.json


Stemming train-claims.json: 100%|██████████| 1228/1228 [00:44<00:00, 27.88doc/s]
Stemming dev-claims.json: 100%|██████████| 154/154 [00:42<00:00,  3.59doc/s]
Stemming test-claims-unlabelled.json: 100%|██████████| 153/153 [00:42<00:00,  3.60doc/s]


Claims saved → /Users/felikskong/Desktop/NLP/NLP_Ass3/preprocessed/claims_stemmed.json

=== Evidence after stemming ===
Total passages        : 1,207,920
Stem length (min/max) : 1 / 304
Stem length (mean)    : 11.3
Vocabulary size       : 510,195
Top-20 stems          : [('also', 66963), ('state', 58250), ('bear', 56376), ('first', 53537), ('one', 49589), ('new', 44100), ('year', 42117), ('play', 39752), ('american', 39704), ('includ', 39608), ('use', 39337), ('unit', 38930), ('nation', 37995), ('name', 37335), ('know', 37286), ('district', 34882), ('two', 34481), ('film', 33964), ('counti', 32636), ('footbal', 31480)]

Finished in 583.2 s – results cached for future runs.


In [2]:
# 0) Imports & Config
import os, json, random, itertools, collections
from pathlib import Path
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# Hyperparameters & Paths
DATA_DIR      = Path("data")
PRE_DIR       = Path("preprocessed")
EVID_J        = PRE_DIR/"evidence_stemmed.json"
CLAIM_J       = PRE_DIR/"claims_stemmed.json"
TRAIN_J       = DATA_DIR/"train-claims.json"
DEV_J         = DATA_DIR/"dev-claims.json"
EVID_CORPUS_J = DATA_DIR/"evidence.json"

EMB_DIM   = 100
HID_DIM   = 128
BATCH     = 128
EPOCHS    = 5
LR        = 3e-4
MARGIN    = 0.3
MIN_FREQ  = 3
TOP_K     = 5
DEVICE    = "cuda" if torch.cuda.is_available() else "cpu"

# 1) Load pre-stemmed data
with open(EVID_J, "r", encoding="utf-8") as f:
    evidence_proc = json.load(f)
with open(CLAIM_J, "r", encoding="utf-8") as f:
    claim_proc_all = json.load(f)

# 2) Build vocab
freq = collections.Counter(
    t for toks in itertools.chain(evidence_proc.values(),
                                  claim_proc_all.values())
    for t in toks
)
PAD, UNK = "<PAD>", "<UNK>"
itos = [PAD, UNK] + [t for t,c in freq.items() if c>=MIN_FREQ]
stoi = {t:i for i,t in enumerate(itos)}
def numerise(tokens):
    return [stoi.get(t, stoi[UNK]) for t in tokens]

# 3) Load labels
train_lbl = json.loads(TRAIN_J.read_text())
dev_lbl   = json.loads(DEV_J.read_text())

# 4) Triplet Dataset & DataLoader
class TripletDataset(Dataset):
    def __init__(self, labeled, evid_dict):
        items, evid_ids = [], list(evid_dict.keys())
        for cid, obj in labeled.items():
            pos = [e for e in obj["evidences"] if e in evid_dict]
            for p in pos:
                n = random.choice(evid_ids)
                while n==p: n = random.choice(evid_ids)
                items.append((cid, p, n))
        self.items = items
        self.evid  = evid_dict
    def __len__(self):
        return len(self.items)
    def __getitem__(self, idx):
        cid,p,n = self.items[idx]
        return (
          torch.tensor(numerise(claim_proc_all[cid]), dtype=torch.long),
          torch.tensor(numerise(self.evid[p]), dtype=torch.long),
          torch.tensor(numerise(self.evid[n]), dtype=torch.long),
        )

def collate_fn(batch):
    def pad(seqs):
        m = max(len(s) for s in seqs)
        return torch.tensor([s.tolist()+[0]*(m-len(s)) for s in seqs])
    c,p,n = zip(*batch)
    return pad(c), pad(p), pad(n)

train_ds = TripletDataset(train_lbl, evidence_proc)
train_dl = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
                      collate_fn=collate_fn)

# 5) BiLSTM Sentence Encoder
class BiLSTMSentenceEncoder(nn.Module):
    def __init__(self, vocab_sz, emb_dim=EMB_DIM, hid_dim=HID_DIM):
        super().__init__()
        self.emb  = nn.Embedding(vocab_sz, emb_dim, padding_idx=0)
        self.lstm = nn.LSTM(emb_dim, hid_dim, batch_first=True,
                            bidirectional=True)
    def forward(self, x):
        mask = (x!=0).float().unsqueeze(-1)
        out, _ = self.lstm(self.emb(x))
        # mean‐pool over the length dim
        out = (out * mask).sum(1) / mask.sum(1)
        return nn.functional.normalize(out, p=2, dim=-1)

In [3]:
# 6) Train the retriever
model   = BiLSTMSentenceEncoder(len(itos)).to(DEVICE)
optim   = torch.optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.MarginRankingLoss(margin=MARGIN)

for epoch in range(EPOCHS):
    model.train()
    total = 0
    for c, p, n in tqdm(train_dl, desc=f"Epoch {epoch+1}"):
        c,p,n = [t.to(DEVICE) for t in (c,p,n)]
        vc, vp, vn = model(c), model(p), model(n)
        pos_sim = (vc * vp).sum(1)
        neg_sim = (vc * vn).sum(1)
        loss    = loss_fn(pos_sim, neg_sim,
                          torch.ones_like(pos_sim, device=DEVICE))
        optim.zero_grad(); loss.backward(); optim.step()
        total += loss.item()
    print(f"  Epoch {epoch+1} avg loss = {total/len(train_dl):.4f}")

Epoch 1: 100%|██████████| 33/33 [01:29<00:00,  2.71s/it]


  Epoch 1 avg loss = 0.1006


Epoch 2: 100%|██████████| 33/33 [01:21<00:00,  2.46s/it]


  Epoch 2 avg loss = 0.0620


Epoch 3: 100%|██████████| 33/33 [01:26<00:00,  2.62s/it]


  Epoch 3 avg loss = 0.0442


Epoch 4: 100%|██████████| 33/33 [01:42<00:00,  3.11s/it]


  Epoch 4 avg loss = 0.0326


Epoch 5: 100%|██████████| 33/33 [01:24<00:00,  2.55s/it]

  Epoch 5 avg loss = 0.0243





In [4]:
# 7) Encode all evidence
# single‐worker to avoid multiprocessing issues in notebook
evidence_vecs = {}
loader = DataLoader(
    [(eid, torch.tensor(numerise(evidence_proc[eid]),
                        dtype=torch.long))
     for eid in evidence_proc],
    batch_size=512, shuffle=False,
    collate_fn=lambda batch: (
        [e[0] for e in batch],
        pad_sequence([e[1] for e in batch],
                     batch_first=True, padding_value=0)
    ),
    num_workers=0,
)
model.eval()
with torch.no_grad():
    for eids, seqs in tqdm(loader, desc="Encoding evidence"):
        vecs = model(seqs.to(DEVICE)).cpu()
        for eid, v in zip(eids, vecs):
            evidence_vecs[eid] = v

Encoding evidence: 100%|██████████| 2360/2360 [52:21<00:00,  1.33s/it] 


In [5]:
# 8) Ranking & Evaluation on dev
def rank_evidence(stems, top_k):
    idxs = numerise(stems)
    x = torch.tensor([idxs], dtype=torch.long, device=DEVICE)
    with torch.no_grad():
        vc = model(x).cpu().squeeze(0)
    sims = {eid: float(torch.dot(vc, v_e))
            for eid, v_e in evidence_vecs.items()}
    return sorted(sims, key=sims.get, reverse=True)[:top_k]

def evaluate(top_k):
    recalls, precisions, f1s = [], [], []
    for cid, obj in tqdm(dev_lbl.items(), desc="Evaluating"):
        gold      = set(obj["evidences"])
        retrieved = rank_evidence(claim_proc_all[cid], top_k)
        hits      = len(gold & set(retrieved))
        r = hits/len(gold) if gold else 0.0
        p = hits/top_k
        f = (2*r*p/(r+p)) if (r+p)>0 else 0.0
        recalls.append(r); precisions.append(p); f1s.append(f)

    print(f"\nRecall@{top_k}:    {np.mean(recalls):.2%}")
    print(f"Precision@{top_k}: {np.mean(precisions):.2%}")
    print(f"F1@{top_k}:        {np.mean(f1s):.2%}")

for k in [3, 4, 5]:
    evaluate(k)

Evaluating: 100%|██████████| 154/154 [05:03<00:00,  1.97s/it]



Recall@3:    5.94%
Precision@3: 4.76%
F1@3:        4.89%


Evaluating: 100%|██████████| 154/154 [04:53<00:00,  1.90s/it]



Recall@4:    6.20%
Precision@4: 3.90%
F1@4:        4.44%


Evaluating: 100%|██████████| 154/154 [04:54<00:00,  1.91s/it]


Recall@5:    7.11%
Precision@5: 3.77%
F1@5:        4.61%



