In [None]:
import pandas as pd
from kg_rag import Inference, KG_RAG_Tool
from doraemon import Doraemon

# --- Self-correct prompt builder (3 steps: solve → review → revise) ---
from typing import List, Optional, Sequence

class SelfCorrectPromptBuilder:
    PROMPT = {
        "solve": (
            "## Role & Objective\n"
            "You are answering a question using ONLY the provided contexts.\n"
            "Requirements:\n"
            "- Use information strictly from the contexts; do not rely on outside knowledge.\n"
            "- Show concise, step-by-step reasoning.\n"
            "- The final answer MUST be a short phrase or a single number, returned in this exact format at the very end:\n"
            "  \\boxed{{answer}}\n"
            "\n"
            "## Provided Contexts\n"
            "{context}\n"
            "\n"
            "## Question\n"
            "{question}\n"
        ),
        "review": (
            "## Task: Review the previous solution for the question (do NOT provide a new solution)\n"
            "Review the answer using ONLY the provided contexts and identify issues. Focus on:\n"
            "- Misuse or omission of relevant context.\n"
            "- Logical/arithmetical mistakes.\n"
            "- Whether the final answer matches the facts in context.\n"
            "- Whether the output format (final answer in \\boxed{{answer}}) was followed.\n"
            "- If contexts are insufficient, the answer should be exactly 'Insufficient information'.\n"
            "\n"
            "## Provided Contexts\n"
            "{context}\n"
            "\n"
            "## Question\n"
            "{question}\n"
            "\n"
            "## Previous Solution\n"
            "{previous_answer}\n"
            "\n"
            "Explicitly remind that the final corrected answer in the next step must be returned in this format: "
            "\\boxed{{answer}}."
        ),
        "revise": (
            "## Task: Revise the previous solution\n"
            "Using ONLY the provided contexts, fix the reasoning and provide the best final answer for the  question.\n"
            "Requirements:\n"
            "- Correct any misinterpretations and errors noted in the review.\n"
            "- If contexts are insufficient, reply exactly with 'Insufficient information'.\n"
            "- Keep reasoning concise and accurate.\n"
            "- The final answer MUST be a short phrase or a single number, returned in this exact format at the very end:\n"
            "  \\boxed{{answer}}\n"
            "- Do NOT add any text after the boxed answer.\n"
            "\n"
            "## Provided Contexts\n"
            "{context}\n"
            "\n"
            "## Question\n"
            "{question}\n"
            "\n"
            "## Your Earlier Solution\n"
            "{previous_answer}\n"
            "\n"
            "## Issues to Address from the Review\n"
            "{review_points}\n"
        )
    }

    # ---------- format helpers ----------
    @staticmethod
    def _format_contexts(ctxs: Optional[Sequence[str]]) -> str:
        """
        Formats a sequence of contexts as numbered blocks. Accepts None, str, or Sequence[str].
        """
        if ctxs is None:
            ctxs = []
        if not isinstance(ctxs, (list, tuple)):
            ctxs = [str(ctxs)]
        lines = [f"Context{i+1}: {str(c)}" for i, c in enumerate(ctxs)]
        return "\n".join(lines) if lines else "(no context provided)"

    @staticmethod
    def as_user(content: str) -> dict:
        return {"role": "user", "content": content}

    @staticmethod
    def as_system(content: str) -> dict:
        return {"role": "system", "content": content}

    # ---------- prompt builders ----------
    @classmethod
    def build_solve(cls, question: str, contexts: Optional[Sequence[str]]) -> str:
        return cls.PROMPT["solve"].format(
            question=question,
            context=cls._format_contexts(contexts)
        )

    @classmethod
    def build_review(cls, question: str, contexts: Optional[Sequence[str]], previous_answer: str) -> str:
        return cls.PROMPT["review"].format(
            question=question,
            context=cls._format_contexts(contexts),
            previous_answer=previous_answer
        )

    @classmethod
    def build_revise(
        cls,
        question: str,
        contexts: Optional[Sequence[str]],
        previous_answer: str,
        review_points: Optional[str]
    ) -> str:
        if not review_points:
            review_points = "- Correct all issues identified in the review."
        return cls.PROMPT["revise"].format(
            question=question,
            context=cls._format_contexts(contexts),
            previous_answer=previous_answer,
            review_points=review_points
        )

    # ---------- convenient message-pack builders (mirrors your style) ----------
    @classmethod
    def msgs_solve(cls, question: str, contexts: Optional[Sequence[str]]) -> List[dict]:
        return [
            cls.as_system("You must use ONLY the provided contexts. Be concise and precise."),
            cls.as_user(cls.build_solve(question, contexts))
        ]

    @classmethod
    def msgs_review(cls, question: str, contexts: Optional[Sequence[str]], prev_answer: str) -> List[dict]:
        return [
            cls.as_system("Be a strict critic. Do NOT provide a new answer."),
            cls.as_user(cls.build_review(question, contexts, prev_answer))
        ]

    @classmethod
    def msgs_revise(
        cls,
        question: str,
        contexts: Optional[Sequence[str]],
        prev_answer: str,
        review_points: Optional[str]
    ) -> List[dict]:
        return [
            cls.as_system("Fix errors and finish with a single \\boxed{answer}. No extra text after the box."),
            cls.as_user(cls.build_revise(question, contexts, prev_answer, review_points))
        ]

def classify(x: str, n: int = 1200, dataset: str = "metaqa") -> pd.DataFrame:
    BASE_PATHS = {
        "metaqa": "/kaggle/input/filtered-multiple-hops-metaqa",
        "webqsp": "/kaggle/input/webqsp"
    }

    # Validate dataset input
    if dataset not in BASE_PATHS:
        raise ValueError(f"Invalid dataset '{dataset}'. Must be one of: {list(BASE_PATHS.keys())}")

    base = BASE_PATHS[dataset]
    x = (x or "").strip().lower()

    # Select file path based on dataset type
    if dataset == "metaqa":
        match x:
            case "one":
                path = f"{base}/one_hop_supported.pickle"
            case "three" | _:
                path = f"{base}/three_hop_supported.pickle"
    else:  # dataset == "another"
        match x:
            case "one":
                path = f"{base}/webqsp_ctxstyle_1200_hop1_nl.pkl"
            case "three":
                path = f"{base}/webqsp_ctxstyle_1200_hop3_nl.pkl"
            case "two" | _:
                raise ValueError("❌ 'two_hop_supported.pickle' does not exist in the 'current' dataset.")

    # Load and return dataframe
    df = pd.read_pickle(path)

    if dataset =="webqsp":
        df=df.rename(columns={
            "ground_truth":"Label",
            "contexts": "ctx_topk"
        })
    
    return df.head(n)

In [None]:
dataset="webqsp"
df = classify("three", n = 1200, dataset=dataset)

# Step 1: SOLVE
df['sc_solve'] = df.apply(lambda row: SelfCorrectPromptBuilder.msgs_solve(
    row['question'], row['ctx_topk']), axis=1)

Doraemon.set_provider('llama3')
logger = Doraemon.get_logger(logfile='self_correct_on_rag.log')
tasks = df.to_dict(orient='records')
solve_out = await Inference.process_batches(tasks, logger, 'sc_solve')
df['sc_solve_a'] = pd.Series(solve_out, dtype='object')

# Step 2: REVIEW (critique only, reminds about \boxed{answer})
df['sc_review'] = df.apply(lambda row: SelfCorrectPromptBuilder.msgs_review(
    row['question'], row['ctx_topk'], KG_RAG_Tool.extract_boxed_answer(row['sc_solve_a'])), axis=1)

tasks = df.to_dict(orient='records')
review_out = await Inference.process_batches(tasks, logger, 'sc_review')
df['sc_review_a'] = pd.Series(review_out, dtype='object')

In [None]:
# Step 3: REVISE (final corrected answer in \boxed{...})
df['sc_revise'] = df.apply(lambda row: SelfCorrectPromptBuilder.msgs_revise(
    row['question'], row['ctx_topk'], KG_RAG_Tool.extract_boxed_answer(row['sc_solve_a']), KG_RAG_Tool.extract_boxed_answer(row['sc_review_a'])), axis=1)

tasks = df.to_dict(orient='records')
revise_out = await Inference.process_batches(tasks, logger, 'sc_revise')
df['sc_revise_a'] = pd.Series(revise_out, dtype='object')

df['final_a'] = df['sc_revise_a'].apply(KG_RAG_Tool.extract_boxed_answer)
print(KG_RAG_Tool.eval_accuracy(df, pred='final_a', g_t='Label'))

df.to_pickle('self_correct.pkl')