In [None]:
from typing import List, Dict

import json
import numpy as np
import pandas as pd
from doraemon import Doraemon
from kg_rag import Judgment, Fusion, KG_RAG_Tool, PromptBuilder, Inference


class Local:
    PROMPT = {
        'rag': (
            "Use the provided contexts to answer the question. "
            "If the contexts are incomplete or weak, still provide your best possible answer. "
            "Output MUST be exactly one line in this format:\n"
            "\\boxed{{final answer}}\n"
            "Do not include any other text. Examples:\n"
            "\\boxed{{Italian}}\n"
        ),
        'cf_use': (
            "Assume your previous answer is wrong due to improper use of the retrieved contexts. "
            "Carefully re-check the provided contexts and regenerate the answer using one or a few words. "
            "Output MUST be exactly one line in this format:\n"
            "\\boxed{{final answer}}\n"
            "Do not include any other text. Examples:\n"
            "\\boxed{{Italian Languages}}\n"
        ),
        'cf_quality': (
            "Assume your previous answer is wrong because the quality of the referred contexts is poor. "
            "Re-select the most relevant parts from the given contexts and regenerate the answer using one or a few words. "
            "Output MUST be exactly one line in this format:\n"
            "\\boxed{{final answer}}\n"
            "Do not include any other text. Examples:\n"
            "\\boxed{{Italian Languages}}\n"
        ),
        "score": (
            "You are given 3 candidate guesses for the same question from different prompts:\n"
            " - baseline: G1\n"
            " - cf_use:   G2\n"
            " - cf_qual:  G3\n"
            "\n"
            "Task:\n"
            "1) Build a unified set of UNIQUE answers by merging semantically identical strings across G1, G2, G3.\n"
            "2) Compute a GLOBAL frequency vector over the unique answers based on how many of G1/G2/G3 map to each canonical answer.\n"
            "3) ASSIGN probabilities for EACH scenario (baseline, cf_use, cf_qual) to be EXACTLY this global frequency vector normalized by 3 (counts/3), after duplicate aggregation and before rounding.\n"
            "\n"
            "Merging rules:\n"
            "- Ignore case, whitespace, punctuation, and plural/singular differences.\n"
            "- Choose a clean canonical form (e.g., title case).\n"
            "- If multiple inputs (G1/G2/G3) map to the SAME canonical answer, AGGREGATE by SUMMING that answer's frequency before normalization.\n"
            "\n"
            "Probability rules (HARD CONSTRAINTS):\n"
            "A) Let count[i] be how many of {G1,G2,G3} map to answers[i]. Then for every scenario S in {baseline, cf_use, cf_qual}, set S[i] = round(count[i]/3, 2).\n"
            "B) Do NOT output equal probabilities across answers when counts differ.\n"
            "C) Each probability ∈ [0.00, 1.00]; sums may be < 1.00 after rounding.\n"
            "D) If a scenario has no plausible answers, use zeros.\n"
            "\n"
            "Return EXACTLY one line of STRICT JSON, no wrapper, no extra text:\n"
            "{"
            "  \"answers\": [\"answer1\",\"answer2\",...],"
            "  \"baseline\": [p1,p2,...],"
            "  \"cf_use\": [p1,p2,...],"
            "  \"cf_qual\": [p1,p2,...]"
            "}\n"
            "Example (inputs):\n"
            "G1(baseline): Masti\n"
            "G2(cf_use):   Masti Returns\n"
            "G3(cf_qual):  Masti\n"
            "\n"
            "Valid output (counts: Masti=2, Masti Returns=1 → probs=[0.88 0.12] for ALL scenarios):\n"
            "{"
            "  \"answers\": [\"Masti\",\"Masti Returns\"],"
            "  \"baseline\": [0.88 0.12],"
            "  \"cf_use\": [0.88 0.12],"
            "  \"cf_qual\": [0.88 0.12]"
            "}"
        )
    }


    @classmethod
    def causal_f(
        cls,
        candidates: str,          # e.g., "a1,a2,a3,a4"
        q: str,
        ctxs: str,
        system_msg: str,
        k: int = 3
    ) -> List[Dict[str, str]]:
        """
        Build chat messages for Stage 2 (probability scoring).
        `candidates` is a single comma-separated string like "a1,a2,a3,a4".
        Output must be: \boxed{p1,p2,...,pk}
        """
        # Minimal safeguard: trim whitespace
        cand_str = (candidates or "").strip()
    
        user_msg = (
            f"QUESTION:\n{q}\n\n"
            f"CONTEXT:\n{ctxs}\n\n"
            "YOUR GUESSES (comma-separated, in order):\n"
            f"{cand_str}\n\n"
            "Example (inputs come from other prompts):\n"
            "G1: Masti\n"
            "G2: Masti Returns\n"
            "G3: masti\n"
        )
    
        return [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": user_msg},
        ]

def classify(x: str, n: int = 1200, dataset: str = "metaqa") -> pd.DataFrame:
    BASE_PATHS = {
        "metaqa": "/kaggle/input/filtered-multiple-hops-metaqa",
        "webqsp": "/kaggle/input/webqsp"
    }

    # Validate dataset input
    if dataset not in BASE_PATHS:
        raise ValueError(f"Invalid dataset '{dataset}'. Must be one of: {list(BASE_PATHS.keys())}")

    base = BASE_PATHS[dataset]
    x = (x or "").strip().lower()

    # Select file path based on dataset type
    if dataset == "metaqa":
        match x:
            case "one":
                path = f"{base}/one_hop_supported.pickle"
            case "three" | _:
                path = f"{base}/three_hop_supported.pickle"
    else:  # dataset == "another"
        match x:
            case "one":
                path = f"{base}/webqsp_ctxstyle_1200_hop1_nl.pkl"
            case "three":
                path = f"{base}/webqsp_ctxstyle_1200_hop3_nl.pkl"
            case "two" | _:
                raise ValueError("❌ 'two_hop_supported.pickle' does not exist in the 'current' dataset.")

    # Load and return dataframe
    df = pd.read_pickle(path)

    if dataset =="webqsp":
        df=df.rename(columns={
            "ground_truth":"Label",
            "contexts": "ctx_topk"
        })
    
    return df.head(n)

In [None]:
dataset="webqsp"

df = classify("three", n = 1200, dataset=dataset)

df['query'] = df.apply(lambda row: [
    {"role": "system", "content": Local.PROMPT['rag']},
    {"role": "user", "content": PromptBuilder.build_user_multi_contents(row['question'], row['ctx_topk'])}
], axis=1)

logger = Doraemon.get_logger(logfile='rkag.log')
tasks = df.to_dict(orient='records')

init_a = await Inference.process_batches(tasks, logger, 'query')
df['init_a'] = pd.Series(init_a, dtype='object')
df['init_a'] = df['init_a'].apply(KG_RAG_Tool.extract_boxed_answer)

df['cf_use'] = df.apply(lambda row: PromptBuilder.alter_answer(row['question'], row['ctx_topk'], row['init_a'], Local.PROMPT['cf_use']), axis=1)

tasks = df.to_dict(orient='records')
cf_use_a = await Inference.process_batches(tasks, logger, 'cf_use')
df['cf_use_a'] = pd.Series(cf_use_a, dtype='object')
df['cf_use_a'] = df['cf_use_a'].apply(KG_RAG_Tool.extract_boxed_answer)

df['cf_qual'] = df.apply(lambda row: PromptBuilder.alter_answer(row['question'], row['ctx_topk'], row['init_a'], Local.PROMPT['cf_quality']), axis=1)

tasks = df.to_dict(orient='records')
cf_qual_a = await Inference.process_batches(tasks, logger, 'cf_qual')
df['cf_qual_a'] = pd.Series(cf_qual_a, dtype='object')
df['cf_qual_a'] = df['cf_qual_a'].apply(KG_RAG_Tool.extract_boxed_answer)

In [None]:
def get_final_answer_prob(json_str):
    data = json.loads(json_str)
    answers = data["answers"]
    baseline, cf_use, cf_qual = data["baseline"], data["cf_use"], data["cf_qual"]

    final_scores = {}
    for i, ans in enumerate(answers):
        vals = [baseline[i], cf_use[i], cf_qual[i]]
        ce = max(vals) - min(vals)
        c = sum(vals) / 3.0   # mean, always divide by 3
        final_scores[ans] = c * (1 - ce)

    # pick the answer with the largest P_stable
    final_answer = max(final_scores, key=final_scores.get)
    return final_answer, round(final_scores[final_answer], 4)


df['causal_f'] = df.apply(lambda row: Local.causal_f(
    candidates=f"G1:{row['init_a']},G2:{row['cf_use_a']},G3:{row['cf_qual_a']}",
    q=row['question'],
    ctxs=row['ctx_topk'],
    system_msg=Local.PROMPT['score']
), axis=1)

tasks = df.to_dict(orient='records')
causal_f_a = await Inference.process_batches(tasks, logger, 'causal_f')
df['causal_f_a'] = pd.Series(causal_f_a, dtype='object')


bad_rows = []

def safe_get_final_answer_prob(idx, x):
    try:
        return pd.Series(get_final_answer_prob(x[0]))
    except Exception:
        bad_rows.append(idx)  # record original row index
        return pd.Series([None, None])

# Apply with row context so we can capture the index
df[["final_a", "final_P"]] = df.apply(
    lambda row: safe_get_final_answer_prob(row.name, row["causal_f_a"]),
    axis=1
)

# DataFrame of bad rows from the original df
df_bad = df.loc[bad_rows].reset_index().rename(columns={"index": "orig_index"})
df_bad.shape

In [None]:
from sklearn.metrics import brier_score_loss
from torchmetrics.classification import BinaryCalibrationError
from calibration_metrics import CalibrationMetrics


def selective_auc_trapz(df, prob_col="final_P", correct_col="is_correct"):
    y = df[correct_col].astype(int).to_numpy()
    p = pd.to_numeric(df[prob_col], errors="coerce").clip(0, 1).to_numpy()
    if len(y) == 0:
        return np.nan, np.array([]), np.array([])
    # sort by prob desc (ties keep row order; SAUC will depend on that)
    idx = np.argsort(-p, kind="mergesort")
    y_sorted = y[idx]
    n = len(y_sorted)
    coverage = np.arange(1, n + 1) / n
    accuracy_curve = np.cumsum(y_sorted) / np.arange(1, n + 1)
    auc = np.trapz(accuracy_curve, coverage)
    return float(auc), coverage, accuracy_curve

def ece_l1(df, prob_col="final_P", correct_col="is_correct", n_bins=10):
    y = df[correct_col].astype(int).to_numpy()
    p = pd.to_numeric(df[prob_col], errors="coerce").clip(0, 1).to_numpy()
    if len(y) == 0:
        return np.nan
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    idx = np.digitize(p, bins, right=True)
    idx[idx == 0] = 1
    idx[idx > n_bins] = n_bins
    ece = 0.0
    N = len(p)
    for b in range(1, n_bins + 1):
        m = (idx == b)
        if not m.any():
            continue
        conf = p[m].mean()
        acc = y[m].mean()
        ece += (m.mean()) * abs(acc - conf)
    return float(ece)

def brier(df, prob_col="final_P", correct_col="is_correct"):
    p = pd.to_numeric(df[prob_col], errors="coerce")   # Series
    y = df[correct_col].astype("Int64")                # Series
    m = p.notna() & y.notna()
    if not m.any():
        return np.nan
    return float(brier_score_loss(y[m].astype(int).to_numpy(),
                                  p[m].clip(0, 1).to_numpy()))


print(KG_RAG_Tool.eval_accuracy(df, pred='fusion_prob', g_t='Label'))
auc, cov, acc = selective_auc_trapz(df, prob_col="fusion_prob", correct_col="is_correct")
# ece = ece_l1(df, prob_col="final_P", correct_col="is_correct", n_bins=10)
bs = brier(df, prob_col="fusion_prob", correct_col="is_correct")

# L2
ece_l2 = CalibrationMetrics.ece_torchmetrics_binary(
    df, prob_col="fusion_prob", correct_col="is_correct", n_bins=10, norm="l2"
)

print("ECE (L2)",ece_l2)
print("Brier score:", bs)
print("Selective AUC (trapz):", auc)

In [None]:
df.to_pickle("rkag_m.pkl")