
# Function-Aware Retrieval: Build E5 Index from JSONL (unarXive subset)

This notebook:
1. Loads **4 JSONL** files (one JSON object per line; each line = one paper).
2. Extracts `paper_id`, `title`, and `abstract` from each object.
3. Builds **E5 embeddings** for `"passage: <title> — <abstract>"`.
4. Saves a **FAISS** index (`index.faiss`) and aligned `meta.parquet`.
5. Provides a **post-filter retrieval** function that:
   - Retrieves **top-20** with E5,
   - Runs your **function classifier** on those snippets,
   - Filters/reranks by function probability.
   
> Plug in your own function classifier where indicated (we include a stub).


In [None]:

# If needed, install packages (uncomment if your environment is missing any of these)
# %pip install faiss-cpu pandas pyarrow sentence-transformers tqdm rank_bm25


In [1]:

import numpy as np
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from pathlib import Path
import time


In [2]:

# ---- EDIT THESE PATHS IF NEEDED ----
jsonl_files = [
    "/data/horse/ws/anpa439f-Function_Retrieval_Citation/Research_Project/Corpus/processed_unarxive_extended_data/unarXive_01981221/01/arXiv_src_0101_001.jsonl",
    "/data/horse/ws/anpa439f-Function_Retrieval_Citation/Research_Project/Corpus/processed_unarxive_extended_data/unarXive_01981221/01/arXiv_src_0102_001.jsonl",
    "/data/horse/ws/anpa439f-Function_Retrieval_Citation/Research_Project/Corpus/processed_unarxive_extended_data/unarXive_01981221/01/arXiv_src_0103_001.jsonl",
    "/data/horse/ws/anpa439f-Function_Retrieval_Citation/Research_Project/Corpus/processed_unarxive_extended_data/unarXive_01981221/01/arXiv_src_0104_001.jsonl",
]

# Directory where we'll save the index + metadata
out_dir = Path("e5_index_subset_1")
out_dir.mkdir(exist_ok=True, parents=True)

# Model to use: 'intfloat/e5-base-v2' is a good speed/quality trade-off
E5_MODEL_NAME = "intfloat/e5-small-v2"

# Limit (optional): set to None to index everything
MAX_PAPERS = None  # e.g., 10000


In [3]:
# 🔒 Force CPU + tame threads + use local cache to avoid any network / CUDA shenanigans
import os, torch, time
os.environ["CUDA_VISIBLE_DEVICES"] = ""          # force no-GPU path
os.environ["TOKENIZERS_PARALLELISM"] = "false"   # quieter + safer in notebooks
os.environ["HF_HOME"] = "./hf_cache"             # local cache (no network)

from sentence_transformers import SentenceTransformer

E5_MODEL_NAME = "intfloat/e5-small-v2"  # swap to e5-base-v2 later

t0 = time.time()
model = SentenceTransformer(E5_MODEL_NAME, device="cpu", cache_folder="./hf_cache")
print("Loaded model in", round(time.time()-t0, 2), "s")

# warmup to avoid first-call lag
_ = model.encode(["query: warmup"], normalize_embeddings=True)
print("Warmup ok")


Loaded model in 11.76 s
Warmup ok


In [4]:
# build_meta_with_authors.py (cell)
import json, ast, re, os
from datetime import datetime
from tqdm import tqdm
import pandas as pd

MAX_PAPERS = None  # or an int to truncate during dev

_yr_re = re.compile(r"(19|20)\d{2}")

def best_year_from_obj(obj):
    for key in ("year","published","date","update_date","created"):
        if key in obj and obj[key]:
            s = str(obj[key])
            try:
                y = int(s[:4])
                if 1900 <= y <= datetime.now().year + 1:
                    return y
            except Exception:
                m = _yr_re.search(s)
                if m: return int(m.group(0))
    md = obj.get("metadata") or {}
    pid = obj.get("paper_id") or md.get("id") or obj.get("id") or md.get("arxiv_id") or obj.get("arxiv_id")
    if isinstance(pid, str) and "/" in pid:
        try:
            yy = int(pid.split("/")[1][:2])
            return 2000 + yy if yy < 50 else 1900 + yy
        except Exception:
            pass
    return None

def extract_title_abstract(obj):
    title = None
    md = obj.get("metadata") or {}
    title = md.get("title") or obj.get("title")
    abstract = None
    if isinstance(md.get("abstract"), str):
        abstract = md["abstract"]
    if not abstract and isinstance(obj.get("abstract"), dict):
        abstract = obj["abstract"].get("text")
    if not abstract:
        abstract = obj.get("abstract")
    return title, abstract

def norm_raw_authors(raw):
    if raw is None: return []
    if isinstance(raw, (list, tuple)): return [str(x).strip() for x in raw if str(x).strip()]
    s = str(raw).strip()
    if not s: return []
    if s.startswith("[") and s.endswith("]"):
        try: data = json.loads(s)
        except Exception:
            try: data = ast.literal_eval(s)
            except Exception: data = None
        if isinstance(data, list): return norm_raw_authors(data)
    sep = ";" if ";" in s else ","
    return [t.strip() for t in s.split(sep) if t.strip()]

def authors_from_parsed(ap):
    out=[]
    if isinstance(ap, list):
        for it in ap:
            if isinstance(it, dict):
                nm=(" ".join([it.get("first",""), it.get("last","")])).strip()
            elif isinstance(it, (list,tuple)):
                last=str(it[0]).strip() if len(it)>0 else ""
                first=str(it[1]).strip() if len(it)>1 else ""
                nm=(" ".join([first,last])).strip()
            else:
                nm=str(it).strip()
            if nm: out.append(nm)
    return out

def authors_from_obj(obj):
    md = obj.get("metadata") or {}
    if "authors_parsed" in obj:
        a = authors_from_parsed(obj["authors_parsed"])
        if a: return a
    if "authors_parsed" in md:
        a = authors_from_parsed(md["authors_parsed"])
        if a: return a
    if "authors" in obj:
        a = norm_raw_authors(obj["authors"])
        if a: return a
    if "authors" in md:
        a = norm_raw_authors(md["authors"])
        if a: return a
    return []

def get_pid(obj):
    md = obj.get("metadata") or {}
    return obj.get("paper_id") or md.get("id") or obj.get("id") or md.get("arxiv_id") or obj.get("arxiv_id")

rows = []
for path in jsonl_files:
    with open(path, "r", encoding="utf-8") as f:
        for line in tqdm(f, desc=f"Reading {os.path.basename(path)}"):
            line = line.strip()
            if not line: 
                continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            pid = get_pid(obj)
            title, abstract = extract_title_abstract(obj)
            if not pid or not title or not abstract:
                continue
            authors = authors_from_obj(obj)
            year = best_year_from_obj(obj)
            rows.append({
                "paper_id": pid,
                "title": title.strip(),
                "abstract": str(abstract).strip(),
                "authors": authors,
                "year": year
            })

df = pd.DataFrame(rows)
df = df.drop_duplicates(subset=["paper_id"]).reset_index(drop=True)
if MAX_PAPERS: df = df.head(MAX_PAPERS)
print(f"Loaded {len(df)} unique papers")
print(df.head(3)[["paper_id","title","authors","year"]])




Reading arXiv_src_0101_001.jsonl: 2319it [00:02, 896.69it/s] 
Reading arXiv_src_0102_001.jsonl: 2192it [00:02, 962.40it/s] 
Reading arXiv_src_0103_001.jsonl: 2441it [00:02, 1001.08it/s]
Reading arXiv_src_0104_001.jsonl: 2310it [00:02, 854.29it/s] 


Loaded 9262 unique papers
           paper_id                                              title  \
0  quant-ph/0101147               Radiation trapping in coherent media   
1  quant-ph/0101145  Mimicking a Kerrlike medium in the dispersive ...   
2  quant-ph/0101144  What is Possible Without Disturbing Partially ...   

                                             authors  year  
0  [A. B. Matsko, I. Novikova, M. O. Scully, G. R...  2001  
1     [A. B. Klimov, L. L. Sanchez-Soto, J. Delgado]  2001  
2                    [Masato Koashi, Nobuyuki Imoto]  2001  


## 1) Load and parse JSONL files

In [5]:
#Skip
import json
from tqdm import tqdm
import pandas as pd

def extract_title_abstract(obj):
    # Title
    title = None
    # Try standard location
    if isinstance(obj.get("metadata"), dict):
        title = obj["metadata"].get("title")
    if not title:
        # fallback: sometimes title may be at top-level (rare)
        title = obj.get("title")
    # Abstract
    abstract = None
    # Try metadata.abstract (string) first
    if isinstance(obj.get("metadata"), dict) and isinstance(obj["metadata"].get("abstract"), str):
        abstract = obj["metadata"]["abstract"]
    # Fallback to top-level 'abstract' object: {"section": "...", "text": "..."}
    if not abstract and isinstance(obj.get("abstract"), dict):
        abstract = obj["abstract"].get("text")
    # final fallback
    if not abstract and isinstance(obj.get("abstract"), str):
        abstract = obj["abstract"]
    return title, abstract

rows = []
for p in jsonl_files:
    p = Path(p)
    if not p.exists():
        print(f"WARNING: missing file -> {p}")
        continue
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line: 
                continue
            try:
                obj = json.loads(line)
            except Exception as e:
                # sometimes lines can be malformed; skip
                continue
            paper_id = obj.get("paper_id") or obj.get("id")
            title, abstract = extract_title_abstract(obj)
            if not paper_id or not title or not abstract:
                continue
            rows.append({"paper_id": paper_id, "title": title.strip(), "abstract": abstract.strip()})

df = pd.DataFrame(rows).drop_duplicates(subset=["paper_id"]).reset_index(drop=True)
if MAX_PAPERS is not None:
    df = df.head(MAX_PAPERS)

print(f"Loaded {len(df)} unique papers")
df.head(3)


Loaded 9262 unique papers


Unnamed: 0,paper_id,title,abstract
0,quant-ph/0101147,Radiation trapping in coherent media,We show that the effective decay rate of Zeema...
1,quant-ph/0101145,Mimicking a Kerrlike medium in the dispersive ...,We find an effective Hamiltonian describing th...
2,quant-ph/0101144,What is Possible Without Disturbing Partially ...,Consider a situation in which a quantum system...


## 2) Encode with E5 and build FAISS index

In [7]:
#Skip running this block
# Prepare texts for embeddings
texts = [f"passage: {t} — {a}" for t, a in zip(df['title'].tolist(), df['abstract'].tolist())]

# Encode (L2-normalized embeddings give cosine via inner product)
emb = model.encode(texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True).astype("float32")
dim = emb.shape[1]

# Build FAISS index (cosine via inner product on normalized vectors)
index = faiss.IndexFlatIP(dim)
index.add(emb)

# Save index + metadata
faiss.write_index(index, str(out_dir / "index.faiss"))
df.to_parquet(out_dir / "meta.parquet", index=False)

print(f"Saved: {out_dir/'index.faiss'} and {out_dir/'meta.parquet'}")


Batches:   0%|          | 0/145 [00:00<?, ?it/s]

Saved: e5_index_subset_1/index.faiss and e5_index_subset_1/meta.parquet


## 3) Query + post-filter (top-20 → function)

In [5]:

# # Reload index & metadata
index = faiss.read_index(str(out_dir / "index.faiss"))
meta  = pd.read_parquet(out_dir / "meta.parquet")

# ✅ Reuse the already-loaded model from earlier
q_model = model                      # <-- do NOT call SentenceTransformer() again
_ = q_model.encode(["query: warmup"], normalize_embeddings=True)  # quick warmup

def encode_query(q: str):
    return q_model.encode([f"query: {q}"], normalize_embeddings=True).astype("float32")


### 3.a) Quick test

In [6]:
import re
import ast
import json
import numpy as np
import pandas as pd
from datetime import datetime

# -----------------------
# helpers (format/display)
# -----------------------

def _trim(text, max_chars=450):
    if not text:
        return ""
    s = str(text).strip()
    if len(s) <= max_chars:
        return s
    cut = s[:max_chars].rsplit(" ", 1)[0]
    return cut + "…"


def _format_authors(a, k=3):
    if a is None:
        return []
    if isinstance(a, (list, tuple)):
        names = [str(x).strip() for x in a if str(x).strip()]
    else:
        sep = ";" if ";" in str(a) else ","
        names = [t.strip() for t in str(a).split(sep) if t.strip()]
    if not names:
        return []
    return names[:k] + (["et al."] if len(names) > k else [])


_yr_re = re.compile(r"(19|20)\d{2}")

def _best_year(row):
    for key in ("year", "published", "date", "update_date", "created"):
        if key in row and pd.notna(row[key]):
            try:
                y = int(str(row[key])[:4])
                if 1900 <= y <= datetime.now().year + 1:
                    return y
            except Exception:
                pass
            m = _yr_re.search(str(row[key]))
            if m:
                return int(m.group(0))
    pid = row.get("paper_id") or row.get("arxiv_id") or row.get("id")
    if isinstance(pid, str) and "/" in pid:
        try:
            yy = int(pid.split("/")[1][:2])
            return 2000 + yy if yy < 50 else 1900 + yy
        except Exception:
            pass
    return None


def _extract_abstract(abstract_field):
    if abstract_field is None:
        return ""
    if isinstance(abstract_field, dict):
        return str(abstract_field.get("text") or abstract_field.get("abstract") or "")
    return str(abstract_field)

# -----------------------
# stopwords + token utils
# -----------------------

DEFAULT_STOPWORDS = {
    "a","an","and","the","of","to","in","on","for","with","by","as","at","or","but","if","than","then",
    "from","into","over","under","between","within","without","about","via","per","through","across",
    "is","are","was","were","be","been","being","have","has","had","do","does","did","can","could",
    "may","might","will","would","shall","should","must","not","no","nor","also","both","either","neither",
    "this","that","these","those","it","its","their","our","your","his","her","them","they","we","you","i",
    "such","thus","there","here","where","when","which","who","whom","whose","what","why","how",
    "using","use","used","based","approach","approaches","method","methods","result","results","show",
    "shows","shown","paper","study","work","new"
}


def minmax_norm(x):
    x = np.asarray(x, dtype=np.float32).reshape(-1)
    if x.size == 0:
        return x
    lo, hi = float(np.min(x)), float(np.max(x))
    if not np.isfinite(lo) or not np.isfinite(hi) or (hi - lo) < 1e-12:
        return np.zeros_like(x, dtype=np.float32)
    return (x - lo) / (hi - lo)


_token_re = re.compile(r"\b\w+\b", re.UNICODE)

def tokenize(text):
    if text is None:
        return []
    return [t.lower() for t in _token_re.findall(str(text))]


def content_terms(tokens, stopwords, min_len=3):
    return [t for t in tokens if not t.isdigit() and len(t) >= min_len and t not in stopwords]

# -----------------------
# robust author extraction
# -----------------------

def _authors_from_row(row):

    def _norm_raw_authors(raw):
        if raw is None:
            return []
        if isinstance(raw, (list, tuple)):
            return [str(x).strip() for x in raw if str(x).strip()]
        s = str(raw).strip()
        if not s:
            return []
        if s.startswith("[") and s.endswith("]"):
            try:
                data = json.loads(s)
            except Exception:
                try:
                    data = ast.literal_eval(s)
                except Exception:
                    data = None
            if isinstance(data, list):
                return _norm_raw_authors(data)
        sep = ";" if ";" in s else ","
        return [t.strip() for t in s.split(sep) if t.strip()]

    def _norm_authors_parsed(ap):
        out = []
        if isinstance(ap, (list, tuple)):
            for item in ap:
                if isinstance(item, (list, tuple)):
                    last = str(item[0]).strip() if len(item) > 0 else ""
                    first = str(item[1]).strip() if len(item) > 1 else ""
                    name = " ".join([first, last]).strip()
                    if name:
                        out.append(name)
                elif isinstance(item, dict):
                    first = str(item.get("first", "")).strip()
                    last = str(item.get("last", "")).strip()
                    name = " ".join([first, last]).strip()
                    if name:
                        out.append(name)
                else:
                    s = str(item).strip()
                    if s:
                        out.append(s)
        elif isinstance(ap, str) and ap.strip():
            try:
                data = json.loads(ap)
            except Exception:
                try:
                    data = ast.literal_eval(ap)
                except Exception:
                    data = None
            if isinstance(data, list):
                return _norm_authors_parsed(data)
        return out

    # direct columns
    if "authors" in row and pd.notna(row["authors"]):
        names = _norm_raw_authors(row["authors"])
        if names:
            return names

    if "authors_parsed" in row and pd.notna(row["authors_parsed"]):
        names = _norm_authors_parsed(row["authors_parsed"])
        if names:
            return names

    # nested metadata dict (JSONL)
    if "metadata" in row and isinstance(row["metadata"], dict):
        if "authors" in row["metadata"]:
            names = _norm_raw_authors(row["metadata"]["authors"])
            if names:
                return names
        if "authors_parsed" in row["metadata"]:
            names = _norm_authors_parsed(row["metadata"]["authors_parsed"])
            if names:
                return names

    return []

# -----------------------
# meta building from JSONL
# -----------------------

def _rows_from_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            obj = json.loads(line)
            md = obj.get("metadata", {}) or {}
            authors_field = md.get("authors") or obj.get("authors")
            # could be list, string, or None — leave as is for _authors_from_row to normalize
            yield {
                "paper_id": obj.get("paper_id") or md.get("id"),
                "title": md.get("title", "") or obj.get("title", ""),
                "abstract": obj.get("abstract", md.get("abstract", "")),
                "authors": authors_field,   # <- flat authors column
                "metadata": md,             # <- keep raw metadata for fallback
            }


def build_meta_from_jsonl(paths):
    """
    paths: str or List[str] of jsonl files
    Returns a pandas.DataFrame with columns:
    paper_id | title | abstract | metadata
    """
    if isinstance(paths, str):
        paths = [paths]
    all_rows = []
    for p in paths:
        for row in _rows_from_jsonl(p):
            all_rows.append(row)
    df = pd.DataFrame(all_rows)

    # Basic cleanup / de-dup
    # Prefer keeping the first occurrence per paper_id, then title
    if "paper_id" in df.columns:
        df = df.drop_duplicates(subset=["paper_id"], keep="first")
    if "title" in df.columns:
        df = df.drop_duplicates(subset=["title"], keep="first")
    df = df.reset_index(drop=True)
    return df

# -----------------------
# main search function
# -----------------------

# COSINE-ONLY retrieval (overwrites any previous definition)
def search_post_filter(
    query,
    desired_function,          # kept only for provenance; NOT used for ranking
    topN=20,
    topK_return=10,
    normalize_scores=False,
    stopwords=None,
    min_term_len=3,
    abstract_chars=450,
    authors_shown=3
):
    stopwords = DEFAULT_STOPWORDS if stopwords is None else set(stopwords)

    # 1) semantic retrieve
    qv = encode_query(query)
    scores, idxs = index.search(qv, topN)
    scores, idxs = scores[0].astype(np.float32), idxs[0]

    # 2) order purely by cosine
    order = np.argsort(-scores)
    display_scores = minmax_norm(scores) if normalize_scores else scores

    # 3) lexical explainers
    q_terms_all  = tokenize(query)
    q_terms_used = content_terms(q_terms_all, stopwords, min_len=min_term_len)

    out = []
    for rank_pos in range(min(topK_return, len(order))):
        r = order[rank_pos]
        row = meta.iloc[idxs[r]]

        title = row.get("title", "")
        abstract_txt = _extract_abstract(row.get("abstract", ""))
        paper_id = row.get("paper_id") or row.get("arxiv_id") or row.get("id")

        title_tokens = content_terms(tokenize(title), stopwords, min_len=min_term_len)
        abs_tokens   = content_terms(tokenize(abstract_txt), stopwords, min_len=min_term_len)

        title_matches = sorted(set(q_terms_used) & set(title_tokens))
        abs_matches   = sorted(set(q_terms_used) & set(abs_tokens))

        authors_list = _authors_from_row(row)
        authors_fmt  = _format_authors(authors_list, k=authors_shown)

        out.append({
            "score": float(display_scores[r]),
            "cosine": float(scores[r]),
            "title": title,
            "abstract": _trim(abstract_txt, abstract_chars),
            "arxiv_id": paper_id,
            "year": _best_year(row),
            "authors": authors_fmt,
            "title_matches": title_matches,
            "abs_matches": abs_matches,
            "query_terms_used": q_terms_used,
            "function_requested": str(desired_function) if desired_function is not None else "",
        })
    return out


# -----------------------
# convenience: results -> DataFrame
# -----------------------

def to_df(res):
    cols = ["cosine","p_function","title","year","arxiv_id","authors","title_matches","abs_matches"]
    return pd.DataFrame([{k: r.get(k) for k in cols} for r in res])


#meta = build_meta_from_jsonl(jsonl_files)
#print(len(meta), "papers loaded")
#print(meta.head(5)[["paper_id","title"]])


In [7]:
import os, json
import pandas as pd
from datetime import datetime

def _safe_join(x):
    if x is None: return ""
    if isinstance(x, list): return ", ".join(str(t) for t in x)
    return str(x)

def display_results_table(results):
    df = pd.DataFrame([
        {
            "Title": r.get("title", ""),
            "Year": r.get("year", ""),
            "Authors": _safe_join(r.get("authors", [])),
            "Abstract": r.get("abstract", ""),
            "Title Matches": _safe_join(r.get("title_matches", [])),
            "Abstract Matches": _safe_join(r.get("abs_matches", [])),
            "Score": r.get("cosine_norm", r.get("cosine", None)),
        } for r in (results or [])
    ])
    return df


In [8]:
def process_current_classification(
    classified_path="classified_outputs.jsonl",
    out_dir="outputs",
    topN=50,
    topK_return=10,
    normalize_scores=True,
    debug=False
):
    import os, json, pandas as pd
    from datetime import datetime

    def read_last_jsonl(path: str):
        """Return the last non-empty JSON object from a .jsonl file."""
        with open(path, "r", encoding="utf-8") as f:
            lines = [ln.strip() for ln in f if ln.strip()]
        if not lines:
            raise ValueError(f"No lines found in {path}")
        try:
            return json.loads(lines[-1])
        except json.JSONDecodeError as e:
            raise ValueError(f"Last line is not valid JSON: {e}\nLine: {lines[-1][:200]}...")

    os.makedirs(out_dir, exist_ok=True)
    out_jsonl = os.path.join(out_dir, "topk_candidates.jsonl")
    out_csv   = os.path.join(out_dir, "topk_candidates.csv")

    # --- Read the most recent block (last line), not the first line ---
    obj = read_last_jsonl(classified_path)

    query = obj.get("query", "")
    sent_list = obj.get("sentence_classification", []) or []
    if debug:
        print(f"[debug] loaded {len(sent_list)} sentences from latest block")

    rows = []
    for si, s in enumerate(sent_list):
        text = (s.get("sentence") or "").strip()
        needs_cit = bool(s.get("needs_citation", True))
        just = s.get("justification", "")

        # --- Correctly read plural key and normalize ---
        func_list = s.get("citation_functions")
        if func_list is None:
            # graceful fallback if upstream ever writes singular by mistake
            func_list = s.get("citation_function")
        if isinstance(func_list, str):
            func_list = [func_list]
        if not isinstance(func_list, list):
            func_list = []

        func_first = (func_list[0] if func_list else "").strip().lower()
        desired_function = func_first or "background"  # fallback so retrieval still works

        retrieved_at = datetime.utcnow().isoformat()

        try:
            results = search_post_filter(
                query=text if text else query,     # prefer sentence; fallback to full query
                desired_function=desired_function, # single normalized label
                topN=topN,
                topK_return=topK_return,
                normalize_scores=normalize_scores
            ) or []
        except Exception as e:
            if debug: print(f"[si{si}] retrieval error: {e}")
            results = []

        if results:
            for rank, r in enumerate(results):
                rows.append({
                    "sentence_idx": si,
                    "sentence_uid": f"blk1|si{si}",
                    "sentence_text": text,
                    "citation_function": desired_function,   # <- always a string
                    "citation_functions": func_list,         # keep original list for traceability
                    "needs_citation": needs_cit,
                    "justification": just,
                    "rank": rank,
                    "paper_id": r.get("arxiv_id") or r.get("paper_id") or "",
                    "title": r.get("title", ""),
                    "year": r.get("year", ""),
                    "authors": ", ".join(r.get("authors", [])) if isinstance(r.get("authors"), list)
                               else (r.get("authors") or ""),
                    "abstract": r.get("abstract", ""),
                    "normalized_score": r.get("cosine_norm", r.get("cosine")),
                    "title_matches": ", ".join(r.get("title_matches", [])) if isinstance(r.get("title_matches"), list)
                                     else (r.get("title_matches") or ""),
                    "abs_matches": ", ".join(r.get("abs_matches", [])) if isinstance(r.get("abs_matches"), list)
                                   else (r.get("abs_matches") or ""),
                    "retrieval_error": None,
                    "retrieved_at": retrieved_at
                })
        else:
            rows.append({
                "sentence_idx": si,
                "sentence_uid": f"blk1|si{si}",
                "sentence_text": text,
                "citation_function": desired_function,
                "citation_functions": func_list,
                "needs_citation": needs_cit,
                "justification": just,
                "rank": None,
                "paper_id": "",
                "title": "",
                "year": "",
                "authors": "",
                "abstract": "",
                "normalized_score": None,
                "title_matches": "",
                "abs_matches": "",
                "retrieval_error": None,
                "retrieved_at": retrieved_at
            })

    # Overwrite outputs each run
    with open(out_jsonl, "w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False) + "\n")

    df = pd.DataFrame(rows)
    df.to_csv(out_csv, index=False)

    if debug:
        print(f"Sentences: {len(sent_list)} | rows written: {len(rows)}")
        print(f"- {out_jsonl}\n- {out_csv}")

    return df


In [23]:
import pandas as pd
df = pd.read_csv("outputs/topk_candidates.csv")
print(df[["sentence_idx","citation_function"]].drop_duplicates().sort_values(["sentence_idx"]))


    sentence_idx citation_function
0              0        motivation
10             1        futurework


In [9]:
# Skip this block
# # force authors into strings
meta["authors"] = meta["authors"].apply(lambda x: x if isinstance(x, str) or x is None else str(x))

# Example: one clause from a generated answer
clause = "Compared to simple quench models, back reactions temper and delay amplification."
func   = "COMPARE"

results = search_post_filter(
    query=clause,
    desired_function=func,
    topN=50,
    topK_return=10
)

# Convert to DataFrame for inspection
df_results = pd.DataFrame(results)
print(df_results[["cosine","function_requested","title","year","arxiv_id","authors","title_matches","abs_matches"]])


     cosine function_requested  \
0  0.828061            COMPARE   
1  0.827704            COMPARE   
2  0.826666            COMPARE   
3  0.826383            COMPARE   
4  0.825162            COMPARE   
5  0.825125            COMPARE   
6  0.824200            COMPARE   
7  0.823486            COMPARE   
8  0.822597            COMPARE   
9  0.822466            COMPARE   

                                               title  year          arxiv_id  \
0      Microscopic Reaction Dynamics at SPS and RHIC  2001   nucl-th/0104040   
1  Jet Quenching and the p-bar >= pi- Anomaly at ...  2001   nucl-th/0104066   
2      New results on the temporal structure of GRBs  2001  astro-ph/0103011   
3  On the mean field treatment of attractive inte...  2001  cond-mat/0104317   
4  Differential Cross Sections Measurement for th...  2001   nucl-ex/0101001   
5  A contiuum model for low temperature relaxatio...  2001  cond-mat/0104235   
6  A semi-analytical approach to non-linear shock...  2001  astro


## Next steps
- Replace `FunctionClassifierStub` with your **trained function classifier**.
- If you need **BM25** fusion for acronyms/symbols, we can add a BM25 index and blend scores.
- For full-text retrieval, split papers into **paragraphs/sentences** and index those instead of abstracts.
- If your set grows beyond ~1–2M passages, consider FAISS **IVF-PQ** or **HNSW** for faster search.


In [10]:
# 1) Run llm_test.py to regenerate the file (it overwrites classified_outputs.jsonl)

# 2) Process the single block
df_all = process_current_classification(
    classified_path="/data/horse/ws/anpa439f-Function_Retrieval_Citation/Research_Project/classification_data/classified_outputs.jsonl",
    out_dir="outputs",
    topN=50,
    topK_return=10,
    normalize_scores=True,
    debug=True
)

# 3) Inspect safely
if df_all.empty:
    print("No rows saved — check your classification or retrieval.")
else:
    # Show the first sentence's candidates
    s0 = df_all["sentence_idx"].min()
    display(df_all[(df_all["sentence_idx"]==s0) & (df_all["rank"].notna())].sort_values("rank").head(10))


[debug] loaded 2 sentences from latest block
Sentences: 2 | rows written: 20
- outputs/topk_candidates.jsonl
- outputs/topk_candidates.csv


Unnamed: 0,sentence_idx,sentence_uid,sentence_text,citation_function,citation_functions,needs_citation,justification,rank,paper_id,title,year,authors,abstract,normalized_score,title_matches,abs_matches,retrieval_error,retrieved_at
0,0,blk1|si0,The current work uses a range of methods to de...,uses,[Uses],True,This sentence was generated to fulfill the 'Us...,0,math/0104121,Eigenvalue estimates of the Dirac operator dep...,2001,Thomas FriedrichKlaus-Dieter Kirchberg,We prove a new lower bound for the first eigen...,0.861814,"dirac, estimates","bounds, curvature, dirac, manifolds, riemannia...",,2025-09-06T18:42:14.038301
1,0,blk1|si0,The current work uses a range of methods to de...,uses,[Uses],True,This sentence was generated to fulfill the 'Us...,1,math/0103095,On eigenvalue estimates for the Dirac operator,2001,N. GinouxB. Morel,We give lower bounds for the eigenvalues of th...,0.855258,"dirac, estimates","bounds, curvature, dirac, operators",,2025-09-06T18:42:14.038301
2,0,blk1|si0,The current work uses a range of methods to de...,uses,[Uses],True,This sentence was generated to fulfill the 'Us...,2,math/0101061,Spectral estimates on 2-tori,2001,Bernd Ammann,We prove upper and lower bounds for the eigenv...,0.851397,estimates,"bounds, dirac, riemannian, spin, uses",,2025-09-06T18:42:14.038301
3,0,blk1|si0,The current work uses a range of methods to de...,uses,[Uses],True,This sentence was generated to fulfill the 'Us...,3,math/0101111,"Eigenvalue estimates for the Dirac-Schr\""oding...",2001,Bertrand Morel,We give new estimates for the eigenvalues of t...,0.843709,"dirac, estimates, operators","curvature, dirac, estimates, inequalities",,2025-09-06T18:42:14.038301
4,0,blk1|si0,The current work uses a range of methods to de...,uses,[Uses],True,This sentence was generated to fulfill the 'Us...,4,math/0102135,New bounds on Kakeya problems,2001,Nets KatzTerence Tao,We establish new estimates on the Minkowski an...,0.830213,bounds,"bounds, establish, estimates",,2025-09-06T18:42:14.038301
5,0,blk1|si0,The current work uses a range of methods to de...,uses,[Uses],True,This sentence was generated to fulfill the 'Us...,5,quant-ph/0103089,Geometrisation of electromagnetic field and to...,2001,O. A. Olkhov,A new approach is proposed for an electromagne...,0.825784,,"curvature, dirac",,2025-09-06T18:42:14.038301
6,0,blk1|si0,The current work uses a range of methods to de...,uses,[Uses],True,This sentence was generated to fulfill the 'Us...,6,math/0101084,Curvature Estimates in Asymptotically Flat Man...,2001,Felix FinsterInes Kath,We consider an asymptotically flat Riemannian ...,0.824458,"curvature, estimates, manifolds","bounds, curvature, riemannian, spin",,2025-09-06T18:42:14.038301
7,0,blk1|si0,The current work uses a range of methods to de...,uses,[Uses],True,This sentence was generated to fulfill the 'Us...,7,hep-th/0103206,Dirac Operator on the Quantum Sphere,2001,A. PinzulA. Stern,We construct a Dirac operator on the quantum s...,0.82273,dirac,dirac,,2025-09-06T18:42:14.038301
8,0,blk1|si0,The current work uses a range of methods to de...,uses,[Uses],True,This sentence was generated to fulfill the 'Us...,8,cond-mat/0101228,Bounding and approximating parabolas for the s...,2001,H. -J. SchmidtJ. SchnackMarshall Luban,We prove that for a wide class of quantum spin...,0.822242,spin,"bounds, spin",,2025-09-06T18:42:14.038301
9,0,blk1|si0,The current work uses a range of methods to de...,uses,[Uses],True,This sentence was generated to fulfill the 'Us...,9,math/0103058,A lower bound in an approximation problem invo...,2001,Jean-Francois Burnol,We slightly improve the lower bound of Baez-Du...,0.821792,,,,2025-09-06T18:42:14.038301


In [37]:
prompt = build_llm_postfilter_prompt(
    clause=clause,
    func=func,
    results=results,   # or candidates[:10] if you rerank
    top_k=10
)

print("=== Prompt sent to LLM ===")
print(prompt[:3000])   # print first ~3000 chars so it's not overwhelming
print("\n=== End prompt ===")


=== Prompt sent to LLM ===
You are a scientific evaluator. Given one clause and N candidate abstracts, decide which abstract (if any) best supports the clause for the specified citation function.
Allowed functions: BACKGROUND, USE, COMPARE, EXTENDS, CONTINUATION, FUTUREWORK.
Function to evaluate: COMPARE
- BACKGROUND: definition/scene-setting.
- USE: concrete application/implementation/where it's used.
- COMPARE: explicit comparison, pros/cons, versus/baseline (look for phrases like 'compared to', 'versus', 'in contrast', 'baseline').
- EXTENDS: builds on prior work, generalizes.
- CONTINUATION: follow-up/replication/continued line.
- FUTUREWORK: explicit future work/limitations/next steps.
Strict accept rule:
  • Only select a candidate if it (a) contains at least 2 topical overlap terms with the clause, AND (b) explicitly expresses the requested function.
  • If no candidate satisfies both, return best_id=null and best_idx=null.
Return STRICT JSON only (no prose) with keys:
{"best_id