In [None]:
import pandas as pd
from doraemon import Doraemon
from kg_rag import Inference, KG_RAG_Tool, PromptBuilder


class IoE:
    PROMPT = {
        "ioe": (
            "If you are very confident about your answer, maintain your answer. "
            "Otherwise, update your answer. The final answer MUST be returned in this exact format at the very end: "
            "\\boxed{{final answer}}\n"
            # Output-only constraint:
            "Your entire response MUST be exactly that single boxed answer. "
            "Do not include any other text, punctuation, quotes, code fences, or reasoning."
        ),
        'fact_kg':{
            "gen": (
                "Provide your single best guess for the following fact checking question. "
                "Give ONLY the guess, no other words or explanation.\n\n"
                "Return ONLY the guess in this EXACT format:\n"
                "\\boxed{guess}\n"
            )
        },
    }

    # -------- helpers --------
    @staticmethod
    def _format_contexts(ctxs) -> str:
        """
        Turn contexts (None | str | List[str]) into numbered lines:
        Context1: ...
        Context2: ...
        """
        if ctxs is None:
            return ""
        if not isinstance(ctxs, (list, tuple)):
            ctxs = [ctxs]
        lines = []
        for i, c in enumerate(ctxs, 1):
            if c is None:
                continue
            s = str(c).strip()
            if s:
                lines.append(f"Context{i}: {s}")
        return "\n".join(lines)

    # -------- user content builder --------
    @classmethod
    def build_user(cls, question, contexts, previous_answer) -> str:
        q = (question or "").strip()
        ctx_block = cls._format_contexts(contexts)
        prev = (previous_answer or "").strip()

        parts = [f"Question: {q}"]
        if ctx_block:
            parts.append(ctx_block)
        parts.append(f"Previous Answer: {prev}")
        return "\n".join(parts)

    # -------- message pack builders (system + user) --------
    @classmethod
    def msgs_ioe(cls, question, contexts, previous_answer):
        """
        Returns a list[dict] with roles 'system' and 'user'.
        """
        return [
            {"role": "system", "content": cls.PROMPT["ioe"]},
            {"role": "user", "content": cls.build_user(question, contexts, previous_answer)},
        ]

    @classmethod
    def msgs_ioe_from_row(cls, row):
        """
        Convenience: build from a pandas Series/dict-like row with:
          - 'question'   : str
          - 'ctx_topk'   : List[str] or str
          - 's_out'      : str (previous answer)
        """
        return cls.msgs_ioe(
            question=row.get("question", ""),
            contexts=row.get("ctx_topk", []),
            previous_answer=row.get("s_out", "")
        )

def classify(x: str, n: int = 1200, dataset: str = "metaqa") -> pd.DataFrame:
    BASE_PATHS = {
        "metaqa": "/kaggle/input/filtered-multiple-hops-metaqa",
        "webqsp": "/kaggle/input/webqsp"
    }

    # Validate dataset input
    if dataset not in BASE_PATHS:
        raise ValueError(f"Invalid dataset '{dataset}'. Must be one of: {list(BASE_PATHS.keys())}")

    base = BASE_PATHS[dataset]
    x = (x or "").strip().lower()

    # Select file path based on dataset type
    if dataset == "metaqa":
        match x:
            case "one":
                path = f"{base}/one_hop_supported.pickle"
            case "three" | _:
                path = f"{base}/three_hop_supported.pickle"
    else:  # dataset == "another"
        match x:
            case "one":
                path = f"{base}/webqsp_ctxstyle_1200_hop1_nl.pkl"
            case "three":
                path = f"{base}/webqsp_ctxstyle_1200_hop3_nl.pkl"
            case "two" | _:
                raise ValueError("‚ùå 'two_hop_supported.pickle' does not exist in the 'current' dataset.")

    # Load and return dataframe
    df = pd.read_pickle(path)

    if dataset =="webqsp":
        df=df.rename(columns={
            "ground_truth":"Label",
            "contexts": "ctx_topk"
        })
    
    return df.head(n)

In [None]:
dataset="webqsp"

df = classify("three", n = 1200, dataset=dataset)


df['query'] = df.apply(lambda row: [
    {"role": "system", "content": IoE.PROMPT['fact_kg']['gen']},
    {"role": "user", "content": PromptBuilder.build_user_multi_contents(row['question'], row['ctx_topk'])}
], axis=1)

Doraemon.set_provider('llama3')
logger = Doraemon.get_logger(logfile='if_or_else.log')
tasks = df.to_dict(orient='records')
standard_out = await Inference.process_batches(tasks, logger, 'query')
df['s_out'] = pd.Series(standard_out, dtype='object')
df['s_out'] = df['s_out'].apply(KG_RAG_Tool.extract_boxed_answer)

df['ioe_prompt'] = df.apply(IoE.msgs_ioe_from_row, axis=1)

tasks = df.to_dict(orient='records')
ioe_a = await Inference.process_batches(tasks, logger, 'ioe_prompt')
df['ioe_a'] = pd.Series(ioe_a, dtype='object')

df['ioe_a'] = df['ioe_a'].apply(KG_RAG_Tool.extract_boxed_answer)

In [None]:
stats = KG_RAG_Tool.eval_accuracy(df, pred="ioe_a", g_t="Label")
print(f"Accuracy: {stats['accuracy']:.2%}  ({stats['num_correct']} / {stats['total']})")