In [1]:
import re
import ast
import asyncio
import math
import pandas as pd
import numpy as np

from typing import Dict, Optional, Tuple, List


from doraemon import Doraemon

BOXED_PAT_DOUBLE = re.compile(r'\\boxed\s*\{\s*\{(.*?)\}\s*\}')
BOXED_PAT_SINGLE = re.compile(r'\\boxed\s*\{\s*(.*?)\s*\}')

# New regex patterns for extracting answer + confidence
BOXED_CONF_DOUBLE = re.compile(r"\\boxed\{\{([^,{}]+)\s*,\s*([0-9]*\.?[0-9]+)\}\}")
BOXED_CONF_SINGLE = re.compile(r"\\boxed\{([^,{}]+)\s*,\s*([0-9]*\.?[0-9]+)\}")

class KG_RAG_Tool:
    """
    A utility class containing common helper methods used throughout the KG‑RAG
    pipeline. All methods are defined as class methods so they can be
    invoked without instantiating the class.
    """

    @classmethod
    def extract_boxed_answer(cls, text):
        """
        Extracts the answer string from a value like '(\boxed{{David Hewlett}}, 7)'
        or '(\boxed{2009}, 6)'. Returns None if not found.
        """
        if text is None or (isinstance(text, float) and pd.isna(text)):
            return None
        s = str(text)
    
        # Primary: find content inside \boxed{{...}} or \boxed{...}
        m = BOXED_PAT_DOUBLE.search(s)
        if not m:
            m = BOXED_PAT_SINGLE.search(s)
        if m:
            return m.group(1).strip()
    
        # Fallback: if something like '(answer, 7)' without boxed
        # take the first comma-separated piece inside top-level parentheses
        m2 = re.match(r'^\s*\(([^,]+),', s)
        if m2:
            return m2.group(1).strip()
    
        return None


    @classmethod
    def extract_boxed_answer_and_confidence(
        cls, text: Optional[str]
    ) -> Optional[Tuple[Optional[str], Optional[float]]]:
        """
        Extracts the final answer and confidence score from model outputs.

        Expected formats:
            '\\boxed{{David Hewlett, 0.92}}'
            '\\boxed{2009, 1.0}'
            '(David Hewlett, 0.92)'  <-- fallback when \boxed is missing

        Args:
            text (Optional[str]): Raw LLM output containing boxed answer and confidence score.

        Returns:
            Optional[Tuple[Optional[str], Optional[float]]]:
                - answer (str): Extracted answer string, stripped of extra spaces.
                - confidence (float): Extracted confidence score, converted to float.
                - Returns None if neither answer nor confidence could be extracted.
        """
        # Handle empty, None, or NaN cases
        if text is None or (isinstance(text, float) and pd.isna(text)):
            return None, math.nan

        s = str(text).strip()
        m = BOXED_CONF_DOUBLE.search(s)
        if m:
            return m.group(1).strip(), float(m.group(2))
        m2 = re.match(r'^\s*\(([^,]+)\s*,\s*([0-9]*\.?[0-9]+)\)\s*$', s)
        if m2:
            return m2.group(1).strip(), float(m2.group(2))
        return None, math.nan
        

    @classmethod
    def norm(cls, x):
        """
        Normalise answers for robust comparison:
        - remove backslashes
        - strip quotes/brackets/whitespace
        - lowercase
        - collapse internal whitespace
        """
        if x is None or (isinstance(x, float) and pd.isna(x)):
            return None
        s = str(x)
        s = s.replace('\\', '')
        s = s.strip().strip('\'"[]{}()')
        s = re.sub(r'\s+', ' ', s).lower()
    
        return s

    @classmethod
    def parse_label_list(cls, label_val):
        """
        Ensure Label becomes a list[str] of normalised items.
        Accepts:
          - already a list/tuple/set -> normalise each item
          - string that looks like a list -> ast.literal_eval
          - other strings like '[a, b]' -> best-effort split
        """
        if label_val is None or (isinstance(label_val, float) and pd.isna(label_val)):
            return []
    
        if isinstance(label_val, (list, tuple, set)):
            return [cls.norm(x) for x in label_val]
    
        s = str(label_val).strip()
        # Try literal eval if it looks like a Python literal list
        if s.startswith('[') and s.endswith(']'):
            try:
                val = ast.literal_eval(s)
                if isinstance(val, (list, tuple, set)):
                    return [cls.norm(x) for x in val]
            except Exception:
                # fall through to regex split
                pass
    
            # Best-effort split on commas inside [ ... ]
            inner = s[1:-1].strip()
            if not inner:
                return []
            parts = [p.strip() for p in inner.split(',')]
            return [cls.norm(p) for p in parts]
    
        # Otherwise treat as a single label
        return [cls.norm(s)]

    @classmethod
    def build_fusion_contents(cls, q, ctxs, answer):
        """
        Build the concatenated contents for the fusion prompt. It concatenates
        the question, multiple contexts and the prior answer into a single
        string. Contexts may be None, a numpy array, a string or a list of
        strings.
        """
        if ctxs is None:
            ctxs = []
        if isinstance(ctxs, np.ndarray):
            ctxs = ctxs.tolist()
        if not isinstance(ctxs, (list, tuple)):
            ctxs = [ctxs]
        lines = [f"Question: {q}"] + [f"Context{i+1}: {c}" for i, c in enumerate(ctxs)] + [f"Answer: {answer}"]
        return "".join(lines)

    @classmethod
    def compute_metrics(cls, df, answer_col="answerable", decision_col="fusion"):
        """
        Compute the contingency counts and derived metrics. The metrics
        returned match those defined in the original notebooks and include
        Risk, Carefulness, Alignment and Coverage percentages.
        """
        a = df[answer_col].astype(str).str.strip().str.upper()
        d = df[decision_col].astype(str).str.strip().str.upper()
        AK = ((a == "A") & (d == "K")).sum()
        AD = ((a == "A") & (d == "D")).sum()
        UK = ((a == "U") & (d == "K")).sum()
        UD = ((a == "U") & (d == "D")).sum()
        N = AK + AD + UK + UD
        def pct(num, den):
            return 100.0 * (float(num) / den) if den > 0 else np.nan
        return pd.DataFrame([{
            "AK": AK, "AD": AD, "UK": UK, "UD": UD, "N": N,
            "Risk %": pct(UK, AK + UK),
            "Carefulness %": pct(UD, UK + UD),
            "Alignment %": pct(AK + UD, N),
            "Coverage %": pct(AK + UK, N),
        }])

    @classmethod
    def eval_accuracy(
        cls,
        df: pd.DataFrame,
        pred: str,
        g_t: str,
        out_col: str = "is_correct"
    ) -> Dict[str, float]:
        """
        Evaluate accuracy by checking if prediction is inside the label list.

        Args:
            df (pd.DataFrame): DataFrame containing predictions and labels.
            pred (str): Column name for predictions.
            g_t (str): Column name for ground-truth labels (list-like).
            out_col (str): Column name to store correctness flags.

        Returns:
            dict: {"accuracy": float, "num_correct": int, "total": int}
        """

        def _in_labels(p, labels) -> bool:
            if p is None or (isinstance(p, float) and pd.isna(p)):
                return False

            if isinstance(labels, np.ndarray):
                labels = labels.tolist()
            if not isinstance(labels, (list, tuple, set)):
                return False

            norm_p =cls.norm(p)

            norm_labels = {cls.norm(x) for x in labels if x is not None}

            return norm_p in norm_labels

        # Mark correctness per row
        df[out_col] = df.apply(lambda r: _in_labels(r[pred], r[g_t]), axis=1)

        total = len(df)
        num_correct = int(df[out_col].sum()) if total else 0
        accuracy = (num_correct / total) if total else float("nan")

        return {"accuracy": accuracy, "num_correct": num_correct, "total": total}
        

class PromptBuilder:
    """
    A class encapsulating the prompt templates used throughout the KG‑RAG
    pipeline. Two dictionaries are exposed: PROMPT for the initial RAG
    prompting and PROMPT_FUSION for the fusion stage. Helper methods are
    provided to assemble user messages.
    """
    PROMPT = {
        'rag': (
            "Use the provided contexts to answer the question. "
            "Always return a confidence score between 0.00 and 1.00 reflecting how confident you are that the final answer is correct. "
            "Output MUST be exactly one line in this format:\n"
            "\\boxed{{final answer, confidence score}}\n"
            "Do not include any other text. Examples:\n"
            "\\boxed{{David Hewlett, 0.92}}\n"
        ),
        'cf_use': (
            "Assume your previous answer is wrong due to improper use of the retrieved contexts. "
            "Carefully re-check the provided contexts and regenerate single answer using one or a few words. "
            "Always return a confidence score between 0.00 and 1.00 reflecting how confident you are that the final answer is correct. "
            "Output MUST be exactly one line in this format:\n"
            "\\boxed{{final answer, confidence score}}\n"
            "Do not include any other text. Examples:\n"
            "\\boxed{{David Hewlett, 0.95}}\n"
        ),
        'cf_quality': (
            "Assume your previous answer is wrong because the quality of the referred contexts is poor. "
            "Re-select the most relevant parts from the given contexts and regenerate single answer using one or a few words. "
            "Always return a confidence score between 0.00 and 1.00 reflecting how confident you are that the final answer is correct. "
            "Output MUST be exactly one line in this format:\n"
            "\\boxed{{final answer, confidence score}}\n"
            "Do not include any other text. Examples:\n"
            "\\boxed{{2009, 0.88}}\n"
        )
    }
    
    PROMPT_FUSION = {
        'fusion_use': (
            "Your answer is likely to be wrong because of the improper use of retrieval contexts. "
            "Decide keep/discard AND provide the best final answer."
            "Do not include commas inside the final answer. If the answer has multiple items, "
            "separate them with a semicolon and a space."
            "Output constraints:"
            "- No extra text before or after the box."
            "- Exactly one box with exactly two fields: [K or D]."
            "Return your decision and final answer in this exact format:"
            "\\boxed{{K or D}}"
        ),
        'fusion_qual': (
            "Your answer is likely to be wrong because of the poor quality of retrieval contexts. "
            "Decide keep/discard AND provide the best final answer."
            "Do not include commas inside the final answer. If the answer has multiple items, "
            "separate them with a semicolon and a space."
            "Output constraints:"
            "- No extra text before or after the box."
            "- Exactly one box with exactly two fields: [K or D]."
            "Return your decision and final answer in this exact format:"
            "\\boxed{{K or D}}"
        ),# fixed the verbalized probability extremely high issue.
        'fusion_prob': ( 
            "Provide the probability that your regenerated answer is correct as a number between 0.00 to 1.00"
            "Respond with ONLY the number (no words or symbols). " 
            "Return the answer in this exact format:" 
            "\\boxed{{[final answer]}}" 
        ),
    }

    PROMPT_MULTIPLICATION = {
        'cf_use': (
            "Assume the usage of the retrieved path in reasoning may be flawed. "
            "Give the confidence score for the baseline answer, the best possible answer for the question, "
            "and the confidence score for that answer based on the provided question and contexts. "
            "Return your answer in the format: \\boxed{{baseline_confidence, answer, confidence_for_answer}}\n\n"
            "Do not include any other text. Examples:\n"
            "\\boxed{{0.99, David, 0.88}}\n"
        ),
        'cf_quality': (
            "Assume the previous baseline answer may be wrong because the quality of the retrieved contexts is poor. "
            "Give the confidence score for the baseline answer, the best possible answer for the question, "
            "and the confidence score for that answer based on the provided question and contexts. "
            "Return your answer in the format: \\boxed{{baseline_confidence, answer, confidence_for_answer}}\n\n"
            "Do not include any other text. Examples:\n"
            "\\boxed{{0.99, David, 0.88}}\n"
        )
    }

    VERB_PROMPTS = {
        # ---------- Baseline RAG ----------
        "rag": {
            "gen": (
                "Provide your best guess for the following question and contexts. "
                "Give ONLY the guess, no other words or explanation.\n\n"
                "Return ONLY the guess in this EXACT format:\n"
                "\\boxed{guess}\n"
            ),
            "score": (
                "Provide the probability that your guess is correct. "
                "Give ONLY the probability between 0.0 and 1.0, no other words or explanation.\n\n"
                "Return ONLY the probability in this EXACT format:\n"
                "\\boxed{probability}\n"
            ),
        },
        # ---------- Counterfactual: misuse of context ----------
        "cf_use": {
            "gen": (
                "Assume your previous guess may be wrong due to improper use of the retrieved contexts. "
                "Carefully re-check the contexts and regenerate your best guess. "
                "Answer in one or a few words; do not explain.\n\n"
                "Return ONLY the guess in this EXACT format:\n"
                "\\boxed{guess}\n"
            ),
            "score": (
                "Given the question, contexts, and your new guess, provide the probability that guess is correct. "
                "Return ONLY a number between 0.00 and 1.00; do not explain.\n\n"
                "Return ONLY the probability in this EXACT format:\n"
                "\\boxed{probability}\n"
            ),
        },
        # ---------- Counterfactual: low context quality ----------
        "cf_quality": {
            "gen": (
                "Assume your previous guess may be wrong because the quality of the contexts is poor. "
                "Select the most relevant parts mentally and regenerate your best guess. "
                "Answer in one or a few words; do not explain.\n\n"
                "Return ONLY the guess in this EXACT format:\n"
                "\\boxed{guess}\n"
            ),
            "score": (
                "Given the question, contexts, and your new guess guess, provide the probability that guess is correct. "
                "Return ONLY a number between 0.00 and 1.00; do not explain.\n\n"
                "Return ONLY the probability in this EXACT format:\n"
                "\\boxed{probability}\n"
            ),
        }
    }

    
    @classmethod
    def build_user_multi_contents(cls, q, ctxs):
        if ctxs is None:
            ctxs = []
        if not isinstance(ctxs, (list, tuple)):
            ctxs = [ctxs]
        lines = [f"Question: {q}"] + [f"\nContext{i+1}: {c}" for i, c in enumerate(ctxs)]
        return "".join(lines)

    @classmethod
    def build_prob(cls, q, ctxs, answer, system_msg):
        user_msg = (
            f"QUESTION:\n{q}\n\n"
            f"CONTEXT:\n{ctxs}\n\n"
            "YOUR GUESS:\n"
            f"{answer}\n\n"
            "Return ONLY the probability less than 1.00 as: \\boxed{probability}"
        )
        return [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": user_msg},
        ]

    @classmethod
    def alter_answer(cls, q, ctxs, answer, system_msg):
        user_msg = (
            f"QUESTION:\n{q}\n\n"
            f"CONTEXT:\n{ctxs}\n\n"
            "YOUR GUESS:\n"
            f"{answer}\n\n"
            "Return ONLY the guess in this EXACT format:\\boxed{{guess}}"
            "Example:"
            "\\boxed{{Italian Language}}"
        )
        return [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": user_msg},
        ]

    @classmethod
    def causal_f(
        cls,
        candidates: str,          # e.g., "a1,a2,a3,a4"
        q: str,
        ctxs: str,
        system_msg: str,
        k: int = 3
    ) -> List[Dict[str, str]]:
        """
        Build chat messages for Stage 2 (probability scoring).
        `candidates` is a single comma-separated string like "a1,a2,a3,a4".
        Output must be: \boxed{p1,p2,...,pk}
        """
        # Minimal safeguard: trim whitespace
        cand_str = (candidates or "").strip()
    
        user_msg = (
            f"QUESTION:\n{q}\n\n"
            f"CONTEXT:\n{ctxs}\n\n"
            "YOUR GUESSES (comma-separated, in order):\n"
            f"{cand_str}\n\n"
            "Example (inputs come from other prompts):\n"
            "G1: Masti\n"
            "G2: Masti Returns\n"
            "G3: masti\n"
            "Valid output:\n"
            "\\boxed{{{Masti:0.78, Masti Returns:0.22}}}\n"
        )
    
        return [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": user_msg},
        ]


class Inference:
    """
    A class encapsulating the asynchronous batched inference logic. It uses
    Doraemon's async_inference method to send batches of messages to the
    underlying language model. The batch_size parameter controls how many
    records are processed in each call.
    """
    @classmethod
    async def process_batches(cls, tasks, logger, column_name, batch_size=20):
        results = []
        for i in range(0, len(tasks), batch_size):
            batch = tasks[i:i + batch_size]
            batch_messages = [task[column_name] for task in batch]
            r = await Doraemon.async_inference(
                logger=logger,
                prompts=batch_messages,
                temperatures=[0.0] * len(batch)
            )
            results.extend(r)
            await asyncio.sleep(5)
        return results


class Judgment:
    """
    A class responsible for computing the various judgment flags and
    intermediate answers. It operates directly on a DataFrame and
    enriches it with new columns: answerable, cf_use_f, cf_quality_judge,
    fusion, and final_a.
    """
    @classmethod
    def apply_judgments(cls, df: pd.DataFrame) -> pd.DataFrame:
        """
        - Extract answers from init_a, cf_use_a, cf_qual_a
        - Normalise all extracted answers and Label(s)
        - answerable: 'A' if init_a matches any label; else 'U'
        - cf_use_f:  'K' if cf_use_a matches init_a;   else 'D'
        - cf_quality_f: 'K' if cf_qual_a matches init_a; else 'D'
        Returns the same dataframe with new columns added:
          ['init_a_ans','cf_use_ans','cf_qual_ans','Label','answerable','cf_use_f','cf_quality_f']
        """
        # Extract raw answers
        df = df.copy()
        # Apply extractor for init_a
        df[["init_ans", "init_cs"]] = df["init_a"].apply(
            lambda x: pd.Series(KG_RAG_Tool.extract_boxed_answer_and_confidence(x))
        )
        
        # Apply extractor for cf_use_a
        df[["cf_use_ans", "cf_use_cs"]] = df["cf_use_a"].apply(
            lambda x: pd.Series(KG_RAG_Tool.extract_boxed_answer_and_confidence(x))
        )
        
        # Apply extractor for cf_qual_a
        df[["cf_qual_ans", "cf_qual_cs"]] = df["cf_qual_a"].apply(
            lambda x: pd.Series(KG_RAG_Tool.extract_boxed_answer_and_confidence(x))
        )

    
        # answerable: A if init_ans ∈ Label else U
        def judge_answerable(row):
            ia = row['init_ans']
            labels = row['Label'] or []
            return 'A' if ia is not None and ia in labels else 'U'
    
        # cf_use_f: K if cf_use_ans == init_ans else D
        def judge_cf_use(row):
            ia = row['init_ans']
            ca = row['cf_use_ans']
            if ia is None or ca is None:
                return 'D'
            return 'K' if ia == ca else 'D'
    
        # cf_quality_f: K if cf_qual_ans == init_ans else D
        def judge_cf_quality(row):
            ia = row['init_ans']
            qa = row['cf_qual_ans']
            if ia is None or qa is None:
                return 'D'
            return 'K' if ia == qa else 'D'
    
        df['answerable']   = df.apply(judge_answerable, axis=1)
        df['cf_use_f']     = df.apply(judge_cf_use, axis=1)
        df['cf_quality_f'] = df.apply(judge_cf_quality, axis=1)

        def fusion_judge(row):
            q, u = row['cf_use_f'], row['cf_quality_f']
            if q == 'K' and u == 'K':
                return 'K'
            elif q == 'D' and u == 'D':
                return 'D'
            else:
                return 'F'
        
        df['fusion'] = df.apply(fusion_judge, axis=1)

        def select_final_answer(row):
            # Case 1: Keep initial answer directly
            if row['fusion'] == 'K':
                return row['init_ans']
            
            # Case 2: Decide based on highest confidence score
            if row['fusion'] == 'D':
                # Prepare mapping between confidence scores and their answers
                candidates = {
                    row['init_cs']: row['init_ans'],
                    row['cf_use_cs']: row['cf_use_ans'],
                    row['cf_qual_cs']: row['cf_qual_ans']
                }
                
                # Filter out None or NaN scores
                valid_candidates = {cs: ans for cs, ans in candidates.items() if pd.notna(cs)}
        
                if not valid_candidates:
                    return pd.NA  # If all scores are missing
        
                # Pick the answer with the highest confidence score
                return valid_candidates[max(valid_candidates.keys())]
        
            # Case 3: No final answer for other fusion values
            return pd.NA
        
        df['final_a'] = df.apply(select_final_answer, axis=1)
        return df


class Fusion:
    """
    A class encapsulating the fusion logic for resolving indeterminate ('F')
    cases and for computing confidence scores. Methods in this class
    construct the appropriate prompts, call the inference engine and update
    the DataFrame with the resulting fusion decisions, final answers and
    probabilities.
    """
    @classmethod
    async def resolve_f_cases(cls, df: pd.DataFrame, logger, batch_size=20) -> pd.DataFrame:
        df_f = df[df['fusion'] == 'F'].copy()
        if df_f.empty:
            return df
        df_f['fusion_prompt'] = df_f.apply(
            lambda row: [
                {
                    "role": "system",
                    "content": PromptBuilder.PROMPT_FUSION['fusion_use'] if row["cf_use_f"] == "D" else PromptBuilder.PROMPT_FUSION['fusion_qual']
                },
                {
                    "role": "user",
                    "content": KG_RAG_Tool.build_fusion_contents(
                        q=row["question"],
                        ctxs=row["ctx_topk"],
                        answer=row["cf_use_ans"] if row["cf_use_f"] == "D" else row["cf_qual_ans"]
                    )
                }
            ],
            axis=1
        )
        tasks = df_f.to_dict(orient='records')
        final_f = await Inference.process_batches(tasks, logger, 'fusion_prompt', batch_size=batch_size)
        answers = [KG_RAG_Tool.extract_boxed_answer(x[0]) for x in final_f]
        df_f["fusion"] = [ans.split(",", 1)[0].strip().upper() if ans else None for ans in answers]
        df_f["final_a"] = [ans.split(",", 1)[1].strip() if ans and "," in ans else pd.NA for ans in answers]
        df.loc[df_f.index, "fusion"] = df_f["fusion"]
        df.loc[df_f.index, "final_a"] = df_f["final_a"]
        return df

    @classmethod
    async def compute_probabilities(cls, df: pd.DataFrame, logger, batch_size=20) -> pd.DataFrame:
        df['prob_query'] = df.apply(
            lambda row: [
                {
                    "role": "system",
                    "content": PromptBuilder.PROMPT_FUSION['fusion_prob']
                },
                {
                    "role": "user",
                    "content": KG_RAG_Tool.build_fusion_contents(
                        q=row["question"],
                        ctxs=row["ctx_topk"],
                        answer=row["final_a"]
                    )
                }
            ],
            axis=1
        )
        tasks = df.to_dict(orient='records')
        fusion_prob = await Inference.process_batches(tasks, logger, 'prob_query', batch_size=batch_size)
        df['fusion_prob'] = pd.Series(fusion_prob, dtype="object")
        df['fusion_prob'] = df['fusion_prob'].apply(KG_RAG_Tool.extract_boxed_answer)
        return df

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.9/124.9 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.7/210.7 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25h