In [None]:
import re
import math

import numpy as np
import pandas as pd
from typing import List, Tuple, Dict, Optional, Any
from doraemon import Doraemon
from kg_rag import Inference, KG_RAG_Tool

class Verb2sTop4:
    """
    Two-stage (2S) prediction with context:
      - Stage 1 (gen): return \boxed{{a1,a2,a3,a4}}
      - Stage 2 (score): return \boxed{{p1,p2,p3,p4}}

    Context is provided per unique question and is injected into BOTH system and user roles.
    """

    SYSTEM_PROMPT: Dict[str, str] = {
        # ---------- Stage 1: Guess generation ----------
        "gen": (
            "Provide your {k} best guesses for the following question based on the given contexts. Give ONLY the guesses, "
            "no other words or explanation.\n\n"
            "For example:\n\n"
            "G1: <first most likely guess, as short as possible; not a complete sentence, just the guess!>\n"
            "...\n"
            "G{k}: <{k}-th most likely guess, as short as possible; not a complete sentence, just the guess!>\n\n"
            "Return ONLY the guesses in this EXACT format:\n"
            "\\boxed{{a1,a2,a3,a4}}\n"
        ),

        # ---------- Stage 2: Probability scoring ----------
        "score": (
            "Provide the probability that each of your {k} guesses is correct based on the given contexts. Give ONLY the probabilities, "
            "no other words or explanation.\n\n"
            "For example:\n\n"
            "P1: <the probability between 0.0 and 1.0 that G1 is correct>\n"
            "...\n"
            "P{k}: <the probability between 0.0 and 1.0 that G{k} is correct>\n\n"
            "Return ONLY the probabilities in this EXACT format:\n"
            "\\boxed{{p1,p2,p3,p4}}\n"
        ),
    }


    # ---------- message builders ----------
    @classmethod
    def build_stage1_messages(cls, question: str, context: str, k: int = 4) -> List[Dict[str, str]]:
        """
        Build chat messages for Stage 1 (candidate generation).
        """
        system_msg = cls.SYSTEM_PROMPT["gen"].format(k=k)
        user_msg = (
            f"CONTEXT:\n{context}\n\n"
            f"QUESTION:\n{question}\n\n"
        )
        return [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": user_msg},
        ]

    
    @classmethod
    def build_stage2_messages(
        cls,
        candidates: str,          # e.g., "a1,a2,a3,a4"
        question: str,
        context: str,
        k: int = 4
    ) -> List[Dict[str, str]]:
        """
        Build chat messages for Stage 2 (probability scoring).
        `candidates` is a single comma-separated string like "a1,a2,a3,a4".
        Output must be: \boxed{p1,p2,...,pk}
        """
        # Minimal safeguard: trim whitespace
        cand_str = (candidates or "").strip()
    
        system_msg = cls.SYSTEM_PROMPT["score"].format(k=k)
    
        user_msg = (
            f"CONTEXT:\n{context}\n\n"
            f"QUESTION:\n{question}\n\n"
            "YOUR GUESSES (comma-separated, in order):\n"
            f"{cand_str}\n\n"        )
    
        return [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": user_msg},
        ]


    # ---------------------- Utility Methods ----------------------
    @classmethod
    def _first(cls, cell: Any) -> Any:
        """If cell is (value, idx), return value; else return as-is."""
        return cell[0] if isinstance(cell, tuple) else cell

    @classmethod
    def _boxed_inner(cls, s: str) -> str:
        """
        Extract inner content from \boxed{{...}} or \\boxed{{...}}.
        Example: '\\boxed{{a1,a2,a3,a4}}' -> 'a1,a2,a3,a4'
        """
        if s is None:
            return ""
        s = str(s)
        m = re.search(r"\\+boxed\s*\{\s*\{?(.*?)\}?\s*\}\s*$", s.strip(), flags=re.DOTALL)
        return m.group(1).strip() if m else ""

    @classmethod
    def _try_kg_extract(cls, cell: Any) -> str:
        """
        Use KG_RAG_Tool.extract_boxed_answer if available, else fallback to regex.
        """
        txt = cls._first(cell)
        try:
            return KG_RAG_Tool.extract_boxed_answer(txt)
        except Exception:
            return cls._boxed_inner(txt)

    # ---------------------- Parsing Methods ----------------------
    @classmethod
    def parse_guesses(cls, cell: Any, k: int = 4) -> List[str]:
        """
        Parse guesses from q_a cell into list of guesses, padded with None if missing.
        """
        inner = cls._try_kg_extract(cell)
        if not inner:
            return [None] * k
        parts = [p.strip() for p in str(inner).split(",")]
        parts = [p if p else None for p in parts]
        if len(parts) < k:
            parts += [None] * (k - len(parts))
        return parts[:k]

    @classmethod
    def parse_probs(cls, cell: Any, k: int = 4) -> List[float]:
        """
        Parse probabilities from q_a_prob cell into list of floats, padded with NaN if missing.
        """
        inner = cls._try_kg_extract(cell)
        if not inner:
            return [float("nan")] * k
        raw = [p.strip() for p in str(inner).split(",")]
        out = []
        for p in raw[:k]:
            try:
                out.append(float(p))
            except Exception:
                out.append(float("nan"))
        if len(out) < k:
            out += [float("nan")] * (k - len(out))
        return out[:k]

    # ---------------------- Main Matching Method ----------------------
    @classmethod
    def pair_prob_guess(
        cls, q_a_cell: Any, q_a_prob_cell: Any, k: int = 4
    ) -> Tuple[List[Dict[float, str]], str, float]:
        """
        Match guesses and probabilities into List[Dict[prob: guess]].
        Returns:
            qa_pairs: List of dicts [{prob: guess}, ...]
            final_a: The highest-probability guess
            final_prob: The highest probability
        """
        guesses = cls.parse_guesses(q_a_cell, k=k)
        probs = cls.parse_probs(q_a_prob_cell, k=k)

        # Build List[Dict[prob: guess]]
        pairs = [{probs[i]: guesses[i]} for i in range(k)]

        # Find highest-probability guess (nan-safe)
        safe_probs = [
            -math.inf if (p is None or (isinstance(p, float) and math.isnan(p))) else p
            for p in probs
        ]

        if all(p == -math.inf for p in safe_probs):
            return pairs, None, None

        final_idx = int(np.argmax(safe_probs))
        final_a = guesses[final_idx]
        final_prob = probs[final_idx]

        return pairs, final_a, final_prob

    # ---------------------- DataFrame Processor ----------------------
    @classmethod
    def process_dataframe(cls, df, k: int = 4):
        """
        Given a DataFrame with columns q_a and q_a_prob,
        produces three new columns:
          - qa_pairs: List[Dict[prob: guess]]
          - final_a: Highest-probability guess
          - final_prob: Highest probability
        """
        df["qa_pairs"], df["final_a"], df["final_prob"] = zip(
            *df.apply(lambda r: cls.pair_prob_guess(r["q_a"], r["q_a_prob"], k=k), axis=1)
        )
        return df

def classify(x: str, n: int = 1200, dataset: str = "metaqa") -> pd.DataFrame:
    BASE_PATHS = {
        "metaqa": "/kaggle/input/filtered-multiple-hops-metaqa",
        "webqsp": "/kaggle/input/webqsp"
    }

    # Validate dataset input
    if dataset not in BASE_PATHS:
        raise ValueError(f"Invalid dataset '{dataset}'. Must be one of: {list(BASE_PATHS.keys())}")

    base = BASE_PATHS[dataset]
    x = (x or "").strip().lower()

    # Select file path based on dataset type
    if dataset == "metaqa":
        match x:
            case "one":
                path = f"{base}/one_hop_supported.pickle"
            case "three" | _:
                path = f"{base}/three_hop_supported.pickle"
    else:  # dataset == "another"
        match x:
            case "one":
                path = f"{base}/webqsp_ctxstyle_1200_hop1_nl.pkl"
            case "three":
                path = f"{base}/webqsp_ctxstyle_1200_hop3_nl.pkl"
            case "two" | _:
                raise ValueError("‚ùå 'two_hop_supported.pickle' does not exist in the 'current' dataset.")

    # Load and return dataframe
    df = pd.read_pickle(path)

    if dataset =="webqsp":
        df=df.rename(columns={
            "ground_truth":"Label",
            "contexts": "ctx_topk"
        })
    
    return df.head(n)

In [None]:
dataset="webqsp"

df = classify("three", n = 1200, dataset=dataset)

df['query']=df.apply(
    lambda x: Verb2sTop4.build_stage1_messages(
        question=x['question'], 
        context=x['ctx_topk']
    ), 
    axis=1
)

Doraemon.set_provider('llama3')
logger=Doraemon.get_logger(logfile='verb_2s_top_4.log')
tasks = df.to_dict(orient='records')

q_a = await Inference.process_batches(tasks, logger, 'query')
df['q_a'] = pd.Series(q_a, dtype='object')

df['query_prob']=df.apply(
    lambda x: Verb2sTop4.build_stage2_messages(
        candidates=KG_RAG_Tool.extract_boxed_answer(x['q_a']),
        question=x['question'], 
        context=x['ctx_topk']
    ), 
    axis=1
)

tasks = df.to_dict(orient='records')

query_prob_a = await Inference.process_batches(tasks, logger, 'query_prob')
df['q_a_prob'] = pd.Series(query_prob_a, dtype='object')

In [None]:
df["qa_pairs"], df["final_a"], df["final_prob"] = zip(
    *df.apply(lambda r: Verb2sTop4.pair_prob_guess(r["q_a"], r["q_a_prob"], k=4), axis=1)
)

print(KG_RAG_Tool.eval_accuracy(df, pred='final_a', g_t='Label'))

In [None]:
from calibration_metrics import CalibrationMetrics

# 1) Use your std_* columns
std_summary = CalibrationMetrics.summarize(df, prob_col="final_prob", correct_col="is_correct", n_bins=10, norm="l2")
print(std_summary["ece"], std_summary["brier"], std_summary["selective_auc"])
tbl_std = std_summary["reliability_table"]

In [None]:
df.to_pickle('verb2s_top4.pkl')