# Retrieval

In [None]:
# RETRIEVAL 3
# %% LEAK-FREE RETRIEVAL ABLATION — FULL DIAGNOSTICS (same artifacts as before)
# Selection uses TRAIN only (no leakage); all plots/CSVs match your earlier names.
#
# Artifacts in results/qeval_ablation_plus/ :
#   ├─ ablation_all_raw.csv
#   ├─ ablation_macro_means.csv
#   ├─ ablation_delta_vs_baseline.csv
#   ├─ ablation_diag_macro_means.csv
#   ├─ main_effect_*.csv
#   ├─ best_config.txt
#   ├─ abl_macro_heatmap.png
#   ├─ abl_macro_heatmap_extra.png
#   ├─ abl_delta_heatmap.png
#   ├─ abl_ast_subset_bars.png
#   ├─ abl_main_effect_*.png
#   ├─ abl_ecdf_*_baseline_vs_best.png
#   ├─ abl_rank_breakdown_baseline_vs_best.png
#   ├─ abl_diag_heatmap_z.png
#   ├─ abl_main_effect_hints_diag.png
#   ├─ scatter_synprior_vs_precproxy_best.png
#   └─ scatter_qdensity_vs_ndcg_best.png
#
# Also writes: splits_70_25_5.json and macro_by_split.csv (best cfg on TRAIN/VAL/TEST)

%pip install -q pandas numpy matplotlib sentence-transformers tqdm

import os, re, json, math, difflib, ast, random
import numpy as np, pandas as pd, matplotlib.pyplot as plt
from pathlib import Path
from dataclasses import dataclass
from collections import Counter
from hashlib import md5
from tqdm import tqdm

plt.rcParams["figure.dpi"] = 150
plt.rcParams.update({"axes.spines.top": False, "axes.spines.right": False})

# ------------------------------- knobs -------------------------------
DB_ROOT = Path("data/bugs4q/Bugs4Q-Database")
SAVE    = Path("results/qeval_ablation_plus"); SAVE.mkdir(parents=True, exist_ok=True)
SEED = 7; random.seed(SEED); np.random.seed(SEED)

TOPK         = 2
OVERRETRIEVE = 80
DATA_PERCENT = 100          # % of TRAIN used in the grid (speed knob)
RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"

# IDF scope used during selection (prevents leakage)
INDEX_SCOPE_FOR_ABLATION = "train"   # {"train","all"}

# ------------------------------- utils -------------------------------
WORD_RE   = re.compile(r"[A-Za-z_][A-Za-z_0-9]*")
STOPWORDS = set("a an and are as at be by for from has have in is it its of on or that the to was were will with not this self none true false return def class if elif else try except finally while for".split())

Q_TOKENS = set("""
x y z h s sdg t tdg rx ry rz rzz rzx rxy sx cx ccx cnot cz swap cswap iswap ecr u u1 u2 u3
measure barrier qreg creg backend provider aer terra pulse schedule bind assign_parameters
QuantumCircuit QuantumRegister ClassicalRegister Parameter ParameterVector
DAGCircuit PassManager layout mapper transpile basis_gates optimization_level qasm dag layout pass
CouplingMap AncillaAllocation NoiseModel Calibrations LayoutPass Unroller
""".split())

def safe_read(p: Path) -> str:
    try: return p.read_text(encoding="utf-8", errors="replace")
    except Exception: return ""

def tokenize(s: str): return [w.lower() for w in WORD_RE.findall(s) if w and w.lower() not in STOPWORDS]

def changed_lines_in_A(a_text: str, b_text: str) -> set[int]:
    a = a_text.splitlines(); b = b_text.splitlines()
    sm = difflib.SequenceMatcher(None, a, b, autojunk=False)
    touched=set()
    for tag,i1,i2,j1,j2 in sm.get_opcodes():
        if tag in ("replace","delete"): touched.update(range(i1+1, i2+1))
    return touched

def dcg(scores): return sum(s/ math.log2(i+2) for i,s in enumerate(scores))
def ecdf(arr):
    arr=np.asarray(arr, float); arr=arr[~np.isnan(arr)]
    x=np.sort(arr); y=np.arange(1,len(x)+1)/max(1,len(x)); return x,y

# ------------------------------- dataset scan & split -------------------------------
def iter_cases(db_root: Path):
    for buggy in db_root.rglob("buggy.py"):
        d=buggy.parent; fixed=None
        for nm in ("fixed.py","fix.py"):
            p=d/nm
            if p.exists(): fixed=p; break
        if not fixed: continue
        cid=str(d.relative_to(db_root)).replace(os.sep,"/")
        yield cid, d, Path(buggy), Path(fixed)

@dataclass
class CodeChunk:
    chunk_id: str; repo_key: str; file_path: str
    start_line: int; end_line: int; symbol: str; kind: str; text: str

class ASTChunker:
    def __init__(self, window_fallback=80, window_overlap=10):
        self.window_fallback=window_fallback; self.window_overlap=window_overlap
    def chunk_file(self, case_dir: Path, file_path: Path, repo_key: str):
        rel = str(file_path.relative_to(case_dir))
        src = safe_read(file_path); lines = src.splitlines()
        try: root = ast.parse(src)
        except Exception: root=None
        chunks=[]
        def add(s,e,sym,kind):
            s=max(1,int(s)); e=max(s,int(e))
            chunks.append(CodeChunk(md5(f"{rel}:{s}-{e}".encode()).hexdigest()[:12],
                                    repo_key, rel, s,e, sym, kind, "\n".join(lines[s-1:e])))
        if root is not None:
            for n in ast.walk(root):
                if isinstance(n,(ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                    add(getattr(n,"lineno",1), getattr(n,"end_lineno",1),
                        getattr(n,"name","<sym>"),
                        "class" if isinstance(n,ast.ClassDef) else "function")
        if not chunks:
            step=self.window_fallback-self.window_overlap; i=0; n=len(lines)
            while i<n:
                s=i+1; e=min(i+self.window_fallback, n); add(s,e,"<module>","module"); i+=step
        return chunks

CALL_RE = re.compile(r"\.([A-Za-z_]\w*)\s*\(")
IMPORT_QISKIT_RE  = re.compile(r"^\s*(?:from\s+qiskit(?:\.[\w\.]+)?\s+import\s+([\w\,\s]+)|import\s+qiskit(?:\.[\w\.]+)?(?:\s+as\s+(\w+))?)", re.M)
Q_PIPELINE = ["QuantumCircuit","DAGCircuit","PassManager","transpile","layout","coupling_map","qasm","basis_gates","Parameter","compose","append","measure","barrier","decompose","reset","initialize"]

def build_hinted_query(seed_q: str, buggy_text: str, kmax=8):
    toks = tokenize(buggy_text); freq = Counter(toks)
    q_from_tokens = [t for t,_ in freq.most_common() if t in Q_TOKENS]
    calls = [m.group(1).lower() for m in CALL_RE.finditer(buggy_text)]
    call_counts = Counter([c for c in calls if c in {"cx","rz","ry","rx","swap","cz","ccx","measure","append","compose","decompose"}])
    present_pipeline = [w.lower() for w in Q_PIPELINE if w.lower() in buggy_text.lower()]
    aliases=[]
    for m in IMPORT_QISKIT_RE.finditer(buggy_text):
        mods, alias = m.groups()
        if alias: aliases.append(alias.strip())
        if mods:
            for name in mods.split(","):
                nm=name.strip()
                if nm: aliases.append(nm)
    merged=[]; 
    def push(xs):
        for x in xs:
            x=x.lower()
            if x and x not in merged: merged.append(x)
    push([g for g,_ in call_counts.most_common()])
    push(q_from_tokens); push(present_pipeline); push([a for a in aliases if a!="qiskit"])
    if not merged: merged=["cx","rz","dag","layout","transpile","qasm"]
    hint_tokens = merged[:kmax]
    return (seed_q + " " + " ".join(hint_tokens)).strip(), len(hint_tokens)

# BM25
class _MiniBM25:
    def __init__(self, docs):
        self.docs=docs; self.N=len(docs); self.lens=[len(d) for d in docs]
        self.avg=sum(self.lens)/max(1,self.N)
        df=Counter()
        for d in docs: df.update(set(d))
        self.df=dict(df)
    def idf(self,t):
        df=self.df.get(t,0)
        return 0.0 if df==0 else math.log(1+(self.N-df+0.5)/(df+0.5))
    def score(self, q, doc, dl):
        k1,b=1.5,0.75; f=Counter(doc); s=0.0
        for t in q:
            if t not in self.df: continue
            tf=f.get(t,0); 
            if tf==0: continue
            denom=tf+k1*(1-b+b*dl/max(1,self.avg))
            s+=self.idf(t)*(tf*(k1+1))/denom
        return s

class HybridIndex:
    def __init__(self, boost_map=None, include_paths=False):
        self.boost_map={k.lower():float(v) for k,v in (boost_map or {}).items()}
        self.include_paths=include_paths; self.records=[]; self.docs=[]
    def build(self, chunks):
        self.records=[]; self.docs=[]
        for c in chunks:
            header=f"{c.symbol} {c.kind} "
            if self.include_paths: header += c.file_path + " "
            toks=tokenize(header+"\n"+c.text)
            boost=sum(self.boost_map.get(t,0.0) for t in toks)
            self.records.append({"chunk":c,"tokens":toks,"boost":float(boost)})
            self.docs.append(toks)
        self.bm25=_MiniBM25(self.docs)
    def search(self, query, topk=10):
        q=tokenize(query)
        scored=[]
        for i,rec in enumerate(self.records):
            s=self.bm25.score(q, rec["tokens"], len(rec["tokens"])) + 0.02*rec["boost"]
            scored.append((s,i))
        scored.sort(reverse=True)
        out=[]
        for s,i in scored[:topk]:
            c=self.records[i]["chunk"]
            out.append({"score":float(s), "re_score":0.0, "file":c.file_path, "symbol":c.symbol, "kind":c.kind,
                        "start":c.start_line, "end":c.end_line, "preview":"\n".join(c.text.splitlines()[:120])})
        return out

def quantum_boost_map(alpha=1.8): return {t.lower():alpha for t in Q_TOKENS}

# Cross-encoder — honest gating (no phantom “on”)
class CrossEncoderReranker:
    def __init__(self, model_name):
        try:
            from sentence_transformers import CrossEncoder
            self.model=CrossEncoder(model_name); self.enabled=True
        except Exception as e:
            print("[WARN] CrossEncoder unavailable; 'rerank=on' configs will be skipped.", e)
            self.model=None; self.enabled=False
    def score_pairs(self, pairs):
        if not self.enabled: raise RuntimeError("Reranker disabled")
        import numpy as np
        return np.asarray(self.model.predict(pairs), dtype=float)

def apply_rerank(query, pool_u, rr):
    pairs=[(query, h.get("preview","")) for h in pool_u]
    scores=rr.score_pairs(pairs)
    for h,s in zip(pool_u, scores): h["re_score"]=float(s)
    return sorted(pool_u, key=lambda r: r.get("re_score",0.0), reverse=True)

# priors, selectors, diagnostics helpers
def syntax_prior_of(hit):
    txt=(hit.get("preview","")+" "+hit.get("symbol","")).lower()
    prior=0.0
    if any(t in txt for t in ["assert","raise","error","exception"]): prior+=0.10
    if any(t in txt for t in ["dag","layout","transpile","qasm","coupling_map","basis_gates"]): prior+=0.12
    if any(t.lower() in txt for t in Q_TOKENS): prior+=0.10
    if re.search(r'\b(run|apply)\b', txt): prior+=0.08
    return min(prior,0.6)

def apply_syntax_prior(pool_u, alpha=0.5):
    out=[]
    for h in pool_u:
        sp=syntax_prior_of(h); base=h.get("re_score", h.get("score",0.0))
        g=dict(h); g["syn_prior"]=sp; g["score"]=base*(1.0+alpha*sp); out.append(g)
    return sorted(out, key=lambda r:r["score"], reverse=True)

def select_by_coverage_balanced(pool_u, topk, w_gain=0.8, w_base=1.0, w_rerank=1.5,
                                w_div_file=0.15, w_div_sym=0.10, pen_overlap=0.10):
    sel,covered=[],set(); seen_files,set_syms=set(),set()
    base=np.array([h.get("score",0.0) for h in pool_u], float)
    bn=(base-base.min())/(base.max()-base.min()+1e-9)
    rn=np.array([h.get("re_score",0.0) for h in pool_u], float)
    for h,b,r in zip(pool_u,bn,rn): h["_bn"]=float(b); h["_rn"]=float(r)
    for _ in range(min(topk,len(pool_u))):
        best,best_s=None,-1e9
        for h in pool_u:
            if h in sel: continue
            rng=set(range(h["start"],h["end"]+1)); size=max(1,h["end"]-h["start"]+1)
            gain=len(rng-covered); gain_norm=gain/size; overlap=1.0-gain_norm
            s=w_gain*gain_norm + w_base*h["_bn"] + w_rerank*h["_rn"]
            s += (w_div_file if h["file"] not in seen_files else 0.0)
            s += (w_div_sym  if h["symbol"] not in set_syms else 0.0)
            s -= pen_overlap*overlap
            if s>best_s: best,best_s=h,s
        if best is None: break
        sel.append(best); covered|=set(range(best["start"],best["end"]+1))
        seen_files.add(best["file"]); set_syms.add(best["symbol"])
    return sel

def select_by_coverage_old(hits, topk, w_new_file=10.0, w_new_symbol=6.0, w_rerank=2.0):
    selected, covered=[], set(); seen_files,seen_syms=set(),set()
    for _ in range(min(topk,len(hits))):
        best,score=None,-1.0
        for h in hits:
            if h in selected: continue
            rng=set(range(h["start"],h["end"]+1)); gain=len(rng-covered)
            tie=h.get("re_score",h.get("score",0.0))
            s=gain + (w_new_file if h["file"] not in seen_files else 0.0) \
                    + (w_new_symbol if h["symbol"] not in seen_syms else 0.0) \
                    + (w_rerank*tie)
            if s>score: best,score=h,s
        if best is None: break
        selected.append(best); covered|=set(range(best["start"],best["end"]+1))
        seen_files.add(best["file"]); seen_syms.add(best["symbol"])
    return selected

def _kendall_tau(keys_a, keys_b):
    n=min(len(keys_a), len(keys_b), 25)
    if n<3: return np.nan
    a=keys_a[:n]; b=keys_b[:n]; pos={k:i for i,k in enumerate(b)}
    conc=disc=0
    for i in range(n):
        for j in range(i+1,n):
            ki,kj=a[i],a[j]
            if ki not in pos or kj not in pos: continue
            conc += int(pos[ki] < pos[kj]); disc += int(pos[ki] > pos[kj])
    denom=conc+disc
    return (conc-disc)/denom if denom>0 else np.nan

def _quantum_density(texts):
    toks=[]; qcnt=0
    for t in texts:
        ts=tokenize(t); toks.extend(ts); qcnt += sum(1 for z in ts if z in Q_TOKENS)
    total=max(1,len(toks)); return qcnt, qcnt/total

# ------------------------------- build chunks & meta -------------------------------
chunker=ASTChunker()
all_chunks_ast, all_chunks_win, meta=[],[],{}
for cid,case_dir,bug_f,fix_f in iter_cases(DB_ROOT):
    for ch in chunker.chunk_file(case_dir, bug_f, repo_key=cid):
        ch.file_path=f"{cid}/{ch.file_path}"; all_chunks_ast.append(ch)
    txt=safe_read(bug_f); lines=txt.splitlines()
    win,overlap=80,10; step=max(1,win-overlap); i=0
    while i<len(lines):
        s=i+1; e=min(i+win,len(lines))
        all_chunks_win.append(CodeChunk(md5(f"{cid}/{bug_f.name}:{s}-{e}".encode()).hexdigest()[:12],
                                        cid, f"{cid}/{bug_f.name}", s,e, f"<win@{s}-{e}>","module","\n".join(lines[s-1:e])))
        i+=step
    meta[cid]={"gold":changed_lines_in_A(txt, safe_read(fix_f)), "query": " ".join(tokenize(txt)[:6]),
               "project": cid.split("/")[0], "bug_text": txt}

# stable 70/25/5 split
ALL = sorted(meta.keys(), key=lambda k: md5(k.encode()).hexdigest())
n=len(ALL); n_train=int(round(0.70*n)); n_val=int(round(0.25*n))
TRAIN, VAL, TEST = ALL[:n_train], ALL[n_train:n_train+n_val], ALL[n_train+n_val:]
json.dump({"train":TRAIN,"val":VAL,"test":TEST,"n":n}, open(SAVE/"splits_70_25_5.json","w"), indent=2)
k_train = max(1, int(math.ceil(len(TRAIN)*DATA_PERCENT/100.0))); TRAIN_SUB=TRAIN[:k_train]
print(f"[SPLIT] train={len(TRAIN)} val={len(VAL)} test={len(TEST)} ; grid uses TRAIN_SUB={len(TRAIN_SUB)}")

# ------------------------------- indices -------------------------------
def build_index(chunks, use_boost=False):
    boost=quantum_boost_map(1.8) if use_boost else {}
    idx=HybridIndex(boost_map=boost, include_paths=False); idx.build(chunks); return idx

def _keep_cases(chunks, subset):
    keep=set(subset); return [c for c in chunks if str(c.repo_key) in keep]

if INDEX_SCOPE_FOR_ABLATION=="train":
    chunks_ast_TR=_keep_cases(all_chunks_ast, TRAIN_SUB)
    chunks_win_TR=_keep_cases(all_chunks_win, TRAIN_SUB)
else:
    chunks_ast_TR=all_chunks_ast; chunks_win_TR=all_chunks_win

# train-scope indices (for selection)
idx_ast_base_TR   = build_index(chunks_ast_TR,  use_boost=False)
idx_ast_q_TR      = build_index(chunks_ast_TR,  use_boost=True)
idx_win_base_TR   = build_index(chunks_win_TR,  use_boost=False)
idx_win_q_TR      = build_index(chunks_win_TR,  use_boost=True)

# all-scope indices (for reporting best cfg across splits)
idx_ast_base_ALL  = build_index(all_chunks_ast, use_boost=False)
idx_ast_q_ALL     = build_index(all_chunks_ast, use_boost=True)
idx_win_base_ALL  = build_index(all_chunks_win, use_boost=False)
idx_win_q_ALL     = build_index(all_chunks_win, use_boost=True)

# ------------------------------- evaluation core (with diagnostics) -------------------------------
def eval_config(index, cases, use_hints, use_reranker, selector, use_syntax, name_for_tqdm="cfg"):
    rr = CrossEncoderReranker(RERANK_MODEL) if use_reranker else None
    if use_reranker and not rr.enabled:
        return pd.DataFrame()  # skip 'on' if model missing

    select_fn = select_by_coverage_old if selector=="old" else select_by_coverage_balanced
    rows=[]
    for cid in tqdm(cases, desc=f"[{name_for_tqdm}] cases", leave=False):
        gold = meta[cid]["gold"]; seed_q = meta[cid]["query"]; bug_txt=meta[cid]["bug_text"]

        q, hint_count = (build_hinted_query(seed_q, bug_txt) if use_hints else (seed_q, 0))
        query_len = len(tokenize(q))

        pool = index.search(q, topk=max(OVERRETRIEVE, 6*TOPK))
        seen=set(); pool_u=[]
        for h in pool:
            key=(h["file"],h["start"],h["end"])
            if key in seen: continue
            seen.add(key); pool_u.append(h)

        base_sorted = sorted(pool_u, key=lambda r:r.get("score",0.0), reverse=True)
        base_keys   = [(h["file"],h["start"],h["end"]) for h in base_sorted]

        # rerank (diagnostics)
        if rr is not None:
            pool_rr = apply_rerank(q, pool_u, rr)
            rr_keys = [(h["file"],h["start"],h["end"]) for h in pool_rr]
            top_shift = 1.0 if (base_keys[:1] != rr_keys[:1]) else 0.0
            kendall25 = _kendall_tau(base_keys, rr_keys)
            n = min(50, len(pool_rr))
            base_scores = [h.get("score",0.0) for h in pool_rr[:n]]
            re_scores   = [h.get("re_score",0.0) for h in pool_rr[:n]]
            if n>=3 and np.std(base_scores)>0 and np.std(re_scores)>0:
                spearman = np.corrcoef(np.argsort(np.argsort(base_scores)),
                                       np.argsort(np.argsort(re_scores)))[0,1]
            else:
                spearman = np.nan
        else:
            pool_rr = pool_u; top_shift=np.nan; kendall25=np.nan; spearman=np.nan

        pool_for_sel = apply_syntax_prior(pool_rr, alpha=0.5) if use_syntax else pool_rr
        select_fn_use = select_by_coverage_old if selector=="old" else select_by_coverage_balanced
        selected = select_fn_use(pool_for_sel, TOPK)

        # graded relevance + diagnostics
        rel, same_scores, covered = [], [], set()
        uniq_files=set(); uniq_syms=set(); span_lens=[]
        synpriors=[]; previews=[]
        for h in selected:
            span=max(1, h["end"]-h["start"]+1); span_lens.append(span)
            uniq_files.add(h["file"]); uniq_syms.add(h["symbol"])
            previews.append(h.get("preview","")); synpriors.append(syntax_prior_of(h))
            same_case = h["file"].startswith(cid + "/")
            overlap = sum(1 for ln in gold if h["start"]<=ln<=h["end"])
            frac = overlap/span
            rel.append(frac if same_case else 0.0)
            if same_case:
                same_scores.append(frac); covered.update(range(h["start"],h["end"]+1))

        hit = 1.0 if any(x>0 for x in rel) else 0.0
        try: rk = next(i+1 for i,x in enumerate(rel) if x>0); mrr = 1.0/rk
        except StopIteration: mrr=0.0
        ideal = dcg(sorted(rel, reverse=True)); ndcg = (dcg(rel)/ideal) if ideal>0 else 0.0
        line_recall = len({ln for ln in covered if ln in gold}) / max(1,len(gold))
        prec_proxy = float(np.mean(same_scores)) if same_scores else 0.0

        q_cnt, q_density = _quantum_density(previews)
        synprior_mean = float(np.mean(synpriors)) if synpriors else np.nan

        rows.append({
            "case":cid,
            # base metrics
            "Hit@K_global":hit, "MRR_line_global":mrr, "nDCG@K_global":ndcg,
            "LineRecall@K":line_recall, "WindowPrecProxy":prec_proxy,
            # extras for your plots
            "Hit@1": 1.0 if len(rel)>=1 and rel[0]>0 else 0.0,
            "Hit@2": 1.0 if (len(rel)>=2 and (rel[0]>0 or rel[1]>0)) else (1.0 if len(rel)>=1 and rel[0]>0 else 0.0),
            "MeanRank_ifHit": (1.0/mrr) if mrr>0 else np.nan,
            "SpanOverlap@K": float(np.mean(rel)) if rel else 0.0,
            "UniqueFiles@K": len(uniq_files), "UniqueSymbols@K": len(uniq_syms),
            "AvgSpanLen@K": float(np.mean(span_lens)) if span_lens else 0.0,
            # rerank diagnostics
            "ReRankSpearman": spearman,
            "TopShift@Pool": top_shift,
            "ReRankKendallTop25": kendall25,
            # hint/syntax diagnostics
            "QueryLen": query_len,
            "HintTokensAppended": hint_count,
            "QuantumTokensInSelected": q_cnt,
            "QuantumDensitySelected": q_density,
            "SynPriorMeanSel": synprior_mean,
        })
    return pd.DataFrame(rows)

# ------------------------------- grid (TRAIN-only) -------------------------------
def build_grid(idx_ast, idx_ast_q, idx_win, idx_win_q):
    cfgs=[]
    for chunking, idx in [("AST_base", idx_ast), ("AST_q", idx_ast_q), ("WIN_base", idx_win), ("WIN_q", idx_win_q)]:
        for hints in [False, True]:
            for selector in ["old","balanced"]:
                for rerank in [False, True]:
                    for syntax_on in [False, True]:
                        cfgs.append((f"{chunking}__{'hint' if hints else 'nohint'}__{selector}__{'rerank' if rerank else 'noR'}__{'syntax' if syntax_on else 'nosyntax'}",
                                     idx, hints, rerank, selector, syntax_on))
    return cfgs

# indices used for selection (train-scoped)
idxs_TR = {
    "AST_base": idx_ast_base_TR, "AST_q": idx_ast_q_TR,
    "WIN_base": idx_win_base_TR, "WIN_q": idx_win_q_TR
}

df_list=[]
for name, idx, hints, rerank, selector, syntax_on in tqdm(build_grid(**{
    "idx_ast": idxs_TR["AST_base"],
    "idx_ast_q": idxs_TR["AST_q"],
    "idx_win": idxs_TR["WIN_base"],
    "idx_win_q": idxs_TR["WIN_q"]
}), desc="[Ablation|TRAIN] configs"):
    dfc = eval_config(idx, TRAIN_SUB, use_hints=hints, use_reranker=rerank, selector=selector, use_syntax=syntax_on, name_for_tqdm=name)
    if dfc.empty:  # reranker gated off
        continue
    dfc["config"]=name; df_list.append(dfc)
df_all = pd.concat(df_list, ignore_index=True)
df_all.to_csv(SAVE/"ablation_all_raw.csv", index=False)

# ------------------------------- aggregations (TRAIN) -------------------------------
base_metrics  = ["Hit@K_global","MRR_line_global","nDCG@K_global","LineRecall@K","WindowPrecProxy"]
extra_metrics = ["Hit@1","Hit@2","MeanRank_ifHit","SpanOverlap@K","UniqueFiles@K","UniqueSymbols@K","AvgSpanLen@K"]
diag_metrics  = ["ReRankSpearman","TopShift@Pool","ReRankKendallTop25","QueryLen","HintTokensAppended","QuantumTokensInSelected","QuantumDensitySelected","SynPriorMeanSel"]

macro     = df_all.groupby("config")[base_metrics+extra_metrics].mean().sort_index()
diag_macro= df_all.groupby("config")[diag_metrics].mean().sort_index()
macro.to_csv(SAVE/"ablation_macro_means.csv")
diag_macro.to_csv(SAVE/"ablation_diag_macro_means.csv")

baseline = "AST_base__nohint__old__noR__nosyntax"
if baseline not in macro.index: baseline = macro.index[0]
best_cfg = macro["nDCG@K_global"].idxmax()
(SAVE/"best_config.txt").write_text(best_cfg+"\n", encoding="utf-8")

delta = (macro - macro.loc[baseline]).drop(index=baseline)
delta.to_csv(SAVE/"ablation_delta_vs_baseline.csv")

# ------------------------------- plots (exactly like before) -------------------------------
def savefig(path): Path(path).parent.mkdir(parents=True, exist_ok=True); plt.tight_layout(); plt.savefig(path); plt.close()

# heatmaps
plt.figure(figsize=(min(22, 2+0.24*len(macro.index)), 10))
im=plt.imshow(macro[base_metrics].values, aspect="auto", vmin=0, vmax=1, cmap="viridis")
plt.colorbar(im, fraction=0.02, pad=0.02).set_label("macro mean")
plt.yticks(range(len(macro.index)), macro.index, fontsize=7)
plt.xticks(range(len(base_metrics)), base_metrics, rotation=25, ha="right")
plt.title("Ablation: macro means (base metrics)")
savefig(SAVE/"abl_macro_heatmap.png")

plt.figure(figsize=(min(22, 2+0.24*len(macro.index)), 10))
im=plt.imshow(macro[extra_metrics].replace([np.inf,-np.inf],np.nan).fillna(0.0).values, aspect="auto", cmap="magma")
plt.colorbar(im, fraction=0.02, pad=0.02).set_label("macro mean")
plt.yticks(range(len(macro.index)), macro.index, fontsize=7)
plt.xticks(range(len(extra_metrics)), extra_metrics, rotation=25, ha="right")
plt.title("Ablation: macro means (extra metrics)")
savefig(SAVE/"abl_macro_heatmap_extra.png")

plt.figure(figsize=(min(22, 2+0.24*len(delta.index)), 10))
v=np.nanmax(np.abs(delta[base_metrics].values)); im=plt.imshow(delta[base_metrics].values, aspect="auto", vmin=-v, vmax=+v, cmap="coolwarm")
plt.colorbar(im, fraction=0.02, pad=0.02).set_label(f"Δ vs {baseline}")
plt.yticks(range(len(delta.index)), delta.index, fontsize=7)
plt.xticks(range(len(base_metrics)), base_metrics, rotation=25, ha="right")
plt.title(f"Ablation: Δ vs baseline ({baseline})")
savefig(SAVE/"abl_delta_heatmap.png")

subset = [f"AST_base__nohint__{sel}__{rr}__{sx}" for sel in ["old","balanced"] for rr in ["noR","rerank"] for sx in ["nosyntax","syntax"]] + \
         [f"AST_q__hint__{sel}__{rr}__{sx}"      for sel in ["old","balanced"] for rr in ["noR","rerank"] for sx in ["nosyntax","syntax"]]
subset=[c for c in subset if c in macro.index]
plt.figure(figsize=(14,6))
x=np.arange(len(base_metrics)); w=0.06
for i,cfg in enumerate(subset):
    plt.bar(x+i*w, macro.loc[cfg, base_metrics].values, width=w, label=cfg)
plt.xticks(x+(len(subset)-1)*w/2, base_metrics, rotation=20, ha="right")
plt.ylabel("macro mean"); plt.title("Readable subset: AST only (base metrics)")
plt.legend(fontsize=7, ncol=3)
savefig(SAVE/"abl_ast_subset_bars.png")

# main effects (base)
def mean_by_factor(df_all, factor_fn):
    tmp=df_all.copy(); tmp["factor"]=tmp["config"].map(factor_fn)
    return tmp.groupby("factor")[base_metrics].mean()

effects = {
    "selector": mean_by_factor(df_all, lambda c: "old" if "__old__" in c else "balanced"),
    "reranker": mean_by_factor(df_all, lambda c: "on" if "__rerank__" in c else "off"),
    "hints":    mean_by_factor(df_all, lambda c: "hint" if "__hint__" in c else "nohint"),
    "boost":    mean_by_factor(df_all, lambda c: "boost" if c.startswith("AST_q") or c.startswith("WIN_q") else "noboost"),
    "chunking": mean_by_factor(df_all, lambda c: "AST" if c.startswith("AST_") else "WIN"),
    "syntax":   mean_by_factor(df_all, lambda c: "syntax" if c.endswith("__syntax") else "nosyntax"),
}
for name, dfm in effects.items():
    plt.figure(figsize=(7.5,4))
    for j,m in enumerate(base_metrics):
        plt.bar(np.arange(len(dfm.index))+j*0.18, dfm[m].values, width=0.18, label=m)
    plt.xticks(np.arange(len(dfm.index))+0.36, dfm.index)
    plt.ylabel("macro mean"); plt.title(f"Main effect: {name} (base metrics)")
    plt.legend(fontsize=7, ncol=3)
    savefig(SAVE/f"abl_main_effect_{name}.png")
    dfm.to_csv(SAVE/f"main_effect_{name}.csv")

# ECDF baseline vs best (TRAIN)
for metric in ["MRR_line_global","nDCG@K_global"]:
    plt.figure(figsize=(6,4))
    xs, ys = ecdf(df_all.loc[df_all["config"]==baseline, metric].values); plt.plot(xs, ys, label=baseline)
    xs, ys = ecdf(df_all.loc[df_all["config"]==best_cfg, metric].values); plt.plot(xs, ys, label=best_cfg)
    plt.xlabel(metric); plt.ylabel("ECDF"); plt.title(f"ECDF — baseline vs best ({best_cfg})")
    plt.legend(fontsize=7); savefig(SAVE/f"abl_ecdf_{metric}_baseline_vs_best.png")

def _bucket_mrr(mrr):
    if mrr <= 0 or np.isnan(mrr): return "Miss"
    r=int(round(1.0/mrr))
    return "Top-1" if r<=1 else "Top-2" if r==2 else "Top-3" if r==3 else "Top-4+"
def _rank_breakdown(df, cfg):
    b=pd.Series([_bucket_mrr(x) for x in df.loc[df["config"]==cfg, "MRR_line_global"]]).value_counts(normalize=True)
    return b.reindex(["Top-1","Top-2","Top-3","Top-4+","Miss"]).fillna(0.0)
rb_base=_rank_breakdown(df_all, baseline); rb_best=_rank_breakdown(df_all, best_cfg)
plt.figure(figsize=(7,4))
bottom=np.zeros(2); labels=["Top-1","Top-2","Top-3","Top-4+","Miss"]
for lab in labels:
    vals=[rb_base[lab], rb_best[lab]]
    plt.bar(["baseline","best"], vals, bottom=bottom, label=lab); bottom += vals
plt.ylabel("share"); plt.title(f"Rank breakdown — baseline vs best ({best_cfg})")
plt.legend(ncol=5, fontsize=7); savefig(SAVE/"abl_rank_breakdown_baseline_vs_best.png")

# diagnostics heatmap (z-scored) + hints diagnostics main-effect
diag_z = (diag_macro - diag_macro.mean())/diag_macro.std(ddof=0)
plt.figure(figsize=(min(22, 2+0.24*len(diag_z.index)), 10))
im=plt.imshow(diag_z[diag_metrics].fillna(0.0).values, aspect="auto", cmap="coolwarm")
plt.colorbar(im, fraction=0.02, pad=0.02).set_label("z-score")
plt.yticks(range(len(diag_z.index)), diag_z.index, fontsize=7)
plt.xticks(range(len(diag_metrics)), diag_metrics, rotation=25, ha="right")
plt.title("Ablation diagnostics (z-scored)")
savefig(SAVE/"abl_diag_heatmap_z.png")

tmp=df_all.copy(); tmp["factor"]=tmp["config"].map(lambda c: "hint" if "__hint__" in c else "nohint")
hints_diag = tmp.groupby("factor")[diag_metrics].mean()
plt.figure(figsize=(10,4))
for j,m in enumerate(diag_metrics):
    plt.bar(np.arange(len(hints_diag.index))+j*0.1, hints_diag[m].values, width=0.1, label=m)
plt.xticks(np.arange(len(hints_diag.index))+0.35, hints_diag.index)
plt.ylabel("mean"); plt.title("Main effect: hints (diagnostics)")
plt.legend(fontsize=7, ncol=3)
savefig(SAVE/"abl_main_effect_hints_diag.png")

# best-config scatters
best_rows = df_all[df_all["config"]==best_cfg]
if not best_rows.empty:
    plt.figure(figsize=(6,4))
    plt.scatter(best_rows["SynPriorMeanSel"], best_rows["WindowPrecProxy"])
    plt.xlabel("SynPriorMeanSel"); plt.ylabel("WindowPrecProxy")
    plt.title("Syntax prior vs precision proxy (best config)")
    savefig(SAVE/"scatter_synprior_vs_precproxy_best.png")

    plt.figure(figsize=(6,4))
    plt.scatter(best_rows["QuantumDensitySelected"], best_rows["nDCG@K_global"])
    plt.xlabel("QuantumDensitySelected"); plt.ylabel("nDCG@K_global")
    plt.title("Quantum density vs retrieval nDCG (best config)")
    savefig(SAVE/"scatter_qdensity_vs_ndcg_best.png")

# ------------------------------- best cfg across splits (reporting) -------------------------------
def idx_from_name(name: str):
    head=name.split("__")[0]
    return {"AST_base": idx_ast_base_ALL, "AST_q": idx_ast_q_ALL, "WIN_base": idx_win_base_ALL, "WIN_q": idx_win_q_ALL}[head]
def parse_cfg(name: str):
    p=name.split("__"); return {"hints":p[1]=="hint","selector":p[2],"rerank":p[3]=="rerank","syntax":p[4]=="syntax"}

cfg=parse_cfg(best_cfg); idx_best=idx_from_name(best_cfg)
rows=[]
for split, CASES in [("TRAIN", TRAIN), ("VAL", VAL), ("TEST", TEST)]:
    df = eval_config(idx_best, CASES, use_hints=cfg["hints"], use_reranker=cfg["rerank"], selector=cfg["selector"], use_syntax=cfg["syntax"], name_for_tqdm=f"best|{split}")
    m  = df[base_metrics].mean()
    rows.append({"split":split, **{k:float(m.get(k,np.nan)) for k in m.index}})
    df.to_csv(SAVE/f"percase__{split}.csv", index=False)
pd.DataFrame(rows).set_index("split").to_csv(SAVE/"macro_by_split.csv")

print("\nArtifacts saved to:", SAVE.resolve())
print("Best config (picked on TRAIN only):", best_cfg)
print("Index scope for ablation:", INDEX_SCOPE_FOR_ABLATION, "| reranker on-configs are skipped if model missing.")


# GRAP-Q Run 2

In [None]:
# %% FULL PIPELINE — GRAP-Q vs Pure-LLM (Validation-only, Leak-free 70/25/5 split)
# Evaluates GRAP-Q agent vs Pure-LLM on the *validation* set only.
# Leak-free donor policy:
#   • Retrieval index is built from buggy.py for all cases (no labels used).
#   • Cross-case donor windows are allowed ONLY from TRAIN cases.
#   • Optional donor filter excludes TRAIN donor windows overlapping their own gold changes.
#
# Artifacts:
#   results/grap_vs_llm_deep/
#       ├─ splits_70_25_5.json
#       ├─ grap_results_val.csv
#       ├─ llm_results_val.csv
#       ├─ combined_results_val.csv
#       ├─ grap_logs_val.json
#       ├─ llm_logs_val.json
#       ├─ (all plots)*.png   # validation-only
#
# Requirements:
#   - Bugs4Q at data/bugs4q/Bugs4Q-Database/**/buggy.py + fixed.py|fix.py
#   - Ollama available (HTTP or CLI fallback)
#   - pip: pandas numpy matplotlib sentence-transformers requests tqdm pytest

%pip install -q pandas numpy matplotlib sentence-transformers requests tqdm

import os, re, json, math, difflib, shutil, subprocess, sys, ast, random, traceback
import numpy as np, pandas as pd, matplotlib.pyplot as plt
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional, Any
from hashlib import md5
from tqdm import tqdm

# ------------------------------- USER KNOBS -------------------------------
DATA_PERCENT     = 100   # MUST be 1 or 100 (applied after split, to VAL only)
BEST_CONFIG_PATH = "results/qeval_ablation_plus/best_config.txt"  # <- train-only best config recommended

# Donor policy knobs (leak-free):
ALLOW_TRAIN_DONORS            = True    # allow cross-case donors only from TRAIN
EXCLUDE_TRAIN_DONOR_CHANGED   = True    # exclude TRAIN donors whose window overlaps their own gold-changed lines

# ------------------------------- PATHS / CONFIG -------------------------------
DB_ROOT   = Path("data/bugs4q/Bugs4Q-Database")
OUT_DIR   = Path("results/grap_vs_llm_deep"); OUT_DIR.mkdir(parents=True, exist_ok=True)
WORK_DIR  = Path(".work/grap_vs_llm_deep");   WORK_DIR.mkdir(parents=True, exist_ok=True)

# Retrieval / selection
TOPK             = 2
OVERRETRIEVE     = 80
RERANK_MODEL     = "cross-encoder/ms-marco-MiniLM-L-6-v2"

# LLM (Ollama) for the agent run
OLLAMA_URL      = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434").rstrip("/")
MODEL_REWRITE   = os.environ.get("OLLAMA_MODEL_REWRITE", "llama3.1:8b")
MODEL_PATCH     = os.environ.get("OLLAMA_MODEL_PATCH",   "qwen2.5-coder:14b-instruct")
REW_FALLBACKS   = ["llama3.1:8b", "mistral:7b-instruct"]
PATCH_FALLBACKS = ["qwen2.5-coder:14b-instruct","qwen2.5-coder:7b-instruct","deepseek-coder:6.7b-instruct","mistral:7b-instruct"]
NUM_CTX_REWRITE = int(os.environ.get("NUM_CTX_REWRITE", "8192"))
NUM_CTX_PATCH   = int(os.environ.get("NUM_CTX_PATCH",  "12288"))
TEMP_REWRITE    = float(os.environ.get("TEMP_REWRITE", "0.2"))
TEMP_PATCH      = float(os.environ.get("TEMP_PATCH",   "0.0"))
AUTO_PULL       = True
ALLOW_CLI_FALLBACK = True
MAX_REFINES     = 2
PYTEST_TIMEOUT  = 90

SEED = 7
random.seed(SEED); np.random.seed(SEED)
plt.rcParams["figure.dpi"] = 150
plt.rcParams.update({"axes.spines.top": False, "axes.spines.right": False})

# ------------------------------- TEXT/UTILITY -------------------------------
WORD_RE   = re.compile(r"[A-Za-z_][A-Za-z_0-9]*")
STOPWORDS = set("a an and are as at be by for from has have in is it its of on or that the to was were will with not this self none true false return def class if elif else try except finally while for".split())
Q_TOKENS  = set("""
x y z h s sdg t tdg rx ry rz rzz rzx rxy sx cx ccx cnot cz swap cswap iswap ecr u u1 u2 u3
measure barrier qreg creg backend provider aer terra pulse schedule bind assign_parameters
QuantumCircuit QuantumRegister ClassicalRegister Parameter ParameterVector
DAGCircuit PassManager layout mapper transpile basis_gates optimization_level qasm dag layout pass
CouplingMap AncillaAllocation NoiseModel Calibrations LayoutPass Unroller
""".split())

def safe_read(p: Path) -> str:
    try: return p.read_text(encoding="utf-8", errors="replace")
    except Exception: return ""

def tokenize(s: str) -> List[str]:
    return [w.lower() for w in WORD_RE.findall(s) if w and w.lower() not in STOPWORDS]

def changed_lines_in_A(a_text: str, b_text: str) -> set[int]:
    a = a_text.splitlines(); b = b_text.splitlines()
    sm = difflib.SequenceMatcher(None, a, b, autojunk=False)
    touched=set()
    for tag,i1,i2,j1,j2 in sm.get_opcodes():
        if tag in ("replace","delete"): touched.update(range(i1+1, i2+1))
    return touched

def dcg(scores): return sum(s/ math.log2(i+2) for i,s in enumerate(scores))
def ecdf(arr): arr=np.asarray(arr,float); arr=arr[~np.isnan(arr)]; x=np.sort(arr); y=np.arange(1,len(x)+1)/max(1,len(x)); return x,y

# ------------------------------- DATASET -------------------------------
def iter_cases(db_root: Path):
    for buggy in db_root.rglob("buggy.py"):
        d = buggy.parent
        fixed=None
        for nm in ("fixed.py","fix.py"):
            p=d/nm
            if p.exists(): fixed=p; break
        if fixed is None: continue
        cid=str(d.relative_to(db_root)).replace(os.sep,"/")
        yield cid, d, Path(buggy), Path(fixed)

def top_tokens_query_from_text(text: str, k: int = 6) -> str:
    toks=[w.lower() for w in WORD_RE.findall(text) if w and w.lower() not in STOPWORDS]
    from collections import Counter
    c=Counter(toks)
    for w in ("def","class","import","return","from","if","else","raise","assert","self"): c[w]=0
    for t in list(Q_TOKENS)[:20]: c[t] *= 2
    return " ".join([w for w,_ in c.most_common(k)])

# ------------------------------- CHUNKING / INDEX -------------------------------
@dataclass
class CodeChunk:
    chunk_id: str; repo_key: str; file_path: str
    start_line: int; end_line: int; symbol: str; kind: str; text: str

class ASTChunker:
    def __init__(self, window_fallback=80, window_overlap=10):
        self.window_fallback=window_fallback; self.window_overlap=window_overlap
    def chunk_file(self, case_dir: Path, file_path: Path, repo_key: str) -> List[CodeChunk]:
        rel = str(file_path.relative_to(case_dir))
        src = safe_read(file_path); lines = src.splitlines()
        try: root = ast.parse(src)
        except Exception: root = None
        chunks=[]
        def add(s,e,sym,kind):
            s=max(1,int(s)); e=max(s,int(e))
            chunks.append(CodeChunk(
                chunk_id = md5(f"{rel}:{s}-{e}".encode()).hexdigest()[:12],
                repo_key = repo_key, file_path=rel, start_line=s, end_line=e,
                symbol=sym, kind=kind, text="\n".join(lines[s-1:e])
            ))
        if root is not None:
            for node in ast.walk(root):
                if isinstance(node,(ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                    s=getattr(node,"lineno",1); e=getattr(node,"end_lineno",s); sym=getattr(node,"name","<sym>")
                    add(s,e,sym,"class" if isinstance(node,ast.ClassDef) else "function")
        if not chunks:
            step=self.window_fallback-self.window_overlap; i=0; n=len(lines)
            while i < n:
                s=i+1; e=min(i+self.window_fallback, n); add(s,e,"<module>","module"); i+=step
        return chunks

class _MiniBM25:
    def __init__(self, docs):
        from collections import Counter
        self.docs=docs; self.N=len(docs); self.lens=[len(d) for d in docs]
        self.avg = sum(self.lens)/max(1,self.N)
        df=Counter()
        for d in docs: df.update(set(d))
        self.df=dict(df)
    def idf(self,t):
        df=self.df.get(t,0)
        return 0.0 if df==0 else math.log(1+(self.N-df+0.5)/(df+0.5))
    def score(self, q, doc, dl):
        k1,b=1.5,0.75; from collections import Counter
        f=Counter(doc); s=0.0
        for t in q:
            if t not in self.df: continue
            tf=f.get(t,0)
            if tf==0: continue
            denom=tf+k1*(1-b+b*dl/max(1,self.avg))
            s+=self.idf(t)*(tf*(k1+1))/denom
        return s

class HybridIndex:
    """BM25 with optional quantum-token boost via additive term."""
    def __init__(self, boost_map: Optional[Dict[str,float]]=None, include_paths: bool=False):
        self.boost_map = {k.lower(): float(v) for k,v in (boost_map or {}).items()}
        self.include_paths = include_paths
        self.records=[]; self.docs=[]; self.bm25=None
    def build(self, chunks: List[CodeChunk]):
        self.records=[]; self.docs=[]
        for c in chunks:
            header = f"{c.symbol} {c.kind} "
            if self.include_paths: header += c.file_path + " "
            toks = tokenize(header + "\n" + c.text)
            boost_sum = sum(self.boost_map.get(t, 0.0) for t in toks)
            self.records.append({"chunk":c, "tokens":toks, "boost_sum": float(boost_sum)})
            self.docs.append(toks)
        self.bm25 = _MiniBM25(self.docs)
    def search(self, query: str, topk: int = 10):
        q = tokenize(query)
        scored=[]
        for i, rec in enumerate(self.records):
            s = self.bm25.score(q, rec["tokens"], len(rec["tokens"]))
            s += 0.02 * rec.get("boost_sum", 0.0)
            scored.append((s,i))
        scored.sort(reverse=True)
        out=[]
        for s,i in scored[:topk]:
            c = self.records[i]["chunk"]
            out.append({
                "score": float(s), "re_score": 0.0,
                "file": c.file_path, "symbol": c.symbol, "kind": c.kind,
                "start": int(c.start_line), "end": int(c.end_line),
                "preview": "\n".join(c.text.splitlines()[:120]),
            })
        return out

def quantum_boost_map(alpha: float = 1.8) -> Dict[str, float]:
    return {t.lower(): alpha for t in Q_TOKENS}

# ------------------------------- RERANK -------------------------------
class CrossEncoderReranker:
    def __init__(self, model_name: str):
        try:
            from sentence_transformers import CrossEncoder
            self.model = CrossEncoder(model_name)
            self.enabled=True
        except Exception as e:
            print("[WARN] CrossEncoder unavailable:", e)
            self.model=None; self.enabled=False
    def score_pairs(self, pairs: List[Tuple[str,str]]) -> np.ndarray:
        if not self.enabled: return np.zeros(len(pairs))
        return np.asarray(self.model.predict(pairs), dtype=float)

def apply_rerank(query: str, pool_u: List[Dict], rr: Optional[CrossEncoderReranker]):
    if rr is None or not rr.enabled: return pool_u
    pairs=[(query, h.get("preview","")) for h in pool_u]
    scores=rr.score_pairs(pairs)
    for h,s in zip(pool_u, scores): h["re_score"]=float(s)
    return sorted(pool_u, key=lambda r: r.get("re_score",0.0), reverse=True)

# ------------------------------- SELECTORS / PRIORS -------------------------------
def syntax_prior_of(hit: Dict) -> float:
    txt = (hit.get("preview","") + " " + hit.get("symbol","")).lower()
    prior = 0.0
    if any(t in txt for t in ["assert","raise","error","exception"]): prior += 0.10
    if any(t.lower() in txt for t in Q_TOKENS):                       prior += 0.15
    if re.search(r'\b(run|apply)\b', txt):                             prior += 0.12
    if "dag" in txt or "layout" in txt:                                prior += 0.08
    return prior

def apply_syntax_prior(pool_u: List[Dict], alpha: float = 0.5):
    out=[]
    for h in pool_u:
        sp = syntax_prior_of(h)
        base = h.get("re_score", h.get("score", 0.0))
        h2 = dict(h); h2["syn_prior"] = sp
        h2["score"] = base * (1.0 + alpha*sp)
        out.append(h2)
    return sorted(out, key=lambda r: r.get("score",0.0), reverse=True)

def select_by_coverage_balanced(pool_u, topk, w_gain=0.8, w_base=1.0, w_rerank=1.5,
                                w_div_file=0.15, w_div_sym=0.10, pen_overlap=0.10):
    sel, covered = [], set()
    seen_files, seen_syms = set(), set()
    base = np.array([h.get("score",0.0) for h in pool_u], dtype=float)
    bn   = (base - base.min()) / (base.max() - base.min() + 1e-9)
    rn   = np.array([h.get("re_score",0.0) for h in pool_u], dtype=float)
    for h,b,r in zip(pool_u, bn, rn):
        h["_bn"]=float(b); h["_rn"]=float(r)
    for _ in range(min(topk, len(pool_u))):
        best, best_score=None, -1e9
        for h in pool_u:
            if h in sel: continue
            rng=set(range(h["start"], h["end"]+1))
            gain=len(rng - covered)
            size=max(1, h["end"]-h["start"]+1)
            gain_norm=gain/size
            overlap_frac=1.0 - gain_norm
            s  = w_gain*gain_norm + w_base*h["_bn"] + w_rerank*h["_rn"]
            s += (w_div_file if h["file"] not in seen_files else 0.0)
            s += (w_div_sym  if h["symbol"] not in seen_syms else 0.0)
            s -= pen_overlap*overlap_frac
            if s > best_score: best, best_score = h, s
        if best is None: break
        sel.append(best)
        covered |= set(range(best["start"], best["end"]+1))
        seen_files.add(best["file"]); seen_syms.add(best["symbol"])
    return sel

def select_by_coverage_old(hits, topk, w_new_file=10.0, w_new_symbol=6.0, w_rerank=2.0):
    selected, covered = [], set()
    seen_files, seen_symbols = set(), set()
    pool = hits[:]
    for _ in range(min(topk, len(pool))):
        best, best_score = None, -1.0
        for h in pool:
            if h in selected: continue
            rng = set(range(h["start"], h["end"] + 1))
            gain = len(rng - covered)
            tie  = h.get("re_score", h.get("score", 0.0))
            s = gain + (w_new_file if h["file"] not in seen_files else 0.0) \
                     + (w_new_symbol if h["symbol"] not in seen_symbols else 0.0) \
                     + (w_rerank * tie)
            if s > best_score:
                best, best_score = h, s
        if best is None: break
        selected.append(best)
        covered |= set(range(best["start"], best["end"] + 1))
        seen_files.add(best["file"]); seen_symbols.add(best["symbol"])
    return selected

# ------------------------------- BUILD CHUNKS & INDICES -------------------------------
all_chunks_ast, all_chunks_win, meta = [], [], {}
chunker = ASTChunker()

for cid, case_dir, bug_f, fix_f in iter_cases(DB_ROOT):
    for ch in chunker.chunk_file(case_dir, bug_f, repo_key=cid):
        ch.file_path = f"{cid}/{ch.file_path}"
        all_chunks_ast.append(ch)
    text = safe_read(bug_f); lines=text.splitlines()
    win, overlap = 80, 10; step=max(1,win-overlap); i=0
    while i < len(lines):
        s=i+1; e=min(i+win, len(lines))
        all_chunks_win.append(
            CodeChunk(
                chunk_id=md5(f"{cid}/{bug_f.name}:{s}-{e}".encode()).hexdigest()[:12],
                repo_key=cid, file_path=f"{cid}/{bug_f.name}",
                start_line=s, end_line=e, symbol=f"<win@{s}-{e}>", kind="module",
                text="\n".join(lines[s-1:e])
            )
        )
        i+=step
    bug_txt = text; fix_txt = safe_read(fix_f)
    meta[cid] = {"gold": changed_lines_in_A(bug_txt, fix_txt),
                 "query": top_tokens_query_from_text(bug_txt, k=6),
                 "project": cid.split("/")[0],
                 "paths": {"bug": bug_f, "fix": fix_f}}

def build_index(chunks, use_boost: bool):
    boost = quantum_boost_map(1.8) if use_boost else {}
    try:
        idx = HybridIndex(boost_map=boost, include_paths=False)
    except TypeError:
        idx = HybridIndex()
    idx.build(chunks); return idx

idx_ast_base   = build_index(all_chunks_ast, use_boost=False)
idx_ast_q      = build_index(all_chunks_ast, use_boost=True)
idx_win_base   = build_index(all_chunks_win, use_boost=False)
idx_win_q      = build_index(all_chunks_win, use_boost=True)

# ------------------------------- SPLIT: 70/25/5 (train/val/test) -------------------------------
ALL_CASES = sorted(meta.keys())
# stable hashed order
order = [ (c, md5(c.encode()).hexdigest()) for c in ALL_CASES ]
order.sort(key=lambda t: t[1])
ordered_cases = [c for c,_ in order]

n = len(ordered_cases)
n_train = int(round(0.70 * n))
n_val   = int(round(0.25 * n))
n_test  = max(0, n - n_train - n_val)  # ~5%

TRAIN_CIDS = set(ordered_cases[:n_train])
VAL_CIDS   = ordered_cases[n_train:n_train+n_val]
TEST_CIDS  = ordered_cases[n_train+n_val:]

with open(OUT_DIR/"splits_70_25_5.json","w",encoding="utf-8") as f:
    json.dump({"n":n,"train":len(TRAIN_CIDS),"val":len(VAL_CIDS),"test":len(TEST_CIDS),
               "train_ids":sorted(TRAIN_CIDS),"val_ids":VAL_CIDS,"test_ids":TEST_CIDS}, f, indent=2)

print(f"Split -> train={len(TRAIN_CIDS)} | val={len(VAL_CIDS)} | test={len(TEST_CIDS)}")

# Validation subset size knob (optional)
if DATA_PERCENT not in (1, 100):
    raise ValueError("DATA_PERCENT must be either 1 or 100.")
num_keep_val = max(1, int(math.ceil(len(VAL_CIDS)*DATA_PERCENT/100.0)))
VAL_CIDS = [c for c,_h in sorted(((c, md5(c.encode()).hexdigest()) for c in VAL_CIDS),
                                 key=lambda t:t[1])][:num_keep_val]
print(f"Validation cases usable: {len(VAL_CIDS)} (DATA_PERCENT={DATA_PERCENT})")

# ------------------------------- READ BEST CONFIG -------------------------------
BEST_CONFIG = Path(BEST_CONFIG_PATH).read_text(encoding="utf-8").strip().splitlines()[0]
print("BEST_CONFIG (from file):", BEST_CONFIG, "(expect this to be TRAIN-only)")

def parse_cfg_name(name: str):
    parts=name.split("__")
    return parts[0], (parts[1]=="hint"), parts[2], (parts[3]=="rerank"), (parts[4]=="syntax")

def pick_index(chunking: str):
    return {"AST_base": idx_ast_base, "AST_q": idx_ast_q, "WIN_base": idx_win_base, "WIN_q": idx_win_q}[chunking]

def select_fn_from_name(selector: str):
    return select_by_coverage_old if selector=="old" else select_by_coverage_balanced

best_chunking, best_hints, best_selector, best_rerank, best_syntax = parse_cfg_name(BEST_CONFIG)
best_index  = pick_index(best_chunking)
best_select = select_fn_from_name(best_selector)
rr_global   = CrossEncoderReranker(RERANK_MODEL) if best_rerank else None
if rr_global is not None and not rr_global.enabled: rr_global=None

# ------------------------------- FOCUS (tighten spans) -------------------------------
FOCUS_MAX = 24; FOCUS_PAD = 3
FOCUS_PAT = re.compile(
    r"(assert|raise|error|exception|todo|fixme|bug|fail|"
    r"cx|rz|swap|measure|quantumcircuit|dagcircuit|layout|transpile|run\(|apply\()",
    re.I
)
def focus_span(hit: Dict, full_path: Path) -> Tuple[int,int,List[int]]:
    s, e = int(hit["start"]), int(hit["end"])
    try:
        lines = full_path.read_text(encoding="utf-8", errors="replace").splitlines()
    except Exception:
        return s, e, []
    seg = lines[s-1:e]
    matches = [i for i,ln in enumerate(seg, start=s) if FOCUS_PAT.search(ln)]
    if not matches:
        mid = (s+e)//2
        return max(1, mid - FOCUS_MAX//2), min(len(lines), max(1, mid - FOCUS_MAX//2) + FOCUS_MAX - 1), []
    lo = max(1, min(matches) - FOCUS_PAD)
    hi = min(len(lines), max(matches) + FOCUS_PAD)
    if hi - lo + 1 > FOCUS_MAX:
        hi = lo + FOCUS_MAX - 1
    return lo, hi, [m for m in matches if lo <= m <= hi]

# ------------------------------- LLM INFRA -------------------------------
import requests
def run(cmd, **kw): return subprocess.run(cmd, text=True, capture_output=True, **kw)
def have_ollama_cli():
    try: return run(["ollama","--version"], timeout=5).returncode==0
    except Exception: return False
def _to_prompt(msgs):
    system=[]; convo=[]
    for m in msgs:
        role=(m.get("role") or "user").lower(); content=m.get("content") or ""
        if role=="system": system.append(content.strip())
        elif role=="user":  convo.append(f"USER:\n{content}\n")
        else:               convo.append(f"ASSISTANT:\n{content}\n")
    return ("\n".join(system).strip(), "".join(convo)+"ASSISTANT:\n")
def _http_json(url, payload, timeout=180):
    r=requests.post(url, json=payload, timeout=timeout); r.raise_for_status(); return r.json()
def _ollama_cli(msgs, model, temperature=0.2, num_ctx=8192, timeout=180):
    sys_txt, prompt = _to_prompt(msgs)
    env=os.environ.copy(); env["OLLAMA_NUM_CTX"]=str(num_ctx)
    p=run(["ollama","run", model, prompt], timeout=timeout, env=env)
    if p.returncode!=0: raise RuntimeError(p.stderr)
    return p.stdout.strip()
def ollama_chat(msgs, *, model, temperature, num_ctx, timeout=180):
    try:
        data=_http_json(f"{OLLAMA_URL}/api/chat", {"model":model,"messages":msgs,"stream":False,"options":{"temperature":temperature,"num_ctx":num_ctx}}, timeout=timeout)
        return data.get("message",{}).get("content") or data.get("response","") or "".join(m.get("content","") for m in data.get("messages",[]))
    except Exception: pass
    try:
        sys_txt, prompt=_to_prompt(msgs)
        payload={"model":model,"prompt":prompt,"stream":False,"options":{"temperature":temperature,"num_ctx":num_ctx}}
        if sys_txt: payload["system"]=sys_txt
        data=_http_json(f"{OLLAMA_URL}/api/generate", payload, timeout=timeout)
        return data.get("response","")
    except Exception: pass
    if ALLOW_CLI_FALLBACK and have_ollama_cli():
        return _ollama_cli(msgs, model=model, temperature=temperature, num_ctx=num_ctx, timeout=timeout)
    raise RuntimeError("Ollama not reachable (API and CLI failed).")

# ------------------------------- PROMPTS -------------------------------
REWRITE_SYS = (
    "You are a software search assistant. Produce 3–8 SHORT queries (<=6 words) "
    "to retrieve the buggy code. Prefer function/class names, module names, error keywords, "
    "and quantum terms (cx, rz, swap, dag, layout, qasm, QuantumCircuit, DAGCircuit) only if relevant. "
    "Return JSON: {'queries':['...']}. No prose."
)
PATCH_SYS = (
    "You are a senior Python engineer. Return STRICT JSON ONLY:\n"
    "{'edits':[{'file':'<rel path>','start':<int 1-based>,'end':<int>,'replacement':'<new full text lines start..end>'}],"
    " 'rationale':'<one paragraph>'}\n"
    "HARD CONSTRAINTS:\n"
    " • Edit ONLY within the allowed line ranges provided.\n"
    " • Do NOT add new files; keep imports unless the context explicitly requires a change.\n"
    " • Keep changes minimal; preserve public APIs.\n"
    "QUANTUM GUARDRAILS:\n"
    " • Preserve qubit order and register semantics; do not swap classical/quantum registers.\n"
    " • Do not change pass interfaces (e.g., run(self, dag)).\n"
    " • Do not silently alter layout or coupling behavior.\n"
    "JSON only. No code fences."
)
def extract_json(s: str) -> dict:
    m=re.search(r"```json\s*(\{.*?\})\s*```", s, re.S)
    raw=m.group(1) if m else s.strip()
    return json.loads(raw)

# ------------------------------- GUARDRAILS (deterministic checks) -------------------------------
def _ast_ok(src: str) -> Tuple[bool, str]:
    try:
        ast.parse(src); return True, ""
    except SyntaxError as e:
        return False, f"SyntaxError: {e.msg} at line {e.lineno}"

def _find_registers(src: str):
    q_regs=set(); c_regs=set()
    for m in re.finditer(r'(\w+)\s*=\s*QuantumRegister\(', src): q_regs.add(m.group(1))
    for m in re.finditer(r'(\w+)\s*=\s*ClassicalRegister\(', src): c_regs.add(m.group(1))
    return q_regs, c_regs

def _pass_interface_ok(before_src: str, after_src: str) -> Tuple[bool,str]:
    def sigs(s):
        out=set()
        try:
            t=ast.parse(s)
            for n in ast.walk(t):
                if isinstance(n, ast.FunctionDef) and n.name=="run":
                    out.add(tuple(a.arg for a in n.args.args))
        except Exception: pass
        return out
    b=sigs(before_src); a=sigs(after_src)
    if not b: return True, ""
    if b != a: return False, f"Pass interface changed: {b} -> {a}"
    return True, ""

def _no_reg_mix_ok(src: str) -> Tuple[bool,str]:
    q_regs, c_regs = _find_registers(src)
    for m in re.finditer(r'measure\s*\(\s*([A-Za-z_]\w*)', src):
        if m.group(1) in c_regs: return False, f"measure() uses classical register '{m.group(1)}' as quantum"
    for m in re.finditer(r'(cx|cz|rz|rx|ry|swap)\s*\(\s*([A-Za-z_]\w*)', src):
        if m.group(2) in c_regs: return False, f"{m.group(1)}() uses classical register '{m.group(2)}' as quantum"
    return True, ""

def _qubit_order_heuristic_ok(before_src: str, after_src: str, edited_ranges: List[Tuple[int,int]]) -> Tuple[bool,str]:
    b_lines=before_src.splitlines(); a_lines=after_src.splitlines()
    def slice_lines(lines, ranges):
        out=[]
        for s,e in ranges:
            s=max(1,s); e=min(len(lines), max(s,e))
            out.extend(lines[s-1:e])
        return "\n".join(out)
    b=slice_lines(b_lines, edited_ranges)
    a=slice_lines(a_lines, edited_ranges)
    if re.search(r'\bq\[\s*1\s*\]\s*,\s*q\[\s*0\s*\]', a) and re.search(r'\bq\[\s*0\s*\]\s*,\s*q\[\s*1\s*\]', b):
        return False, "Potential qubit order swap in edited lines"
    return True, ""

def guardrail_validate_patch(bug_file: Path, edits: List[Dict]) -> Tuple[bool, List[str]]:
    before = safe_read(bug_file)
    after  = before.splitlines()
    ranges=[]
    for e in edits or []:
        s=max(1,int(e.get("start",1))); en=int(e.get("end",s))
        replacement = str(e.get("replacement","")).splitlines()
        after = after[:s-1] + replacement + after[en:]
        ranges.append((s,en))
    after_src = "\n".join(after)
    oks=[]; msgs=[]
    ok,msg = _ast_ok(after_src); oks.append(ok);  (not ok) and msgs.append(msg)
    ok,msg = _pass_interface_ok(before, after_src); oks.append(ok); (not ok) and msgs.append(msg)
    ok,msg = _no_reg_mix_ok(after_src); oks.append(ok); (not ok) and msgs.append(msg)
    ok,msg = _qubit_order_heuristic_ok(before, after_src, ranges); oks.append(ok); (not ok) and msgs.append(msg)
    return all(oks), msgs

# ------------------------------- DONOR FILTER (leak-free) -------------------------------
def _case_from_hitfile(path_str: str) -> Optional[str]:
    # file looks like: "<cid>/buggy.py" or "<cid>/<relpath>"
    if not path_str: return None
    parts = path_str.split("/")
    return "/".join(parts[:2]) if len(parts) >= 2 else None

def donor_is_allowed_for_case(hit: Dict, current_cid: str) -> bool:
    """Allow same-case always. For cross-case, allow TRAIN only, and optionally exclude donor windows overlapping their own changes."""
    donor_cid = _case_from_hitfile(hit.get("file",""))
    if donor_cid is None: return False
    if donor_cid == current_cid:
        return True
    if not ALLOW_TRAIN_DONORS:
        return False
    if donor_cid not in TRAIN_CIDS:
        return False
    if not EXCLUDE_TRAIN_DONOR_CHANGED:
        return True
    # exclude if donor window overlaps donor's gold
    gold = meta.get(donor_cid,{}).get("gold", set())
    s,e = int(hit.get("start",1)), int(hit.get("end",1))
    return not any((ln in gold) for ln in range(s, e+1))

# ------------------------------- EDIT HELPERS & RUN -------------------------------
def enforce_in_region(edits: List[Dict], allowed: List[Tuple[int,int]]) -> List[Dict]:
    ok=[]
    for e in edits or []:
        st=int(e.get("start",1)); en=int(e.get("end",st)); repl=e.get("replacement","")
        for (a,b) in allowed:
            if st>=a and en<=b:
                ok.append({"file":e.get("file","buggy.py"), "start":st, "end":en, "replacement":repl})
                break
    return ok

def apply_edits(src_repo: Path, edits: List[Dict], out_repo: Path) -> Path:
    if out_repo.exists(): shutil.rmtree(out_repo)
    shutil.copytree(src_repo, out_repo)
    p=out_repo/"buggy.py"
    if not p.exists(): return out_repo
    lines=p.read_text(encoding="utf-8",errors="replace").splitlines()
    for e in edits or []:
        st, en = max(1,int(e["start"])), min(len(lines), int(e["end"]))
        new = lines[:st-1] + str(e.get("replacement","")).splitlines() + lines[en:]
        lines = new
    p.write_text("\n".join(lines), encoding="utf-8")
    return out_repo

def run_pytest(path: Path, timeout=PYTEST_TIMEOUT):
    try:
        p = subprocess.run([sys.executable, "-m", "pytest", "-q"], cwd=path, text=True, capture_output=True, timeout=timeout)
        return p.returncode, (p.stdout or "") + "\n" + (p.stderr or "")
    except Exception as e:
        return 99, f"(pytest error) {e}"

def last_failing_assert(trace: str) -> str:
    tail = "\n".join(trace.splitlines()[-120:])
    m = re.search(r"(E\s+AssertionError[^\n]*\n(?:[^\n]*\n){0,6})", tail)
    return (m.group(1).strip() if m else tail[-400:].strip())

def llm_rewrite_queries(seed_query: str) -> List[str]:
    msgs=[{"role":"system","content":REWRITE_SYS},
          {"role":"user","content":json.dumps({"seed_query":seed_query, "rules":["<=6 words/query","no quotes/paths"]})}]
    out=ollama_chat(msgs, model=MODEL_REWRITE, temperature=TEMP_REWRITE, num_ctx=NUM_CTX_REWRITE)
    try:
        obj=extract_json(out)
        if isinstance(obj, dict) and "queries" in obj: return [q.strip() for q in obj["queries"] if isinstance(q,str) and q.strip()]
        if isinstance(obj, list): return [q.strip() for q in obj if isinstance(q,str) and q.strip()]
    except Exception:
        pass
    return [q.strip("-• ").strip() for q in out.splitlines() if q.strip()][:6]

def llm_patch_once(cid: str, focused_ctx: List[Dict], allowed_ranges: List[Tuple[int,int]], extra_feedback: str = "") -> dict:
    payload={"case":cid,
             "allowed_ranges":allowed_ranges,
             "context":focused_ctx,
             "instruction":"Return strict JSON only. No markdown fences.",
             "feedback": extra_feedback}
    msgs=[{"role":"system","content":PATCH_SYS},
          {"role":"user","content":json.dumps(payload)}]
    out=ollama_chat(msgs, model=MODEL_PATCH, temperature=TEMP_PATCH, num_ctx=NUM_CTX_PATCH)
    try:
        return extract_json(out)
    except Exception:
        msgs.append({"role":"system","content":"Your previous output was not valid JSON. Return ONLY JSON now."})
        out2=ollama_chat(msgs, model=MODEL_PATCH, temperature=0.0, num_ctx=NUM_CTX_PATCH)
        return extract_json(out2)

# ------------------------------- SCORING / DIAGNOSTICS -------------------------------
def evaluate_candidate(bug_repo: Path, fix_repo: Path, cand_repo: Optional[Path]) -> Dict[str, Any]:
    a = safe_read(bug_repo/"buggy.py").splitlines()
    b = safe_read(fix_repo/"buggy.py").splitlines()
    c = safe_read(cand_repo/"buggy.py").splitlines() if cand_repo and (cand_repo/"buggy.py").exists() else []
    def _touched(x,y):
        sm = difflib.SequenceMatcher(None, x, y, autojunk=False)
        touched=set()
        for tag,i1,i2,j1,j2 in sm.get_opcodes():
            if tag in ("replace","delete"):
                touched.update(range(i1+1,i2+1))
        return touched
    gold=_touched(a,b); pred=_touched(a,c) if c else set()
    inter=len(gold & pred)
    lp = inter / max(1,len(pred))
    lr = inter / max(1,len(gold))
    lf = 0.0 if lp+lr==0 else 2*lp*lr/(lp+lr)
    return {"lines_p":lp,"lines_r":lr,"lines_f1":lf}

def count_lines_edited(bug_repo: Path, edits: List[Dict]) -> Tuple[int,int]:
    src_lines = safe_read(bug_repo/"buggy.py").splitlines()
    touched=0; delta=0
    for e in edits or []:
        st=max(1,int(e.get("start",1))); en=int(e.get("end",st))
        repl = str(e.get("replacement","")).splitlines()
        old_len = en-st+1
        touched += max(0, old_len)
        delta += abs(len(repl) - old_len)
    return touched, delta

def api_drift_score(before: str, after: str) -> float:
    def names(s):
        try:
            t=ast.parse(s); out=set()
            for n in ast.walk(t):
                if isinstance(n, ast.FunctionDef): out.add(("fun", n.name, len(n.args.args)))
                if isinstance(n, ast.ClassDef):    out.add(("cls", n, 0))
            return out
        except Exception:
            return set()
    b=names(before); a=names(after)
    if not b and not a: return 0.0
    j = len(b & a)/max(1,len(b | a))
    return 1.0 - j

def identifier_jaccard(before: str, after: str) -> float:
    B=set(tokenize(before)); A=set(tokenize(after))
    if not (A or B): return 1.0
    return len(A & B)/max(1,len(A | B))

def distortion_flags(bug_repo: Path, edits: List[Dict], cand_repo: Optional[Path], lines_f1: float) -> Dict[str, Any]:
    before = safe_read(bug_repo/"buggy.py")
    after  = safe_read(cand_repo/"buggy.py") if cand_repo and (cand_repo/"buggy.py").exists() else ""
    ast_ok, _ = _ast_ok(after) if after else (False,"")
    drift    = api_drift_score(before, after) if after else np.nan
    jacc     = identifier_jaccard(before, after) if after else np.nan
    touched, delta = count_lines_edited(bug_repo, edits)
    excessive_no_gain = (lines_f1==0.0 and delta>=5)
    flags = {
        "ast_parse_fail": (not ast_ok),
        "api_drift_gt40": bool(drift!=drift and False or (drift>0.40)),
        "id_jacc_lt60":   bool(jacc!=jacc and False or (jacc<0.60)),
        "excessive_no_gain": excessive_no_gain,
        "drift": float(drift if drift==drift else np.nan),
        "id_jacc": float(jacc if jacc==jacc else np.nan),
        "delta_abs_lines": int(delta),
        "lines_touched": int(touched)
    }
    return flags

# ------------------------------- GRAP-Q AGENT RUN (VAL only) -------------------------------
def run_grap_once(case_ids: List[str],label: str = "VAL"):
    rows=[]; logs_all=[]
    for cid in tqdm(case_ids, desc="[GRAP-Q|VAL] cases"):
        q0 = meta[cid]["query"]
        q  = (q0 + " cx rz dag") if best_hints else q0
        # retrieval pool (global index, then donor filter)
        pool = best_index.search(q, topk=max(OVERRETRIEVE, 6*TOPK))
        pool = [h for h in pool if donor_is_allowed_for_case(h, cid)]
        # rerank
        pool = apply_rerank(q, pool, rr_global)
        # syntax prior (if config says syntax)
        pool = apply_syntax_prior(pool, alpha=0.5) if best_syntax else pool
        # select
        select_fn = best_select
        selected  = select_fn(pool, TOPK)
        # focus windows
        bug_path  = meta[cid]["paths"]["bug"]
        focused_ctx=[]; allowed=[]
        for i,h in enumerate(selected,1):
            lo,hi,_ = focus_span(h, bug_path)
            allowed.append((lo,hi))
            snippet = safe_read(bug_path).splitlines()[lo-1:hi]
            focused_ctx.append({"rank":i,"file":h["file"],"span":f"{lo}-{hi}","symbol":h["symbol"],"code":"\n".join(snippet)})
        # tiny repos
        tiny_b = WORK_DIR / f"{cid.replace('/','__')}__g_bug"; tiny_f = WORK_DIR / f"{cid.replace('/','__')}__g_fix"
        if tiny_b.exists(): shutil.rmtree(tiny_b)
        if tiny_f.exists(): shutil.rmtree(tiny_f)
        tiny_b.mkdir(parents=True, exist_ok=True); tiny_f.mkdir(parents=True, exist_ok=True)
        shutil.copy(meta[cid]["paths"]["bug"], tiny_b/"buggy.py"); shutil.copy(meta[cid]["paths"]["fix"], tiny_f/"buggy.py")
        # refine loop with guardrails
        feedback=""; patch={"edits":[],"rationale":""}; cand_repo=None; guard_notes=[]; rationale_autofill=False; autofill_reason=""
        for it in range(MAX_REFINES+1):
            proposal = llm_patch_once(cid, focused_ctx, allowed, extra_feedback=feedback)
            # rationale auto-fill tracker
            if not isinstance(proposal.get("rationale",""), str) or not proposal.get("rationale","").strip():
                proposal["rationale"] = "Autofill: minimal, localized fix within allowed span; keep APIs/layout/register semantics; address failure indicated by guardrails/tests."
                rationale_autofill=True; autofill_reason="missing_or_empty"
            edits = enforce_in_region(proposal.get("edits",[]), allowed)
            ok, reasons = guardrail_validate_patch(tiny_b/"buggy.py", edits)
            if not ok:
                feedback = "Guardrail violations:\n- " + "\n- ".join(reasons) + "\nFix minimally within allowed ranges."
                guard_notes.extend(reasons)
                if it==MAX_REFINES: break
                continue
            patch={"edits":edits,"rationale":proposal.get("rationale","")}
            cand_repo = WORK_DIR / f"{cid.replace('/','__')}__g_cand"
            apply_edits(tiny_b, edits, cand_repo)
            # syntax check
            src = safe_read(cand_repo/"buggy.py")
            ok,_ = _ast_ok(src)
            if not ok:
                feedback = "Your edit produced a SyntaxError. Repair minimally."
                guard_notes.append("syntax_fail_after_apply")
                if it==MAX_REFINES: break
                continue
            # pytest run for feedback (even if tests absent)
            rc, out = run_pytest(cand_repo, timeout=PYTEST_TIMEOUT)
            if rc==0:
                break
            if rc in (5, 4):
                feedback="No runnable tests. Ensure edit compiles and is minimal."
            else:
                feedback="Last failing assertion/stack:\n"+last_failing_assert(out)
            if it==MAX_REFINES: break

        rep = evaluate_candidate(tiny_b, tiny_f, cand_repo)
        touched, delta = count_lines_edited(tiny_b, patch.get("edits",[]))
        flags = distortion_flags(tiny_b, patch.get("edits",[]), cand_repo, rep["lines_f1"])
        rows.append({
            "case": cid, "method":"GRAP", "lines_f1":rep["lines_f1"], "lines_p":rep["lines_p"], "lines_r":rep["lines_r"],
            "num_edits": len(patch.get("edits",[])), "lines_touched": touched, "delta_abs_lines": delta,
            "rationale_autofill": bool(rationale_autofill),
            **flags
        })
        logs_all.append({"case":cid,"guardrail_notes":guard_notes,"selected":selected,"allowed":allowed,
                         "patch":patch,"rationale_autofill":rationale_autofill,"autofill_reason":autofill_reason})
    df=pd.DataFrame(rows)
    df.to_csv(OUT_DIR / f"llm_results_{label.lower()}.csv", index=False)

    with open(OUT_DIR/"grap_logs_val.json","w",encoding="utf-8") as f: json.dump(logs_all, f, indent=2)
    return df

# ------------------------------- PURE LLM RUN (VAL only) -------------------------------
def run_pure_llm_once(case_ids: List[str],label: str = "VAL"):
    rows=[]; logs_all=[]
    for cid in tqdm(case_ids, desc="[Pure-LLM|VAL] cases"):
        bug_path = meta[cid]["paths"]["bug"]; fix_path = meta[cid]["paths"]["fix"]
        code = "\n".join(safe_read(bug_path).splitlines()[:220])
        ctx = [{"rank":1,"file":f"{cid}/buggy.py","span":"1-220","symbol":"<file>","code":code}]
        msgs=[{"role":"system","content":PATCH_SYS},
              {"role":"user","content":json.dumps({"case":cid,"context":ctx,"instruction":"Return strict JSON only."})}]
        rationale_autofill=False; autofill_reason=""
        try:
            out=ollama_chat(msgs, model=MODEL_PATCH, temperature=TEMP_PATCH, num_ctx=NUM_CTX_PATCH)
            patch=extract_json(out)
        except Exception as e:
            patch={"edits":[],"rationale":f"error: {e}"}
        if not isinstance(patch.get("rationale",""), str) or not patch.get("rationale","").strip():
            patch["rationale"] = "Autofill: file-level attempt based on first 220 lines; keep APIs/layout/register semantics; apply smallest plausible fix."
            rationale_autofill=True; autofill_reason="missing_or_empty"
        edits = patch.get("edits",[]) or []
        tiny_b = WORK_DIR / f"{cid.replace('/','__')}__p_bug"; tiny_f = WORK_DIR / f"{cid.replace('/','__')}__p_fix"
        if tiny_b.exists(): shutil.rmtree(tiny_b)
        if tiny_f.exists(): shutil.rmtree(tiny_f)
        tiny_b.mkdir(parents=True, exist_ok=True); tiny_f.mkdir(parents=True, exist_ok=True)
        shutil.copy(bug_path, tiny_b/"buggy.py"); shutil.copy(fix_path, tiny_f/"buggy.py")
        cand_repo=None
        if edits:
            cand_repo = WORK_DIR / f"{cid.replace('/','__')}__p_cand"
            apply_edits(tiny_b, edits, cand_repo)
        rep = evaluate_candidate(tiny_b, tiny_f, cand_repo)
        touched, delta = count_lines_edited(tiny_b, edits)
        flags = distortion_flags(tiny_b, edits, cand_repo, rep["lines_f1"])
        rows.append({
            "case": cid, "method":"LLM", "lines_f1":rep["lines_f1"], "lines_p":rep["lines_p"], "lines_r":rep["lines_r"],
            "num_edits": len(edits), "lines_touched": touched, "delta_abs_lines": delta,
            "rationale_autofill": bool(rationale_autofill),
            **flags
        })
        logs_all.append({"case":cid,"patch":patch,"rationale_autofill":rationale_autofill,"autofill_reason":autofill_reason})
    df=pd.DataFrame(rows)
    df.to_csv(OUT_DIR/"llm_results_val.csv", index=False)
    with open(OUT_DIR/"llm_logs_val.json","w",encoding="utf-8") as f: json.dump(logs_all, f, indent=2)
    return df

# ------------------------------- RUN BOTH (VAL) & MERGE -------------------------------
df_grap = run_grap_once(VAL_CIDS)
df_llm  = run_pure_llm_once(VAL_CIDS)
df_all  = pd.concat([df_grap, df_llm], ignore_index=True)
df_wide = df_all.pivot(index="case", columns="method", values="lines_f1")
df_all.to_csv(OUT_DIR/"combined_results_val.csv", index=False)

# ------------------------------- STATS & PLOTS (Validation only) -------------------------------
def savefig(path): Path(path).parent.mkdir(parents=True, exist_ok=True); plt.tight_layout(); plt.savefig(path); plt.close()
def mean_ci95(a):
    a=np.asarray(pd.to_numeric(a, errors="coerce").dropna(), float)
    if len(a)==0: return np.nan, (np.nan, np.nan)
    m=a.mean(); se=a.std(ddof=1)/np.sqrt(len(a)) if len(a)>1 else 0.0
    return m, (m-1.96*se, m+1.96*se)

# 1) Macro bar (mean ± 95% CI) Lines-F1
m_g,ci_g = mean_ci95(df_grap["lines_f1"])
m_l,ci_l = mean_ci95(df_llm["lines_f1"])
plt.figure(figsize=(6,4))
means=[m_g,m_l]; cis=[ci_g,ci_l]; xs=np.arange(2)
plt.bar(xs, means, yerr=[[means[i]-cis[i][0] for i in range(2)],[cis[i][1]-means[i] for i in range(2)]], capsize=6)
plt.xticks(xs, ["GRAP-Q","Pure-LLM"]); plt.ylim(0,1); plt.ylabel("Lines-F1"); plt.title("Macro comparison (VAL, mean ± 95% CI)")
savefig(OUT_DIR/"macro_linesf1_bar_val.png")

# 2) ECDF Lines-F1
plt.figure(figsize=(6,4))
x,y = ecdf(df_grap["lines_f1"]); plt.plot(x,y,label="GRAP-Q")
x,y = ecdf(df_llm["lines_f1"]);  plt.plot(x,y,label="Pure-LLM")
plt.xlabel("Lines-F1"); plt.ylabel("ECDF"); plt.title("Distribution of effectiveness (VAL)"); plt.legend()
savefig(OUT_DIR/"ecdf_linesf1_val.png")

# 3) Patch minimality: Δ lines edited vs Lines-F1
plt.figure(figsize=(6,4))
plt.scatter(df_grap["delta_abs_lines"], df_grap["lines_f1"], label="GRAP-Q")
plt.scatter(df_llm["delta_abs_lines"],  df_llm["lines_f1"],  label="Pure-LLM", marker="x")
plt.xlabel("Δ lines edited (abs)"); plt.ylabel("Lines-F1"); plt.title("Patch minimality vs correctness (VAL)"); plt.legend()
savefig(OUT_DIR/"scatter_minimality_val.png")

# 4) Edit efficiency: Lines-F1 per 10 edited lines
def efficiency(df): 
    d = pd.to_numeric(df["delta_abs_lines"], errors="coerce").fillna(0.0)
    return pd.to_numeric(df["lines_f1"], errors="coerce").fillna(0.0) / (d.replace(0, np.nan)/10.0)
eff_g = efficiency(df_grap); eff_l = efficiency(df_llm)
plt.figure(figsize=(6,4))
plt.boxplot([eff_g.dropna(), eff_l.dropna()], labels=["GRAP-Q","Pure-LLM"], showmeans=True)
plt.ylabel("Lines-F1 per 10 edited lines"); plt.title("Edit efficiency (VAL, higher is better)")
savefig(OUT_DIR/"box_efficiency_val.png")

# 5) Distortion rates (stacked)
def rate(df, col): s=pd.to_numeric(df[col], errors="coerce").fillna(0).astype(bool); return float(s.mean())
rates = pd.DataFrame({
    "syntax_fail":[rate(df_grap,"ast_parse_fail"), rate(df_llm,"ast_parse_fail")],
    "api_drift>0.40":[rate(df_grap,"api_drift_gt40"), rate(df_llm,"api_drift_gt40")],
    "id_jacc<0.60":[rate(df_grap,"id_jacc_lt60"), rate(df_llm,"id_jacc_lt60")],
    "excessive_no_gain":[rate(df_grap,"excessive_no_gain"), rate(df_llm,"excessive_no_gain")]
}, index=["GRAP-Q","Pure-LLM"])
bottom=np.zeros(2)
plt.figure(figsize=(7.2,4.2))
for col in rates.columns:
    plt.bar(["GRAP-Q","Pure-LLM"], rates[col].values, bottom=bottom, label=col)
    bottom += rates[col].values
plt.ylim(0,1); plt.ylabel("share of cases"); plt.title("Distortion/Failure modes (VAL, lower is better)"); plt.legend(fontsize=8, ncol=2)
savefig(OUT_DIR/"distortion_rates_stacked_val.png")

# 6) Hist overlays: Δ lines edited
plt.figure(figsize=(6,4))
plt.hist(pd.to_numeric(df_grap["delta_abs_lines"], errors="coerce"), bins=20, alpha=0.6, label="GRAP-Q")
plt.hist(pd.to_numeric(df_llm["delta_abs_lines"], errors="coerce"),  bins=20, alpha=0.6, label="Pure-LLM")
plt.xlabel("Δ lines edited (abs)"); plt.ylabel("count"); plt.title("Patch size distribution (VAL)"); plt.legend()
savefig(OUT_DIR/"hist_patch_size_val.png")

# 7) Boxplots of precision/recall
plt.figure(figsize=(8,4))
plt.subplot(1,2,1)
plt.boxplot([pd.to_numeric(df_grap["lines_p"], errors="coerce").dropna(),
             pd.to_numeric(df_llm["lines_p"],  errors="coerce").dropna()], labels=["GRAP-Q","Pure-LLM"], showmeans=True)
plt.title("Line-level Precision (VAL)")
plt.subplot(1,2,2)
plt.boxplot([pd.to_numeric(df_grap["lines_r"], errors="coerce").dropna(),
             pd.to_numeric(df_llm["lines_r"],  errors="coerce").dropna()], labels=["GRAP-Q","Pure-LLM"], showmeans=True)
plt.title("Line-level Recall (VAL)"); plt.tight_layout()
savefig(OUT_DIR/"box_precision_recall_val.png")

# 8) Win-rate (per-case head-to-head Lines-F1)
joined = df_wide.dropna()
wins  = float((joined["GRAP"] > joined["LLM"]).mean())
loss  = float((joined["GRAP"] < joined["LLM"]).mean())
ties  = float((joined["GRAP"] == joined["LLM"]).mean())
plt.figure(figsize=(6,4))
plt.bar(["GRAP better","LLM better","Tie"], [wins,loss,ties])
plt.ylim(0,1); plt.ylabel("share of cases"); plt.title("Head-to-head win-rate (VAL)")
savefig(OUT_DIR/"winrate_val.png")

# 9) Paired Lines-F1 (sorted by GRAP−LLM)
diff = (joined["GRAP"] - joined["LLM"]).sort_values()
plt.figure(figsize=(7,4))
plt.plot(range(len(diff)), diff.values)
plt.axhline(0, linestyle="--")
plt.xlabel("cases (sorted)"); plt.ylabel("GRAP − LLM (Lines-F1)"); plt.title("Per-case advantage (VAL)")
savefig(OUT_DIR/"paired_diff_curve_val.png")

# 10) API drift & identifier Jaccard distributions
plt.figure(figsize=(8,4))
plt.subplot(1,2,1)
plt.boxplot([pd.to_numeric(df_grap["drift"], errors="coerce").dropna(),
             pd.to_numeric(df_llm["drift"],  errors="coerce").dropna()], labels=["GRAP-Q","Pure-LLM"], showmeans=True)
plt.title("API drift (1−Jaccard of API, VAL)")
plt.subplot(1,2,2)
plt.boxplot([pd.to_numeric(df_grap["id_jacc"], errors="coerce").dropna(),
             pd.to_numeric(df_llm["id_jacc"],  errors="coerce").dropna()], labels=["GRAP-Q","Pure-LLM"], showmeans=True)
plt.title("Identifier Jaccard (VAL)"); plt.tight_layout()
savefig(OUT_DIR/"box_api_id_jacc_val.png")

# 11) Efficiency scatter with LSQ trend lines
def scatter_with_trend(x, y, label):
    x = pd.to_numeric(x, errors="coerce").fillna(0).values.astype(float)
    y = pd.to_numeric(y, errors="coerce").fillna(0).values.astype(float)
    plt.scatter(x, y, label=label, alpha=0.7)
    if len(x)>=2:
        A = np.vstack([x, np.ones(len(x))]).T
        m, c = np.linalg.lstsq(A, y, rcond=None)[0]
        xs = np.linspace(0, max(x)+1, 100)
        plt.plot(xs, m*xs+c)
plt.figure(figsize=(7,4))
scatter_with_trend(df_grap["delta_abs_lines"], df_grap["lines_f1"], "GRAP-Q")
scatter_with_trend(df_llm["delta_abs_lines"],  df_llm["lines_f1"],  "Pure-LLM")
plt.xlabel("Δ lines edited"); plt.ylabel("Lines-F1"); plt.title("Efficiency trendlines (VAL)"); plt.legend()
savefig(OUT_DIR/"scatter_trend_efficiency_val.png")

print("\nVALIDATION comparison artifacts saved to:", OUT_DIR.resolve())
print("BEST_CONFIG used:", BEST_CONFIG)
print("Donor policy -> TRAIN only:", ALLOW_TRAIN_DONORS, "| exclude donor-changed windows:", EXCLUDE_TRAIN_DONOR_CHANGED)



[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Note: you may need to restart the kernel to use updated packages.
Split -> train=33 | val=12 | test=2
Validation cases usable: 12 (DATA_PERCENT=100)
BEST_CONFIG (from file): WIN_base__hint__balanced__rerank__nosyntax (expect this to be TRAIN-only)


[GRAP-Q|VAL] cases: 100%|██████████| 12/12 [32:04<00:00, 160.37s/it]
[Pure-LLM|VAL] cases: 100%|██████████| 12/12 [07:57<00:00, 39.81s/it]
  plt.boxplot([eff_g.dropna(), eff_l.dropna()], labels=["GRAP-Q","Pure-LLM"], showmeans=True)
  plt.boxplot([pd.to_numeric(df_grap["lines_p"], errors="coerce").dropna(),
  plt.boxplot([pd.to_numeric(df_grap["lines_r"], errors="coerce").dropna(),
  plt.boxplot([pd.to_numeric(df_grap["drift"], errors="coerce").dropna(),
  plt.boxplot([pd.to_numeric(df_grap["id_jacc"], errors="coerce").dropna(),



VALIDATION comparison artifacts saved to: C:\Users\Alberto\Desktop\Quantum\LAST\results\grap_vs_llm_deep
BEST_CONFIG used: WIN_base__hint__balanced__rerank__nosyntax
Donor policy -> TRAIN only: True | exclude donor-changed windows: True


# RUN


In [None]:
# Diagnostic Mode
!python GRAP-Q.py \
  --mode diagnostic \
  --best_config results/qeval_ablation_plus/best_config.txt \
  --db_root data/bugs4q/Bugs4Q-Database \
  --out_dir results/infer \
  --work_dir .work/infer \
  --data_percent_test 100

## Test Mode
!python GRAP-Q.py \
  --mode test \
  --best_config results/qeval_ablation_plus/best_config.txt \
  --db_root data/bugs4q/Bugs4Q-Database \
  --out_dir results/infer \
  --work_dir .work/infer \
  --data_percent_test 100
