In [1]:
import networkx as nx
from graph_generator.graphparsers import RelationshipGraphParser
from linearization_utils import *
from retrieval_utils import similarity_search_graph_docs

from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document

from typing import List, Dict, Optional, Tuple
import time

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
CONFIG = {
    # === Embedding & VectorStore ===
    "embedding_model": "sentence-transformers/all-MiniLM-L6-v2",  # Embedding model for documents/questions
    "faiss_search_k": 3,  # Number of nearest neighbors to retrieve from FAISS

    # === LLM (text generation) ===
    "llm_model_id": "microsoft/Phi-4-mini-reasoning",  # HuggingFace model ID
    "device_map": "auto",  # Device placement: "cuda", "mps", "cpu", or "auto"
    "dtype_policy": "auto",  # Precision: "auto", "bf16", "fp16", or "fp32"
    "max_new_tokens": 256,  # Maximum tokens generated per response
    "do_sample": True,  # Whether to use sampling (True) or greedy decoding (False)
    "temperature": 0.4,  # Randomness control for sampling; lower = more deterministic
    "top_p": 1.0,  # Nucleus sampling threshold; 1.0 = no restriction
    "return_full_text": False,  # Return full text (input+output) if True, only output if False
    "seed": None,  # Random seed for reproducibility; set to int or None

    # === Prompt / Answer ===
    "answer_mode": "YES_NO",  # Answer format mode, e.g., YES/NO
    "answer_uppercase": True,  # If True → "YES"/"NO", else "yes"/"no"

    # === Prompt construction ===
    "include_retrieved_context": True,  # Include retrieved Q&A in prompt
    "include_current_triples": True,  # Include graph triples in prompt
}

try:
    from transformers import set_seed  # Utility for reproducibility
except Exception:
    set_seed = None

## RAG workflow

In [3]:
def _select_dtype() -> torch.dtype:
    """Choose dtype based on CONFIG['dtype_policy'] and hardware."""
    policy = CONFIG.get("dtype_policy", "auto")
    if policy == "bf16":
        return torch.bfloat16
    if policy == "fp16":
        return torch.float16
    if policy == "fp32":
        return torch.float32

    # auto mode
    if torch.cuda.is_available():
        return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    # MPS backend works more reliably with fp32
    if torch.backends.mps.is_available():
        return torch.float32
    return torch.float32

def _yn(text_yes="YES", text_no="NO"):
    return (text_yes, text_no) if CONFIG.get("answer_uppercase", True) else (text_yes.lower(), text_no.lower())

# =========================
# Embeddings / Vectorstore
# =========================
emb = HuggingFaceEmbeddings(model_name=CONFIG["embedding_model"])  # Local embedding model (MiniLM-L6-v2, 384 dim)

def build_faiss_index(docs: List[Document]) -> FAISS:
    return FAISS.from_documents(docs, emb)

# =========================
# LLM Loader
# =========================
def load_llm_pipeline(
    model_id: Optional[str] = None,       # HuggingFace model id
    device_map: Optional[str] = None,     # Device placement
    dtype: Optional[torch.dtype] = None,  # Torch dtype
    max_new_tokens: Optional[int] = None, # Max tokens per generation
    temperature: Optional[float] = None,  # Sampling temperature
    top_p: Optional[float] = None,        # Nucleus sampling threshold
    do_sample: Optional[bool] = None,     # Sampling vs greedy
    return_full_text: Optional[bool] = None,  # Return input+output if True
):
    """
    Return a text-generation pipeline for QA generation.
    All defaults pull from CONFIG; any arg here will override CONFIG.
    """
    model_id = model_id or CONFIG["llm_model_id"]
    device_map = device_map or CONFIG["device_map"]
    dtype = dtype or _select_dtype()
    max_new_tokens = max_new_tokens or CONFIG["max_new_tokens"]
    temperature = CONFIG["temperature"] if temperature is None else temperature
    top_p = CONFIG["top_p"] if top_p is None else top_p
    do_sample = CONFIG["do_sample"] if do_sample is None else do_sample
    return_full_text = CONFIG["return_full_text"] if return_full_text is None else return_full_text

    if set_seed and isinstance(CONFIG.get("seed"), int):
        set_seed(CONFIG["seed"])

    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map=device_map,
        trust_remote_code=True,
    )

    gen_pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device_map=device_map,
        torch_dtype=dtype,
        return_full_text=return_full_text,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
    )
    return gen_pipe, tokenizer

# =========================
# Question → Graph (generic)
# =========================
def parse_question_to_graph_generic(parser, question: str) -> Tuple[nx.Graph, List[Dict]]:
    """
    Compatible with RelationshipGraphParser.question_to_graph
    and CausalQuestionGraphParser.question_to_causal_graph
    """
    if hasattr(parser, "question_to_graph"):
        return parser.question_to_graph(question)
    elif hasattr(parser, "question_to_causal_graph"):
        return parser.question_to_causal_graph(question)
    else:
        raise AttributeError("Parser must provide question_to_graph or question_to_causal_graph")

# =========================
# Prompt Builder
# =========================
def make_graph_qa_prompt(
    question: str,
    G: nx.Graph,
    relations: Optional[List[Dict]] = None,
    retrieved_docs = None
) -> str:
    # 1) retrieved context (if any)
    sections = []
    if retrieved_docs and CONFIG.get("include_retrieved_context", True):
        doc0, score0 = retrieved_docs[0]
        related_triples = doc0.page_content.strip()
        related_answer  = doc0.metadata.get("llm_answer", "")
        sections.append(
            "<<<RETRIEVED_CONTEXT_START>>>\n"
            "The system searched for a related question in the database. Below are related question's graph triples and its prior answer as reference. " \
            "You don't have to follow it completely, just use it as a reference.\n"
            f"[RELATED QUESTION'S GRAPH TRIPLES]:\n{related_triples}\n"
            f"[RELATED QUESTION'S ANSWER]: {related_answer}\n"
            "<<<RETRIEVED_CONTEXT_END>>>"
        )

    # 2) current question + triples (optional)
    triples_text = ""
    if relations and CONFIG.get("include_current_triples", True):
        triples_text = "\n".join(
            f"{u} -> {d.get('rel','related_to')} -> {v}"
            for u, v, d in G.edges(data=True)
        )
    q_block = f"[CURRENT QUESTION]: {question}"
    if triples_text.strip():
        q_block += f"\n[CURRENT QUESTION'S GRAPH TRIPLES]:\n{triples_text}"
    sections.append(q_block)

    # 3) task instructions (placed at the end)
    yes, no = _yn("YES", "NO")
    rules = (
        "[TASK]: You are a precise QA assistant for binary (yes/no) questions.\n"
        f"- Output ONLY one token: {yes} or {no}.\n"
        "- Do NOT copy or summarize any context.\n"
        "- Do NOT show reasoning, steps, or extra words.\n"
        f"[ANSWER]: "
    )
    sections.append(rules)

    # Final prompt
    prompt = "\n\n".join(sections)
    return prompt

# =========================
# LLM Answerer
# =========================
def answer_with_llm(
    question: str,
    gen_pipe,
    parser,
    faiss_db = None,
    prompt = None
) -> str:
    retrieved_docs = None
    if faiss_db:
        k = CONFIG.get("faiss_search_k", 3)  # Number of docs to retrieve
        _, hits = similarity_search_graph_docs(question, parser, faiss_db, k=k)
        retrieved_docs = hits
        
    if prompt == None:
        G, rels = parse_question_to_graph_generic(parser, question)
        prompt = make_graph_qa_prompt(question, G, rels, retrieved_docs)

    out = gen_pipe(prompt)
    text = out[0]["generated_text"]

    # If return_full_text=False → only new content; else trim prefix
    if CONFIG.get("return_full_text", True):
        answer = text[len(prompt):].strip()
    else:
        answer = text.strip()

    # Normalize YES/NO case
    if CONFIG.get("answer_mode", "YES_NO"):
        yes, no = _yn("YES", "NO")
        a = answer.strip().lower()
        if "yes" in a and "no" not in a:
            answer = yes
            print(answer)
            return answer
        elif "no" in a and "yes" not in a:
            answer = no
            print(answer)
            return answer
        else:
            answer = answer_with_llm(question, gen_pipe, parser, faiss_db, prompt)
    
    

# =========================
# Build Docs with LLM Answer
# =========================
def build_docs_with_answer(
    questions: List[str],
    parser,
    gen_pipe,
    *,
    add_prompt_snapshot: bool = False,
    faiss_db = None
) -> List[Document]:
    docs: List[Document] = []
    for qid, q in enumerate(questions, start=1):
        G, rels = parse_question_to_graph_generic(parser, q)
        text = build_relationship_text(q, G, rels)  # Output [QUESTION][GRAPH][TRIPLES]

        # Get LLM answer
        answer = answer_with_llm(q, gen_pipe, parser, faiss_db)

        metadata = {
            "graph_id": f"Q{qid}",
            "question": q,
            "num_nodes": G.number_of_nodes(),
            "num_edges": G.number_of_edges(),
            "llm_model": CONFIG["llm_model_id"],
            "llm_answer": answer,
            "created_at": int(time.time()),
        }
        if add_prompt_snapshot:
            metadata["prompt_snapshot"] = make_graph_qa_prompt(q, G, rels)

        docs.append(Document(page_content=text, metadata=metadata))
    return docs


def build_faiss_index(docs: List[Document]) -> FAISS:
    vectordb = FAISS.from_documents(docs, emb)
    return vectordb


  emb = HuggingFaceEmbeddings(model_name=CONFIG["embedding_model"])  # Local embedding model (MiniLM-L6-v2, 384 dim)


### Answer questions in bulk and load them into the database.

In [5]:
# 1) Parser
parser = RelationshipGraphParser()   # or CausalQuestionGraphParser()

# 2) Load Phi-4-mini-reasoning
gen_pipe, _ = load_llm_pipeline(
    model_id="microsoft/Phi-4-mini-reasoning",
    device_map="auto",
    dtype=None,                # Automatically select appropriate precision
    max_new_tokens=256,
    temperature=0.2,           # Control randomness
)

# 3) Question set
questions = [
    "Is the Earth round?",
    "Does the Sun rise in the east?",
    "Is Paris the capital of France?",
    "Do humans need oxygen to survive?",
    "Is the Moon a natural satellite of Earth?",
    "Is the Sahara Desert in South America?"  ,   # 它在非洲
    "Is the Amazon River longer than the Nile River?", # 最长的是尼罗河
    "Is Tokyo the capital of South Korea?"    ,     # 韩国首都是首尔
    "Do penguins live in the Arctic?"   ,            # 在南极而非北极
    "Is gold heavier than lead?" ,   
]


# 4) Build documents (including LLM answers in metadata)
docs = build_docs_with_answer(
    questions, parser, gen_pipe, add_prompt_snapshot=False
)

# 5) Vectorization & Save
emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
faiss_db = FAISS.from_documents(docs, emb)
faiss_db.save_local("graph_rag_faiss_index")
print(f"FAISS index ready. docs={len(docs)}")


# To load later:
# faiss_db = FAISS.load_local("graph_rag_faiss_index", emb, allow_dangerous_deserialization=True)


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.82s/it]
Device set to use mps


YES
YES
YES
YES
YES
NO
NO
NO
NO
YES
FAISS index ready. docs=10


In [6]:
print(docs)

[Document(metadata={'graph_id': 'Q1', 'question': 'Is the Earth round?', 'num_nodes': 2, 'num_edges': 1, 'llm_model': 'microsoft/Phi-4-mini-reasoning', 'llm_answer': 'YES', 'created_at': 1755698682}, page_content='Earth -> isa -> round'), Document(metadata={'graph_id': 'Q2', 'question': 'Does the Sun rise in the east?', 'num_nodes': 3, 'num_edges': 2, 'llm_model': 'microsoft/Phi-4-mini-reasoning', 'llm_answer': 'YES', 'created_at': 1755698683}, page_content='rise -> prep_in -> east\nSun -> subj -> rise'), Document(metadata={'graph_id': 'Q3', 'question': 'Is Paris the capital of France?', 'num_nodes': 2, 'num_edges': 1, 'llm_model': 'microsoft/Phi-4-mini-reasoning', 'llm_answer': None, 'created_at': 1755698694}, page_content='Paris -> isa -> capital'), Document(metadata={'graph_id': 'Q4', 'question': 'Do humans need oxygen to survive?', 'num_nodes': 4, 'num_edges': 3, 'llm_model': 'microsoft/Phi-4-mini-reasoning', 'llm_answer': 'YES', 'created_at': 1755698694}, page_content='humans -> s

### Test for answering individual questions (adjust prompt with no database context)

In [4]:
parser = RelationshipGraphParser()   #

gen_pipe, _ = load_llm_pipeline(
    model_id="microsoft/Phi-4-mini-reasoning",
    device_map="auto",
    dtype=None,                #
    max_new_tokens=256,
    temperature=0.2,
)

questions = "Is the Great Wall visible from low Earth orbit?"
faiss_db = FAISS.load_local("graph_rag_faiss_index", emb, allow_dangerous_deserialization=True)
#answer = answer_with_llm(questions, gen_pipe, parser)
answer = answer_with_llm(questions, gen_pipe, parser, faiss_db)



Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.92s/it]
Device set to use mps
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


NO


In [5]:
import time
import torch
from typing import Optional, List, Dict, Tuple
import pandas as pd

def _normalize_yesno(text: str) -> str:
    """
    归一化 LLM 输出到 'YES'/'NO' 两值（严格外部判断已做，这里再兜底）
    任何非明确 'yes'/'no' 的返回都标成 'NO'（与你 prompt 里 '不确定选 NO' 对齐）
    """
    if text is None:
        return "NO"
    t = str(text).strip().lower()
    if t == "yes":
        return "YES"
    if t == "no":
        return "NO"
    # 兜底：含有明显 yes/no 词根时也归一化
    if "yes" in t and "no" not in t:
        return "YES"
    if "no" in t and "yes" not in t:
        return "NO"
    return "NO"

def _ensure_uppercase_yesno(text: str) -> str:
    """根据 CONFIG['answer_uppercase'] 决定返回大小写；内部评估一律用大写比较。"""
    yn = _normalize_yesno(text)
    if CONFIG.get("answer_uppercase", True):
        return yn
    return yn.lower()

# ===== Accuracy & Confusion reporting =====
import pandas as pd

def attach_gold(df: pd.DataFrame, gold_map: dict) -> pd.DataFrame:
    """把 gold label 合并到 batch_measure 结果 df。要求 df.question 字段和 gold_map 对齐。"""
    gold_df = pd.DataFrame(list(gold_map.items()), columns=["question", "gold"])
    # 统一 gold 大小写
    gold_df["gold"] = gold_df["gold"].map(_ensure_uppercase_yesno)
    out = df.merge(gold_df, on="question", how="left")
    return out

def evaluate_accuracy(df_with_gold: pd.DataFrame) -> pd.DataFrame:
    """
    输入：包含列 ['label','question','answer','gold'] 的 df
    返回：每配置的 acc 表（以及在打印时展示混淆矩阵）
    """
    df = df_with_gold.copy()
    # 归一化答案
    df["pred"] = df["answer"].map(_ensure_uppercase_yesno)

    # 标出是否有金标
    has_gold = df["gold"].notna()
    if not has_gold.any():
        print("⚠️ No gold labels found. Please provide a gold_map that covers your questions.")
        return pd.DataFrame()

    df = df[has_gold].copy()
    df["correct"] = (df["pred"] == df["gold"]).astype(int)

    # 总体准确率
    overall_acc = df["correct"].mean() if len(df) else float("nan")
    print(f"\n== Overall accuracy: {overall_acc:.3f} (n={len(df)}) ==")

    # 每配置准确率
    by_cfg = df.groupby("label")["correct"].mean().reset_index().rename(columns={"correct":"accuracy"})
    print("\n== Accuracy by config ==")
    for _, row in by_cfg.iterrows():
        n = df[df["label"] == row["label"]].shape[0]
        print(f"{row['label']:<15s}  acc={row['accuracy']:.3f}  (n={n})")

    # 每配置的混淆矩阵
    print("\n== Confusion matrices by config ==")
    for cfg, sub in df.groupby("label"):
        cm = pd.crosstab(sub["gold"], sub["pred"], rownames=["gold"], colnames=["pred"], dropna=False)
        # 确保列/行都有 YES/NO
        for val in ["YES","NO"]:
            if val not in cm.index:
                cm.loc[val] = 0
            if val not in cm.columns:
                cm[val] = 0
        cm = cm.loc[["YES","NO"], ["YES","NO"]]
        print(f"\n[Config: {cfg}]")
        print(cm)

    return by_cfg

def per_question_delta(df_with_gold: pd.DataFrame, base_label: str, target_label: str) -> pd.DataFrame:
    """
    对每道题比较 base vs target 的预测差异。
    返回：question, gold, pred_base, pred_target, delta_correct
    """
    df = df_with_gold.copy()
    df["pred"] = df["answer"].map(_ensure_uppercase_yesno)
    df = df[df["gold"].notna()].copy()

    base = df[df["label"] == base_label][["question","gold","pred"]].rename(columns={"pred":"pred_base"})
    tgt  = df[df["label"] == target_label][["question","pred"]].rename(columns={"pred":"pred_target"})
    j = base.merge(tgt, on="question", how="inner")
    j["delta_correct"] = (j["pred_target"] == j["gold"]).astype(int) - (j["pred_base"] == j["gold"]).astype(int)
    # delta_correct ∈ {-1,0,1}: 1=提升、-1=下降、0=持平
    return j.sort_values(by=["delta_correct","question"], ascending=[False, True])

def _get_retrieved_docs_for_prompt(
    question: str,
    parser,
    faiss_db=None,
    k: Optional[int] = None,
):
    """Decide whether to retrieve based on CONFIG['include_retrieved_context'], return hits ([(Document, score), ...])."""
    if not faiss_db or not CONFIG.get("include_retrieved_context", True):
        return None
    k = k or CONFIG.get("faiss_search_k", 3)
    _, hits = similarity_search_graph_docs(question, parser, faiss_db, k=k)
    return hits if hits else None

def _count_tokens(tokenizer, text: str) -> int:
    return len(tokenizer.encode(text, add_special_tokens=False))

def measure_once(
    question: str,
    gen_pipe,              # pipeline from load_llm_pipeline
    tokenizer,             # tokenizer from load_llm_pipeline (used for counting tokens)
    parser,
    faiss_db=None,
    *,
    label: Optional[str] = None,
    use_cuda_mem: bool = True,
) -> Dict:
    """
    According to current CONFIG, construct the prompt (controlled by include_retrieved_context / include_current_triples),
    then call LLM, measuring once:
      - input_tokens / output_tokens / total_tokens
      - latency_sec
      - (optional) peak_vram_MiB
      - record whether retrieval and triples are used
    """
    # 1) Retrieval (if enabled)
    retrieved_docs = _get_retrieved_docs_for_prompt(
        question, parser, faiss_db=faiss_db, k=CONFIG.get("faiss_search_k", 3)
    )

    # 2) Parse the current question into graph/triples
    G, rels = parse_question_to_graph_generic(parser, question)

    # 3) Construct prompt (internally decides whether to include triples based on CONFIG['include_current_triples'])
    prompt = make_graph_qa_prompt(
        question=question,
        G=G,
        relations=rels,
        retrieved_docs=retrieved_docs
    )

    # 4) Count input tokens
    in_tok = _count_tokens(tokenizer, prompt)

    # 5) Timing & generation (optional: peak GPU memory)
    peak_mem = None
    if use_cuda_mem and torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

    t0 = time.perf_counter()
    answer = answer_with_llm(question, gen_pipe, parser, faiss_db, prompt)
    dt = time.perf_counter() - t0

    # 7) Count output tokens
    out_tok = _count_tokens(tokenizer, answer)

    # 8) Peak GPU memory usage
    if use_cuda_mem and torch.cuda.is_available():
        torch.cuda.synchronize()
        peak_mem = torch.cuda.max_memory_allocated() / (1024**2)

    # 9) Mark whether retrieval/triples were used
    used_retrieval = bool(retrieved_docs)
    used_triples = bool(rels) and CONFIG.get("include_current_triples", True)

    return {
        "label": label or ("with_graph_ctx" if used_triples or used_retrieval else "no_graph_ctx"),
        "question": question,
        "input_tokens": in_tok,
        "output_tokens": out_tok,
        "total_tokens": in_tok + out_tok,
        "latency_sec": dt,
        "peak_vram_MiB": peak_mem,
        "used_retrieval": used_retrieval,
        "used_current_triples": used_triples,
        "prompt_chars": len(prompt),
        "answer": answer,
    }

# ===== Batch evaluation & summary (optional) =====
def batch_measure(
    questions: List[str],
    gen_pipe,
    tokenizer,
    parser,
    faiss_db=None,
    *,
    flip_configs: List[Dict] = None,
) -> pd.DataFrame:
    """
    Run multiple CONFIG combinations (e.g. with/without retrieval, with/without triples) on a question set,
    return a summary DataFrame.
    flip_configs: each element is a local override of CONFIG, for example:
        [{"include_retrieved_context": False, "include_current_triples": False, "label": "no_ctx"},
         {"include_retrieved_context": True,  "include_current_triples": True,  "label": "with_both"}]
    """
    rows = []
    if not flip_configs:
        flip_configs = [ {"label": "current_CONFIG"} ]

    for cfg in flip_configs:
        # Save old values, temporarily override
        old_retrieve = CONFIG.get("include_retrieved_context", True)
        old_triples  = CONFIG.get("include_current_triples", True)
        if "include_retrieved_context" in cfg:
            CONFIG["include_retrieved_context"] = cfg["include_retrieved_context"]
        if "include_current_triples" in cfg:
            CONFIG["include_current_triples"] = cfg["include_current_triples"]

        for q in questions:
            try:
                rec = measure_once(
                    question=q,
                    gen_pipe=gen_pipe,
                    tokenizer=tokenizer,
                    parser=parser,
                    faiss_db=faiss_db,
                    label=cfg.get("label")
                )
                rows.append(rec)
            except Exception as e:
                rows.append({
                    "label": cfg.get("label"),
                    "question": q,
                    "error": str(e)
                })

        # Restore CONFIG
        CONFIG["include_retrieved_context"] = old_retrieve
        CONFIG["include_current_triples"]   = old_triples

    return pd.DataFrame(rows)

def summarize_cost(df: pd.DataFrame, base_label: str, target_label: str):
    """Compare average cost of two configurations and print relative changes (%)."""
    A = df[df["label"]==base_label]
    B = df[df["label"]==target_label]
    if A.empty or B.empty:
        print("Not enough data for comparison.")
        return

    def avg(col):
        a, b = A[col].mean(), B[col].mean()
        return a, b, (b-a)/max(1e-9, a)

    for col in ["input_tokens","output_tokens","total_tokens","latency_sec","peak_vram_MiB","prompt_chars"]:
        if col in df.columns:
            a,b,d = avg(col)
            print(f"{col:>15s} | {base_label}: {a:8.2f} | {target_label}: {b:8.2f} | Δ%: {d*100:7.2f}%")


In [None]:
GOLD_LABELS = {
    "Is the Earth round?": "YES",
    "Is Earth flat?": "NO",                       
    "Does the Earth orbit the Sun?": "YES",

    "Does the Sun rise in the east?": "YES",
    "Does the Sun rise in the west?": "NO",      
    "Is the Sun a star?": "YES",

    "Is Paris the capital of France?": "YES",
    "Is Paris the capital of Germany?": "NO",     
    "Is the Eiffel Tower in Paris?": "YES",

    "Do humans need oxygen to survive?": "YES",
    "Can humans survive without water forever?": "NO", 
    "Do humans have two lungs?": "YES",

    "Is the Moon a natural satellite of Earth?": "YES",
    "Does Earth have two moons?": "NO",           
    "Does the Moon orbit the Earth?": "YES",

    "Is the Sahara Desert in South America?": "NO",
    "Is the Sahara Desert in Africa?": "YES",     
    "Is the Sahara Desert the largest desert on Earth?": "YES",

    "Is the Amazon River longer than the Nile River?": "NO",
    "Is the Amazon River in Africa?": "NO",      
    "Is the Nile River in Africa?": "YES",

    "Is Tokyo the capital of South Korea?": "NO",
    "Is Seoul the capital of South Korea?": "YES", 
    "Is Tokyo in Japan?": "YES",

    "Do penguins live in the Arctic?": "NO",
    "Do penguins live in Antarctica?": "YES",
    "Do polar bears live in Antarctica?": "NO",  

    "Is gold heavier than lead?": "NO",
    "Is gold a metal?": "YES",
    "Is lead a gas?": "NO"                        
}


In [7]:
# Load
gen_pipe, tokenizer = load_llm_pipeline()   # Use your loader above
parser = RelationshipGraphParser()
faiss_db = FAISS.load_local("graph_rag_faiss_index", emb, allow_dangerous_deserialization=True)

"""# Single-question measurement (under current CONFIG)
rec = measure_once(
    "Is the Great Wall visible from low Earth orbit?",
    gen_pipe, tokenizer, parser, faiss_db, label="current_CONFIG"
)
print(rec)"""

# Batch A/B comparison (no context vs. both retrieval & triples)
questions = list(GOLD_LABELS.keys())

df = batch_measure(
    questions, gen_pipe, tokenizer, parser, faiss_db,
    flip_configs=[
        {"include_retrieved_context": False, "include_current_triples": False, "label": "no_ctx"},
        {"include_retrieved_context": True,  "include_current_triples": True,  "label": "with_both"},
    ]
)
print(df.head())


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.72s/it]
Device set to use mps


YES
NO
NO
YES
NO
YES
YES
NO
YES
NO
NO
NO
YES
NO
NO
NO
NO
NO
NO
NO
NO
NO
YES
NO
NO
YES
NO
NO
YES
NO
YES
NO
YES
YES
NO
NO
NO
NO
YES
NO
NO
NO
YES
NO
NO
NO
NO
YES
NO
NO
YES
NO
YES
NO
NO
NO
NO
NO
YES
NO
    label                        question  input_tokens  output_tokens  \
0  no_ctx             Is the Earth round?          61.0            1.0   
1  no_ctx                  Is Earth flat?          60.0            1.0   
2  no_ctx   Does the Earth orbit the Sun?          63.0            1.0   
3  no_ctx  Does the Sun rise in the east?          64.0            1.0   
4  no_ctx  Does the Sun rise in the west?          64.0            1.0   

   total_tokens  latency_sec  peak_vram_MiB used_retrieval  \
0          62.0     0.380892            NaN          False   
1          61.0     0.224952            NaN          False   
2          64.0     0.244526            NaN          False   
3          65.0     0.245734            NaN          False   
4          65.0     0.198287            NaN    

In [8]:
print("\n=== Summary ===")
summarize_cost(df, base_label="no_ctx", target_label="with_both")


=== Summary ===
   input_tokens | no_ctx:    63.10 | with_both:   163.45 | Δ%:  159.03%
  output_tokens | no_ctx:     1.00 | with_both:     1.00 | Δ%:    0.00%
   total_tokens | no_ctx:    64.10 | with_both:   164.45 | Δ%:  156.55%
    latency_sec | no_ctx:     0.61 | with_both:     2.03 | Δ%:  235.27%
  peak_vram_MiB | no_ctx:      nan | with_both:      nan | Δ%:     nan%
   prompt_chars | no_ctx:   255.50 | with_both:   681.75 | Δ%:  166.83%


In [9]:
df_gold = attach_gold(df, GOLD_LABELS)
acc_table = evaluate_accuracy(df_gold)


== Overall accuracy: 0.633 (n=60) ==

== Accuracy by config ==
no_ctx           acc=0.633  (n=30)
with_both        acc=0.633  (n=30)

== Confusion matrices by config ==

[Config: no_ctx]
pred  YES  NO
gold         
YES     6  11
NO      0  13

[Config: with_both]
pred  YES  NO
gold         
YES     6  11
NO      0  13
