In [None]:
import re, math, random
import numpy as np
from tqdm import tqdm
from rank_bm25 import BM25Okapi

random.seed(42)

def tokenize(text: str):
    return re.findall(r"[A-Za-z0-9]+", str(text).lower())

def dcg(rels):
    return sum((2**rel - 1) / math.log2(i + 2) for i, rel in enumerate(rels))

def ndcg_at_k(ranked_ids, rel_set, k=10):
    rels = [1 if doc_id in rel_set else 0 for doc_id in ranked_ids[:k]]
    ideal = sorted(rels, reverse=True)
    denom = dcg(ideal)
    return 0.0 if denom == 0 else dcg(rels) / denom

def recall_at_k(ranked_ids, rel_set, k=100):
    hit = sum(1 for doc_id in ranked_ids[:k] if doc_id in rel_set)
    return hit / max(1, len(rel_set))

def build_bm25(corpus_texts):
    tokenized = [tokenize(t) for t in tqdm(corpus_texts, desc="Tokenizing corpus")]
    return BM25Okapi(tokenized), tokenized

def bm25_search(bm25, query, topk=1000):
    scores = bm25.get_scores(tokenize(query))
    ranked = np.argsort(-scores)[:topk]
    return ranked.tolist()

def eval_intent_retrieval(df, text_col, intent_col, n_queries=200, min_rel=20, topk=1000):
    """
    Query = random text
    Relevant = same intent label (excluding itself)
    """
    df = df.dropna(subset=[text_col, intent_col]).reset_index(drop=True)
    corpus = df[text_col].astype(str).tolist()
    intents = df[intent_col].astype(str).tolist()

    from collections import defaultdict
    intent_to_idx = defaultdict(list)
    for i, lab in enumerate(intents):
        intent_to_idx[lab].append(i)

    eligible = [i for i, lab in enumerate(intents) if len(intent_to_idx[lab]) >= min_rel]
    if len(eligible) == 0:
        raise ValueError("No eligible queries: increase data or lower min_rel")

    chosen = random.sample(eligible, min(n_queries, len(eligible)))

    bm25, _ = build_bm25(corpus)

    ndcgs, recalls = [], []
    for qi in tqdm(chosen, desc="Evaluating (intent retrieval)"):
        qtext = corpus[qi]
        lab = intents[qi]
        rel_set = set(intent_to_idx[lab]) - {qi}

        ranked = bm25_search(bm25, qtext, topk=topk)
        ndcgs.append(ndcg_at_k(ranked, rel_set, k=10))
        recalls.append(recall_at_k(ranked, rel_set, k=100))

    return float(np.mean(ndcgs)), float(np.mean(recalls)), len(df)


In [None]:
import pandas as pd

md = pd.read_parquet(PATH_MDCITE)
print("MDCite columns:", md.columns.tolist())
print("Rows:", len(md))

def pick_col(cols, candidates):
    for c in candidates:
        if c in cols:
            return c
    return None

cols = md.columns.tolist()
TEXT_COL = pick_col(cols, ["context", "text", "sentence", "citing_context", "citing_contexts"])
INTENT_COL = pick_col(cols, ["intent", "intents", "label"])
FIELD_COL = pick_col(cols, ["field", "category", "wos_field"])

print("Mapped:", TEXT_COL, INTENT_COL, FIELD_COL)

SAMPLE_N = min(200_000, len(md))
if FIELD_COL:
    stratified_sample_per_group = md.groupby(FIELD_COL, group_keys=False) \
                                    .apply(lambda x: x.sample(n=min(len(x), max(1000, SAMPLE_N // max(1, md[FIELD_COL].nunique()))),
                                                              random_state=42), include_groups=False)
    md_sample = stratified_sample_per_group.sample(n=min(SAMPLE_N, len(stratified_sample_per_group)), random_state=42) \
                                           .reset_index(drop=True)
else:
    md_sample = md.sample(n=SAMPLE_N, random_state=42).reset_index(drop=True)

print("MDCite sample rows:", len(md_sample))


In [None]:
import tarfile, json

def load_scicite_from_tar(tar_path):
    rows = []
    with tarfile.open(tar_path, "r:gz") as tar:
        members = [m for m in tar.getmembers() if m.isfile() and m.name.endswith((".jsonl", ".json"))]
        print("SciCite files:", [m.name for m in members])

        for m in members:
            f = tar.extractfile(m)
            if f is None:
                continue
            content = f.read().decode("utf-8", errors="ignore")

            if m.name.endswith(".jsonl"):
                for line in content.splitlines():
                    line = line.strip()
                    if not line:
                        continue
                    obj = json.loads(line)
                    text = obj.get("string") or obj.get("text") or obj.get("sentence")
                    lab = obj.get("label") or obj.get("intent")
                    if text is not None and lab is not None:
                        rows.append({"text": text, "intent": lab})
            else:
                obj = json.loads(content)
                if isinstance(obj, list):
                    for it in obj:
                        text = it.get("string") or it.get("text") or it.get("sentence")
                        lab = it.get("label") or it.get("intent")
                        if text is not None and lab is not None:
                            rows.append({"text": text, "intent": lab})
    return pd.DataFrame(rows)

scidf = load_scicite_from_tar(PATH_SCICITE)
print("SciCite rows:", len(scidf))
print(scidf.head())


In [None]:
results = []

ndcg, rec, n = eval_intent_retrieval(
    md_sample,
    TEXT_COL,
    INTENT_COL,
    n_queries=200,
    min_rel=30
)
results.append(("MDCite (single)", n, ndcg, rec))

ndcg, rec, n = eval_intent_retrieval(
    scidf,
    "text",
    "intent",
    n_queries=200,
    min_rel=30
)
results.append(("SciCite", n, ndcg, rec))

results_df = pd.DataFrame(
    results,
    columns=["Dataset", "NumDocs", "BM25 nDCG@10", "BM25 Recall@100"]
)

results_df


In [None]:
def eval_intent_retrieval_extended(
    df, text_col, intent_col,
    n_queries=200, min_rel=30, topk=2000
):
    df = df.dropna(subset=[text_col, intent_col]).reset_index(drop=True)
    corpus = df[text_col].astype(str).tolist()
    intents = df[intent_col].astype(str).tolist()

    from collections import defaultdict
    intent_to_idx = defaultdict(list)
    for i, lab in enumerate(intents):
        intent_to_idx[lab].append(i)

    eligible = [i for i, lab in enumerate(intents)
                if len(intent_to_idx[lab]) >= min_rel]
    chosen = random.sample(eligible, min(n_queries, len(eligible)))

    bm25, _ = build_bm25(corpus)

    ndcgs, r100, r1000 = [], [], []
    for qi in tqdm(chosen, desc="Eval intent retrieval"):
        qtext = corpus[qi]
        lab = intents[qi]
        rel_set = set(intent_to_idx[lab]) - {qi}

        ranked = bm25_search(bm25, qtext, topk=topk)
        ndcgs.append(ndcg_at_k(ranked, rel_set, k=10))
        r100.append(recall_at_k(ranked, rel_set, k=100))
        r1000.append(recall_at_k(ranked, rel_set, k=1000))

    return (
        float(np.mean(ndcgs)),
        float(np.mean(r100)),
        float(np.mean(r1000)),
        len(df)
    )


In [None]:
ndcg, r100, r1000, n = eval_intent_retrieval_extended(
    md_sample, TEXT_COL, INTENT_COL
)
print("MDCite nDCG@10:", ndcg)
print("MDCite Recall@100:", r100)
print("MDCite Recall@1000:", r1000)


In [None]:
import re, math, random
import numpy as np
import pandas as pd
from tqdm import tqdm
from rank_bm25 import BM25Okapi
from collections import defaultdict

random.seed(42)

def tokenize(text: str):
    return re.findall(r"[A-Za-z0-9]+", str(text).lower())

def dcg(rels):
    return sum((2**rel - 1) / math.log2(i + 2) for i, rel in enumerate(rels))

def ndcg_at_k(ranked_ids, rel_set, k=10):
    rels = [1 if doc_id in rel_set else 0 for doc_id in ranked_ids[:k]]
    ideal = sorted(rels, reverse=True)
    denom = dcg(ideal)
    return 0.0 if denom == 0 else dcg(rels) / denom

def recall_at_k(ranked_ids, rel_set, k=100):
    hit = sum(1 for doc_id in ranked_ids[:k] if doc_id in rel_set)
    return hit / max(1, len(rel_set))

def build_bm25(corpus_texts):
    tokenized = [tokenize(t) for t in tqdm(corpus_texts, desc="Tokenizing corpus")]
    bm25 = BM25Okapi(tokenized)
    return bm25

def bm25_search(bm25, query, topk=2000):
    scores = bm25.get_scores(tokenize(query))
    ranked = np.argsort(-scores)[:topk]
    return ranked.tolist()
def sample_queries_for_intent_task(df, text_col, intent_col, n_queries=200, min_rel=5):
    df = df.dropna(subset=[text_col, intent_col]).reset_index(drop=True)
    intents = df[intent_col].astype(str).tolist()

    intent_to_idx = defaultdict(list)
    for i, lab in enumerate(intents):
        intent_to_idx[lab].append(i)

    eligible = [i for i, lab in enumerate(intents) if len(intent_to_idx[lab]) >= (min_rel + 1)]
    if not eligible:
        raise ValueError("No eligible queries. Lower min_rel or check labels.")

    chosen = random.sample(eligible, min(n_queries, len(eligible)))
    return df, intent_to_idx, chosen
def recall_at_k_curve_bm25(
    df, text_col, intent_col,
    k_values=(10, 50, 100, 200, 500, 1000),
    n_queries=200,
    min_rel=5,
    topk=2000
):
    df, intent_to_idx, chosen = sample_queries_for_intent_task(
        df, text_col, intent_col, n_queries=n_queries, min_rel=min_rel
    )

    corpus = df[text_col].astype(str).tolist()
    intents = df[intent_col].astype(str).tolist()

    bm25 = build_bm25(corpus)

    k_values = list(k_values)
    recalls = {k: [] for k in k_values}

    for qi in tqdm(chosen, desc="Evaluating Recall@k curve"):
        qtext = corpus[qi]
        lab = intents[qi]
        rel_set = set(intent_to_idx[lab]) - {qi}

        ranked = bm25_search(bm25, qtext, topk=topk)

        for k in k_values:
            recalls[k].append(recall_at_k(ranked, rel_set, k=k))

    curve = {f"Recall@{k}": float(np.mean(vals)) for k, vals in recalls.items()}
    curve["NumQueries"] = len(chosen)
    curve["NumDocs"] = len(df)
    return curve
curve = recall_at_k_curve_bm25(
    md_sample, TEXT_COL, INTENT_COL,
    k_values=(10, 50, 100, 200, 500, 1000),
    n_queries=200,
    min_rel=30,      
    topk=2000
)

pd.DataFrame([curve])
def intent_wise_metrics_on_global_corpus(
    df, text_col, intent_col,
    intents_list=None,
    max_queries_per_intent=100,
    min_rel=5,
    topk=2000
):
    df = df.dropna(subset=[text_col, intent_col]).reset_index(drop=True)
    corpus = df[text_col].astype(str).tolist()
    intents = df[intent_col].astype(str).tolist()

    intent_to_idx = defaultdict(list)
    for i, lab in enumerate(intents):
        intent_to_idx[lab].append(i)

    if intents_list is None:
        intents_list = sorted(intent_to_idx.keys())

    bm25 = build_bm25(corpus)

    rows = []
    for lab in intents_list:
        idxs = intent_to_idx.get(lab, [])
        if len(idxs) < (min_rel + 1):
            rows.append({
                "Intent": lab,
                "NumDocs": len(idxs),
                "nDCG@10": None,
                "Recall@100": None,
                "Recall@1000": None,
                "NumQueries": 0
            })
            continue

        n_q = min(max_queries_per_intent, len(idxs))
        chosen = random.sample(idxs, n_q)

        ndcgs, r100, r1000 = [], [], []
        for qi in tqdm(chosen, desc=f"Intent={lab}", leave=False):
            qtext = corpus[qi]
            rel_set = set(idxs) - {qi}

            ranked = bm25_search(bm25, qtext, topk=topk)

            ndcgs.append(ndcg_at_k(ranked, rel_set, k=10))
            r100.append(recall_at_k(ranked, rel_set, k=100))
            r1000.append(recall_at_k(ranked, rel_set, k=1000))

        rows.append({
            "Intent": lab,
            "NumDocs": len(idxs),
            "nDCG@10": float(np.mean(ndcgs)),
            "Recall@100": float(np.mean(r100)),
            "Recall@1000": float(np.mean(r1000)),
            "NumQueries": len(chosen)
        })

    return pd.DataFrame(rows).sort_values("NumDocs", ascending=False).reset_index(drop=True)
intent_table = intent_wise_metrics_on_global_corpus(
    md_sample,
    TEXT_COL,
    INTENT_COL,
    intents_list=None,             
    max_queries_per_intent=100,
    min_rel=5,
    topk=2000
)

intent_table

import matplotlib.pyplot as plt

curve = recall_at_k_curve_bm25(
    md_sample, TEXT_COL, INTENT_COL,
    k_values=(10, 50, 100, 200, 500, 1000),
    n_queries=200,
    min_rel=30,
    topk=2000
)

ks = [10, 50, 100, 200, 500, 1000]
vals = [curve[f"Recall@{k}"] for k in ks]

plt.figure(figsize=(6.5, 4))
plt.plot(ks, vals, marker="o")
plt.xscale("log")
plt.xlabel("k (log scale)")
plt.ylabel("Recall@k")
plt.title("BM25 Recall@k Curve (Intent Retrieval)")
plt.grid(True, linestyle="--", alpha=0.5)
plt.tight_layout()
plt.show()
