In [22]:
# -*- coding: utf-8 -*-
import os, re, json
import numpy as np
import pandas as pd
from typing import Dict, Any, Optional, Tuple, List
from sklearn.metrics import (
    classification_report, confusion_matrix, f1_score, accuracy_score,
    precision_recall_fscore_support,
)

# ------------------------------
# GLOBAL CONFIG (Monte Carlo sampling)
# ------------------------------
DEFAULT_CSV = "/Users/stevefeng/Downloads/GlueWork-Tracker-branch-mentoring/glue_work_bot/training_data/maintainance_training_dataset.csv"
CSV_PATH = os.environ.get("MAINT_DATASET", DEFAULT_CSV)

# Columns preferences; will auto-detect among these
PREFERRED_TEXT_COLS: List[str] = ["comments", "comment", "text", "body", "message", "content"]
PREFERRED_LABEL_COLS: List[str] = ["label_norm", "label", "y", "target"]

# Monte Carlo settings
N_ROUNDS = int(os.environ.get("N_ROUNDS", 5))
SAMPLE_SIZE = int(os.environ.get("SAMPLE_SIZE", 50))
RANDOM_STATE = int(os.environ.get("RANDOM_STATE", 42))

# Optional rules
USE_RULE_SHORT_TEXT = True      # classify ultra-short comments as -1
USE_RULE_REGEX     = False      # light positive regex for classic maintenance cues

# OpenAI model
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")

# ------------------------------
# 1) Helpers
# ------------------------------
def normalize_text_series(s: pd.Series) -> pd.Series:
    return s.astype(str).fillna("").str.replace(r"\s+", " ", regex=True).str.strip()

def autodetect_columns(df: pd.DataFrame) -> Tuple[str, str]:
    text_col = next((c for c in PREFERRED_TEXT_COLS if c in df.columns), None) or df.columns[0]
    label_col = next((c for c in PREFERRED_LABEL_COLS if c in df.columns), None)
    if label_col is None:
        raise ValueError("No label column found. Expected one of: " + ", ".join(PREFERRED_LABEL_COLS))
    return text_col, label_col

def strip_quoted_lines(text: str) -> str:
    lines = []
    for ln in str(text).splitlines():
        if ln.lstrip().startswith(">"):
            continue
        lines.append(ln)
    out = "\n".join(lines).strip()
    return re.sub(r"\s+\n", "\n", out)

# --- GitHub template & checklist stripper (for PR bodies) ---
GITHUB_TEMPLATE_SECTIONS = [
    r"##\s*Pre-?launch Checklist.*?(?=^##|\Z)",
    r"<!--.*?-->",
    r"^\s*Thanks for filing a pull request!.*?$",
    r"^\s*If you need help, consider asking.*?$",
    r"^\s*\[/?(Contributor Guide|Tree Hygiene|Flutter Style Guide|CLA|tests|breaking change policy|Discord|Data Driven Fixes).*$",
]
def strip_templates(text: str) -> str:
    t = str(text)
    # Remove fenced code and media/links first
    t = re.sub(r"`{3}[\s\S]*?`{3}", " ", t)                         # ``` code fences
    t = re.sub(r"!\[[^\]]*\]\([^)]+\)", " ", t)                     # images
    t = re.sub(r"\[[^\]]*\]\([^)]+\)", " ", t)                      # links
    t = re.sub(r"^-\s*\[[ xX]\]\s*.*$", " ", t, flags=re.MULTILINE) # checklist items
    # Remove common GH template sections
    for pat in GITHUB_TEMPLATE_SECTIONS:
        t = re.sub(pat, " ", t, flags=re.IGNORECASE | re.MULTILINE | re.DOTALL)
    # Collapse whitespace
    t = re.sub(r"\s+", " ", t).strip()
    return t

# ------------------------------
# 2) Rules (optional)
# ------------------------------
def rule_based_short_text(comment: str) -> Optional[int]:
    text = str(comment).strip()
    words = re.findall(r"\b\w+\b", text)
    if len(words) < 3:
        return -1
    return None

MAINT_PATTERNS = re.compile(
    r"(fix(es|ed)?|resolves?|closes? #\d+|regression|deflake|flaky|timeout|"
    r"backport|revert|migrat(e|ion)|deprecat(ed|ion)|refactor|remove dead code|"
    r"cve-\d{4}-\d+|security|bump|upgrade|pin(ned|ning)?|ci|build|tests?)",
    re.IGNORECASE
)
def rule_based_regex(comment: str) -> Optional[int]:
    txt = str(comment)
    if MAINT_PATTERNS.search(txt):
        # require some second technical token to reduce false positives
        if re.search(r"\b(\d+\.\d+(\.\d+)?)\b|#\d+|\bCI\b|\btest(s)?\b|\bbuild\b", txt, re.IGNORECASE):
            return 0
    return None

# ------------------------------
# 3) LLM client (OpenAI)
# ------------------------------
try:
    from openai import OpenAI
except Exception as e:
    raise RuntimeError("openai package missing. Install with `pip install openai`") from e

def get_client() -> "OpenAI":
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise ValueError("‚ùå OPENAI_API_KEY not set. Export it before running.")
    return OpenAI(api_key=api_key)

def to_label(resp: str) -> int:
    s = (resp or "").strip()
    s = s.replace("\u2013", "-").replace("\u2014", "-")
    m = re.search(r"-?\d+", s)
    if m:
        try:
            v = int(m.group(0))
            if v in (0, -1):
                return v
        except ValueError:
            pass
    # fallbacks
    sl = s.lower()
    if "not maintenance" in sl:
        return -1
    if "maintenance" in sl:
        return 0
    if re.search(r"\b0\b", s):
        return 0
    return -1

SYSTEM_MSG = """You are a GitHub review-classifier.

Label each comment/PR text as:
- Maintenance (0) ‚Äî it repairs or sustains the project: bug/regression fixes, parity with platforms, refactors/cleanup/moves, backports/reverts, dependency/security updates, CI/test deflakes, migrations/deprecations, documentation corrections that prevent errors, or changes that unblock broken/mismatched behavior. Count it even if terse.
- Not Maintenance (-1) ‚Äî new features/APIs with no upkeep/repair motivation; pure social/process chatter; bot messages.

DECISION CHECKLIST (apply in order):
A. Does the text claim or imply REPAIR/STABILITY/UNBLOCKING? (e.g., ‚ÄúFix/Fixes/Fixed/Resolves/Closes #‚Ä¶‚Äù, ‚Äúregression/crash/timeout/deflake/flaky‚Äù, ‚Äúunblocks‚Äù, ‚Äúaligns/matches/parity with native‚Äù, ‚Äúprevent misuse‚Äù, ‚Äúcorrect behavior‚Äù, ‚Äúsecurity/CVE‚Äù). ‚Üí Maintenance (0).
B. Is it a MIGRATION/REFILE/REFOLLOW-UP/RELOCATION/REFACTOR/CLEANUP even without long rationale? (e.g., ‚Äúfollow up of #‚Ä¶‚Äù, ‚Äúrefiling of ‚Ä¶‚Äù, ‚Äúmigrate from X to Y‚Äù, ‚Äúmove A to B to reduce conflicts‚Äù, ‚Äúremove dead code/duplication‚Äù). ‚Üí Maintenance (0).
C. Does it reference an issue/PR and say it **handles/addresses/unblocks/fixes** that problem, even if it starts with ‚ÄúAdds/Introduces/Improves‚Äù? ‚Üí Maintenance (0).
D. Dependency/version updates, security patches, or CI/test stabilization (timeouts, flakes) ‚Üí Maintenance (0).
E. Purely **new capability** (new API/constructor/parameter/widget) with no repair/parity/stability/security/cleanup/migration intent ‚Üí Not Maintenance (-1).
F. Process/scheduling/acknowledgements/bots ‚Üí Not Maintenance (-1).

Ambiguity resolution:
‚Ä¢ If both ‚ÄúAdds/Introduces ‚Ä¶‚Äù and clear **repair/unblock/parity/migration/cleanup** intent are present, choose Maintenance (0).
‚Ä¢ ‚ÄúRelated to #‚Ä¶‚Äù alone is not enough; combine with ‚Äúhandles/addresses/unblocks/fixes/migrates/cleans up‚Äù.
‚Ä¢ Documentation counts as Maintenance (0) only when it corrects wrong behavior or prevents errors; otherwise -1.

Output exactly one number: 0 or -1.

"""

POS_CUES = [
    # Feature-looking openers that are actually repair/unblock/parity
    ("Adds glob syntax to proxy server to resolve mismatch with rules and unblock issue #173435.", 0),
    ("Adds headers to proxy rules to align behavior across platforms and fix incorrect proxying in #173434.", 0),
    ("This change improves overscroll to match native Android behavior; fixes clipped fling behavior (Fixes #169659).", 0),

    # Follow-up / refiling / migration / move-as-cleanup
    ("Follow up of #174421: migrate some files to WidgetState to reduce conflicts; remaining files in later PRs.", 0),
    ("Refiling of #169273: bundle experimental data assets to restore expected tool behavior and unblock usage.", 0),
    ("Move PageTransitionsBuilder from material/ to widget/ to keep types in the correct layer.", 0),

    # Classic fixes & CI/test stability
    ("Fix DropdownMenu filtering by storing selected value instead of index; add a regression test.", 0),
    ("Deflake GPU tests by removing real-time sleeps; use a virtual clock to prevent timeouts.", 0),
]

NEAR_MISS_NEGS = [
    # Pure feature without upkeep intent
    ("Introduces ReorderableListView.separated constructor (new API).", -1),
    ("Adds weekType parameter to CupertinoDatePicker to control selectable days (feature).", -1),
    ("Widget previewer filters previews by active editor location; includes UI changes (feature).", -1),

    # Process/social only
    ("Thanks! I'll merge after CI.", -1),
    ("Please rebase on main.", -1),

    # Issue mention without maintenance intent
    ("Related to #173838", -1),
]



def fewshot_block(examples):
    return "\n\n".join([f"Comment: {t}\nLabel: {y}" for t, y in examples])

FEWSHOT_CURATED = fewshot_block(POS_CUES + NEAR_MISS_NEGS)

def fs_strict(comment: str) -> str:
    return f"""{BASE}

Few-shot examples:
{FEWSHOT_CURATED}

Comment to classify:
{comment}
"""

def classify_once(client, prompt: str) -> Tuple[int, str]:
    r = client.chat.completions.create(
        model=OPENAI_MODEL,
        messages=[
            {"role": "system", "content": SYSTEM_MSG},
            {"role": "user", "content": prompt},
        ],
        max_completion_tokens=10,
    )
    raw = (r.choices[0].message.content or "").strip()
    if not raw:
        return -1, ""
    return to_label(raw), raw

# ------------------------------
# 4) Evaluation
# ------------------------------
def eval_on_split(name: str, X_split: pd.Series, y_split: pd.Series, use_rules: bool=True) -> Dict[str, Any]:
    client = get_client()
    preds, raws, prompts_preview, detailed_rows = [], [], [], []

    print(f"\n### Per-instance outputs for: {name} | N={len(X_split)}")
    for i, (txt, gold) in enumerate(zip(X_split, y_split)):
        # Clean text (strip quotes + GH template noise)
        clean = strip_quoted_lines(txt)
        clean = strip_templates(clean)

        prompt = fs_strict(clean)
        prompt_preview = prompt[:240].replace("\n", " ") + ("..." if len(prompt) > 240 else "")

        if use_rules:
            rb = rule_based_short_text(clean)
            if rb is not None:
                pred, raw = rb, "[RULE: short-text <3 words]"
            else:
                if USE_RULE_REGEX:
                    rb2 = rule_based_regex(clean)
                    if rb2 is not None:
                        pred, raw = rb2, "[RULE: maint-regex]"
                    else:
                        pred, raw = classify_once(client, prompt)
                else:
                    pred, raw = classify_once(client, prompt)
        else:
            pred, raw = classify_once(client, prompt)

        preds.append(pred)
        raws.append(raw)
        detailed_rows.append({
            "index": i,
            "comment": txt,
            "cleaned": clean,
            "gold_label": int(gold),
            "pred_label": int(pred),
            "correct": int(pred) == int(gold),
            "raw_model_output": raw,
            "prompt_preview": prompt_preview,
        })

        correct = (int(pred) == int(gold))
        print(f"[{i}] GOLD={int(gold)} | PRED={int(pred)} | {'‚úÖ' if correct else '‚ùå'}")

    preds = np.array(preds)
    y_true = y_split.values

    detailed_df = pd.DataFrame(detailed_rows)
    return {
        "name": name,
        "accuracy": float(accuracy_score(y_true, preds)),
        "f1_macro": float(f1_score(y_true, preds, average="macro")),
        "f1_weighted": float(f1_score(y_true, preds, average="weighted")),
        "precision_macro": float(precision_recall_fscore_support(y_true, preds, average="macro", zero_division=0)[0]),
        "recall_macro": float(precision_recall_fscore_support(y_true, preds, average="macro", zero_division=0)[1]),
        "report": classification_report(y_true, preds, digits=3),
        "cm": confusion_matrix(y_true, preds),
        "preds": preds,
        "detailed_df": detailed_df,
    }

# --- robust y/n ‚Üí 0/-1 normalization
def normalize_label_value(v: str) -> int:
    LABEL_MAP = {
        "y": 0, "yes": 0, "true": 0, "t": 0, "1": 0, "maint": 0, "maintenance": 0, "0": 0,
        "n": -1, "no": -1, "false": -1, "f": -1, "-1": -1, "not maintenance": -1,
    }
    s = str(v).strip().lower()
    if s in LABEL_MAP:
        return LABEL_MAP[s]
    try:
        x = int(float(s))
        if x in (0, -1):
            return x
    except Exception:
        pass
    raise ValueError(f"Unrecognized label value: {v!r}. Expected y/n or 0/-1.")

# ------------------------------
# 5) Main (Monte Carlo 5 rounds √ó 50 samples)
# ------------------------------
def main():
    np.random.seed(RANDOM_STATE)

    if not os.path.exists(CSV_PATH):
        raise FileNotFoundError(f"CSV not found: {CSV_PATH}. Override with MAINT_DATASET=/path/to.csv")
    df = pd.read_csv(CSV_PATH)
    text_col, label_col = autodetect_columns(df)

    # Normalize text and labels
    X_all = normalize_text_series(df[text_col]).reset_index(drop=True)
    df["label_norm"] = df[label_col].apply(normalize_label_value).astype(int)
    y_all = df["label_norm"].reset_index(drop=True).astype(int)

    print(f"‚úÖ Loaded {len(df)} rows from {CSV_PATH}")
    print(f"Using TEXT_COL='{text_col}', LABEL_COL='{label_col}' ‚Üí normalized to 'label_norm'")
    print("Label distribution (raw):", df[label_col].value_counts(dropna=False).to_dict())
    print("Label distribution (norm):", df["label_norm"].value_counts().to_dict())
    print(f"Rules: short_text={USE_RULE_SHORT_TEXT}, maint_regex={USE_RULE_REGEX}")
    print(f"Model: {OPENAI_MODEL}")

    # Monte Carlo rounds
    fold_metrics = []
    all_round_rows = []

    for round_idx in range(1, N_ROUNDS + 1):
        # random 50 (or fewer if dataset smaller)
        size = min(SAMPLE_SIZE, len(X_all))
        sample_idx = np.random.choice(len(X_all), size=size, replace=False)
        X_sample = X_all.iloc[sample_idx].reset_index(drop=True)
        y_sample = y_all.iloc[sample_idx].reset_index(drop=True)

        print(f"\n================= Round {round_idx}/{N_ROUNDS} =================")
        print(f"Round {round_idx}: {len(X_sample)} random test cases")
        print("Gold counts:", y_sample.value_counts().to_dict())

        res = eval_on_split(
            name=f"maintenance_random_round{round_idx}",
            X_split=X_sample,
            y_split=y_sample,
            use_rules=False,   # pure LLM evaluation; set True to turn on rules
        )

        round_df = res["detailed_df"].copy()
        round_df["round"] = round_idx
        out_csv = f"maintenance_random_round{round_idx}_details.csv"
        round_df.to_csv(out_csv, index=False)
        print(f"üìÑ Saved per-instance details: {out_csv}")

        print("\n--- Classification Report ---")
        print(res["report"])
        print("Confusion matrix (rows=true [-1,0], cols=pred):\n", res["cm"])
        print("Pred counts:", pd.Series(res["preds"]).value_counts().to_dict())

        fold_metrics.append({
            "round": round_idx,
            "accuracy": res["accuracy"],
            "f1_macro": res["f1_macro"],
            "precision_macro": res["precision_macro"],
            "recall_macro": res["recall_macro"],
        })
        all_round_rows.append(round_df)

    # Aggregate metrics across rounds
    cv_df = pd.DataFrame(fold_metrics)
    cv_df.to_csv("maintenance_random_5round_metrics.csv", index=False)
    print("\nüìÑ Saved random sampling metrics to maintenance_random_5round_metrics.csv")

    # === Compute mean ¬± std and export ===
    def mean_std(series: pd.Series) -> Tuple[float, float]:
        mu = float(series.mean())
        sd = float(series.std(ddof=1)) if len(series) > 1 else 0.0
        return mu, sd

    acc_mu, acc_sd = mean_std(cv_df["accuracy"])
    f1_mu,  f1_sd  = mean_std(cv_df["f1_macro"])
    pr_mu,  pr_sd  = mean_std(cv_df["precision_macro"])
    rc_mu,  rc_sd  = mean_std(cv_df["recall_macro"])

    print("\n================= 5-round RANDOM SAMPLING SUMMARY =================")
    print(cv_df.to_string(index=False))
    print(
        f"\nMEAN ¬± STD over {N_ROUNDS} rounds:\n"
        f"- Accuracy         : {acc_mu:.4f} ¬± {acc_sd:.4f}\n"
        f"- F1 (macro)       : {f1_mu:.4f} ¬± {f1_sd:.4f}\n"
        f"- Precision (macro): {pr_mu:.4f} ¬± {pr_sd:.4f}\n"
        f"- Recall (macro)   : {rc_mu:.4f} ¬± {rc_sd:.4f}"
    )

    # Save mean ¬± std summary CSV
    summary_df = pd.DataFrame({
        "metric": ["Accuracy", "F1_macro", "Precision_macro", "Recall_macro"],
        "mean":   [acc_mu, f1_mu, pr_mu, rc_mu],
        "std":    [acc_sd, f1_sd, pr_sd, rc_sd],
    })
    summary_path = "maintenance_random_5round_summary_mean.csv"
    summary_df.to_csv(summary_path, index=False)
    print(f"\nüìÑ Saved mean ¬± std summary to {summary_path}")

    # Aggregate & export misclassified/correct across all rounds
    if len(all_round_rows):
        all_df = pd.concat(all_round_rows, ignore_index=True)

        # MISCLASSIFIED
        mis = all_df[~all_df["correct"]].copy()
        print("\n================= MISCLASSIFIED (All Rounds) =================")
        if len(mis) == 0:
            print("üéâ No misclassified instances.")
        else:
            pair_counts = (
                mis.groupby(["gold_label", "pred_label"])
                   .size()
                   .reset_index(name="count")
                   .sort_values(["gold_label", "pred_label", "count"], ascending=[True, True, False])
            )
            print(f"Total misclassified: {len(mis)} of {len(all_df)}")
            print("\nCounts by (gold_label -> pred_label):")
            print(pair_counts.to_string(index=False))

            def trunc(s, n=160): s=str(s); return (s[:n]+"‚Ä¶") if len(s)>n else s
            print("\n--- Examples (up to 3 per (gold -> pred)) ---")
            for (g,p), grp in mis.groupby(["gold_label","pred_label"]):
                print(f"\n[gold={g} -> pred={p}]  n={len(grp)}")
                for _, row in grp.head(3).iterrows():
                    print(f"‚Ä¢ round={int(row['round'])} gold={int(row['gold_label'])} pred={int(row['pred_label'])}")
                    print(f"  cleaned: {trunc(row['cleaned'])}")

            mis_out = "maintenance_misclassified_5rounds.csv"
            cols = ["round", "index", "gold_label", "pred_label", "correct", "comment", "cleaned", "raw_model_output", "prompt_preview"]
            mis[cols].to_csv(mis_out, index=False)
            print(f"\nüìÑ Saved misclassified instances: {mis_out}")

            # Optional JSONL
            try:
                import pathlib
                jpath = pathlib.Path("maintenance_misclassified_5rounds.jsonl")
                with jpath.open("w", encoding="utf-8") as f:
                    for _, r in mis.iterrows():
                        f.write(json.dumps({
                            "round": int(r["round"]),
                            "index": int(r["index"]),
                            "gold_label": int(r["gold_label"]),
                            "pred_label": int(r["pred_label"]),
                            "comment": r["comment"],
                            "cleaned": r["cleaned"],
                            "raw_model_output": r["raw_model_output"],
                            "prompt_preview": r["prompt_preview"],
                        }, ensure_ascii=False) + "\n")
                print(f"üìÑ Saved JSONL: {jpath}")
            except Exception as e:
                print(f"JSONL save skipped: {e}")

        # CORRECT
        correct_df = all_df[all_df["correct"]].copy()
        print("\n================= CORRECTLY CLASSIFIED (All Rounds) =================")
        if len(correct_df) == 0:
            print("‚ö†Ô∏è No correctly classified instances found.")
        else:
            print(f"Total correct: {len(correct_df)} of {len(all_df)}")
            correct_counts = (
                correct_df.groupby(["gold_label"])
                          .size()
                          .reset_index(name="count")
                          .sort_values("gold_label")
            )
            print("\nCounts by gold_label:")
            print(correct_counts.to_string(index=False))

            correct_out = "maintenance_correct_5rounds.csv"
            cols2 = ["round", "index", "gold_label", "pred_label", "comment", "cleaned", "raw_model_output", "prompt_preview"]
            correct_df[cols2].to_csv(correct_out, index=False)
            print(f"\nüìÑ Saved correctly classified instances: {correct_out}")

if __name__ == "__main__":
    main()


‚úÖ Loaded 100 rows from /Users/stevefeng/Downloads/GlueWork-Tracker-branch-mentoring/glue_work_bot/training_data/maintainance_training_dataset.csv
Using TEXT_COL='body', LABEL_COL='label' ‚Üí normalized to 'label_norm'
Label distribution (raw): {'n': 50, 'y': 50}
Label distribution (norm): {-1: 50, 0: 50}
Rules: short_text=True, maint_regex=False
Model: gpt-4o-mini

Round 1: 50 random test cases
Gold counts: {-1: 30, 0: 20}

### Per-instance outputs for: maintenance_random_round1 | N=50
[0] GOLD=0 | PRED=-1 | ‚ùå
[1] GOLD=-1 | PRED=-1 | ‚úÖ
[2] GOLD=0 | PRED=0 | ‚úÖ
[3] GOLD=-1 | PRED=-1 | ‚úÖ
[4] GOLD=-1 | PRED=-1 | ‚úÖ
[5] GOLD=0 | PRED=-1 | ‚ùå
[6] GOLD=-1 | PRED=-1 | ‚úÖ
[7] GOLD=0 | PRED=0 | ‚úÖ
[8] GOLD=-1 | PRED=-1 | ‚úÖ
[9] GOLD=-1 | PRED=-1 | ‚úÖ
[10] GOLD=-1 | PRED=-1 | ‚úÖ
[11] GOLD=-1 | PRED=-1 | ‚úÖ
[12] GOLD=0 | PRED=-1 | ‚ùå
[13] GOLD=-1 | PRED=-1 | ‚úÖ
[14] GOLD=0 | PRED=-1 | ‚ùå
[15] GOLD=-1 | PRED=-1 | ‚úÖ
[16] GOLD=0 | PRED=-1 | ‚ùå
[17] GOLD=0 | PRED=0 | ‚úÖ
[18] G