In [1]:
import networkx as nx
from graph_generator.graphparsers import RelationshipGraphParser
from groupwords import *
from linearization_utils import *
from retrieval_utils import similarity_search_graph_docs

from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document

from typing import List, Dict, Optional, Tuple
import time

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

import pandas as pd
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
CONFIG = {
    # === Embedding & VectorStore ===
    "embedding_model": "sentence-transformers/all-MiniLM-L6-v2",  # Embedding model for documents/questions
    "faiss_search_k": 3,  # Number of nearest neighbors to retrieve from FAISS

    # === LLM (text generation) ===
    "llm_model_id": "microsoft/Phi-4-mini-reasoning",  # HuggingFace model ID  Phi-4-mini-reasoning Phi-4-mini-instruct
    "device_map": "auto",  # Device placement: "cuda", "mps", "cpu", or "auto"
    "dtype_policy": "auto",  # Precision: "auto", "bf16", "fp16", or "fp32"
    "max_new_tokens": 256,  # Maximum tokens generated per response
    "do_sample": True,  # Whether to use sampling (True) or greedy decoding (False)
    "temperature": 0.4,  # Randomness control for sampling; lower = more deterministic
    "top_p": 1.0,  # Nucleus sampling threshold; 1.0 = no restriction
    "return_full_text": False,  # Return full text (input+output) if True, only output if False
    "seed": None,  # Random seed for reproducibility; set to int or None

    # === Prompt / Answer ===
    "answer_mode": "short",  # Answer format mode, e.g., YES/NO
    "answer_uppercase": True,  # If True → "YES"/"NO", else "yes"/"no"

    # === Prompt construction ===
    "include_retrieved_context": True,  # Include retrieved Q&A in prompt
    "include_current_triples": True,  # Include graph triples in prompt
    "use_cached_text_embeddings": True, # Whether text rag reuse embedding to search
    "use_cached_graph_embeddings": True  #Whether graph rag reuse embedding to search
}

try:
    from transformers import set_seed  # Utility for reproducibility
except Exception:
    set_seed = None


In [3]:
from gensim.models import KeyedVectors
import numpy as np
import re
from langchain.embeddings.base import Embeddings

class WordAvgEmbeddings(Embeddings):
    def __init__(self, model_path: str = "gensim-data/glove-wiki-gigaword-100/glove-wiki-gigaword-100.model.vectors.npy"):
        self.kv = KeyedVectors.load(model_path, mmap='r')
        self.dim = self.kv.vector_size
        self.token_pat = re.compile(r"[A-Za-z]+")

    def _embed_text(self, text: str) -> np.ndarray:
        toks = [t.lower() for t in self.token_pat.findall(text)]
        vecs = [self.kv[w] for w in toks if w in self.kv]
        if not vecs:
            return np.zeros(self.dim, dtype=np.float32)
        return np.mean(vecs, axis=0).astype(np.float32)

    def embed_documents(self, texts):
        return [self._embed_text(t) for t in texts]

    def embed_query(self, text):
        return self._embed_text(text)

word_emb = WordAvgEmbeddings(model_path="gensim-data/glove-wiki-gigaword-100/glove-wiki-gigaword-100.model")
sentence_emb = HuggingFaceEmbeddings(model_name=CONFIG["embedding_model"])  # Local embedding model (MiniLM-L6-v2, 384 dim)

  sentence_emb = HuggingFaceEmbeddings(model_name=CONFIG["embedding_model"])  # Local embedding model (MiniLM-L6-v2, 384 dim)


## RAG workflow

In [4]:
def _select_dtype() -> torch.dtype:
    """Choose dtype based on CONFIG['dtype_policy'] and hardware."""
    policy = CONFIG.get("dtype_policy", "auto")
    if policy == "bf16":
        return torch.bfloat16
    if policy == "fp16":
        return torch.float16
    if policy == "fp32":
        return torch.float32

    # auto mode
    if torch.cuda.is_available():
        return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    # MPS backend works more reliably with fp32
    if torch.backends.mps.is_available():
        return torch.float32
    return torch.float32

def _yn(text_yes="YES", text_no="NO"):
    return (text_yes, text_no) if CONFIG.get("answer_uppercase", True) else (text_yes.lower(), text_no.lower())

def _avg_pool(mat: np.ndarray) -> np.ndarray:
    if mat is None or len(mat) == 0:
        return None
    m = np.asarray(mat, dtype=np.float32)
    if m.ndim == 1:
        return m.astype(np.float32)
    return m.mean(axis=0).astype(np.float32)

def _normalize(v: np.ndarray) -> np.ndarray:
    v = np.asarray(v, dtype=np.float32)
    v /= (np.linalg.norm(v) + 1e-12)
    return v

def _graph_doc_vec_from_cached_or_embed(
    er_e: list[str],
    er_r: list[str],
    e_embeds: list | None,
    r_embeds: list | None,
    emb_model,                     
    use_cache: bool = True,
) -> np.ndarray:
    """
    文档向量 = concat( avg(e_embeds), avg(r_embeds) ) 或者两者的平均。
    这里用简单且稳定的做法：取 e,r 的平均再做均值融合。
    """
    if use_cache and e_embeds is not None and len(e_embeds) and r_embeds is not None and len(r_embeds):
        e_mean = _avg_pool(np.asarray(e_embeds, dtype=np.float32))
        r_mean = _avg_pool(np.asarray(r_embeds, dtype=np.float32))
        v = (e_mean + r_mean) / 2.0
        return _normalize(v)

    # 缓存不可用 → 现算：对每个实体/关系分别用词向量平均，再整体平均
    e_vecs = []
    for e in (er_e or []):
        e_vecs.append(np.asarray(emb_model.embed_query(e), dtype=np.float32))
    r_vecs = []
    for r in (er_r or []):
        r_vecs.append(np.asarray(emb_model.embed_query(r), dtype=np.float32))

    e_mean = _avg_pool(np.stack(e_vecs, axis=0)) if e_vecs else None
    r_mean = _avg_pool(np.stack(r_vecs, axis=0)) if r_vecs else None

    if e_mean is None and r_mean is None:
        # 两边都空，退化为零向量（用实体任意词兜底也可）
        dim = getattr(emb_model, "dim", None) or len(emb_model.embed_query("a"))
        return np.zeros(dim, dtype=np.float32)

    if e_mean is None: v = r_mean
    elif r_mean is None: v = e_mean
    else: v = (e_mean + r_mean) / 2.0
    return _normalize(v)

# =========================
# Embeddings / Vectorstore
# =========================
#emb = HuggingFaceEmbeddings(model_name=CONFIG["embedding_model"])  # Local embedding model (MiniLM-L6-v2, 384 dim)

def build_faiss_index(docs: List[Document], emb) -> FAISS:
    return FAISS.from_documents(docs, emb)

# =========================
# LLM Loader
# =========================
def load_llm_pipeline(
    model_id: Optional[str] = None,       # HuggingFace model id
    device_map: Optional[str] = None,     # Device placement
    dtype: Optional[torch.dtype] = None,  # Torch dtype
    max_new_tokens: Optional[int] = None, # Max tokens per generation
    temperature: Optional[float] = None,  # Sampling temperature
    top_p: Optional[float] = None,        # Nucleus sampling threshold
    do_sample: Optional[bool] = None,     # Sampling vs greedy
    return_full_text: Optional[bool] = None,  # Return input+output if True
):
    """
    Return a text-generation pipeline for QA generation.
    All defaults pull from CONFIG; any arg here will override CONFIG.
    """
    model_id = model_id or CONFIG["llm_model_id"]
    device_map = device_map or CONFIG["device_map"]
    dtype = dtype or _select_dtype()
    max_new_tokens = max_new_tokens or CONFIG["max_new_tokens"]
    temperature = CONFIG["temperature"] if temperature is None else temperature
    top_p = CONFIG["top_p"] if top_p is None else top_p
    do_sample = CONFIG["do_sample"] if do_sample is None else do_sample
    return_full_text = CONFIG["return_full_text"] if return_full_text is None else return_full_text

    if set_seed and isinstance(CONFIG.get("seed"), int):
        set_seed(CONFIG["seed"])

    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map=device_map,
        trust_remote_code=True,
    )

    gen_pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device_map=device_map,
        torch_dtype=dtype,
        return_full_text=return_full_text,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
    )
    return gen_pipe, tokenizer

# =========================
# Question → Graph (generic)
# =========================

def print_graph(title: str, G):
    print(f"\n=== {title} ===")
    print("Nodes:")
    for n, d in G.nodes(data=True):
        print(f"  - {n!r} :: {d}")
    print("Edges:")
    if G.is_multigraph():
        for u, v, k, d in G.edges(keys=True, data=True):
            print(f"  - {u!r} -[{k}]-> {v!r} :: {d}")
    else:
        arrow = "->" if G.is_directed() else "--"
        for u, v, d in G.edges(data=True):
            print(f"  - {u!r} {arrow} {v!r} :: {d}")

def parse_question_to_graph_generic(parser, question: str) -> Tuple[nx.Graph, List[Dict]]:
    """
    Compatible with RelationshipGraphParser.question_to_graph
    and CausalQuestionGraphParser.question_to_causal_graph
    """
    if hasattr(parser, "question_to_graph"):
        G, rels = parser.question_to_graph(question)
        G = merge_graph_nodes_by_canonical(G, normalizer=normalize_text, merge_edge_attrs=("relation",))
        return G, rels
    elif hasattr(parser, "question_to_causal_graph"):
        G, rels = parser.question_to_causal_graph(question)
        G = merge_graph_nodes_by_canonical(G, normalizer=normalize_text, merge_edge_attrs=("relation",))
        return G, rels
    else:
        raise AttributeError("Parser must provide question_to_graph or question_to_causal_graph")

import ast
# =========================
# Prompt Builder
# =========================
def make_graph_qa_prompt(
    question: str,
    G: nx.Graph,
    relations: Optional[List[Dict]] = None,
    retrieved_docs = None
) -> str:
    # 1) retrieved context (if any)
    sections = []
    if retrieved_docs and CONFIG.get("include_retrieved_context", True):
        doc0, score0 = retrieved_docs[0]

        # Get codebook and decode question list
        metadata = doc0.metadata or {}
        er = ast.literal_eval(doc0.page_content)
        codebook_main = {
            "e": er['e'],
            "r": er['r'],
            "edge_matrix": metadata["edge_matrix"],
            "questions_lst": metadata["questions_lst"],
            "e_embeddings": metadata["e_embeddings"],
            "r_embeddings": metadata["r_embeddings"],
        }

        # 循环遍历所有 question
        query_chains = []
        for group_idx, group in enumerate(codebook_main["questions_lst"]):
            for q_idx, question_chain in enumerate(group):
                # 1) 用于检索的输入：必须收集“边索引链”
                query_chains.append(question_chain)    

        related_triples = "__EMPTY_JSON__"
        if query_chains:
            wrapper_res = coarse_filter(
                questions=query_chains,
                codebook_main=codebook_main,
                emb=sentence_emb,
                top_k=3,
                question_batch_size=2,
                questions_db_batch_size=8,
                top_m=2,
            )

            if isinstance(wrapper_res, dict) and wrapper_res:
        
                first_non_empty = next((lst for lst in wrapper_res.values() if lst), [])
                if first_non_empty:
                    related_triples = first_non_empty[0].get("text", "__EMPTY_JSON__")

        related_answer  = doc0.metadata.get("llm_answer", "")
        
        if related_triples != "__EMPTY_JSON__":
            sections.append(
                "<<<RETRIEVED_CONTEXT_START>>>\n"
                "The system searched for a related question in the database. Below are related question's graph triples and its prior answer as reference. " \
                "You don't have to follow it completely, just use it as a reference.\n"
                f"[RELATED QUESTION'S GRAPH TRIPLES]:\n{related_triples}\n"
                f"[RELATED QUESTION'S ANSWER]: {related_answer}\n"
                "<<<RETRIEVED_CONTEXT_END>>>"
            )

    q_block = f"[CURRENT QUESTION]: {question}"
    sections.append(q_block)

    mode = _mode()

    if mode in {"yes_no", "binary"}:
        yes, no = _yn("YES", "NO")
        rules = (
            "[TASK]: You are a precise QA assistant for binary (yes/no) questions.\n"
            f"- Output ONLY one token: {yes} or {no}.\n"
            "- Do NOT copy or summarize any context.\n"
            "- Do NOT show reasoning, steps, or extra words.\n"
            "[ANSWER]: "
        )
    else:
        style_line = {
            "short":    "- Give a short, direct answer in 2–3 sentences.\n",
            "detail":   "- Provide a clear, detailed, and structured answer.\n",
            "long":     "- Provide a clear, detailed, and structured answer.\n",
            "reasoning":"- Provide a well-structured explanation with logical reasoning flow.\n- If useful, break the answer into brief sections.\n",
            "explain":  "- Provide a well-structured explanation with logical reasoning flow.\n- If useful, break the answer into brief sections.\n",
        }.get(mode, "- Provide a clear and helpful answer.\n")

        rules = (
            "[TASK]: You are a QA assistant for open-ended questions.\n"
            f"{style_line}"
            "- Do NOT restrict to yes/no.\n"
            "[FORMAT]: Write complete sentences (not a single word)."
            "Avoid starting with just 'Yes.' or 'No.'; if the question is yes/no-style, state the conclusion AND 1–2 short reasons.\n"
            "[ANSWER]: "
        )

    sections.append(rules)
    prompt = "\n\n".join(sections)
    return prompt

# =========================
# LLM Answerer
# =========================
def _mode():
    return (CONFIG.get("answer_mode") or "short").lower()

def _gen(gen_pipe, prompt):
    # 显式传 generation 参数；有些 pipeline 会忽略默认 config
    kwargs = dict(
        do_sample=CONFIG.get("do_sample", True),
        temperature=CONFIG.get("temperature", 0.4),
        top_p=CONFIG.get("top_p", 1.0),
        max_new_tokens=CONFIG.get("max_new_tokens", 256),
        return_full_text=CONFIG.get("return_full_text", False),
    )
    # 兼容 pad/eos（部分 Phi 模型需要）
    try:
        tok = gen_pipe.tokenizer
        if tok is not None:
            if tok.pad_token_id is None and tok.eos_token_id is not None:
                tok.pad_token_id = tok.eos_token_id
            kwargs.setdefault("eos_token_id", tok.eos_token_id)
            kwargs.setdefault("pad_token_id", tok.pad_token_id)
    except Exception:
        pass

    out = gen_pipe(prompt, **kwargs)
    return out[0]["generated_text"]

def _extract_answer_text(prompt, text):
    if CONFIG.get("return_full_text", False):
        return text[len(prompt):].strip()
    return text.strip()

def strip_think(s: str) -> Tuple[str, Optional[str]]:
    """
    提取并去掉 <think> ... </think> 段，同时识别尾部残缺 <think>...（无 </think>）。
    返回:
        clean_text: 去掉所有 think 段后的文本
        thinks: 提取到的所有 think 内容（含尾部残缺内容）
        has_dangling: 是否存在尾部残缺的 <think>
    """
    if not s:
        return "", [], False

    s_lower = s.lower()
    thinks: List[str] = []
    spans: List[Tuple[int, int]] = []  # 需要从原文删除的区间 [start, end)

    # 1) 提取所有完整的 <think>...</think>
    for m in re.finditer(r"<think>(.*?)</think>", s, flags=re.S | re.I):
        thinks.append(m.group(1).strip())
        spans.append((m.start(), m.end()))

    # 2) 检测尾部残缺的 <think>...
    has_dangling = False
    last_open = s_lower.rfind("<think>")
    if last_open != -1:
        # 若在 last_open 之后找不到 </think>，则视为残缺
        if s_lower.find("</think>", last_open) == -1:
            has_dangling = True
            # 提取残缺内容：从 <think> 后面到字符串末尾
            content_start = last_open + len("<think>")
            dangling_text = s[content_start:].strip()
            if dangling_text:
                thinks.append(dangling_text)
            spans.append((last_open, len(s)))  # 删除从 <think> 到末尾

    # 3) 从原文删除所有 spans（可能存在重叠，先排序再合并）
    if spans:
        spans.sort()
        merged = []
        cur_s, cur_e = spans[0]
        for st, en in spans[1:]:
            if st <= cur_e:  # 重叠或相接
                cur_e = max(cur_e, en)
            else:
                merged.append((cur_s, cur_e))
                cur_s, cur_e = st, en
        merged.append((cur_s, cur_e))
    else:
        merged = []

    # 4) 拼接非 think 的文本片段
    parts = []
    prev = 0
    for st, en in merged:
        if prev < st:
            parts.append(s[prev:st])
        prev = en
    if prev < len(s):
        parts.append(s[prev:])

    clean_text = "".join(parts)

    # 5) 额外清理常见的分析前缀（可选）
    clean_text = re.sub(r"(?:^|\n)\s*(Okay,|Let’s|Let's|Step by step|Thought:).*", "", clean_text, flags=re.I)

    return clean_text.strip(), thinks


def answer_with_llm(
    question: str,
    gen_pipe,
    parser,
    faiss_db=None,
    prompt=None,
    max_retries: int = 5,   
) -> str:
    retrieved_docs = None
    if faiss_db:
        _, hits = similarity_search_graph_docs(
            question, parser, faiss_db, k=CONFIG.get("faiss_search_k", 3),
            emb_model=word_emb,
            use_cache=CONFIG.get("use_cached_graph_embeddings", True),
        )
        retrieved_docs = hits

    if prompt is None:
        G, rels = parse_question_to_graph_generic(parser, question)
        prompt = make_graph_qa_prompt(question, G, rels, retrieved_docs)

    mode = _mode()
    YES_RE = re.compile(r"^\s*(yes|y|true|correct|affirmative)\s*\.?\s*$", re.I)
    NO_RE  = re.compile(r"^\s*(no|n|false|incorrect|negative)\s*\.?\s*$", re.I)

    attempt = 0
    while attempt < max_retries:
        attempt += 1
        raw = _gen(gen_pipe, prompt)
        print(f"----- RAW (try {attempt}):", raw)

        answer = _extract_answer_text(prompt, raw)
        answer, thinking = strip_think(raw)  # thinking 以后可用
        print("----- ANS:", answer)

        # 如果生成了空字符串 → 继续下一次循环
        if not answer.strip():
            continue  

        if mode in {"yes_no", "binary"}:
            if YES_RE.match(answer) and not NO_RE.match(answer):
                y, n = _yn("YES", "NO") if CONFIG.get("answer_uppercase", True) else _yn("yes", "no")
                return y
            if NO_RE.match(answer) and not YES_RE.match(answer):
                y, n = _yn("YES", "NO") if CONFIG.get("answer_uppercase", True) else _yn("yes", "no")
                return n

            # 再尝试加严格格式后缀
            strict_suffix = (
                "\n\n[FORMAT]: Answer with exactly ONE token: "
                + ("YES or NO." if CONFIG.get("answer_uppercase", True) else "yes or no.")
            )
            raw2 = _gen(gen_pipe, prompt + strict_suffix)
            ans2 = _extract_answer_text(prompt + strict_suffix, raw2)

            if YES_RE.match(ans2) and not NO_RE.match(ans2):
                y, n = _yn("YES", "NO") if CONFIG.get("answer_uppercase", True) else _yn("yes", "no")
                return y
            if NO_RE.match(ans2) and not YES_RE.match(ans2):
                y, n = _yn("YES", "NO") if CONFIG.get("answer_uppercase", True) else _yn("yes", "no")
                return n
            return ans2  # 兜底返回
        else:
            # 开放式模式 → 避免单词回答
            if YES_RE.match(answer) or NO_RE.match(answer) or len(answer.split()) <= 2:
                format_suffix = (
                    "\n\n[FORMAT]: Provide a 2–3 sentence explanation; do not answer with a single word."
                )
                raw2 = _gen(gen_pipe, prompt + format_suffix)
                ans2 = _extract_answer_text(prompt + format_suffix, raw2)
                if len(ans2.strip()) > len(answer.strip()):
                    return ans2.strip()
            return answer.strip()

    return answer



def build_graph_faiss_index_from_cached(
    docs: list[Document],
    emb_model,   
) -> FAISS:
    use_cache = CONFIG.get("use_cached_graph_embeddings", True)
    texts, metas, vecs = [], [], []

    for d in docs:
        texts.append(d.page_content)
        metas.append(d.metadata or {})


        import ast
        er = ast.literal_eval(d.page_content) if isinstance(d.page_content, str) else d.page_content
        er_e = er.get("e", []) if isinstance(er, dict) else []
        er_r = er.get("r", []) if isinstance(er, dict) else []

        e_embeds = (d.metadata or {}).get("e_embeddings")
        r_embeds = (d.metadata or {}).get("r_embeddings")

        v = _graph_doc_vec_from_cached_or_embed(
            er_e, er_r, e_embeds, r_embeds, emb_model, use_cache=use_cache
        )
        vecs.append(v.tolist())


        d.metadata["graph_vec"] = v.tolist()

    X = np.asarray(vecs, dtype=np.float32)


    text_embeddings = [(texts[i], X[i].tolist()) for i in range(len(texts))]
    try:
        return FAISS.from_embeddings(text_embeddings, embedding=emb_model, metadatas=metas)
    except TypeError:
        try:
            return FAISS.from_embeddings(
                embeddings=X.tolist(), metadatas=metas, texts=texts, embedding=emb_model
            )
        except Exception:
            vs = FAISS.from_texts(texts=[], embedding=emb_model)
            if hasattr(vs, "add_embeddings"):
                vs.add_embeddings(embeddings=X, metadatas=metas, texts=texts)
            else:
                vs.add_texts(texts=texts, metadatas=metas)  
            return vs
        
_GRAPH_QVEC_CACHE = {}

def _faiss_search_by_vec_graph(vs, qv, k):
    if hasattr(vs, "similarity_search_by_vector_with_score"):
        return vs.similarity_search_by_vector_with_score(qv, k=k)
    if hasattr(vs, "similarity_search_by_vector"):
        docs = vs.similarity_search_by_vector(qv, k=k)
        return [(d, None) for d in docs]
    index = getattr(vs, "index", None)
    id_map = getattr(vs, "index_to_docstore_id", None)
    store  = getattr(vs, "docstore", None)
    if index is None or id_map is None or store is None:
        raise AttributeError("FAISS vectorstore has no by-vector APIs and no accessible index/docstore.")
    q = np.asarray(qv, dtype=np.float32).reshape(1, -1)
    D, I = index.search(q, k)
    out = []
    for dist, idx in zip(D[0], I[0]):
        if idx == -1: continue
        doc_id = id_map[idx]
        doc = store.search(doc_id)
        out.append((doc, float(dist)))
    return out

def similarity_search_graph_docs(
    user_question: str,
    parser,
    vectordb: FAISS,
    k: int = 5,
    emb_model=None,                    # word_emb
    use_cache: Optional[bool] = None,  # 是否缓存查询向量（不影响库文档）
):
    if use_cache is None:
        use_cache = CONFIG.get("use_cached_graph_embeddings", True)
    if emb_model is None:
        emb_model = globals().get("word_emb", None)
    if emb_model is None:
        raise ValueError("similarity_search_graph_docs: need `emb_model` for query embedding.")

    # 1) 缓存命中
    if use_cache and user_question in _GRAPH_QVEC_CACHE:
        qv = _GRAPH_QVEC_CACHE[user_question]
        return user_question, _faiss_search_by_vec_graph(vectordb, qv, k)

    # 2) 现算查询向量：解析成 triples → 拿出 e/r → 算平均
    G, rels = parse_question_to_graph_generic(parser, user_question)
    # 你已有的 linearization/抽取逻辑不变，这里只要 e/r
    # 用简单抽取：把图里节点名当作实体，把边的 relation 当作关系
    er_e = list({str(n) for n in G.nodes})          # 去重
    er_r = []
    if G.is_multigraph():
        for _, _, _, data in G.edges(keys=True, data=True):
            rel = data.get("relation")
            if rel: er_r.append(str(rel))
    else:
        for _, _, data in G.edges(data=True):
            rel = data.get("relation")
            if rel: er_r.append(str(rel))

    qv = _graph_doc_vec_from_cached_or_embed(er_e, er_r, None, None, emb_model, use_cache=False)
    qv = _normalize(qv)

    if use_cache:
        _GRAPH_QVEC_CACHE[user_question] = qv

    return user_question, _faiss_search_by_vec_graph(vectordb, qv, k)


# =========================
# Build Docs with LLM Answer
# =========================
def build_docs_with_answer(
    questions: List[str],
    parser,
    gen_pipe,
    *,
    add_prompt_snapshot: bool = False,
    faiss_db = None
) -> List[Document]:
    docs: List[Document] = []
    for qid, q in enumerate(questions, start=1):
        G, rels = parse_question_to_graph_generic(parser, q)        
        codebook_main = build_relationship_text(q, G, rels, json_style="codebook_main")  # Output [QUESTION][GRAPH][TRIPLES]
        er = {
                "e": codebook_main['e'],  
                "r": codebook_main['r']
            }

        er = str(er)
        # Get LLM answer
        answer = answer_with_llm(q, gen_pipe, parser, faiss_db)
        answers_tuple = (answer)
        codebook_answer = get_code_book(answers_tuple, type='answers')
        codebook_main = merging_codebook(codebook_main, codebook_answer, type='answers', word_emb=word_emb)

        metadata = {
            "graph_id": f"Q{qid}",
            "llm_model": CONFIG["llm_model_id"],
            "llm_answer": answer,
            "created_at": int(time.time()),
            "edge_matrix": codebook_main['edge_matrix'],
            "questions_lst": codebook_main['questions_lst'],
            "answers_lst": codebook_main['answers_lst'],
            "e_embeddings": codebook_main["e_embeddings"],
            "r_embeddings": codebook_main["r_embeddings"],
        }
   
        if add_prompt_snapshot:
            metadata["prompt_snapshot"] = make_graph_qa_prompt(q, G, rels)

        docs.append(Document(page_content=er, metadata=metadata))
    return docs


## Text RAG

In [5]:
questions = {
       # --- 原始 YES 类 ---
    "Is the Earth, which orbits the Sun along with seven other planets in the solar system, generally considered to be round in shape despite its slight equatorial bulge?",
    #"Given that Earth completes a rotation approximately every 24 hours, does this rotation cause the Sun to appear to rise in the east and set in the west from the perspective of an observer on the surface?",
    #"Considering that Paris is the administrative and cultural center of France, is it also the official capital city of the country according to its constitution?",
    #"Based on basic human biology, which requires oxygen for cellular respiration and energy production, do humans need oxygen to survive under normal conditions?",
    #"Since the Moon is gravitationally bound to Earth and completes an orbit approximately every 27 days, is it classified as Earth's only natural satellite?", 

    #"Although the Sahara Desert is one of the largest deserts in the world, is it located in the South American continent instead of Africa?",
    #"Even though the Amazon River is among the longest rivers globally, is it actually longer than the Nile River when measured by official geographical surveys?",
    #"Despite Tokyo being the largest city in Japan and a major global hub, is it the capital city of South Korea?",
    #"Considering the natural habitats of penguins, which are mainly located in the Southern Hemisphere, do penguins naturally live in the Arctic region alongside polar bears?",
    #"Although gold is a dense and valuable metal, is its density greater than that of lead, making it heavier per cubic centimeter?",
}

In [6]:
GOLD_LABELS = {
    # --- 原始 YES 类 ---
    "Is the Earth, which orbits the Sun along with seven other planets in the solar system, generally considered to be round in shape despite its slight equatorial bulge?": "YES",
    "Given that Earth completes a rotation approximately every 24 hours, does this rotation cause the Sun to appear to rise in the east and set in the west from the perspective of an observer on the surface?": "YES",
    "Considering that Paris is the administrative and cultural center of France, is it also the official capital city of the country according to its constitution?": "YES",
    "Based on basic human biology, which requires oxygen for cellular respiration and energy production, do humans need oxygen to survive under normal conditions?": "YES",
    "Since the Moon is gravitationally bound to Earth and completes an orbit approximately every 27 days, is it classified as Earth's only natural satellite?": "YES",

    # --- 原始 NO 类 ---
    "Although the Sahara Desert is one of the largest deserts in the world, is it located in the South American continent instead of Africa?": "NO",
    "Even though the Amazon River is among the longest rivers globally, is it actually longer than the Nile River when measured by official geographical surveys?": "NO",
    "Despite Tokyo being the largest city in Japan and a major global hub, is it the capital city of South Korea?": "NO",
    "Considering the natural habitats of penguins, which are mainly located in the Southern Hemisphere, do penguins naturally live in the Arctic region alongside polar bears?": "NO",
    "Although gold is a dense and valuable metal, is its density greater than that of lead, making it heavier per cubic centimeter?": "NO",

    # --- Follow-up Questions (YES/NO balanced) ---
    # Earth shape
    "Despite the Earth's slightly flattened poles, is its shape closer to a sphere than to a flat surface?": "YES",
    "Is the Earth perfectly flat with no curvature anywhere on its surface?": "NO",

    # Sun rotation
    "Given the Earth's rotation, is the apparent motion of the Sun consistent with the Sun rising in the east?": "YES",
    "If the Earth did not rotate on its axis, would the Sun still rise and set in the same pattern as it does now?": "NO",

    # Paris as capital
    "Considering France's administrative structure, is Paris recognized as the political and economic capital of the nation?": "YES",
    "Is Berlin, rather than Paris, designated as the official capital of France in any historical or legal record?": "NO",

    # Oxygen necessity
    "Since oxygen is vital for human life, is it correct to say that humans cannot survive without breathing air containing oxygen?": "YES",
    "Can humans live indefinitely without any access to oxygen in their environment?": "NO",

    # Moon as satellite
    "Is the Moon the only large natural body that consistently orbits Earth in the solar system?": "YES",
    "Do humans have multiple moons orbiting the Earth, similar to Jupiter or Saturn?": "NO",

    # Sahara Desert
    "Is the Sahara Desert geographically located across multiple countries in northern Africa?": "YES",
    "Is the Sahara Desert primarily located in the continent of South America?": "NO",

    # Amazon River
    "Does the Nile River surpass the Amazon River in length when measured by the most widely accepted geographical data?": "YES",
    "Is the Amazon River considered to originate in Europe according to global mapping authorities?": "NO",

    # Tokyo
    "Is Tokyo the capital city of Japan and a major economic center in Asia?": "YES",
    "Is Tokyo officially listed as the capital city of South Korea in government documents?": "NO",

    # Penguins
    "Do penguins naturally inhabit regions in the Southern Hemisphere, particularly Antarctica?": "YES",
    "Do penguins live alongside polar bears in the Arctic region as part of their natural habitat?": "NO",

    # Gold vs Lead
    "Is gold denser than most metals but still slightly less dense than lead?": "NO",
    "Is gold classified as a metal due to its physical and chemical properties?": "YES"
}


In [7]:
# === Text RAG: 仅用问题文本入库 ===
from langchain.schema import Document

def build_text_docs_with_answer(
    questions: List[str],
    gen_pipe,
    *,
    add_prompt_snapshot: bool = False,
    text_db: Optional[FAISS] = None
) -> List[Document]:
    """
    生成仅文本 RAG 的文档，并使 metadata 字段与图 RAG 对齐：
    - graph_id / question / num_nodes / num_edges / llm_model / llm_answer / created_at / prompt_snapshot(可选)
    - 其中 num_nodes/num_edges 统一置 0，保持同名键方便评测与对比
    """
    docs: List[Document] = []
    for qid, q in enumerate(questions, start=1):
        page_content = f"{q}"
        q_vec = sentence_emb.embed_query(q)

        answer = answer_with_llm_text(q, gen_pipe, text_db=text_db, q_vec=q_vec)
        
        metadata = {
            "graph_id": f"Q{qid}",
            "question": q,                  
            "llm_model": CONFIG["llm_model_id"],
            "llm_answer": answer,
            "created_at": int(time.time()),
            "q_embeddings": q_vec
        }
        if add_prompt_snapshot:
            prompt_snapshot = make_text_qa_prompt(q, None if not text_db else similarity_search_text_docs(q, text_db, k=CONFIG.get("faiss_search_k",3))[1])
            metadata["prompt_snapshot"] = prompt_snapshot

        docs.append(Document(page_content=page_content, metadata=metadata))
    return docs, q_vec

_QVEC_CACHE = {}

def _faiss_search_by_vec(vs, qv, k):
    """兼容不同 langchain-community 版本的 FAISS 向量检索。返回 [(doc, score_or_None), ...]"""
    if hasattr(vs, "similarity_search_by_vector_with_score"):
        return vs.similarity_search_by_vector_with_score(qv, k=k)

    if hasattr(vs, "similarity_search_by_vector"):
        docs = vs.similarity_search_by_vector(qv, k=k)
        return [(d, None) for d in docs]

    index = getattr(vs, "index", None)
    id_map = getattr(vs, "index_to_docstore_id", None)
    store  = getattr(vs, "docstore", None)
    if index is None or id_map is None or store is None:
        raise AttributeError("FAISS vectorstore has no by-vector APIs and no accessible index/docstore.")

    import numpy as np
    q = np.asarray(qv, dtype=np.float32).reshape(1, -1)
    D, I = index.search(q, k)
    out = []
    for dist, idx in zip(D[0], I[0]):
        if idx == -1:
            continue
        doc_id = id_map[idx]
        doc = store.search(doc_id)
        out.append((doc, float(dist)))
    return out


def similarity_search_text_docs(
    user_question: str,
    vectordb: FAISS,
    k: int = 5,
    query_vec: Optional[List[float]] = None,
    emb=None,
    use_cache: Optional[bool] = None,   
):
    import numpy as np
    if use_cache is None:
        use_cache = CONFIG.get("use_cached_text_embeddings", True)

    if query_vec is not None and use_cache:
        qv = np.asarray(query_vec, dtype=np.float32)
        qv /= (np.linalg.norm(qv) + 1e-12)
        results = _faiss_search_by_vec(vectordb, qv, k)
        return user_question, results

    if use_cache and user_question in _QVEC_CACHE:
        qv = _QVEC_CACHE[user_question]
        results = _faiss_search_by_vec(vectordb, qv, k)
        return user_question, results

    if emb is None:
        emb = globals().get("sentence_emb", None)
    if emb is None:
        raise ValueError("similarity_search_text_docs: need `emb` when no cache/vec provided.")

    qv = np.asarray(emb.embed_query(user_question), dtype=np.float32)
    qv /= (np.linalg.norm(qv) + 1e-12)

    if use_cache:
        _QVEC_CACHE[user_question] = qv  

    results = _faiss_search_by_vec(vectordb, qv, k)
    return user_question, results


def make_text_qa_prompt(
    question: str,
    retrieved_docs=None
) -> str:
    sections = []
    if retrieved_docs and CONFIG.get("include_retrieved_context", True):
        doc0, _ = retrieved_docs[0]
        related_q_txt = doc0.page_content.strip()
        related_answer = (doc0.metadata or {}).get("llm_answer", "")
        sections.append(
            "<<<RETRIEVED_CONTEXT_START>>>\n"
            "The system searched for a related question in the database. Below are related question's graph triples and its prior answer as reference. " \
            "You don't have to follow it completely, just use it as a reference.\n"
            f"[RELATED QUESTION TEXT]:\n{related_q_txt}\n"
            f"[RELATED ANSWER]: {related_answer}\n"
            "<<<RETRIEVED_CONTEXT_END>>>"
        )

    sections.append(f"[CURRENT QUESTION]: {question}")

    mode = _mode()

    if mode in {"yes_no", "binary"}:
        yes, no = _yn("YES", "NO")
        rules = (
            "[TASK]: You are a precise QA assistant for binary (yes/no) questions.\n"
            f"- Output ONLY one token: {yes} or {no}.\n"
            "- Do NOT copy or summarize any context.\n"
            "- Do NOT show reasoning, steps, or extra words.\n"
            "[ANSWER]: "
        )
    else:
        style_line = {
            "short":    "- Give a short, direct answer in 2–3 sentences.\n",
            "detail":   "- Provide a clear, detailed, and structured answer.\n",
            "long":     "- Provide a clear, detailed, and structured answer.\n",
            "reasoning":"- Provide a well-structured explanation with logical reasoning flow.\n- If useful, break the answer into brief sections.\n",
            "explain":  "- Provide a well-structured explanation with logical reasoning flow.\n- If useful, break the answer into brief sections.\n",
        }.get(mode, "- Provide a clear and helpful answer.\n")

        rules = (
            "[TASK]: You are a QA assistant for open-ended questions.\n"
            f"{style_line}"
            "- Do NOT restrict to yes/no.\n"
            "[FORMAT]: Write complete sentences (not a single word)."
            "Avoid starting with just 'Yes.' or 'No.'; if the question is yes/no-style, state the conclusion AND 1–2 short reasons.\n"
            "[ANSWER]: "
        )

    sections.append(rules)
    prompt = "\n\n".join(sections)
    return prompt

def answer_with_llm_text(
    question: str,
    gen_pipe,
    q_vec=None, 
    *,
    text_db: Optional["FAISS"] = None,
    max_retries: int = 3,  
) -> str:
  
    retrieved_docs = None
    if text_db:
        _, hits = similarity_search_text_docs(
            question, text_db, k=CONFIG.get("faiss_search_k", 3),
            emb=sentence_emb,
            use_cache=CONFIG.get("use_cached_text_embeddings", True)
        )
        retrieved_docs = hits

   
    prompt = make_text_qa_prompt(question, retrieved_docs)
    mode = _mode()

    attempt = 0
    while attempt < max_retries:
        attempt += 1


        raw = _gen(gen_pipe, prompt)
    
        print(f"----- RAW (try {attempt}):", raw)

        text = _extract_answer_text(prompt, raw)
        answer, think = strip_think(text)  
        answer = (answer or "").strip()
        print("----- ANS:", answer)

        if not answer:
            continue

        YES_RE = re.compile(r"^\s*(yes|y|true|correct|affirmative)\s*\.?\s*$", re.I)
        NO_RE  = re.compile(r"^\s*(no|n|false|incorrect|negative)\s*\.?\s*$", re.I)
        if mode in {"yes_no", "binary"}:
            if YES_RE.match(answer) and not NO_RE.match(answer):
                y, n = _yn("YES", "NO") if CONFIG.get("answer_uppercase", True) else _yn("yes", "no")
                return y
            if NO_RE.match(answer) and not YES_RE.match(answer):
                y, n = _yn("YES", "NO") if CONFIG.get("answer_uppercase", True) else _yn("yes", "no")
                return n

            strict_suffix = (
                "\n\n[FORMAT]: Answer with exactly ONE token: "
                + ("YES or NO." if CONFIG.get("answer_uppercase", True) else "yes or no.")
            )
            raw2 = _gen(gen_pipe, prompt + strict_suffix)
            ans2 = _extract_answer_text(prompt + strict_suffix, raw2)
            ans2, _ = strip_think(ans2)
            if YES_RE.match(ans2) and not NO_RE.match(ans2):
                y, n = _yn("YES", "NO") if CONFIG.get("answer_uppercase", True) else _yn("yes", "no")
                return y
            if NO_RE.match(ans2) and not YES_RE.match(ans2):
                y, n = _yn("YES", "NO") if CONFIG.get("answer_uppercase", True) else _yn("yes", "no")
                return n
            return ans2.strip() if ans2 else answer

        else:
            if YES_RE.match(answer) or NO_RE.match(answer) or len(answer.split()) <= 2:
                format_suffix = (
                    "\n\n[FORMAT]: Provide a 2–3 sentence explanation; "
                    "do not answer with a single word."
                )
                raw2 = _gen(gen_pipe, prompt + format_suffix)
                ans2 = _extract_answer_text(prompt + format_suffix, raw2)
                ans2, _ = strip_think(ans2)
                if ans2 and len(ans2.strip()) > len(answer):
                    return ans2.strip()
            return answer

    return answer

    

import numpy as np
from langchain_community.vectorstores import FAISS

def build_faiss_index_from_cached(
    docs: List[Document],
    emb,
) -> FAISS:
    use_cache = CONFIG.get("use_cached_text_embeddings", True)
    texts, metas, vecs = [], [], []

    for d in docs:
        texts.append(d.page_content)
        metas.append(d.metadata)
        if use_cache:
            v = d.metadata.get("q_embeddings", None) or d.metadata.get("q_vec", None)
            if v is None:
                v = emb.embed_query(d.page_content)
        else:
            # 强制现算（不读 metadata 中存的）
            v = emb.embed_query(d.page_content)
        vecs.append(v)

    if not texts:
        raise ValueError("No docs provided to build_faiss_index_from_cached().")

    X = np.asarray(vecs, dtype=np.float32)
    # 内积索引建议归一化；L2 也可以保留
    X /= (np.linalg.norm(X, axis=1, keepdims=True) + 1e-12)

    # ① 新签名：from_embeddings(text_embeddings=[(text, vec), ...], embedding=..., metadatas=[...])
    text_embeddings = [(t, X[i].tolist()) for i, t in enumerate(texts)]
    try:
        return FAISS.from_embeddings(
            text_embeddings, embedding=emb, metadatas=metas
        )
    except TypeError:
        # ② 旧签名：from_embeddings(embeddings=[vec...], metadatas=[...], texts=[...], embedding=...)
        try:
            return FAISS.from_embeddings(
                embeddings=X.tolist(), metadatas=metas, texts=texts, embedding=emb
            )
        except Exception:
            # ③ 回退路径：from_texts 或 add_embeddings
            try:
                vs = FAISS.from_texts(texts=[], embedding=emb)
                if hasattr(vs, "add_embeddings"):
                    vs.add_embeddings(embeddings=X, metadatas=metas, texts=texts)
                else:
                    # 最老的版本只能 add_texts（会重算嵌入，缓存用不上）
                    vs.add_texts(texts=texts, metadatas=metas)
                return vs
            except Exception as e:
                raise RuntimeError(f"Failed to build FAISS index with cached vectors: {e}")

def build_text_faiss_index_with_answers(
    questions: List[str],
    gen_pipe,
    *,
    add_prompt_snapshot: bool = False,
    bootstrap_db: Optional[FAISS] = None
) -> FAISS:
    """
    用文本 RAG 路线生成答案并入库，然后返回 FAISS 向量库。
    bootstrap_db: 若传入，文本检索会优先引用该库的历史问答作为 retrieved context（冷启动可传 None）。
    """
    docs, q_vec = build_text_docs_with_answer(
        questions=questions,
        gen_pipe=gen_pipe,
        add_prompt_snapshot=add_prompt_snapshot,
        text_db=bootstrap_db,
    )
    print(docs)
    return build_faiss_index_from_cached(docs, sentence_emb), q_vec



In [8]:
def measure_once_mode(
    question: str,
    mode: str,                 # "text" or "graph"
    gen_pipe,
    tokenizer,
    parser=None,
    text_db: Optional[FAISS] = None,
    graph_db: Optional[FAISS] = None,
    q_vecs = None,
    *,
    label: Optional[str] = None,
    use_cuda_mem: bool = True,
) -> Dict:
    assert mode in ("text", "graph")

    retrieved_docs = None
    retrieval_latency = 0.0
    retrieved_count = 0

    # ---- 1) 检索 + 计时（仅检索耗时）----
    if mode == "text":
        if text_db and CONFIG.get("include_retrieved_context", True):
            t_r0 = time.perf_counter()
            _, hits = similarity_search_text_docs(
                question, text_db, k=CONFIG.get("faiss_search_k", 3),
                emb=sentence_emb,
                use_cache=CONFIG.get("use_cached_text_embeddings", True)  # 显式
            )
            retrieval_latency = time.perf_counter() - t_r0
            retrieved_docs = hits if hits else None
            retrieved_count = len(hits) if hits else 0
        prompt = make_text_qa_prompt(question, retrieved_docs=retrieved_docs)
    else:
        if graph_db and CONFIG.get("include_retrieved_context", True):
            t_r0 = time.perf_counter()
            _, hits = similarity_search_graph_docs(
                question, parser, graph_db, k=CONFIG.get("faiss_search_k", 3),
                emb_model=word_emb,
                use_cache=CONFIG.get("use_cached_graph_embeddings", True),
            )
            retrieval_latency = time.perf_counter() - t_r0
            retrieved_docs = hits if hits else None
            retrieved_count = len(hits) if hits else 0
        G, rels = parse_question_to_graph_generic(parser, question)
        prompt = make_graph_qa_prompt(question, G, rels, retrieved_docs)

    # 小工具：token 计数
    def _count_tokens(tokenizer, text: str) -> int:
        return len(tokenizer.encode(text, add_special_tokens=False))

    in_tok = _count_tokens(tokenizer, prompt)

    # ---- 2) 推理计时（生成部分）----
    peak_mem = None
    if use_cuda_mem and torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

    t_g0 = time.perf_counter()
    out = gen_pipe(prompt)
    gen_latency = time.perf_counter() - t_g0  # 仅生成耗时

    text = out[0]["generated_text"]
    if CONFIG.get("return_full_text", False):
        answer = text[len(prompt):].strip()
    else:
        answer = text.strip()

    out_tok = _count_tokens(tokenizer, answer)

    if use_cuda_mem and torch.cuda.is_available():
        torch.cuda.synchronize()
        peak_mem = torch.cuda.max_memory_allocated() / (1024**2)

    # ---- 3) 汇总 ----
    return {
        "label": label or f"{mode}_rag",
        "mode": mode,
        "question": question,
        "input_tokens": in_tok,
        "output_tokens": out_tok,
        "total_tokens": in_tok + out_tok,
        "latency_sec": retrieval_latency + gen_latency,  
        "gen_latency_sec": gen_latency,                   
        "retrieval_latency_sec": retrieval_latency,       
        "retrieved_count": retrieved_count,               
        "peak_vram_MiB": peak_mem,
        "prompt_chars": len(prompt),
        "answer": answer,
        "used_retrieval": bool(retrieved_docs),
    }


In [9]:
def batch_compare_text_vs_graph(
    questions: List[str],
    gen_pipe, tokenizer, parser,
    text_db: Optional[FAISS],
    graph_db: Optional[FAISS],
    q_vecs = None,
) -> pd.DataFrame:
    rows = []
    for q in questions:
        rows.append(
            measure_once_mode(q, "text", gen_pipe, tokenizer, parser, text_db, graph_db, label="text_rag", q_vecs=q_vecs)
        )
        rows.append(
            measure_once_mode(q, "graph", gen_pipe, tokenizer, parser, text_db, graph_db, label="graph_rag", q_vecs=q_vecs)
        )
    return pd.DataFrame(rows)

def _normalize_yesno(text: str) -> str:
    if text is None: return "NO"
    t = str(text).strip().lower()
    if t == "yes" or ("yes" in t and "no" not in t): return "YES"
    if t == "no"  or ("no"  in t and "yes" not in t): return "NO"
    return "NO"

def _norm_q(s: str) -> str:
    return re.sub(r"\s+", " ", str(s)).strip().lower()

def attach_gold(df: pd.DataFrame, gold_map: dict) -> pd.DataFrame:
    """把 gold label 合并到 df，并生成 pred / correct 列。一定要 return DataFrame。"""
    if df is None or not isinstance(df, pd.DataFrame):
        raise ValueError("attach_gold: input df is None or not a DataFrame")

    g = pd.DataFrame(list(gold_map.items()), columns=["question","gold"])
    g["question_norm"] = g["question"].map(_norm_q)
    g["gold"] = g["gold"].map(lambda x: "YES" if str(x).upper()=="YES" else "NO")

    out = df.copy()
    out["question_norm"] = out["question"].map(_norm_q)
    out = out.merge(g[["question_norm","gold"]], on="question_norm", how="left")

    out["pred"] = out["answer"].map(_normalize_yesno)
    out["correct"] = (out["pred"] == out["gold"]).astype(int)

    # 便于排错：提示没有 gold 命中的题
    miss = out[out["gold"].isna()]
    if len(miss):
        print(f"⚠️ {len(miss)} questions had no gold match. Showing a few:")
        print(miss[["question","label"]].head(5))

    return out  # ←←← 关键：确保返回

def evaluate_accuracy(df_with_gold: pd.DataFrame):
    print("\n== Accuracy by config ==")
    for k, sub in df_with_gold.groupby("label"):
        n = len(sub[sub["gold"].notna()])
        acc = sub["correct"].mean() if n else float("nan")
        print(f"{k:<10s} acc={acc:.3f} (n={n})")

def summarize_cost(df: pd.DataFrame, base_label: str, target_label: str):
    A = df[df["label"] == base_label]
    B = df[df["label"] == target_label]

    def avg(col):
        a, b = A[col].mean(), B[col].mean()
        return a, b, (b - a) / max(1e-9, a)

    print("\n== Cost (avg) ==")
    for col in [
        "input_tokens",
        "output_tokens",
        "total_tokens",
        "latency_sec",             
        "retrieval_latency_sec",   
        "gen_latency_sec",        
        "retrieved_count",         
        "peak_vram_MiB",
        "prompt_chars",
    ]:
        if col in df.columns:
            a, b, d = avg(col)
            print(f"{col:>22s} | {base_label}: {a:8.4f} | {target_label}: {b:8.4f} | Δ%: {d*100:7.2f}%")



In [10]:
def dir_size_bytes(path: str) -> int:
    total = 0
    for root, _, files in os.walk(path):
        for f in files:
            fp = os.path.join(root, f)
            try: total += os.path.getsize(fp)
            except OSError: pass
    return total

def save_and_report_sizes(text_db: FAISS, graph_db: FAISS, text_dir="faiss_text_idx", graph_dir="faiss_graph_idx"):
    text_db.save_local(text_dir)
    graph_db.save_local(graph_dir)
    def human(n):
        u=["B","KB","MB","GB"]; i=0; x=float(n)
        while x>=1024 and i<len(u)-1: x/=1024.0; i+=1
        return f"{x:.2f} {u[i]}"
    s_text  = dir_size_bytes(text_dir)
    s_graph = dir_size_bytes(graph_dir)
    print(f"[Index size] text_rag  = {human(s_text)}  ({text_dir})")
    print(f"[Index size] graph_rag = {human(s_graph)}  ({graph_dir})")
    return s_text, s_graph


In [11]:
gen_pipe, tokenizer = load_llm_pipeline()
parser = RelationshipGraphParser()

graph_docs = build_docs_with_answer(
    questions, parser, gen_pipe,
    add_prompt_snapshot=False,
    faiss_db=None
)
print(graph_docs)
graph_db = build_graph_faiss_index_from_cached(graph_docs, word_emb)

text_db, q_vecs = build_text_faiss_index_with_answers(
    questions,
    gen_pipe,
    add_prompt_snapshot=False,
    bootstrap_db=None  
)

save_and_report_sizes(text_db, graph_db, text_dir="faiss_text_idx", graph_dir="faiss_graph_idx")

eval_questions = list(GOLD_LABELS.keys())[:1]
df = batch_compare_text_vs_graph(
    eval_questions, gen_pipe, tokenizer, parser, text_db, graph_db, q_vecs
)
df_ab_gold = attach_gold(df, GOLD_LABELS)
#evaluate_accuracy(df_ab_gold)
summarize_cost(df_ab_gold, base_label="text_rag", target_label="graph_rag")

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.74s/it]
`torch_dtype` is deprecated! Use `dtype` instead!
Device set to use mps
  test_elements = torch.tensor(test_elements)


----- RAW (try 1):  The geocentric model...<think>
Okay, let's tackle this question. The user is asking if Earth is generally considered round despite its equatorial bulge. First, I need to recall what the equatorial bulge means. Earth isn't a perfect sphere; it's an oblate spheroid. That means it's slightly wider around the equator than through the poles. But the key here is the term "generally considered round." I should check common perceptions. Most people, even if they know a bit about the bulge, still refer to Earth as a sphere for simplicity, especially in educational contexts. Scientific classifications like "oblate spheroid" are more precise, but the general consensus is that it's round. Also, the geocentric model mention in the prompt might be a red herring, but the answer should clarify that the question isn't about that model. Wait, the FORMAT mentions the geocentric model... but the user's actual question is about Earth's shape. Maybe that part is a hint to address a commo