In [None]:
import re

import pandas as pd
from kg_rag import Inference, KG_RAG_Tool
from doraemon import Doraemon
from typing import List, Optional, Sequence, Dict

class Verb1STop:

    SYSTEM_PROMPT = (
        "Provide your 4 BEST FINAL ANSWERS and the probability that each is correct (0.0 to 1.0) "
        "for the following question.\n\n"
        "STRICT REQUIREMENTS:\n"
        "1) Output EXACTLY 4 final answers, ordered by DESCENDING probability.\n"
        "2) Each answer is a short token/phrase (NOT a sentence) and MUST NOT contain the characters ':', ',', '{', '}', '[', ']'.\n"
        "3) Each probability is a numeric value between 0.0 and 1.0 (e.g., 0.95 or 1.0).\n"
        "4) Output ONLY the 4 answers and probabilities—no extra words, no explanations.\n"
        "5) The final output MUST follow this EXACT format, including the outer square brackets and inner braces:\n"
        "   \\boxed{{[{a1:p1},{a2:p2},{a3:p3},{a4:p4}]}}\n"
        "   - Note: Do NOT omit the leading '[' or the trailing ']'.\n"
        "   - Note: Each item MUST be wrapped in braces {…}.\n"
        "\n"
        "Where:\n"
        "- a1, a2, a3, a4 = your 4 most likely final answers (no ':', ',', '{', '}', '[', ']').\n"
        "- p1, p2, p3, p4 = probabilities for each answer.\n"
        "\n"
        "Example (format only):\n"
        "\\boxed{{[{True:0.75},{False:0.15},{True:0.07},{False:0.03}]}}\n"
    )


    # ----------- Format contexts helper (optional) -----------
    @staticmethod
    def _format_contexts(ctxs: Optional[Sequence[str]]) -> str:
        if not ctxs:
            return ""
        return "\n".join(f"Context{i+1}: {c}" for i, c in enumerate(ctxs))

    # ----------- Build messages (system + user) -----------
    @classmethod
    def msgs(cls, question: str, contexts: Optional[Sequence[str]] = None) -> List[Dict[str, str]]:
        """
        Build OpenAI-style messages:
        - System prompt contains strict format & final answer rules.
        - User prompt contains the question + optional contexts.
        """
        user_prompt_parts = []
        if contexts:
            user_prompt_parts.append("## Provided Contexts")
            user_prompt_parts.append(cls._format_contexts(contexts))
            user_prompt_parts.append("")  # blank line
        user_prompt_parts.append(f"Question: {question}")
        user_prompt = "\n".join(user_prompt_parts)

        return [
            {"role": "system", "content": cls.SYSTEM_PROMPT},
            {"role": "user", "content": user_prompt},
        ]


    @classmethod
    def get_first_guess_info(cls, extracted: str):
        """
        From the extracted string '[a1:p1],[a2:p2],...' or '{a1:p1},{a2:p2},...',
        return (final_a, probability) from the first valid pair.
        """

        if not extracted or pd.isna(extracted):
            return None, None

        # Regex to match the first [answer:prob] or {answer:prob} pair
        match = re.search(r'[\[\{]\s*([^:\]\}]+?)\s*:\s*([0-9]+(?:\.[0-9]+)?)\s*[\]\}]', extracted)
        if not match:
            return None, None

        # Extract answer and probability safely
        answer = match.group(1).strip()
        try:
            probability = float(match.group(2))
        except ValueError:
            probability = None

        return answer, probability


def classify(x: str, n: int = 1200, dataset: str = "metaqa") -> pd.DataFrame:
    BASE_PATHS = {
        "metaqa": "/kaggle/input/filtered-multiple-hops-metaqa",
        "webqsp": "/kaggle/input/webqsp"
    }

    # Validate dataset input
    if dataset not in BASE_PATHS:
        raise ValueError(f"Invalid dataset '{dataset}'. Must be one of: {list(BASE_PATHS.keys())}")

    base = BASE_PATHS[dataset]
    x = (x or "").strip().lower()

    # Select file path based on dataset type
    if dataset == "metaqa":
        match x:
            case "one":
                path = f"{base}/one_hop_supported.pickle"
            case "three" | _:
                path = f"{base}/three_hop_supported.pickle"
    else:  # dataset == "another"
        match x:
            case "one":
                path = f"{base}/webqsp_ctxstyle_1200_hop1_nl.pkl"
            case "three":
                path = f"{base}/webqsp_ctxstyle_1200_hop3_nl.pkl"
            case "two" | _:
                raise ValueError("❌ 'two_hop_supported.pickle' does not exist in the 'current' dataset.")

    # Load and return dataframe
    df = pd.read_pickle(path)

    if dataset =="webqsp":
        df=df.rename(columns={
            "ground_truth":"Label",
            "contexts": "ctx_topk"
        })
    
    return df.head(n)

In [None]:
dataset="webqsp"

df = classify("three", n = 1200, dataset=dataset)
df['query'] = df.apply(lambda row: Verb1STop.msgs(row['question'], row['ctx_topk']), axis=1)

Doraemon.set_provider('llama3')
logger = Doraemon.get_logger(logfile='verb_1s_top_4.log')
tasks = df.to_dict(orient='records')
q_a = await Inference.process_batches(tasks, logger, 'query')
df['q_a'] = pd.Series(q_a, dtype='object')

df[['final_a', 'fusion_prob']] = df['q_a'].apply(
    lambda x: pd.Series(
        Verb1STop.get_first_guess_info(KG_RAG_Tool.extract_boxed_answer(x))
    )
)

In [None]:
from calibration_metrics import CalibrationMetrics

print(KG_RAG_Tool.eval_accuracy(df, 'final_a', 'Label'))
# 1) Use your std_* columns
std_summary = CalibrationMetrics.summarize(df, prob_col="fusion_prob", correct_col="is_correct", n_bins=10, norm="l2")
print(std_summary["ece"], std_summary["brier"], std_summary["selective_auc"])
tbl_std = std_summary["reliability_table"]

In [None]:
df.to_pickle('verb1s_top4.pkl')