In [1]:
import networkx as nx
from graph_generator.graphparsers import RelationshipGraphParser
from linearization_utils import *
from retrieval_utils import similarity_search_graph_docs

from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document

from typing import List, Dict, Optional, Tuple
import time

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

import time
import pandas as pd
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
CONFIG = {
    # === Embedding & VectorStore ===
    "embedding_model": "sentence-transformers/all-MiniLM-L6-v2",  # Embedding model for documents/questions
    "faiss_search_k": 3,  # Number of nearest neighbors to retrieve from FAISS

    # === LLM (text generation) ===
    "llm_model_id": "microsoft/Phi-4-mini-reasoning",  # HuggingFace model ID
    "device_map": "auto",  # Device placement: "cuda", "mps", "cpu", or "auto"
    "dtype_policy": "auto",  # Precision: "auto", "bf16", "fp16", or "fp32"
    "max_new_tokens": 256,  # Maximum tokens generated per response
    "do_sample": True,  # Whether to use sampling (True) or greedy decoding (False)
    "temperature": 0.4,  # Randomness control for sampling; lower = more deterministic
    "top_p": 1.0,  # Nucleus sampling threshold; 1.0 = no restriction
    "return_full_text": False,  # Return full text (input+output) if True, only output if False
    "seed": None,  # Random seed for reproducibility; set to int or None

    # === Prompt / Answer ===
    "answer_mode": "YES_NO",  # Answer format mode, e.g., YES/NO
    "answer_uppercase": True,  # If True → "YES"/"NO", else "yes"/"no"

    # === Prompt construction ===
    "include_retrieved_context": True,  # Include retrieved Q&A in prompt
    "include_current_triples": True,  # Include graph triples in prompt
}

try:
    from transformers import set_seed  # Utility for reproducibility
except Exception:
    set_seed = None


## RAG workflow

In [3]:
def _select_dtype() -> torch.dtype:
    """Choose dtype based on CONFIG['dtype_policy'] and hardware."""
    policy = CONFIG.get("dtype_policy", "auto")
    if policy == "bf16":
        return torch.bfloat16
    if policy == "fp16":
        return torch.float16
    if policy == "fp32":
        return torch.float32

    # auto mode
    if torch.cuda.is_available():
        return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    # MPS backend works more reliably with fp32
    if torch.backends.mps.is_available():
        return torch.float32
    return torch.float32

def _yn(text_yes="YES", text_no="NO"):
    return (text_yes, text_no) if CONFIG.get("answer_uppercase", True) else (text_yes.lower(), text_no.lower())

# =========================
# Embeddings / Vectorstore
# =========================
emb = HuggingFaceEmbeddings(model_name=CONFIG["embedding_model"])  # Local embedding model (MiniLM-L6-v2, 384 dim)

def build_faiss_index(docs: List[Document]) -> FAISS:
    return FAISS.from_documents(docs, emb)

# =========================
# LLM Loader
# =========================
def load_llm_pipeline(
    model_id: Optional[str] = None,       # HuggingFace model id
    device_map: Optional[str] = None,     # Device placement
    dtype: Optional[torch.dtype] = None,  # Torch dtype
    max_new_tokens: Optional[int] = None, # Max tokens per generation
    temperature: Optional[float] = None,  # Sampling temperature
    top_p: Optional[float] = None,        # Nucleus sampling threshold
    do_sample: Optional[bool] = None,     # Sampling vs greedy
    return_full_text: Optional[bool] = None,  # Return input+output if True
):
    """
    Return a text-generation pipeline for QA generation.
    All defaults pull from CONFIG; any arg here will override CONFIG.
    """
    model_id = model_id or CONFIG["llm_model_id"]
    device_map = device_map or CONFIG["device_map"]
    dtype = dtype or _select_dtype()
    max_new_tokens = max_new_tokens or CONFIG["max_new_tokens"]
    temperature = CONFIG["temperature"] if temperature is None else temperature
    top_p = CONFIG["top_p"] if top_p is None else top_p
    do_sample = CONFIG["do_sample"] if do_sample is None else do_sample
    return_full_text = CONFIG["return_full_text"] if return_full_text is None else return_full_text

    if set_seed and isinstance(CONFIG.get("seed"), int):
        set_seed(CONFIG["seed"])

    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map=device_map,
        trust_remote_code=True,
    )

    gen_pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device_map=device_map,
        torch_dtype=dtype,
        return_full_text=return_full_text,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
    )
    return gen_pipe, tokenizer

# =========================
# Question → Graph (generic)
# =========================
def parse_question_to_graph_generic(parser, question: str) -> Tuple[nx.Graph, List[Dict]]:
    """
    Compatible with RelationshipGraphParser.question_to_graph
    and CausalQuestionGraphParser.question_to_causal_graph
    """
    if hasattr(parser, "question_to_graph"):
        return parser.question_to_graph(question)
    elif hasattr(parser, "question_to_causal_graph"):
        return parser.question_to_causal_graph(question)
    else:
        raise AttributeError("Parser must provide question_to_graph or question_to_causal_graph")

# =========================
# Prompt Builder
# =========================
def make_graph_qa_prompt(
    question: str,
    G: nx.Graph,
    relations: Optional[List[Dict]] = None,
    retrieved_docs = None
) -> str:
    # 1) retrieved context (if any)
    sections = []
    if retrieved_docs and CONFIG.get("include_retrieved_context", True):
        doc0, score0 = retrieved_docs[0]
        related_triples = doc0.page_content.strip()
        related_answer  = doc0.metadata.get("llm_answer", "")
        sections.append(
            "<<<RETRIEVED_CONTEXT_START>>>\n"
            "The system searched for a related question in the database. Below are related question's graph triples and its prior answer as reference. " \
            "You don't have to follow it completely, just use it as a reference.\n"
            f"[RELATED QUESTION'S GRAPH TRIPLES]:\n{related_triples}\n"
            f"[RELATED QUESTION'S ANSWER]: {related_answer}\n"
            "<<<RETRIEVED_CONTEXT_END>>>"
        )

    # 2) current question + triples (optional)
    triples_text = ""
    if relations and CONFIG.get("include_current_triples", True):
        text = build_relationship_text(question, G, relations) 
        triples_text += text
    
    sections.append(
        "[GRAPH FORMAT DESCRIPTION]:\n"
        "The graph triples are encoded as a JSON object with three fields:\n"
        "- \"entity_dict\": A list of entity names, where the index is the entity ID.\n"
        "- \"relation_dict\": A list of relation types, where the index is the relation ID.\n"
        "- \"edges\": A list of triples [head_id, relation_id, tail_id], meaning:\n"
        "  entity_dict[head_id] -- relation_dict[relation_id] --> entity_dict[tail_id].\n\n"
        "Example:\n"
        "{\n"
        "  \"entity_dict\": [\"cat\", \"chases\", \"mouse\"],\n"
        "  \"relation_dict\": [\"subj\", \"obj\"],\n"
        "  \"edges\": [[0,0,1], [1,1,2]]\n"
        "}\n"
        "represents:\n"
        "- cat --subj--> chases\n"
        "- chases --obj--> mouse\n"
    )    
    
    q_block = f"[CURRENT QUESTION]: {question}"
    if triples_text.strip():
        q_block += f"\n[CURRENT QUESTION'S GRAPH TRIPLES]:\n{triples_text}"
    sections.append(q_block)

    # 3) task instructions (placed at the end)
    yes, no = _yn("YES", "NO")
    rules = (
        "[TASK]: You are a precise QA assistant for binary (yes/no) questions.\n"
        f"- Output ONLY one token: {yes} or {no}.\n"
        "- Do NOT copy or summarize any context.\n"
        "- Do NOT show reasoning, steps, or extra words.\n"
        f"[ANSWER]: "
    )
    sections.append(rules)

    # Final prompt
    prompt = "\n\n".join(sections)
    return prompt

# =========================
# LLM Answerer
# =========================
def answer_with_llm(
    question: str,
    gen_pipe,
    parser,
    faiss_db = None,
    prompt = None
) -> str:
    retrieved_docs = None
    if faiss_db:
        k = CONFIG.get("faiss_search_k", 3)  # Number of docs to retrieve
        _, hits = similarity_search_graph_docs(question, parser, faiss_db, k=k)
        retrieved_docs = hits
        
    if prompt == None:
        G, rels = parse_question_to_graph_generic(parser, question)
        prompt = make_graph_qa_prompt(question, G, rels, retrieved_docs)

    out = gen_pipe(prompt)
    text = out[0]["generated_text"]

    # If return_full_text=False → only new content; else trim prefix
    if CONFIG.get("return_full_text", True):
        answer = text[len(prompt):].strip()
    else:
        answer = text.strip()

    # Normalize YES/NO case
    if CONFIG.get("answer_mode", "YES_NO"):
        yes, no = _yn("YES", "NO")
        a = answer.strip().lower()
        if "yes" in a and "no" not in a:
            answer = yes
            print(answer)
            return answer
        elif "no" in a and "yes" not in a:
            answer = no
            print(answer)
            return answer
        else:
            answer = answer_with_llm(question, gen_pipe, parser, faiss_db, prompt)
    
    

# =========================
# Build Docs with LLM Answer
# =========================
def build_docs_with_answer(
    questions: List[str],
    parser,
    gen_pipe,
    *,
    add_prompt_snapshot: bool = False,
    faiss_db = None
) -> List[Document]:
    docs: List[Document] = []
    for qid, q in enumerate(questions, start=1):
        G, rels = parse_question_to_graph_generic(parser, q)
        text = build_relationship_text(q, G, rels)  # Output [QUESTION][GRAPH][TRIPLES]

        # Get LLM answer
        answer = answer_with_llm(q, gen_pipe, parser, faiss_db)

        metadata = {
            "graph_id": f"Q{qid}",
            "question": q,
            "num_nodes": G.number_of_nodes(),
            "num_edges": G.number_of_edges(),
            "llm_model": CONFIG["llm_model_id"],
            "llm_answer": answer,
            "created_at": int(time.time()),
        }
        if add_prompt_snapshot:
            metadata["prompt_snapshot"] = make_graph_qa_prompt(q, G, rels)

        docs.append(Document(page_content=text, metadata=metadata))
    return docs


def build_faiss_index(docs: List[Document]) -> FAISS:
    vectordb = FAISS.from_documents(docs, emb)
    return vectordb


  emb = HuggingFaceEmbeddings(model_name=CONFIG["embedding_model"])  # Local embedding model (MiniLM-L6-v2, 384 dim)


## Text RAG

In [4]:
questions = {
       # --- 原始 YES 类 ---
    "Is the Earth, which orbits the Sun along with seven other planets in the solar system, generally considered to be round in shape despite its slight equatorial bulge?",
    "Given that Earth completes a rotation approximately every 24 hours, does this rotation cause the Sun to appear to rise in the east and set in the west from the perspective of an observer on the surface?",
    "Considering that Paris is the administrative and cultural center of France, is it also the official capital city of the country according to its constitution?",
    "Based on basic human biology, which requires oxygen for cellular respiration and energy production, do humans need oxygen to survive under normal conditions?",
    "Since the Moon is gravitationally bound to Earth and completes an orbit approximately every 27 days, is it classified as Earth's only natural satellite?", 

    "Although the Sahara Desert is one of the largest deserts in the world, is it located in the South American continent instead of Africa?",
    "Even though the Amazon River is among the longest rivers globally, is it actually longer than the Nile River when measured by official geographical surveys?",
    "Despite Tokyo being the largest city in Japan and a major global hub, is it the capital city of South Korea?",
    "Considering the natural habitats of penguins, which are mainly located in the Southern Hemisphere, do penguins naturally live in the Arctic region alongside polar bears?",
    "Although gold is a dense and valuable metal, is its density greater than that of lead, making it heavier per cubic centimeter?",
}

In [5]:
GOLD_LABELS = {
    # --- 原始 YES 类 ---
    "Is the Earth, which orbits the Sun along with seven other planets in the solar system, generally considered to be round in shape despite its slight equatorial bulge?": "YES",
    "Given that Earth completes a rotation approximately every 24 hours, does this rotation cause the Sun to appear to rise in the east and set in the west from the perspective of an observer on the surface?": "YES",
    "Considering that Paris is the administrative and cultural center of France, is it also the official capital city of the country according to its constitution?": "YES",
    "Based on basic human biology, which requires oxygen for cellular respiration and energy production, do humans need oxygen to survive under normal conditions?": "YES",
    "Since the Moon is gravitationally bound to Earth and completes an orbit approximately every 27 days, is it classified as Earth's only natural satellite?": "YES",

    # --- 原始 NO 类 ---
    "Although the Sahara Desert is one of the largest deserts in the world, is it located in the South American continent instead of Africa?": "NO",
    "Even though the Amazon River is among the longest rivers globally, is it actually longer than the Nile River when measured by official geographical surveys?": "NO",
    "Despite Tokyo being the largest city in Japan and a major global hub, is it the capital city of South Korea?": "NO",
    "Considering the natural habitats of penguins, which are mainly located in the Southern Hemisphere, do penguins naturally live in the Arctic region alongside polar bears?": "NO",
    "Although gold is a dense and valuable metal, is its density greater than that of lead, making it heavier per cubic centimeter?": "NO",

    # --- Follow-up Questions (YES/NO balanced) ---
    # Earth shape
    "Despite the Earth's slightly flattened poles, is its shape closer to a sphere than to a flat surface?": "YES",
    "Is the Earth perfectly flat with no curvature anywhere on its surface?": "NO",

    # Sun rotation
    "Given the Earth's rotation, is the apparent motion of the Sun consistent with the Sun rising in the east?": "YES",
    "If the Earth did not rotate on its axis, would the Sun still rise and set in the same pattern as it does now?": "NO",

    # Paris as capital
    "Considering France's administrative structure, is Paris recognized as the political and economic capital of the nation?": "YES",
    "Is Berlin, rather than Paris, designated as the official capital of France in any historical or legal record?": "NO",

    # Oxygen necessity
    "Since oxygen is vital for human life, is it correct to say that humans cannot survive without breathing air containing oxygen?": "YES",
    "Can humans live indefinitely without any access to oxygen in their environment?": "NO",

    # Moon as satellite
    "Is the Moon the only large natural body that consistently orbits Earth in the solar system?": "YES",
    "Do humans have multiple moons orbiting the Earth, similar to Jupiter or Saturn?": "NO",

    # Sahara Desert
    "Is the Sahara Desert geographically located across multiple countries in northern Africa?": "YES",
    "Is the Sahara Desert primarily located in the continent of South America?": "NO",

    # Amazon River
    "Does the Nile River surpass the Amazon River in length when measured by the most widely accepted geographical data?": "YES",
    "Is the Amazon River considered to originate in Europe according to global mapping authorities?": "NO",

    # Tokyo
    "Is Tokyo the capital city of Japan and a major economic center in Asia?": "YES",
    "Is Tokyo officially listed as the capital city of South Korea in government documents?": "NO",

    # Penguins
    "Do penguins naturally inhabit regions in the Southern Hemisphere, particularly Antarctica?": "YES",
    "Do penguins live alongside polar bears in the Arctic region as part of their natural habitat?": "NO",

    # Gold vs Lead
    "Is gold denser than most metals but still slightly less dense than lead?": "NO",
    "Is gold classified as a metal due to its physical and chemical properties?": "YES"
}


In [6]:
# === Text RAG: 仅用问题文本入库 ===
from langchain.schema import Document

def build_text_docs_with_answer(
    questions: List[str],
    gen_pipe,
    *,
    add_prompt_snapshot: bool = False,
    text_db: Optional[FAISS] = None
) -> List[Document]:
    """
    生成仅文本 RAG 的文档，并使 metadata 字段与图 RAG 对齐：
    - graph_id / question / num_nodes / num_edges / llm_model / llm_answer / created_at / prompt_snapshot(可选)
    - 其中 num_nodes/num_edges 统一置 0，保持同名键方便评测与对比
    """
    docs: List[Document] = []
    for qid, q in enumerate(questions, start=1):
        # 文本版页面内容：只存问题文本
        page_content = f"{q}"

        # 生成 LLM 答案（文本 RAG 检索）
        answer = answer_with_llm_text(q, gen_pipe, text_db=text_db)

        # metadata 字段与 Graph RAG 对齐
        metadata = {
            "graph_id": f"Q{qid}",
            "question": q,
            "num_nodes": 0,                    # 对齐字段
            "num_edges": 0,                    # 对齐字段
            "llm_model": CONFIG["llm_model_id"],
            "llm_answer": answer,
            "created_at": int(time.time()),
        }
        if add_prompt_snapshot:
            # 为了对齐，也提供 prompt_snapshot；注意这里是“文本 RAG”的 prompt
            # 为了避免再次触发生成，这里重建与上面一致的 prompt 片段
            prompt_snapshot = make_text_qa_prompt(q, None if not text_db else similarity_search_text_docs(q, text_db, k=CONFIG.get("faiss_search_k",3))[1])
            metadata["prompt_snapshot"] = prompt_snapshot

        docs.append(Document(page_content=page_content, metadata=metadata))
    return docs


def build_docs_text_only(questions: List[str]) -> List[Document]:
    return [build_text_doc(q) for q in questions]

def build_text_faiss_index(questions: List[str]) -> FAISS:
    docs = build_docs_text_only(questions)
    return FAISS.from_documents(docs, emb)

# === Text RAG: 相似度检索（与入库同分布：纯问题文本） ===
def similarity_search_text_docs(
    user_question: str,
    vectordb: FAISS,
    k: int = 5,
):
    query_text = f"{user_question}"
    results = vectordb.similarity_search_with_score(query_text, k=k)
    return query_text, results


# === Text RAG: Prompt（不含图三元组；可拼检索上下文的原问题文本与历史答案） ===
def make_text_qa_prompt(
    question: str,
    retrieved_docs=None
) -> str:
    sections = []
    if retrieved_docs and CONFIG.get("include_retrieved_context", True):
        doc0, _ = retrieved_docs[0]
        related_q_txt = doc0.page_content.strip()
        related_answer = (doc0.metadata or {}).get("llm_answer", "")
        sections.append(
            "<<<RETRIEVED_CONTEXT_START>>>\n"
            "The system searched for a related question in the database. Below are related question's graph triples and its prior answer as reference. " \
            "You don't have to follow it completely, just use it as a reference.\n"
            f"[RELATED QUESTION TEXT]:\n{related_q_txt}\n"
            f"[RELATED ANSWER]: {related_answer}\n"
            "<<<RETRIEVED_CONTEXT_END>>>"
        )

    sections.append(f"[CURRENT QUESTION]: {question}")

    yes, no = _yn("YES", "NO")
    sections.append(
        "[TASK]: You are a precise QA assistant for binary (yes/no) questions.\n"
        f"- Output ONLY one token: {yes} or {no}.\n"
        "- Do NOT copy or summarize any context.\n"
        "- Do NOT show reasoning, steps, or extra words.\n"
        f"[ANSWER]: "
    )
    return "\n\n".join(sections)

def answer_with_llm_text(
    question: str,
    gen_pipe,
    *,
    text_db: Optional[FAISS] = None
) -> str:
    # 检索（可选）
    retrieved_docs = None
    if text_db:
        k = CONFIG.get("faiss_search_k", 3)
        _, hits = similarity_search_text_docs(question, text_db, k=k)
        retrieved_docs = hits

    # Prompt
    prompt = make_text_qa_prompt(question, retrieved_docs)

    # 生成
    out = gen_pipe(prompt)
    text = out[0]["generated_text"]

    # 取输出
    if CONFIG.get("return_full_text", True):
        answer = text[len(prompt):].strip()
    else:
        answer = text.strip()

    # 归一化 YES/NO（与现有 answer_with_llm 一致）
    if CONFIG.get("answer_mode", "YES_NO"):
        yes, no = _yn("YES", "NO")
        a = answer.strip().lower()
        if "yes" in a and "no" not in a:
            answer = yes
        elif "no" in a and "yes" not in a:
            answer = no
        else:
            # 回退：保持原样（或你也可和图版一样递归一次，这里避免递归以最少改动为主）
            answer = no  # 与“不确定选 NO”的规则一致
    return answer

def build_text_faiss_index_with_answers(
    questions: List[str],
    gen_pipe,
    *,
    add_prompt_snapshot: bool = False,
    bootstrap_db: Optional[FAISS] = None
) -> FAISS:
    """
    用文本 RAG 路线生成答案并入库，然后返回 FAISS 向量库。
    bootstrap_db: 若传入，文本检索会优先引用该库的历史问答作为 retrieved context（冷启动可传 None）。
    """
    docs = build_text_docs_with_answer(
        questions=questions,
        gen_pipe=gen_pipe,
        add_prompt_snapshot=add_prompt_snapshot,
        text_db=bootstrap_db,
    )
    print(docs)
    return FAISS.from_documents(docs, emb)



In [7]:
def measure_once_mode(
    question: str,
    mode: str,                 # "text" or "graph"
    gen_pipe,
    tokenizer,
    parser=None,
    text_db: Optional[FAISS] = None,
    graph_db: Optional[FAISS] = None,
    *,
    label: Optional[str] = None,
    use_cuda_mem: bool = True,
) -> Dict:
    assert mode in ("text", "graph")

    # 选择检索与prompt
    if mode == "text":
        retrieved_docs = None
        if text_db and CONFIG.get("include_retrieved_context", True):
            _, hits = similarity_search_text_docs(question, text_db, k=CONFIG.get("faiss_search_k", 3))
            retrieved_docs = hits if hits else None
        prompt = make_text_qa_prompt(question, retrieved_docs=retrieved_docs)
    else:
        retrieved_docs = None
        if graph_db and CONFIG.get("include_retrieved_context", True):
            _, hits = similarity_search_graph_docs(question, parser, graph_db, k=CONFIG.get("faiss_search_k", 3))
            retrieved_docs = hits if hits else None
        G, rels = parse_question_to_graph_generic(parser, question)
        prompt = make_graph_qa_prompt(question, G, rels, retrieved_docs)

    # token 计数
    def _count_tokens(tokenizer, text: str) -> int:
        return len(tokenizer.encode(text, add_special_tokens=False))

    in_tok = _count_tokens(tokenizer, prompt)

    # 计时 & 生成
    peak_mem = None
    if use_cuda_mem and torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

    t0 = time.perf_counter()
    out = gen_pipe(prompt)
    dt = time.perf_counter() - t0

    text = out[0]["generated_text"]
    if CONFIG.get("return_full_text", False):
        answer = text[len(prompt):].strip()
    else:
        answer = text.strip()

    out_tok = _count_tokens(tokenizer, answer)

    if use_cuda_mem and torch.cuda.is_available():
        torch.cuda.synchronize()
        peak_mem = torch.cuda.max_memory_allocated() / (1024**2)

    return {
        "label": label or f"{mode}_rag",
        "mode": mode,
        "question": question,
        "input_tokens": in_tok,
        "output_tokens": out_tok,
        "total_tokens": in_tok + out_tok,
        "latency_sec": dt,
        "peak_vram_MiB": peak_mem,
        "prompt_chars": len(prompt),
        "answer": answer,
        "used_retrieval": bool(retrieved_docs),
    }


In [8]:
def batch_compare_text_vs_graph(
    questions: List[str],
    gen_pipe, tokenizer, parser,
    text_db: Optional[FAISS],
    graph_db: Optional[FAISS],
) -> pd.DataFrame:
    rows = []
    for q in questions:
        rows.append(
            measure_once_mode(q, "text", gen_pipe, tokenizer, parser, text_db, graph_db, label="text_rag")
        )
        rows.append(
            measure_once_mode(q, "graph", gen_pipe, tokenizer, parser, text_db, graph_db, label="graph_rag")
        )
    return pd.DataFrame(rows)

# ---- 最小准确率工具（如你已有可忽略）----
def _normalize_yesno(text: str) -> str:
    if text is None: return "NO"
    t = str(text).strip().lower()
    if t == "yes" or ("yes" in t and "no" not in t): return "YES"
    if t == "no"  or ("no"  in t and "yes" not in t): return "NO"
    return "NO"

def attach_gold(df: pd.DataFrame, gold_map: dict) -> pd.DataFrame:
    g = pd.DataFrame(list(gold_map.items()), columns=["question","gold"])
    g["gold"] = g["gold"].map(lambda x: "YES" if str(x).upper()=="YES" else "NO")
    out = df.merge(g, on="question", how="left")
    out["pred"] = out["answer"].map(_normalize_yesno)
    out["correct"] = (out["pred"] == out["gold"]).astype(int)
    return out

def evaluate_accuracy(df_with_gold: pd.DataFrame):
    print("\n== Accuracy by config ==")
    for k, sub in df_with_gold.groupby("label"):
        n = len(sub[sub["gold"].notna()])
        acc = sub["correct"].mean() if n else float("nan")
        print(f"{k:<10s} acc={acc:.3f} (n={n})")

def summarize_cost(df: pd.DataFrame, base_label: str, target_label: str):
    A = df[df["label"]==base_label]; B = df[df["label"]==target_label]
    def avg(col):
        a, b = A[col].mean(), B[col].mean()
        return a, b, (b-a)/max(1e-9,a)
    print("\n== Cost (avg) ==")
    for col in ["input_tokens","output_tokens","total_tokens","latency_sec","peak_vram_MiB","prompt_chars"]:
        if col in df.columns:
            a,b,d = avg(col); print(f"{col:>15s} | {base_label}: {a:8.2f} | {target_label}: {b:8.2f} | Δ%: {d*100:7.2f}%")


In [9]:
def dir_size_bytes(path: str) -> int:
    total = 0
    for root, _, files in os.walk(path):
        for f in files:
            fp = os.path.join(root, f)
            try: total += os.path.getsize(fp)
            except OSError: pass
    return total

def save_and_report_sizes(text_db: FAISS, graph_db: FAISS, text_dir="faiss_text_idx", graph_dir="faiss_graph_idx"):
    text_db.save_local(text_dir)
    graph_db.save_local(graph_dir)
    def human(n):
        u=["B","KB","MB","GB"]; i=0; x=float(n)
        while x>=1024 and i<len(u)-1: x/=1024.0; i+=1
        return f"{x:.2f} {u[i]}"
    s_text  = dir_size_bytes(text_dir)
    s_graph = dir_size_bytes(graph_dir)
    print(f"[Index size] text_rag  = {human(s_text)}  ({text_dir})")
    print(f"[Index size] graph_rag = {human(s_graph)}  ({graph_dir})")
    return s_text, s_graph


In [10]:
emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# 0) 模型与解析器
gen_pipe, tokenizer = load_llm_pipeline()
parser = RelationshipGraphParser()

# 1) Graph RAG：你已有的流程
graph_docs = build_docs_with_answer(
    questions, parser, gen_pipe,
    add_prompt_snapshot=False,
    faiss_db=None  # 冷启动
)
print(graph_docs)
graph_db = build_faiss_index(graph_docs)  # == 你已有 build_faiss_index

# 2) Text RAG：并行构建（字段对齐）
text_db = build_text_faiss_index_with_answers(
    questions,
    gen_pipe,
    add_prompt_snapshot=False,
    bootstrap_db=None  # 冷启动；也可传已有 text_db 做增量
)

# 3) 现在 text_db 与 graph_db 的 Document.metadata 键名一致，你可以直接复用你现有的
#    度量、准确率与存储体积对比函数（比如 batch_compare_text_vs_graph / summarize_cost 等）


# 2) 存储体积对比
save_and_report_sizes(text_db, graph_db, text_dir="faiss_text_idx", graph_dir="faiss_graph_idx")

# 3) 成本 & 准确率 A/B（用你现有 GOLD_LABELS）
df_ab = batch_compare_text_vs_graph(questions, gen_pipe, tokenizer, parser, text_db, graph_db)
df_ab_gold = attach_gold(df_ab, GOLD_LABELS)
evaluate_accuracy(df_ab_gold)
summarize_cost(df_ab_gold, base_label="text_rag", target_label="graph_rag")


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.54it/s]
Device set to use mps


YES
YES
NO
NO
YES
NO
YES
YES
NO
NO
[Document(metadata={'graph_id': 'Q1', 'question': 'Given that Earth completes a rotation approximately every 24 hours, does this rotation cause the Sun to appear to rise in the east and set in the west from the perspective of an observer on the surface?', 'num_nodes': 11, 'num_edges': 9, 'llm_model': 'microsoft/Phi-4-mini-reasoning', 'llm_answer': 'YES', 'created_at': 1756361420}, page_content='{"e":["Sun","appear","rise","Earth","complete","rotation","east","set","west","perspective","cause"],"r":["subj","obj","prep_in","prep_from"],"questions([[e,r,e], ...])":[[0,0,1],[0,0,2],[3,0,4],[4,1,5],[6,0,7],[7,2,8],[7,3,9],[2,2,6],[5,0,10]]}'), Document(metadata={'graph_id': 'Q2', 'question': 'Based on basic human biology, which requires oxygen for cellular respiration and energy production, do humans need oxygen to survive under normal conditions?', 'num_nodes': 9, 'num_edges': 7, 'llm_model': 'microsoft/Phi-4-mini-reasoning', 'llm_answer': 'YES', 'created