In [7]:
from datasets import load_dataset
import numpy as np, faiss, pickle, json, re

In [None]:
DOC_PKL   = r"C:\\Users\\gabri\\Downloads\\scifact_evidence_embeddings.pkl" 
CLAIM_PKL = r"C:\\Users\\gabri\\Downloads\\scifact_claim_embeddings.pkl"     
with open(DOC_PKL, "rb") as f:
    docs_raw = pickle.load(f)    
with open(CLAIM_PKL, "rb") as f:
    claims_raw = pickle.load(f)   

doc_ids, doc_embs = zip(*[(k[0], np.asarray(v, dtype="float32")) for k,v in docs_raw.items()])
claim_ids, claim_embs = zip(*[(k[0], np.asarray(v, dtype="float32")) for k,v in claims_raw.items()])

doc_ids   = np.array(doc_ids, dtype=object)
doc_embs  = np.vstack(doc_embs).astype("float32")
claim_ids = np.array(claim_ids, dtype=object)
claim_embs= np.vstack(claim_embs).astype("float32")

doc_embs.shape, claim_embs.shape

((5183, 1536), (809, 1536))

In [None]:
claims_ds = load_dataset("allenai/scifact", data_dir="claims", revision="refs/convert/parquet")
corpus_ds = load_dataset("allenai/scifact", data_dir="corpus",  revision="refs/convert/parquet")

claim_splits = [s for s in ["train","validation","dev","test"] if s in claims_ds]

def parse_evidence_doc_ids(example):
    gold = set()
    if "evidence" in example and example["evidence"] is not None:
        ev = example["evidence"]
        if isinstance(ev, list):
            for e in ev:
                if isinstance(e, dict) and "doc_id" in e:
                    gold.add(int(e["doc_id"]))
                elif isinstance(e, (int, np.integer)):
                    gold.add(int(e))
    if not gold and "evidence_doc_id" in example and example["evidence_doc_id"] is not None:
        s = example["evidence_doc_id"]
        if isinstance(s, str) and s.strip():
            try:
                parsed = json.loads(s)   
                if isinstance(parsed, list):
                    gold.update(int(x) for x in parsed)
                else:
                    gold.add(int(parsed))
            except Exception:
                for tok in re.findall(r"-?\d+", s):
                    gold.add(int(tok))
        elif isinstance(s, (int, np.integer)):
            gold.add(int(s))
    if not gold and "cited_doc_ids" in example and example["cited_doc_ids"] is not None:
        if isinstance(example["cited_doc_ids"], list):
            gold.update(int(x) for x in example["cited_doc_ids"] if isinstance(x, (int, np.integer)))
    return gold

gold_map = {}
for sp in claim_splits:
    for ex in claims_ds[sp]:
        cid = int(ex["id"]) if "id" in ex else int(ex.get("claim_id"))
        gold_map[cid] = parse_evidence_doc_ids(ex)


Generating train split: 1261 examples [00:00, 553591.93 examples/s]
Generating validation split: 450 examples [00:00, 179022.74 examples/s]
Generating test split: 300 examples [00:00, 139608.48 examples/s]
Generating train split: 5183 examples [00:00, 363669.60 examples/s]


In [10]:
gold_sets = []
keep_idx = []
for i, cid in enumerate(claim_ids):
    g = gold_map.get(int(cid), set())
    if len(g) > 0:
        keep_idx.append(i)
        gold_sets.append(g)

claim_ids_f   = claim_ids[keep_idx]
claim_embs_f  = claim_embs[keep_idx]
len(claim_ids), len(claim_ids_f)

(809, 809)

In [None]:
def l2norm(x): return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-12)

doc_embs_n   = l2norm(doc_embs)
claim_embs_n = l2norm(claim_embs_f)

index = faiss.IndexFlatIP(doc_embs_n.shape[1])
index.add(doc_embs_n)

D, I = index.search(claim_embs_n, 50)
faiss2doc = np.array(doc_ids, dtype=object)
retrieved = [[faiss2doc[j] for j in row] for row in I]

def mrr_at_k(recs, golds, k):
    vals = []
    for r, g in zip(recs, golds):
        rank = next((i+1 for i,doc in enumerate(r[:k]) if doc in g), None)
        vals.append(0.0 if rank is None else 1.0/rank)
    return float(np.mean(vals))

def map_at_k(recs, golds, k):
    vals = []
    for r, g in zip(recs, golds):
        hits, ap = 0, 0.0
        for i,doc in enumerate(r[:k], start=1):
            if doc in g:
                hits += 1
                ap += hits / i
        denom = min(len(g), k) if len(g) > 0 else 1
        vals.append(0.0 if denom==0 else ap/denom)
    return float(np.mean(vals))

Ks = [1,10,50]
metrics = {f"MRR@{k}": mrr_at_k(retrieved, gold_sets, k) for k in Ks}
metrics.update({f"MAP@{k}": map_at_k(retrieved, gold_sets, k) for k in Ks})
print(metrics)



{'MRR@1': 0.5822002472187886, 'MRR@10': 0.6740616477328467, 'MRR@50': 0.6785670111103024, 'MAP@1': 0.5822002472187886, 'MAP@10': 0.6724213706908392, 'MAP@50': 0.6771543808644933}

Table row:
OpenAI Embeddings | 0.5822 | 0.6741 | 0.6786 | 0.5822 | 0.6724 | 0.6772


In [12]:
print("Approach - Metric |", " | ".join([f"MRR@{k}" for k in Ks] + [f"MAP@{k}" for k in Ks]))
print("OpenAI Embeddings |",
      " | ".join([f"{metrics[f'MRR@{k}']:.4f}" for k in Ks] +
                 [f"{metrics[f'MAP@{k}']:.4f}" for k in Ks]))


Approach - Metric | MRR@1 | MRR@10 | MRR@50 | MAP@1 | MAP@10 | MAP@50
OpenAI Embeddings | 0.5822 | 0.6741 | 0.6786 | 0.5822 | 0.6724 | 0.6772
