In [None]:
import os
import glob
import json
import time
import csv
import requests
from tqdm import tqdm
from typing import Dict, Any, List, Optional, Tuple


ENDPOINT = ".."
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
TAG = "MISTRAL_7B_V03"


IN_DIR = "../Dataset-commonSence/GN_RD_ge4_only_json_balanced"
IN_PATTERN = "b*_BALANCED.json"

OUT_DIR = "./mistral_7B_eval_outputs"
os.makedirs(OUT_DIR, exist_ok=True)
OUT_SUMMARY_ALL = os.path.join(OUT_DIR, f"{TAG}_metrics_summary_ALL_categories.csv")

# =========================
# DATA KEYS
# =========================
QA_KEY = "qa_fa"
MCQ_KEY = "mcq"
EVOL_KEY = "mcq_evol"

GOLD_RATIONALE_KEY = "rationale"
GOLD_ANSWER_KEY    = "answer"
EVOL_CORRECT_KEY   = "correct_answer"

RQ_KEY = "rewritten_question"
RO_KEY = "rewritten_options"

CHOICES = ["A", "B", "C", "D"]

# =========================
# PROMPTING
# =========================
SYSTEM_PROMPT = (
    "You are a careful multiple-choice question solver.\n"
    "Return ONLY a valid JSON object with keys: answer, rationale.\n"
    "answer must be one of: A, B, C, D.\n"
    "rationale must be concise (2-5 sentences), logical, and based only on the question/options.\n"
    "Do not include any extra keys or text outside JSON."
)

def build_user_prompt(question: str, options: Dict[str, str]) -> str:
    opts = "\n".join([f"{k}) {options.get(k,'')}" for k in CHOICES])
    return (
        f"Question (Persian):\n{question}\n\n"
        f"Options:\n{opts}\n\n"
        f'Return JSON like:\n{{"answer":"A","rationale":"..."}}'
    )

# =========================
# API CALL
# =========================
def call_chat_completions(
    endpoint: str,
    model: str,
    system_prompt: str,
    user_prompt: str,
    timeout: int = 180,
    max_retries: int = 6,
    backoff_base: float = 1.8,
) -> Dict[str, Any]:
    payload = {
        "model": model,
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        "temperature": 0.0,
        "max_tokens": 200,
        "max_new_tokens": 200,
    }

    last_err: Optional[Exception] = None
    last_status: Optional[int] = None
    last_body: str = ""

    for attempt in range(max_retries):
        try:
            r = requests.post(endpoint, json=payload, timeout=timeout)
            last_status = r.status_code
            last_body = (r.text or "")[:800]

            if r.status_code in (408, 425, 429) or r.status_code >= 500:
                time.sleep((backoff_base ** attempt) + 0.1)
                continue

            r.raise_for_status()
            return r.json()

        except Exception as e:
            last_err = e
            time.sleep((backoff_base ** attempt) + 0.1)

    raise RuntimeError(
        f"Failed after retries.\n"
        f"endpoint={endpoint}\nmodel={model}\n"
        f"last_status={last_status}\nlast_body={last_body}\nlast_error={repr(last_err)}"
    )

def extract_text(resp_json: Dict[str, Any]) -> str:
    try:
        return resp_json["choices"][0]["message"]["content"]
    except Exception:
        return ""

def safe_parse_json_object(text: str) -> Optional[Dict[str, Any]]:
   
    if not text:
        return None
    s = text.strip()

    if s.startswith("```"):
        s = s.strip("`").strip()
        if s.startswith("json"):
            s = s[4:].strip()

    l = s.find("{")
    r = s.rfind("}")
    if l == -1 or r == -1 or r <= l:
        return None

    candidate = s[l:r+1]
    try:
        obj = json.loads(candidate)
        return obj if isinstance(obj, dict) else None
    except Exception:
        return None

# =========================
# I/O HELPERS
# =========================
def load_json_array(path: str) -> List[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, list):
        raise ValueError(f"Input must be a JSON array (list). Bad file: {path}")
    return data

def get_question_and_options(item: Dict[str, Any]) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
    """Returns (question, options_dict, evol_dict_for_debug)."""
    evol = (((item.get(QA_KEY) or {}).get(EVOL_KEY)) or {})
    rq = (evol.get(RQ_KEY) or "").strip()
    ro = evol.get(RO_KEY) or {}
    if not isinstance(ro, dict):
        ro = {}
    options = {k: (ro.get(k) or "").strip() for k in CHOICES}
    return rq, options, evol

def get_gold(item: Dict[str, Any]) -> Tuple[str, str]:
    qa = item.get(QA_KEY) or {}
    mcq = qa.get(MCQ_KEY) or {}
    evol = qa.get(EVOL_KEY) or {}
    gold_rationale = (mcq.get(GOLD_RATIONALE_KEY) or "").strip()
    gold_answer = (evol.get(EVOL_CORRECT_KEY) or mcq.get(GOLD_ANSWER_KEY) or "").strip().upper()
    return gold_rationale, gold_answer

def compute_sbert_similarities(pairs: List[Tuple[str, str]]) -> List[float]:
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")

    pred_texts = [p for p, _ in pairs]
    gold_texts = [g for _, g in pairs]

    emb_pred = model.encode(pred_texts, convert_to_numpy=True, normalize_embeddings=True, batch_size=32)
    emb_gold = model.encode(gold_texts, convert_to_numpy=True, normalize_embeddings=True, batch_size=32)
    sims = (emb_pred * emb_gold).sum(axis=1)
    return sims.tolist()

def mean(xs: List[float]) -> float:
    return sum(xs) / len(xs) if xs else 0.0

# =========================
# CORE RUN
# =========================
def run_one_file(in_path: str, limit: Optional[int]) -> Dict[str, Any]:
    base = os.path.basename(in_path).replace(".json", "")

    out_pred_jsonl  = os.path.join(OUT_DIR, f"{base}__{TAG}__predictions.jsonl")
    out_peritem_csv = os.path.join(OUT_DIR, f"{base}__{TAG}__per_item.csv")
    out_metrics_csv = os.path.join(OUT_DIR, f"{base}__{TAG}__metrics.csv")

    data = load_json_array(in_path)
    if limit is not None:
        data = data[:limit]

    open(out_pred_jsonl, "w", encoding="utf-8").close()

    per_item_rows = []
    sim_pairs = []
    sim_pair_indices = []

    skipped_no_question = 0
    skipped_bad_options = 0
    skipped_no_evol = 0
    called = 0

    for idx, item in enumerate(tqdm(data, desc=f"{TAG} | {base}")):
        question, options, evol_dbg = get_question_and_options(item)
        gold_rationale, gold_answer = get_gold(item)

     
        if not isinstance(evol_dbg, dict) or not evol_dbg:
            skipped_no_evol += 1
            continue

        if not question:
            skipped_no_question += 1
            continue

        nonempty = sum(1 for k in CHOICES if options.get(k, "").strip())
        if nonempty < 2:
            skipped_bad_options += 1
            continue

        called += 1
        user_prompt = build_user_prompt(question, options)
        resp = call_chat_completions(ENDPOINT, MODEL_NAME, SYSTEM_PROMPT, user_prompt)

        raw_text = extract_text(resp)
        parsed = safe_parse_json_object(raw_text) or {}

        pred_answer = (parsed.get("answer") or "").strip().upper()
        pred_rationale = (parsed.get("rationale") or "").strip()

        if pred_answer not in CHOICES:
            pred_answer = ""

        is_correct = int(bool(gold_answer) and pred_answer == gold_answer)

        row = {
            "idx": idx,
            "gold_answer": gold_answer,
            "pred_answer": pred_answer,
            "is_correct": is_correct,
            "question": question,
            "gold_rationale": gold_rationale,
            "pred_rationale": pred_rationale,
            "sbert_sim": "",
        }
        per_item_rows.append(row)

        with open(out_pred_jsonl, "a", encoding="utf-8") as f:
            f.write(json.dumps({
                "idx": idx,
                "gold_answer": gold_answer,
                "pred_answer": pred_answer,
                "pred_rationale": pred_rationale,
                "raw_model_text": raw_text,
            }, ensure_ascii=False) + "\n")

        if gold_rationale and pred_rationale:
            sim_pairs.append((pred_rationale, gold_rationale))
            sim_pair_indices.append(idx)

    # --- fill SBERT sims ---
    if sim_pairs:
        sims = compute_sbert_similarities(sim_pairs)
        idx_to_sim = {i: s for i, s in zip(sim_pair_indices, sims)}
        for row in per_item_rows:
            i = row["idx"]
            if i in idx_to_sim:
                row["sbert_sim"] = float(round(idx_to_sim[i], 6))

    n_total = len(data)
    n_used = len(per_item_rows)
    n_valid = sum(1 for r in per_item_rows if r["gold_answer"] and r["pred_answer"])
    n_correct = sum(r["is_correct"] for r in per_item_rows)
    acc = (n_correct / n_valid) if n_valid else 0.0

    sims_all = [r["sbert_sim"] for r in per_item_rows if isinstance(r["sbert_sim"], float)]
    sims_correct = [r["sbert_sim"] for r in per_item_rows if isinstance(r["sbert_sim"], float) and r["is_correct"] == 1]
    sims_wrong = [r["sbert_sim"] for r in per_item_rows if isinstance(r["sbert_sim"], float) and r["is_correct"] == 0]

    summary = {
        "category_file": os.path.basename(in_path),
        "model": MODEL_NAME,
        "endpoint": ENDPOINT,
        "N_items_total": n_total,
        "N_items_called": called,
        "N_rows_saved": n_used,
        "skipped_no_evol": skipped_no_evol,
        "skipped_no_question": skipped_no_question,
        "skipped_bad_options": skipped_bad_options,
        "N_valid_for_accuracy": n_valid,
        "N_correct": n_correct,
        "accuracy": round(acc, 6),
        "N_sbert_pairs": len(sims_all),
        "sbert_mean_all": round(mean(sims_all), 6),
        "sbert_mean_correct": round(mean(sims_correct), 6),
        "sbert_mean_wrong": round(mean(sims_wrong), 6),
        "pred_jsonl": out_pred_jsonl,
        "per_item_csv": out_peritem_csv,
        "metrics_csv": out_metrics_csv,
    }

    if per_item_rows:
        with open(out_peritem_csv, "w", encoding="utf-8", newline="") as f:
            cols = list(per_item_rows[0].keys())
            w = csv.DictWriter(f, fieldnames=cols)
            w.writeheader()
            w.writerows(per_item_rows)

    with open(out_metrics_csv, "w", encoding="utf-8", newline="") as f:
        cols = list(summary.keys())
        w = csv.DictWriter(f, fieldnames=cols)
        w.writeheader()
        w.writerow(summary)

    print(
        f"[{base}] total={n_total} called={called} saved_rows={n_used} "
        f"skip_no_evol={skipped_no_evol} skip_no_q={skipped_no_question} skip_bad_opt={skipped_bad_options} "
        f"valid={n_valid} acc={acc:.4f}"
    )

    return summary

def main(limit_per_category: Optional[int] = None):
    in_files = sorted(glob.glob(os.path.join(IN_DIR, IN_PATTERN)))
    if not in_files:
        raise SystemExit(f"No files matched: {os.path.join(IN_DIR, IN_PATTERN)}")

    all_summaries = []
    print(f"TAG={TAG}")
    print(f"endpoint={ENDPOINT}")
    print(f"model={MODEL_NAME}")
    print(f"Found {len(in_files)} category files.")

    for in_path in in_files:
        print(f"\n=== Running: {os.path.basename(in_path)} ===")
        s = run_one_file(in_path, limit=limit_per_category)
        all_summaries.append(s)

    with open(OUT_SUMMARY_ALL, "w", encoding="utf-8", newline="") as f:
        cols = list(all_summaries[0].keys())
        w = csv.DictWriter(f, fieldnames=cols)
        w.writeheader()
        w.writerows(all_summaries)

    print("\n=========================")
    print(f"[saved] ALL categories summary: {OUT_SUMMARY_ALL}")
    print(f"Per-category outputs saved under: {OUT_DIR}")

if __name__ == "__main__":
    main(limit_per_category=None)


TAG=MISTRAL_7B_V03
endpoint=http://192.168.0.222:80/v1/chat/completions
model=mistralai/Mistral-7B-Instruct-v0.3
Found 1 category files.

=== Running: banking_1_with_mcq_fa_T_evol_GN_RD_GNge4_RDge4_BALANCED.json ===


MISTRAL_7B_V03 | banking_1_with_mcq_fa_T_evol_GN_RD_GNge4_RDge4_BALANCED: 100%|██████████| 643/643 [12:39<00:00,  1.18s/it]


[banking_1_with_mcq_fa_T_evol_GN_RD_GNge4_RDge4_BALANCED] total=643 called=643 saved_rows=643 skip_no_evol=0 skip_no_q=0 skip_bad_opt=0 valid=643 acc=0.4510

[saved] ALL categories summary: ./mistral_7B_eval_outputs/MISTRAL_7B_V03_metrics_summary_ALL_categories.csv
Per-category outputs saved under: ./mistral_7B_eval_outputs
