In [15]:
import os
import re
import json
import time
import subprocess
import numpy as np
import pandas as pd
from openai import OpenAI

# ==========================================
# 0. CONFIG
# ==========================================

LLAMA_SERVER_PATH = "/scratch/ajb8866/llama_cpp_src/llama.cpp/build-cuda/bin/llama-server"
CACHE_DIR = "/scratch/ajb8866/cache/llama.cpp"
HOST = "127.0.0.1"
PORT = 8080
GPU_LAYERS = "75"
CTX_SIZE = "8192"
MAX_TOKENS = 4096

MODELS = [
    ("DeepSeek-R1-32B", f"{CACHE_DIR}/unsloth_DeepSeek-R1-Distill-Qwen-32B-GGUF_DeepSeek-R1-Distill-Qwen-32B-F16_DeepSeek-R1-Distill-Qwen-32B-F16-00001-of-00002.gguf"),
    ("Qwen3-30B",       f"{CACHE_DIR}/unsloth_Qwen3-30B-A3B-GGUF_BF16_Qwen3-30B-A3B-BF16-00001-of-00002.gguf"),
    ("GPT-OSS-20B",     f"{CACHE_DIR}/ggml-org_gpt-oss-20b-GGUF_gpt-oss-20b-mxfp4.gguf")
]

TSV_FILE = "a2_vs_d37.tsv"
BENCHMARK_CSV = "a2_vs_d37_ground_truth.csv"
OUTPUT_DIR = "a2_vs_d37_results_2"

# ==========================================
# 1.  SYSTEM PROMPT 
# ==========================================

SYSTEM_PROMPT = """You are an expert bioinformatics assistant. Your task is to answer questions based on the provided Nanocompore TSV context (first 25 rows).

STRATEGY:
1. EXTRACT: If the answer is in the data, extract it directly (no guessing).
2. CALCULATE: If the answer requires math, compute it from the context values.
3. INFER: Only if necessary, use limited domain knowledge; reflect uncertainty in confidence.

CRITICAL OUTPUT INSTRUCTIONS:
1. Output ONLY valid JSON.
2. Schema (must match exactly):
   {
     "answer": "value",
     "confidence": 0.90
   }

3. TYPE-BASED ANSWER FORMATTING RULES (match the requested Type exactly):
   - GENOMIC_SITE: "ref_id:pos" (e.g., "NC_003796.1:893")
   - FLOAT: digits with optional decimal/scientific notation (e.g., "0.320240" or "1.2e-5")
   - INTEGER: digits only (e.g., "3")
   - BOOLEAN: "Yes"/"No" (or "True"/"False" if the question uses that wording)
   - STRING: short text label (avoid extra commentary)
   - BASE: a single nucleotide letter (e.g., "A", "C", "G", "T")

   - SITE_WITH_FLOAT: "ref_id:pos (metric=value)"
       Example: "NC_003796.1:893 (abs_GMM_LOR=1.078)"

   - SITE_WITH_KEYED_METRICS: "ref_id:pos (key=value; key=value)"
       Example: "NC_003796.1:893 (p=0.014041; GMM_LOR=1.078)"

   - KEYED_FLOAT_LIST: "key=value; key=value"
       Example: "sig=1.033500; nonsig=0.492000"

   - KEYED_MIXED_LIST: "key=value; key=value; key=value"
       Values may be int/float/string, but always keep "key=value".
       Example: "effect=0.978; coverage=-0.181; stronger=effect_size"

   - RANKING_WITH_SCORES: semicolon-separated ordered items:
       "site(value); site(value); ..."
       Example: "NC_003796.1:893(0.274799); NC_003796.1:870(0.258363)"

   - RUN_AND_TREND: "start-end; trend=VALUE"
       Example: "882-886; trend=mixed"

   - COUNT_AND_SITE_LIST: "count=N; sites=site1; site2; site3"
       Example: "count=3; sites=NC_003796.1:893; NC_003796.1:891; NC_003796.1:870"

4. CATEGORY SET (use these meanings when reasoning; do NOT output category labels):
   - Directionality & Change: effect direction/sign (e.g., GMM_LOR, positive/negative fractions, run trends)
   - Derived Metrics: computed from counts (e.g., mod fraction, deltas, rankings of derived values)
   - Model vs Counts Consistency: compare model outputs (GMM_LOR) to ratios derived from raw counts
   - Statistical Metrics & Significance: p-values/q-values, correlations with -log10(p), significance behavior
   - High-Confidence Predictions: thresholding rules, candidate lists, significant+large-effect filters
   - Site/Feature Annotation: ref_kmer, base/k-mer properties, motif-like summaries
   - Coverage & Depth: coverage/total reads, coverage imbalance, median-based coverage filters

5. CONFIDENCE:
Use ONLY one of these values: 1.00, 0.90, 0.75, 0.55, 0.30

- 1.00 = directly copied from a single unambiguous row in CONTEXT (unique match)
- 0.90 = simple calculation from context values with no ambiguity
- 0.75 = multi-step calculation or multiple rows but still clear
- 0.55 = requires inference/assumption or context is incomplete/ambiguous
- 0.30 = best guess; answer not fully supported by context
"""

# ==========================================
# 2. INFERENCE & SERVER UTILS
# ==========================================

def load_context_chunk(file_path, rows=25):
    try:
        sep = '\t' if file_path.endswith('.tsv') else ','
        df = pd.read_csv(file_path, sep=sep)
        chunk = df.head(rows)
        context_str = chunk.to_csv(index=False, sep='\t')
        print(f"Loaded context: {len(chunk)} rows ({len(context_str)} chars).")
        return context_str
    except Exception as e:
        print(f"Error loading context file: {e}")
        raise


def extract_model_id(client, fallback="default"):
    """
    Try to use the actual model id if exposed; fallback to 'default'.
    """
    try:
        models = client.models.list()
        if hasattr(models, "data") and models.data:
            return models.data[0].id
    except Exception:
        pass
    return fallback


def wait_for_chat_ready(client, model_id="default", timeout_s=1500):
    """
    Real readiness check: server must serve a tiny chat completion.
    /v1/models may respond while the model is still loading.
    """
    t0 = time.time()
    last_err = None
    while time.time() - t0 < timeout_s:
        try:
            client.chat.completions.create(
                model=model_id,
                messages=[{"role": "user", "content": "ping"}],
                temperature=0.0,
                max_tokens=1,
            )
            return True
        except Exception as e:
            last_err = e
            msg = str(e)
            if "Loading model" in msg or "503" in msg:
                time.sleep(2)
                continue
            raise
    print(f"\nChat readiness timed out. Last error: {last_err}")
    return False


def chat_with_retries(client, *, model, messages, temperature, max_tokens, max_attempts=25):
    """
    Retry transient 503 / 'Loading model' errors during warmup.
    """
    last_err = None
    for attempt in range(1, max_attempts + 1):
        try:
            return client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
            )
        except Exception as e:
            last_err = e
            msg = str(e)
            if "Loading model" in msg or "503" in msg:
                time.sleep(min(2 * attempt, 20))
                continue
            raise
    raise RuntimeError(f"Persistent 503/Loading model after retries. Last error: {last_err}")


# Load once
CONTEXT_TEXT = load_context_chunk(TSV_FILE, rows=25)
QUESTIONS_DF = pd.read_csv(BENCHMARK_CSV)

# Ensure question_id exists
if 'question_id' not in QUESTIONS_DF.columns:
    QUESTIONS_DF.insert(0, 'question_id', range(len(QUESTIONS_DF)))


def start_server(model_path, model_name=None):
    if not os.path.exists(model_path):
        print(f"Critical Error: Model file not found at:\n{model_path}")
        return None

    print(f"**** Booting Server for: {os.path.basename(model_path)} ****")

    # Log stderr to file for debugging load failures
    stderr_target = subprocess.PIPE
    log_fh = None
    if model_name:
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        log_path = os.path.join(OUTPUT_DIR, f"{model_name}_server.stderr.log")
        log_fh = open(log_path, "wb")
        stderr_target = log_fh

    cmd = [
        LLAMA_SERVER_PATH, "-m", model_path,
        "--host", HOST, "--port", str(PORT),
        "--ctx-size", str(CTX_SIZE),
        "--n-gpu-layers", str(GPU_LAYERS),
        "--tensor-split", "50,50",
        "--parallel", "1"

    ]

    try:
        process = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=stderr_target)
    except Exception as e:
        print(f"Failed to launch process: {e}")
        if log_fh:
            log_fh.close()
        return None

    print(" Waiting for server...", end="", flush=True)
    client = OpenAI(base_url=f"http://{HOST}:{PORT}/v1", api_key="x")

    # Phase 1: wait for HTTP layer
    for _ in range(180):
        try:
            if process.poll() is not None:
                break
            client.models.list()
            print(" http-ok", end="", flush=True)
            break
        except Exception:
            time.sleep(1)
            print(".", end="", flush=True)

    if process.poll() is not None:
        print("\nServer process exited early (check stderr log).")
        try:
            process.kill()
        except Exception:
            pass
        if log_fh:
            log_fh.close()
        return None

    # Use actual model id if exposed
    model_id = extract_model_id(client, fallback="default")

    # Phase 2: wait for chat readiness
    ok = wait_for_chat_ready(client, model_id=model_id, timeout_s=900)
    if ok:
        print(" Ready!")
        return process

    print("\nServer timed out waiting for model readiness.")
    try:
        process.kill()
    except Exception:
        pass
    if log_fh:
        log_fh.close()
    return None


def run_benchmark(model_name):
    client = OpenAI(base_url=f"http://{HOST}:{PORT}/v1", api_key="na")
    model_id = extract_model_id(client, fallback="default")

    save_dir = os.path.join(OUTPUT_DIR, model_name)
    os.makedirs(save_dir, exist_ok=True)
    output_file = os.path.join(save_dir, "results.jsonl")

    done_ids = set()
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            for line in f:
                try:
                    done_ids.add(json.loads(line)['question_id'])
                except Exception:
                    pass

    remaining_count = len(QUESTIONS_DF) - len(done_ids)
    print(f"ðŸš€ Benchmarking {model_name} ({remaining_count} remaining)...")

    if remaining_count == 0:
        return

    with open(output_file, 'a') as f_out:
        for idx, row in QUESTIONS_DF.iterrows():
            q_id = int(row['question_id'])
            if q_id in done_ids:
                continue

            prompt = f"CONTEXT DATA:\n{CONTEXT_TEXT}\n\nQUESTION:\n{row['Question']}"

            try:
                response = chat_with_retries(
                    client,
                    model=model_id,
                    messages=[
                        {"role": "system", "content": SYSTEM_PROMPT},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0.1,
                    max_tokens=MAX_TOKENS
                )
                raw_text = response.choices[0].message.content
            except Exception as e:
                print(f"Error Q{q_id}: {e}")
                raw_text = "ERROR_API"

            result = {
                "model": model_name,
                "question_id": q_id,
                "question": row['Question'],
                "llm_response": raw_text
            }
            f_out.write(json.dumps(result) + "\n")
            f_out.flush()
            print(f"  - Q{q_id} complete")

# ==========================================
# 3. EVALUATION METRICS LOGIC
# ==========================================

def extract_first_json_object(s: str) -> str | None:
    # Prefer fenced JSON if present anywhere
    m = re.search(r"```(?:json)?\s*({.*?})\s*```", s, flags=re.DOTALL | re.IGNORECASE)
    if m:
        return m.group(1).strip()

    # Otherwise, brace-balance from first '{'
    start = s.find("{")
    if start == -1:
        return None

    depth = 0
    in_str = False
    esc = False
    for i in range(start, len(s)):
        ch = s[i]

        # track strings so braces inside strings don't count
        if in_str:
            if esc:
                esc = False
            elif ch == "\\":
                esc = True
            elif ch == '"':
                in_str = False
            continue
        else:
            if ch == '"':
                in_str = True
                continue

        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                return s[start:i+1].strip()

    return None


def validate_and_parse_response(response_str):
    if pd.isna(response_str):
        return False, None, None

    s = str(response_str).strip()

    # 1) Fast path: already pure JSON
    try:
        data = json.loads(s)
    except json.JSONDecodeError:
        # 2) Try extracting JSON from inside mixed text
        candidate = extract_first_json_object(s)
        if candidate is None:
            return False, None, None
        try:
            data = json.loads(candidate)
        except json.JSONDecodeError:
            return False, None, None

    # Validate schema
    if not (isinstance(data, dict) and "answer" in data and "confidence" in data):
        return False, None, None

    raw_ans = data["answer"]
    raw_conf = data["confidence"]

    # normalize answer to string (helps later comparisons)
    ans = None if raw_ans is None else str(raw_ans)

    # confidence must be a real number in [0,1]
    try:
        conf = float(raw_conf)
    except (ValueError, TypeError):
        return False, ans, None

    if 0.0 <= conf <= 1.0:
        return True, ans, conf
    return False, ans, None


def check_exactness(row):
    """
    Checker that supports the NEW Types used in your benchmark CSV.
    """
    if pd.isna(row['parsed_answer']) or row['parsed_answer'] is None:
        return False

    gt = str(row['Answer']).strip()
    pred = str(row['parsed_answer']).strip()
    dtype = str(row['Type']).strip()

    gt_l = gt.lower().replace(" ", "")
    pred_l = pred.lower().replace(" ", "")

    try:
        if dtype == 'FLOAT':
            return np.isclose(float(gt), float(pred), rtol=1e-3)

        if dtype == 'INTEGER':
            return int(float(gt)) == int(float(pred))

        if dtype == 'BOOLEAN':
            true_vals = {'true', 'yes', '1', 't', 'correct'}
            false_vals = {'false', 'no', '0', 'f', 'incorrect'}
            if pred_l not in true_vals and pred_l not in false_vals:
                return False
            return (gt_l in true_vals) == (pred_l in true_vals)

        if dtype == 'BASE':
            return gt.strip().upper() == pred.strip().upper()

        if dtype == 'GENOMIC_SITE':
            return gt_l == pred_l

        if dtype in {'STRING', 'RUN_AND_TREND', 'KEYED_MIXED_LIST'}:
            return gt_l == pred_l

        if dtype == 'KEYED_FLOAT_LIST':
            # compare as unordered key->float maps
            def parse_keyed_floats(s):
                out = {}
                for part in s.split(';'):
                    part = part.strip()
                    if not part:
                        continue
                    k, v = part.split('=', 1)
                    out[k.strip()] = float(v.strip())
                return out

            a = parse_keyed_floats(gt)
            b = parse_keyed_floats(pred)
            if a.keys() != b.keys():
                return False
            return all(np.isclose(a[k], b[k], rtol=1e-3) for k in a)

        if dtype == 'COUNT_AND_SITE_LIST':
            # normalize list ordering after "sites="
            def parse_count_sites(s):
                parts = [p.strip() for p in s.split(';') if p.strip()]
                first = parts[0] if parts else ""
                m = re.search(r"count\s*=\s*(\d+)", first, flags=re.I)
                count = int(m.group(1)) if m else None

                sites = []
                m2 = re.search(r"sites\s*=\s*(.*)", first, flags=re.I)
                if m2 and m2.group(1).strip():
                    sites.append(m2.group(1).strip())
                sites.extend(parts[1:])
                sites = [x.strip() for x in sites if x.strip()]
                return count, sorted(sites)

            c1, s1 = parse_count_sites(gt)
            c2, s2 = parse_count_sites(pred)
            return (c1 == c2) and (s1 == s2)

        if dtype == 'RANKING_WITH_SCORES':
            # require exact order; normalize whitespace
            def norm_rank(s):
                return ";".join([p.strip().replace(" ", "") for p in s.split(";") if p.strip()])
            return norm_rank(gt).lower() == norm_rank(pred).lower()

        if dtype in {'SITE_WITH_FLOAT', 'SITE_WITH_KEYED_METRICS'}:
            # structured: compare normalized strings
            return gt_l == pred_l

        # fallback
        return gt_l == pred_l

    except Exception:
        return False


def calculate_calibration(df):
    """
    Brier Score: mean squared error between predicted confidence and correctness.
    Lower is better; 0.0 is perfect.
    """
    valid_df = df[df['parsed_confidence'].notna()].copy()
    if len(valid_df) == 0:
        return 0.0

    y = valid_df['is_correct'].astype(int).to_numpy()
    p = valid_df['parsed_confidence'].astype(float).to_numpy()
    p = np.clip(p, 0.0, 1.0)

    brier = np.mean((p - y) ** 2)
    return float(brier)


def calculate_auc(df):
    """
    ROC AUC of confidence as a score for correctness (binary label).
    Rank-based formula (Mannâ€“Whitney), handles ties, no sklearn dependency.
    """
    valid_df = df[df['parsed_confidence'].notna()].copy()

    # AUC undefined if only one class
    if len(valid_df) < 2 or valid_df['is_correct'].nunique() < 2:
        return np.nan

    y_true = valid_df['is_correct'].astype(int).to_numpy()
    y_score = valid_df['parsed_confidence'].astype(float).to_numpy()

    n_pos = int(y_true.sum())
    n_neg = int(len(y_true) - n_pos)
    if n_pos == 0 or n_neg == 0:
        return np.nan

    ranks = pd.Series(y_score).rank(method="average").to_numpy()
    sum_ranks_pos = ranks[y_true == 1].sum()

    auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
    return float(auc)

# ==========================================
# 4. FINAL REPORT GENERATION (THE 3 CSVs)
# ==========================================

def generate_final_report():
    print("\n" + "=" * 60)
    print("ðŸ“Š GENERATING FINAL REPORTS")
    print("=" * 60)

    all_data = []
    model_summaries = []

    for model_name, _ in MODELS:
        result_file = os.path.join(OUTPUT_DIR, model_name, "results.jsonl")
        if not os.path.exists(result_file):
            print(f"No results found for {model_name}")
            continue

        model_results = pd.read_json(result_file, lines=True)
        model_results['question_id'] = model_results['question_id'].astype(int)

        # Merge using question_id
        df = pd.merge(QUESTIONS_DF, model_results, on='question_id', how='inner')
        df['Model'] = model_name

        # Apply Scoring
        parsed = df['llm_response'].apply(validate_and_parse_response)
        df['format_valid'] = parsed.apply(lambda x: x[0])
        df['parsed_answer'] = parsed.apply(lambda x: x[1])
        df['parsed_confidence'] = parsed.apply(lambda x: x[2])
        df['is_correct'] = df.apply(check_exactness, axis=1)

        all_data.append(df)

        # Stats
        acc = float(df['is_correct'].mean())
        fmt = float(df['format_valid'].mean())
        brier = calculate_calibration(df)
        auc = calculate_auc(df)

        model_summaries.append({
            "Model": model_name,
            "Accuracy": acc,
            "Format_Adherence": fmt,
            "Brier (Calibration)": brier,
            "AUC": auc
        })

    if not all_data:
        print("No data found to report on.")
        return

    # CSV 1: DETAILED BY QUESTION
    big_df = pd.concat(all_data, ignore_index=True)
    cols = [
        'Model', 'question_id', 'Question', 'Type', 'Category',
        'Answer', 'parsed_answer', 'parsed_confidence',
        'is_correct', 'format_valid', 'llm_response'
    ]
    safe_cols = [c for c in cols if c in big_df.columns]
    big_df[safe_cols].to_csv("benchmark_1_detailed.csv", index=False)
    print("âœ… Saved: benchmark_1_detailed.csv")

    # CSV 2: SUMMARY BY MODEL
    summary_df = pd.DataFrame(model_summaries)
    summary_df.to_csv("benchmark_2_by_model.csv", index=False)
    print("âœ… Saved: benchmark_2_by_model.csv")

        # CSV 3: SUMMARY BY MODEL x TYPE (and Category)
    by_model_by_type = (
        big_df
        .groupby(['Model', 'Type', 'Category'], as_index=False)
        .agg(
            N=('question_id', 'count'),
            Accuracy=('is_correct', 'mean'),
            Format_Adherence=('format_valid', 'mean'),
            Mean_Confidence=('parsed_confidence', 'mean'),
        )
    )
    by_model_by_type.to_csv("benchmark_3_by_model_by_type.csv", index=False)
    print("âœ… Saved: benchmark_3_by_model_by_type.csv")

    # CSV 4: GLOBAL STATS BY TYPE (and Category) across all models
    global_stats = (
        big_df
        .groupby(['Type', 'Category'], as_index=False)
        .agg(
            N=('question_id', 'count'),
            Accuracy=('is_correct', 'mean'),
            Format_Adherence=('format_valid', 'mean'),
            Mean_Confidence=('parsed_confidence', 'mean'),
        )
    )
    global_stats.to_csv("benchmark_4_global_stats.csv", index=False)
    print("âœ… Saved: benchmark_4_global_stats.csv")



# ==========================================
# 5. MAIN EXECUTION FLOW
# ==========================================

if __name__ == "__main__":
    print(f"--- Environment: 2x A100 | Context: {CTX_SIZE} | Max Tokens: {MAX_TOKENS} ---")

    for friendly_name, model_path in MODELS:
        print(f"\n{'=' * 50}\nSTARTING: {friendly_name}\n{'=' * 50}")

        server_proc = start_server(model_path, model_name=friendly_name)
        if server_proc is None:
            print(f"Skipping {friendly_name} due to server failure.")
            continue

        try:
            run_benchmark(friendly_name)
        except KeyboardInterrupt:
            print("\nInterrupted by user.")
            if server_proc:
                server_proc.kill()
            raise
        except Exception as e:
            print(f"\nUnexpected error: {e}")
        finally:
            print(f"Shutting down {friendly_name}...")
            if server_proc:
                server_proc.terminate()
                try:
                    server_proc.wait(timeout=5)
                except Exception:
                    server_proc.kill()
            time.sleep(3)

    generate_final_report()


Loaded context: 25 rows (4219 chars).
--- Environment: 2x A100 | Context: 8192 | Max Tokens: 4096 ---

STARTING: DeepSeek-R1-32B
**** Booting Server for: unsloth_DeepSeek-R1-Distill-Qwen-32B-GGUF_DeepSeek-R1-Distill-Qwen-32B-F16_DeepSeek-R1-Distill-Qwen-32B-F16-00001-of-00002.gguf ****
 Waiting for server... http-ok

 Ready!
ðŸš€ Benchmarking DeepSeek-R1-32B (34 remaining)...
  - Q0 complete
  - Q1 complete
  - Q2 complete
  - Q3 complete
  - Q4 complete
  - Q5 complete
  - Q6 complete
  - Q7 complete
  - Q8 complete
  - Q9 complete
  - Q10 complete
  - Q11 complete
  - Q12 complete
  - Q13 complete
  - Q14 complete
  - Q15 complete
  - Q16 complete
  - Q17 complete
  - Q18 complete
  - Q19 complete
  - Q20 complete
  - Q21 complete
  - Q22 complete
  - Q23 complete
  - Q24 complete
  - Q25 complete
  - Q26 complete
  - Q27 complete
  - Q28 complete
  - Q29 complete
  - Q30 complete
  - Q31 complete
  - Q32 complete
  - Q33 complete
Shutting down DeepSeek-R1-32B...

STARTING: Qwen3-30B
**** Booting Server for: unsloth_Qwen3-30B-A3B-GGUF_BF16_Qwen3-30B-A3B-BF16-00001-of-00002.gguf ****
 Waiting for server... http-ok Ready!
ðŸš€ Benchmarking Qwen3-30B (34 remaining)...
  - Q0 complete
  - Q1 complete
  - Q2 complete
  - Q3 complete
  - Q4 complete
  - Q5 complete
  - Q6 complete
  - Q7 complete
  - Q

In [None]:
import os
import re
import json
import time
import subprocess
import numpy as np
import pandas as pd
from openai import OpenAI

# ==========================================
# 0. CONFIG
# ==========================================

LLAMA_SERVER_PATH = "/scratch/ajb8866/llama_cpp_src/llama.cpp/build-cuda/bin/llama-server"
CACHE_DIR = "/scratch/ajb8866/cache/llama.cpp"
HOST = "127.0.0.1"
PORT = 8080
GPU_LAYERS = "75"
CTX_SIZE = "8192"
MAX_TOKENS = 4096

MODELS = [
    ("DeepSeek-R1-32B", f"{CACHE_DIR}/unsloth_DeepSeek-R1-Distill-Qwen-32B-GGUF_DeepSeek-R1-Distill-Qwen-32B-F16_DeepSeek-R1-Distill-Qwen-32B-F16-00001-of-00002.gguf"),
    ("Qwen3-30B",       f"{CACHE_DIR}/unsloth_Qwen3-30B-A3B-GGUF_BF16_Qwen3-30B-A3B-BF16-00001-of-00002.gguf"),
    ("GPT-OSS-20B",     f"{CACHE_DIR}/ggml-org_gpt-oss-20b-GGUF_gpt-oss-20b-mxfp4.gguf")
]

TSV_FILE = "a2_vs_d37.tsv"
BENCHMARK_CSV = "a2_vs_d37_ground_truth.csv"
OUTPUT_DIR = "a2_vs_d37_results_2"

# ==========================================
# 1.  SYSTEM PROMPT 
# ==========================================

SYSTEM_PROMPT = """You are an expert bioinformatics assistant. Your task is to answer questions based on the provided Nanocompore TSV context (first 25 rows).

STRATEGY:
1. EXTRACT: If the answer is in the data, extract it directly (no guessing).
2. CALCULATE: If the answer requires math, compute it from the context values.
3. INFER: Only if necessary, use limited domain knowledge; reflect uncertainty in confidence.

CRITICAL OUTPUT INSTRUCTIONS:
1. Output ONLY valid JSON.
2. Schema (must match exactly):
   {
     "answer": "value",
     "confidence": 0.90
   }

3. TYPE-BASED ANSWER FORMATTING RULES (match the requested Type exactly):
   - GENOMIC_SITE: "ref_id:pos" (e.g., "NC_003796.1:893")
   - FLOAT: digits with optional decimal/scientific notation (e.g., "0.320240" or "1.2e-5")
   - INTEGER: digits only (e.g., "3")
   - BOOLEAN: "Yes"/"No" (or "True"/"False" if the question uses that wording)
   - STRING: short text label (avoid extra commentary)
   - BASE: a single nucleotide letter (e.g., "A", "C", "G", "T")

   - SITE_WITH_FLOAT: "ref_id:pos (metric=value)"
       Example: "NC_003796.1:893 (abs_GMM_LOR=1.078)"

   - SITE_WITH_KEYED_METRICS: "ref_id:pos (key=value; key=value)"
       Example: "NC_003796.1:893 (p=0.014041; GMM_LOR=1.078)"

   - KEYED_FLOAT_LIST: "key=value; key=value"
       Example: "sig=1.033500; nonsig=0.492000"

   - KEYED_MIXED_LIST: "key=value; key=value; key=value"
       Values may be int/float/string, but always keep "key=value".
       Example: "effect=0.978; coverage=-0.181; stronger=effect_size"

   - RANKING_WITH_SCORES: semicolon-separated ordered items:
       "site(value); site(value); ..."
       Example: "NC_003796.1:893(0.274799); NC_003796.1:870(0.258363)"

   - RUN_AND_TREND: "start-end; trend=VALUE"
       Example: "882-886; trend=mixed"

   - COUNT_AND_SITE_LIST: "count=N; sites=site1; site2; site3"
       Example: "count=3; sites=NC_003796.1:893; NC_003796.1:891; NC_003796.1:870"

4. CATEGORY SET (use these meanings when reasoning; do NOT output category labels):
   - Directionality & Change: effect direction/sign (e.g., GMM_LOR, positive/negative fractions, run trends)
   - Derived Metrics: computed from counts (e.g., mod fraction, deltas, rankings of derived values)
   - Model vs Counts Consistency: compare model outputs (GMM_LOR) to ratios derived from raw counts
   - Statistical Metrics & Significance: p-values/q-values, correlations with -log10(p), significance behavior
   - High-Confidence Predictions: thresholding rules, candidate lists, significant+large-effect filters
   - Site/Feature Annotation: ref_kmer, base/k-mer properties, motif-like summaries
   - Coverage & Depth: coverage/total reads, coverage imbalance, median-based coverage filters

5. CONFIDENCE:
Use ONLY one of these values: 1.00, 0.90, 0.75, 0.55, 0.30

- 1.00 = directly copied from a single unambiguous row in CONTEXT (unique match)
- 0.90 = simple calculation from context values with no ambiguity
- 0.75 = multi-step calculation or multiple rows but still clear
- 0.55 = requires inference/assumption or context is incomplete/ambiguous
- 0.30 = best guess; answer not fully supported by context
"""

# ==========================================
# 2. INFERENCE & SERVER UTILS
# ==========================================

def load_context_chunk(file_path, rows=25):
    try:
        sep = '\t' if file_path.endswith('.tsv') else ','
        df = pd.read_csv(file_path, sep=sep)
        chunk = df.head(rows)
        context_str = chunk.to_csv(index=False, sep='\t')
        print(f"Loaded context: {len(chunk)} rows ({len(context_str)} chars).")
        return context_str
    except Exception as e:
        print(f"Error loading context file: {e}")
        raise


def extract_model_id(client, fallback="default"):
    """
    Try to use the actual model id if exposed; fallback to 'default'.
    """
    try:
        models = client.models.list()
        if hasattr(models, "data") and models.data:
            return models.data[0].id
    except Exception:
        pass
    return fallback


def wait_for_chat_ready(client, model_id="default", timeout_s=1500):
    """
    Real readiness check: server must serve a tiny chat completion.
    /v1/models may respond while the model is still loading.
    """
    t0 = time.time()
    last_err = None
    while time.time() - t0 < timeout_s:
        try:
            client.chat.completions.create(
                model=model_id,
                messages=[{"role": "user", "content": "ping"}],
                temperature=0.0,
                max_tokens=1,
            )
            return True
        except Exception as e:
            last_err = e
            msg = str(e)
            if "Loading model" in msg or "503" in msg:
                time.sleep(2)
                continue
            raise
    print(f"\nChat readiness timed out. Last error: {last_err}")
    return False


def chat_with_retries(client, *, model, messages, temperature, max_tokens, max_attempts=25):
    """
    Retry transient 503 / 'Loading model' errors during warmup.
    """
    last_err = None
    for attempt in range(1, max_attempts + 1):
        try:
            return client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
            )
        except Exception as e:
            last_err = e
            msg = str(e)
            if "Loading model" in msg or "503" in msg:
                time.sleep(min(2 * attempt, 20))
                continue
            raise
    raise RuntimeError(f"Persistent 503/Loading model after retries. Last error: {last_err}")


# Load once
CONTEXT_TEXT = load_context_chunk(TSV_FILE, rows=25)
QUESTIONS_DF = pd.read_csv(BENCHMARK_CSV)

# Ensure question_id exists
if 'question_id' not in QUESTIONS_DF.columns:
    QUESTIONS_DF.insert(0, 'question_id', range(len(QUESTIONS_DF)))


def start_server(model_path, model_name=None):
    if not os.path.exists(model_path):
        print(f"Critical Error: Model file not found at:\n{model_path}")
        return None

    print(f"**** Booting Server for: {os.path.basename(model_path)} ****")

    # Log stderr to file for debugging load failures
    stderr_target = subprocess.PIPE
    log_fh = None
    if model_name:
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        log_path = os.path.join(OUTPUT_DIR, f"{model_name}_server.stderr.log")
        log_fh = open(log_path, "wb")
        stderr_target = log_fh

    cmd = [
        LLAMA_SERVER_PATH, "-m", model_path,
        "--host", HOST, "--port", str(PORT),
        "--ctx-size", str(CTX_SIZE),
        "--n-gpu-layers", str(GPU_LAYERS),
        "--tensor-split", "50,50",
        "--parallel", "1"

    ]

    try:
        process = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=stderr_target)
    except Exception as e:
        print(f"Failed to launch process: {e}")
        if log_fh:
            log_fh.close()
        return None

    print(" Waiting for server...", end="", flush=True)
    client = OpenAI(base_url=f"http://{HOST}:{PORT}/v1", api_key="x")

    # Phase 1: wait for HTTP layer
    for _ in range(180):
        try:
            if process.poll() is not None:
                break
            client.models.list()
            print(" http-ok", end="", flush=True)
            break
        except Exception:
            time.sleep(1)
            print(".", end="", flush=True)

    if process.poll() is not None:
        print("\nServer process exited early (check stderr log).")
        try:
            process.kill()
        except Exception:
            pass
        if log_fh:
            log_fh.close()
        return None

    # Use actual model id if exposed
    model_id = extract_model_id(client, fallback="default")

    # Phase 2: wait for chat readiness
    ok = wait_for_chat_ready(client, model_id=model_id, timeout_s=900)
    if ok:
        print(" Ready!")
        return process

    print("\nServer timed out waiting for model readiness.")
    try:
        process.kill()
    except Exception:
        pass
    if log_fh:
        log_fh.close()
    return None


def run_benchmark(model_name):
    client = OpenAI(base_url=f"http://{HOST}:{PORT}/v1", api_key="na")
    model_id = extract_model_id(client, fallback="default")

    save_dir = os.path.join(OUTPUT_DIR, model_name)
    os.makedirs(save_dir, exist_ok=True)
    output_file = os.path.join(save_dir, "results.jsonl")

    done_ids = set()
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            for line in f:
                try:
                    done_ids.add(json.loads(line)['question_id'])
                except Exception:
                    pass

    remaining_count = len(QUESTIONS_DF) - len(done_ids)
    print(f"ðŸš€ Benchmarking {model_name} ({remaining_count} remaining)...")

    if remaining_count == 0:
        return

    with open(output_file, 'a') as f_out:
        for idx, row in QUESTIONS_DF.iterrows():
            q_id = int(row['question_id'])
            if q_id in done_ids:
                continue

            prompt = f"CONTEXT DATA:\n{CONTEXT_TEXT}\n\nQUESTION:\n{row['Question']}"

            try:
                response = chat_with_retries(
                    client,
                    model=model_id,
                    messages=[
                        {"role": "system", "content": SYSTEM_PROMPT},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0.1,
                    max_tokens=MAX_TOKENS
                )
                raw_text = response.choices[0].message.content
            except Exception as e:
                print(f"Error Q{q_id}: {e}")
                raw_text = "ERROR_API"

            result = {
                "model": model_name,
                "question_id": q_id,
                "question": row['Question'],
                "llm_response": raw_text
            }
            f_out.write(json.dumps(result) + "\n")
            f_out.flush()
            print(f"  - Q{q_id} complete")

# ==========================================
# 3. EVALUATION METRICS LOGIC
# ==========================================

def extract_first_json_object(s: str) -> str | None:
    # Prefer fenced JSON if present anywhere
    m = re.search(r"```(?:json)?\s*({.*?})\s*```", s, flags=re.DOTALL | re.IGNORECASE)
    if m:
        return m.group(1).strip()

    # Otherwise, brace-balance from first '{'
    start = s.find("{")
    if start == -1:
        return None

    depth = 0
    in_str = False
    esc = False
    for i in range(start, len(s)):
        ch = s[i]

        # track strings so braces inside strings don't count
        if in_str:
            if esc:
                esc = False
            elif ch == "\\":
                esc = True
            elif ch == '"':
                in_str = False
            continue
        else:
            if ch == '"':
                in_str = True
                continue

        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                return s[start:i+1].strip()

    return None


def validate_and_parse_response(response_str):
    if pd.isna(response_str):
        return False, None, None

    s = str(response_str).strip()

    # 1) Fast path: already pure JSON
    try:
        data = json.loads(s)
    except json.JSONDecodeError:
        # 2) Try extracting JSON from inside mixed text
        candidate = extract_first_json_object(s)
        if candidate is None:
            return False, None, None
        try:
            data = json.loads(candidate)
        except json.JSONDecodeError:
            return False, None, None

    # Validate schema
    if not (isinstance(data, dict) and "answer" in data and "confidence" in data):
        return False, None, None

    raw_ans = data["answer"]
    raw_conf = data["confidence"]

    # normalize answer to string (helps later comparisons)
    ans = None if raw_ans is None else str(raw_ans)

    # confidence must be a real number in [0,1]
    try:
        conf = float(raw_conf)
    except (ValueError, TypeError):
        return False, ans, None

    if 0.0 <= conf <= 1.0:
        return True, ans, conf
    return False, ans, None


def check_exactness(row):
    """
    Checker that supports the NEW Types used in your benchmark CSV.
    """
    if pd.isna(row['parsed_answer']) or row['parsed_answer'] is None:
        return False

    gt = str(row['Answer']).strip()
    pred = str(row['parsed_answer']).strip()
    dtype = str(row['Type']).strip()

    gt_l = gt.lower().replace(" ", "")
    pred_l = pred.lower().replace(" ", "")

    try:
        if dtype == 'FLOAT':
            return np.isclose(float(gt), float(pred), rtol=1e-3)

        if dtype == 'INTEGER':
            return int(float(gt)) == int(float(pred))

        if dtype == 'BOOLEAN':
            true_vals = {'true', 'yes', '1', 't', 'correct'}
            false_vals = {'false', 'no', '0', 'f', 'incorrect'}
            if pred_l not in true_vals and pred_l not in false_vals:
                return False
            return (gt_l in true_vals) == (pred_l in true_vals)

        if dtype == 'BASE':
            return gt.strip().upper() == pred.strip().upper()

        if dtype == 'GENOMIC_SITE':
            return gt_l == pred_l

        if dtype in {'STRING', 'RUN_AND_TREND', 'KEYED_MIXED_LIST'}:
            return gt_l == pred_l

        if dtype == 'KEYED_FLOAT_LIST':
            # compare as unordered key->float maps
            def parse_keyed_floats(s):
                out = {}
                for part in s.split(';'):
                    part = part.strip()
                    if not part:
                        continue
                    k, v = part.split('=', 1)
                    out[k.strip()] = float(v.strip())
                return out

            a = parse_keyed_floats(gt)
            b = parse_keyed_floats(pred)
            if a.keys() != b.keys():
                return False
            return all(np.isclose(a[k], b[k], rtol=1e-3) for k in a)

        if dtype == 'COUNT_AND_SITE_LIST':
            # normalize list ordering after "sites="
            def parse_count_sites(s):
                parts = [p.strip() for p in s.split(';') if p.strip()]
                first = parts[0] if parts else ""
                m = re.search(r"count\s*=\s*(\d+)", first, flags=re.I)
                count = int(m.group(1)) if m else None

                sites = []
                m2 = re.search(r"sites\s*=\s*(.*)", first, flags=re.I)
                if m2 and m2.group(1).strip():
                    sites.append(m2.group(1).strip())
                sites.extend(parts[1:])
                sites = [x.strip() for x in sites if x.strip()]
                return count, sorted(sites)

            c1, s1 = parse_count_sites(gt)
            c2, s2 = parse_count_sites(pred)
            return (c1 == c2) and (s1 == s2)

        if dtype == 'RANKING_WITH_SCORES':
            # require exact order; normalize whitespace
            def norm_rank(s):
                return ";".join([p.strip().replace(" ", "") for p in s.split(";") if p.strip()])
            return norm_rank(gt).lower() == norm_rank(pred).lower()

        if dtype in {'SITE_WITH_FLOAT', 'SITE_WITH_KEYED_METRICS'}:
            # structured: compare normalized strings
            return gt_l == pred_l

        # fallback
        return gt_l == pred_l

    except Exception:
        return False


def calculate_calibration(df):
    """
    Brier Score: mean squared error between predicted confidence and correctness.
    Lower is better; 0.0 is perfect.
    """
    valid_df = df[df['parsed_confidence'].notna()].copy()
    if len(valid_df) == 0:
        return 0.0

    y = valid_df['is_correct'].astype(int).to_numpy()
    p = valid_df['parsed_confidence'].astype(float).to_numpy()
    p = np.clip(p, 0.0, 1.0)

    brier = np.mean((p - y) ** 2)
    return float(brier)


def calculate_auc(df):
    """
    ROC AUC of confidence as a score for correctness (binary label).
    Rank-based formula (Mannâ€“Whitney), handles ties, no sklearn dependency.
    """
    valid_df = df[df['parsed_confidence'].notna()].copy()

    # AUC undefined if only one class
    if len(valid_df) < 2 or valid_df['is_correct'].nunique() < 2:
        return np.nan

    y_true = valid_df['is_correct'].astype(int).to_numpy()
    y_score = valid_df['parsed_confidence'].astype(float).to_numpy()

    n_pos = int(y_true.sum())
    n_neg = int(len(y_true) - n_pos)
    if n_pos == 0 or n_neg == 0:
        return np.nan

    ranks = pd.Series(y_score).rank(method="average").to_numpy()
    sum_ranks_pos = ranks[y_true == 1].sum()

    auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
    return float(auc)

# ==========================================
# 4. FINAL REPORT GENERATION (THE 3 CSVs)
# ==========================================

def generate_final_report():
    print("\n" + "=" * 60)
    print("ðŸ“Š GENERATING FINAL REPORTS")
    print("=" * 60)

    all_data = []
    model_summaries = []

    for model_name, _ in MODELS:
        result_file = os.path.join(OUTPUT_DIR, model_name, "results.jsonl")
        if not os.path.exists(result_file):
            print(f"No results found for {model_name}")
            continue

        model_results = pd.read_json(result_file, lines=True)
        model_results['question_id'] = model_results['question_id'].astype(int)

        # Merge using question_id
        df = pd.merge(QUESTIONS_DF, model_results, on='question_id', how='inner')
        df['Model'] = model_name

        # Apply Scoring
        parsed = df['llm_response'].apply(validate_and_parse_response)
        df['format_valid'] = parsed.apply(lambda x: x[0])
        df['parsed_answer'] = parsed.apply(lambda x: x[1])
        df['parsed_confidence'] = parsed.apply(lambda x: x[2])
        df['is_correct'] = df.apply(check_exactness, axis=1)

        all_data.append(df)

        # Stats
        acc = float(df['is_correct'].mean())
        fmt = float(df['format_valid'].mean())
        brier = calculate_calibration(df)
        auc = calculate_auc(df)

        model_summaries.append({
            "Model": model_name,
            "Accuracy": acc,
            "Format_Adherence": fmt,
            "Brier (Calibration)": brier,
            "AUC": auc
        })

    if not all_data:
        print("No data found to report on.")
        return

    # CSV 1: DETAILED BY QUESTION
    big_df = pd.concat(all_data, ignore_index=True)
    cols = [
        'Model', 'question_id', 'Question', 'Type', 'Category',
        'Answer', 'parsed_answer', 'parsed_confidence',
        'is_correct', 'format_valid', 'llm_response'
    ]
    safe_cols = [c for c in cols if c in big_df.columns]
    big_df[safe_cols].to_csv("benchmark_1_detailed.csv", index=False)
    print("âœ… Saved: benchmark_1_detailed.csv")

    # CSV 2: SUMMARY BY MODEL
    summary_df = pd.DataFrame(model_summaries)
    summary_df.to_csv("benchmark_2_by_model.csv", index=False)
    print("âœ… Saved: benchmark_2_by_model.csv")

        # CSV 3: SUMMARY BY MODEL x TYPE (and Category)
    by_model_by_type = (
        big_df
        .groupby(['Model', 'Type', 'Category'], as_index=False)
        .agg(
            N=('question_id', 'count'),
            Accuracy=('is_correct', 'mean'),
            Format_Adherence=('format_valid', 'mean'),
            Mean_Confidence=('parsed_confidence', 'mean'),
        )
    )
    by_model_by_type.to_csv("benchmark_3_by_model_by_type.csv", index=False)
    print("âœ… Saved: benchmark_3_by_model_by_type.csv")

    # CSV 4: GLOBAL STATS BY TYPE (and Category) across all models
    global_stats = (
        big_df
        .groupby(['Type', 'Category'], as_index=False)
        .agg(
            N=('question_id', 'count'),
            Accuracy=('is_correct', 'mean'),
            Format_Adherence=('format_valid', 'mean'),
            Mean_Confidence=('parsed_confidence', 'mean'),
        )
    )
    global_stats.to_csv("benchmark_4_global_stats.csv", index=False)
    print("âœ… Saved: benchmark_4_global_stats.csv")



# ==========================================
# 5. MAIN EXECUTION FLOW
# ==========================================

if __name__ == "__main__":
    print(f"--- Environment: 2x A100 | Context: {CTX_SIZE} | Max Tokens: {MAX_TOKENS} ---")

    for friendly_name, model_path in MODELS:
        print(f"\n{'=' * 50}\nSTARTING: {friendly_name}\n{'=' * 50}")

        server_proc = start_server(model_path, model_name=friendly_name)
        if server_proc is None:
            print(f"Skipping {friendly_name} due to server failure.")
            continue

        try:
            run_benchmark(friendly_name)
        except KeyboardInterrupt:
            print("\nInterrupted by user.")
            if server_proc:
                server_proc.kill()
            raise
        except Exception as e:
            print(f"\nUnexpected error: {e}")
        finally:
            print(f"Shutting down {friendly_name}...")
            if server_proc:
                server_proc.terminate()
                try:
                    server_proc.wait(timeout=5)
                except Exception:
                    server_proc.kill()
            time.sleep(3)

    generate_final_report()
