# MACO-Style Multi-Agent Content Optimization (Paper-Aligned)

This notebook implements the full pipeline, including frozen corpus, evaluator with MIS/ISR/MIV, iterative optimization loop, analyst/editor agents, and hybrid selector.

## 0) Setup & config

In [1]:
from dotenv import load_dotenv
load_dotenv() 

True

In [2]:
import os, json, time, hashlib, re, textwrap
from typing import List, Dict, Any, Tuple
from dataclasses import dataclass
from datetime import datetime
import sys

# NOTE: API keys and secrets are loaded from environment variables (see .env or your shell config)
# Example expected variables (DO NOT hard-code real values here):
#   GOOGLE_API_KEY="***"
#   LANGSMITH_API_KEY="***"
#   LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
#   GENSEE_API_KEY="***"

# Model choices & constants
MODEL_EVAL     = "gemini-2.5-flash"
MODEL_ANALYST  = "gemini-2.5-flash"
MODEL_EDITOR   = "gemini-2.5-flash-lite"
TEMPERATURE_EVAL    = 0.0
TEMPERATURE_ANALYST = 0.6
TEMPERATURE_EDITOR  = 0.1

N_QUERIES   = 3        # 5–10, the paper uses 10
MAX_CTX     = 3       # contexts per query, the paper uses 10
SUCCESS_TAU = 0.75     # ISR threshold
N_ITERS     = 5       # iterations; selector often picks ~, the paper uses 10
RANDOM_SEED = 42

# TODO: update the anchors, the paper uses [0,10]
ANCHORS = [0.00, 0.17, 0.33, 0.50, 0.67, 0.83, 1.00]
METRICS = ["CP","AA","FA","KC","SC","AD"]

# TODO: baseline-style labeling
# Optional: tag detection for edits (baseline-style labeling)
TAG_PATTERNS = [
    ("Statistics",    r"\b\d{1,3}(,\d{3})*(\.\d+)?\s?%|\b(?:million|billion|thousand)\b"),
    ("More Quotes",   r"[\"“][^\"”]{8,}[\"”]"),
    ("Citing Sources",r"\b(?:According to|Source:|cited by|as reported by)\b"),
    ("Technical Terms", r"\b(latency|throughput|gradient|API|OAuth|schema|vector|embedding|protocol|REST|GraphQL)\b"),
    ("Authoritative", r"\b(must|should|undoubtedly|certainly|we recommend)\b"),
    ("Fluent",        r"."),  # fallback: any edit without the above
]

# Reproducibility tweaks where applicable
import random
random.seed(RANDOM_SEED)


In [3]:
DEBUG = True

def log_heading(h: str):
    """Log a heading - both prints to console and writes to log file"""
    if DEBUG:
        print("\n" + "="*8 + " " + h + " " + "="*8)

def log_json(name: str, obj):
    """Log JSON object - both prints to console and writes to log file"""
    if DEBUG:
        print(f"\n[{name}]")
        try:
            print(json.dumps(obj, ensure_ascii=False, indent=2))
        except Exception:
            print(str(obj)[:2000])

def log_info(message: str):
    """Helper function to log info messages with timestamp"""
    timestamp = datetime.now().strftime('%H:%M:%S')
    print(f"[{timestamp}] {message}")


In [4]:
# ===== LOGGING SETUP =====
class TeeOutput:
    """Class to capture stdout/stderr and write to both console and file"""
    def __init__(self, terminal, log_file):
        self.terminal = terminal
        self.log_file = log_file
        self.file_handle = None
        self._open_file()
        
    def _open_file(self):
        """Open file handle for writing"""
        self.file_handle = open(self.log_file, 'w', encoding='utf-8')
        
    def write(self, message):
        # Write to terminal
        self.terminal.write(message)
        # Write to file (only if message is not empty)
        if message and self.file_handle:
            self.file_handle.write(message)
            self.file_handle.flush()
        
    def flush(self):
        self.terminal.flush()
        if self.file_handle:
            self.file_handle.flush()
    
    def close(self):
        """Close file handle"""
        if self.file_handle:
            self.file_handle.close()
            self.file_handle = None

def setup_logging():
    """Setup logging to timestamped file. Returns log file path."""
    now = datetime.now()
    # Create logs directory if it doesn't exist
    log_dir = "logs"
    os.makedirs(log_dir, exist_ok=True)
    
    # Create filename: YYYY_MM_DD_HH_MM.txt (e.g., 2025_01_15_14_30.txt)
    log_filename = f"{now.year:04d}_{now.month:02d}_{now.day:02d}_{now.hour:02d}_{now.minute:02d}.txt"
    log_path = os.path.join(log_dir, log_filename)
    
    # Store original stdout/stderr
    original_stdout = sys.stdout
    original_stderr = sys.stderr
    
    # Create TeeOutput instances (they will open the file)
    tee_stdout = TeeOutput(original_stdout, log_path)
    tee_stderr = TeeOutput(original_stderr, log_path)
    
    # Redirect stdout and stderr to TeeOutput
    sys.stdout = tee_stdout
    sys.stderr = tee_stderr
    
    # Write header (this will go through TeeOutput, so no duplication)
    print('='*80)
    print(f"MACO Pipeline Log - Started at {now.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Log file: {log_path}")
    print('='*80 + '\n')
    
    return log_path

# Initialize logging
LOG_FILE_PATH = setup_logging()
print(f"[LOG] All output will be saved to: {LOG_FILE_PATH}\n")



MACO Pipeline Log - Started at 2025-12-08 02:25:59
Log file: logs/2025_12_08_02_25.txt

[LOG] All output will be saved to: logs/2025_12_08_02_25.txt



🚀 Starting Experiment at 2025-12-08 02:26:00

📋 Configuration:
  - TSV Path: ../data/ariticles/target_articles.tsv
  - Articles to process: 30
  - Start index: 0
  - Iterations per article: 8
  - Generated queries per article: 8
  - Output directory: data
  - Resume from existing: True


[02:26:00] Loaded 446 valid articles from ../data/ariticles/target_articles.tsv
[02:26:00] 📊 Total articles in TSV: 446
[02:26:00] 🔍 Checking for processed queries in: data/metrics.csv
[02:26:00] 📂 Loaded 3 processed articles from: data/metrics.csv
[02:26:00] ⏭️  Skipping 3 already processed queries
[02:26:00] 📋 443 queries remaining to process
[02:26:00] 📊 Processing queries: 30 remaining

--------------------------------------------------------------------------------


📝 Article ID: Is email marketing good for small businesses? | Preview: Reddit - The heart of the internet Skip to main content Open menu Log In Go to R...


Article ID: Is email marketing good for small businesses?
Article length: 105

## 1) LLM client (LangChain Google GenAI)

In [5]:
from langchain_google_genai import ChatGoogleGenerativeAI

def make_llm(model: str, temperature: float):
    return ChatGoogleGenerativeAI(
        model=model,
        temperature=temperature,
        max_retries=0,
        max_output_tokens=8192,
        # relies on GOOGLE_API_KEY env var
    )

def call_llm_json(llm, system: str, user: str, retry: int = 1) -> Dict[str, Any]:
    """
    Call an LLM with system+user text and parse JSON output robustly.
    If schema fails, return {"__SCHEMA_ERROR__": raw_text}
    """
    msgs = [("system", system), ("human", user)]
    out = llm.invoke(msgs)
    
    # --- FIX START: Handle list content from Gemini ---
    val = getattr(out, "content", "")
    if isinstance(val, list):
        # If content is a list, join it into a single string
        text = "".join(str(v) for v in val)
    else:
        text = str(val) if val else str(out)
    # --- FIX END ---
    
    # Strip fencing if present
    text = text.strip()
    if text.startswith("```"):
        text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.S)
    try:
        return json.loads(text)
    except Exception:
        if retry > 0:
            nudged = textwrap.dedent(f"""Your previous reply was not valid JSON. Reprint ONLY strict JSON, no commentary. Original reply: {text}""")
            out2 = llm.invoke([("system", system), ("human", nudged)])
            
            # --- FIX START: Handle list content for retry as well ---
            val2 = getattr(out2, "content", "")
            if isinstance(val2, list):
                t2 = "".join(str(v) for v in val2)
            else:
                t2 = str(val2) if val2 else str(out2)
            # --- FIX END ---
            
            t2 = re.sub(r"^```(?:json)?\s*|\s*```$", "", t2.strip(), flags=re.S)
            try:
                return json.loads(t2)
            except Exception:
                return {"__SCHEMA_ERROR__": t2}
        return {"__SCHEMA_ERROR__": text}


## 2) Retrieval (Gensee AI)

In [6]:
import os
import requests
from typing import List

def gensee_ai_retrieve(query: str, max_results: int = 3) -> List[str]:
    """
    Retrieves context snippets using the Gensee AI Platform API.

    Notes:
        - Relies on an environment variable `GENSEE_API_KEY` to get the 'Bearer your_token_here'.
        - Returns an empty list if the request fails.
    """
    
    
    api_key = os.getenv("GENSEE_API_KEY",)
    if not api_key:
        print("[WARN] Missing GENSEE_API_KEY environment variable — returning empty list.")
        return []

    # 2. Prepare the API request
    url = 'https://platform.gensee.ai/tool/search'
    
    # 3. Build the payload matching your API's requirements
    data = {
        'query': query,
        'max_results': max_results
    }
    
    # 4. Build the headers matching your API's requirements
    headers = {
        'Content-Type': 'application/json',
        'Authorization': f'Bearer {api_key}' # Dynamically load the key from env
    }

    try:
        # 5. Send the POST request
        response = requests.post(url, json=data, headers=headers, timeout=60)
        response.raise_for_status() # Raise an exception for bad statuses (401, 403, 500, etc.)
        data = response.json()

        # 6. Parse your specific JSON response structure
        #    Based on your example, results are in the 'search_response' key
        results = data.get("search_response", [])
        
        contexts = []
        for item in results:
            # Based on your example, the text snippet is in the 'content' key
            snippet = item.get("content") or ""
            if snippet:
                contexts.append(snippet)
        
        # 7. Ensure a List[str] is returned
        return [ctx for ctx in contexts if ctx][:max_results]

    except requests.exceptions.RequestException as e:
        print(f"[WARN] Gensee AI request failed: {e}")
        return []

## 3) Prompts (Query, Evaluator, Analyst, Editor, Selector)

In [7]:
PROMPT_QUERY_SYSTEM = """You generate user queries for evaluating ONE article.
Produce content-centric queries:
- Each query MUST be answerable using the article alone.
- Cover varied user intents: {definition, learn/explanation, how-to/guide, compare, list/examples}.
- Avoid brand bait, clickbait, or unanswerable questions.

CRITICAL FORMATTING RULES:
1. Return ONLY the raw JSON object.
2. DO NOT use Markdown fencing (no ```json ... ```).
3. DO NOT add any conversational text (e.g., "Here are the queries...").

Example Output:
{"queries":[
  {"intent":"definition","q":"..."},
  {"intent":"learn","q":"..."},
  {"intent":"how-to","q":"..."},
  {"intent":"compare","q":"..."},
  {"intent":"list","q":"..."}
]}
"""

def prompt_query_user(doc: str, n_queries: int = N_QUERIES) -> str:
    return f"[ARTICLE]\n{doc}\n\nReturn {n_queries} queries spread across the intents."

PROMPT_EVAL_SYSTEM = """You are an evaluation judge. Given a candidate article and external contexts:
1) Answer the user query (RAG style).
2) Score the CANDIDATE ARTICLE on SIX dimensions in [0,1] using anchors {0,.17,.33,.5,.67,.83,1}:
   - CP (Citation Prominence): clear, prominent citation/attribution of the candidate article in the final answer.
   - AA (Attribution Accuracy): statements attributed to the article truly originate from it.
   - FA (Faithfulness): answer remains faithful to the article’s meaning (no distortions).
   - KC (Key Concepts): article covers essential concepts needed for this query.
   - SC (Semantic Contribution): article contributes unique/central meaning vs other contexts.
   - AD (Answer Dominance): overall share of answer content deriving from the article vs other contexts.
Rules:
- Judge ONLY the candidate article’s contribution; do not reward contexts.
- If the answer can be formed without the article, penalize SC and AD.
- If external contexts are absent or minimal relative to the answer, DO NOT award SC or AD above 0.33 unless you explicitly justify why the article itself supplies the necessary unique content.
- If the article is very short/sparse and lacks definitions/examples/comparisons needed by the query, reduce KC and FA accordingly.

CRITICAL FORMATTING RULES:
1. Return ONLY the raw JSON object.
2. DO NOT use Markdown fencing (no ```json ... ```).
3. DO NOT include any preamble or postscript.

Example Output:
{
 "answer": "...",
 "scores": {"CP":0.83,"AA":0.67,"FA":0.83,"KC":0.67,"SC":0.50,"AD":0.50},
 "why": {
   "CP":"...", "AA":"...", "FA":"...", "KC":"...", "SC":"...", "AD":"..."
 }
}
"""

def prompt_eval_user(query: str, doc: str, contexts: List[str]) -> str:
    ctx = "\n---\n".join(contexts[:MAX_CTX]) if contexts else "(no external contexts)"
    return f"[QUERY]\n{query}\n\n[CANDIDATE_ARTICLE]\n{doc}\n\n[CONTEXTS]\n{ctx}"

PROMPT_ANALYST_SYSTEM = """You propose targeted edits to improve the article’s weakest metrics.
Inputs: (1) article, (2) per-query scores with brief rationales, (3) aggregate MIS/ISR/MIV.
Find the single weakest metric by MIS; break ties by high MIV and low ISR.
Propose up to 3 precise edits. For EACH edit include:
- target_metric: one of {CP,AA,FA,KC,SC,AD}
- reason: ≤2 sentences
- location_hint: exact anchor text or section title
- operation: one of {"insert_after","replace_span","append_section","delete_span","merge_sections"}
- patch: exact text to insert/replace (≤180 words)

CRITICAL FORMATTING RULES:
1. Return ONLY the raw JSON object.
2. DO NOT use Markdown fencing (no ```json ... ```).
3. DO NOT add conversational filler.

Example Output:
{"edits":[{"target_metric":"CP","reason":"...","location_hint":"...","operation":"insert_after","patch":"..."}]}
"""

def prompt_analyst_user(doc: str, per_query: List[Dict[str, Any]], agg: Dict[str, Any]) -> str:
    return json.dumps({
        "article": doc,
        "per_query": per_query,
        "aggregate": agg
    }, ensure_ascii=False)

PROMPT_EDITOR_SYSTEM = """Apply ONE provided edit to the article faithfully. 
Do NOT rewrite unrelated text. If location_hint not found, place patch in the nearest logical spot.
Return the FULL revised article only. No explanations.
"""

def prompt_editor_user(doc: str, json_edit: Dict[str, Any]) -> str:
    return json.dumps({"article": doc, "edit": json_edit}, ensure_ascii=False)

PROMPT_SELECTOR_SYSTEM = """You are a selector comparing multiple article versions evaluated on the SAME query+context corpus.
Given MIS, ISR, MIV per version, pick the version that maximizes:
score = sum(MIS[m] for m in [CP,AA,FA,KC,SC,AD]) - 0.2 * sum(MIV[m] for m in [CP,AA,FA,KC,SC,AD]).
Return your entire response in STRICT JSON:: {"winner_index": k, "reason":"≤2 sentences"}
"""

def prompt_selector_user(history_summary: List[Dict[str, Any]]) -> str:
    # history_summary: [{"idx": i, "agg": {...}, "snippet": "..."}]
    return json.dumps({"candidates": history_summary}, ensure_ascii=False)


## 4) Query generation + frozen corpus

In [8]:
# Corpus (build once, then freeze) 
def generate_queries_from_doc(doc_text: str, n_queries: int = N_QUERIES) -> List[str]:
    llm = make_llm(MODEL_ANALYST, temperature=0.3)  # tiny diversity, still on-topic
    payload = call_llm_json(llm, PROMPT_QUERY_SYSTEM, prompt_query_user(doc_text, n_queries))
    if "__SCHEMA_ERROR__" in payload:
        # very robust fallback: produce 5 generic but doc-specific queries
        base = [
            "Give a concise definition.",
            "Explain the key benefits.",
            "Provide a simple example.",
            "Compare it with an alternative.",
            "Give a short step-by-step guide."
        ]
        return [f"{q} (based on the article above)" for q in base][:n_queries]
    qs = [q["q"] for q in payload.get("queries", []) if q.get("q")]
    # dedupe, cap
    seen, uniq = set(), []
    for q in qs:
        if q not in seen:
            uniq.append(q)
            seen.add(q)
    
    if DEBUG:
        log_heading("Query Agent: generated queries")
        for i, q in enumerate(uniq[:n_queries]):
            print(f"{i+1}. {q}")

    return uniq[:n_queries]

def build_corpus_for_doc(doc_text: str, retriever=gensee_ai_retrieve,
                         n_queries=N_QUERIES, max_ctx=MAX_CTX) -> Dict[str, Any]:
    queries = generate_queries_from_doc(doc_text, n_queries=n_queries)
    pairs = []
    for q in queries:
        try:
            ctxs = retriever(q)[:max_ctx]
        except Exception as e:
            ctxs = []
        # keep only queries with at least 2 contexts (so the judge can compare)
        cleaned = []
        for c in ctxs:
            c = re.sub(r"\s+", " ", c.strip())
            if c and c not in cleaned:
                cleaned.append(c)
        if len(cleaned) >= 2:
            pairs.append({"q": q, "ctx": cleaned})
    if DEBUG:
        log_heading("Retrieval: per-query context counts")
        for p in pairs:
            print(f"- {p['q'][:80]}...  | ctx={len(p['ctx'])}")
        log_heading("Retrieved Contexts (Full Content)")
        for i, p in enumerate(pairs):
            print(f"\n--- Query {i+1}: {p['q']} ---")
            for j, ctx in enumerate(p['ctx']):
                print(f"\n[Context {j+1}]")
                print(ctx[:500] + ("..." if len(ctx) > 500 else ""))

    # require minimum coverage
    if len(pairs) < 2:
        raise RuntimeError(f"Corpus too small ({len(pairs)} with >=2 contexts). "
                           f"Set GENSEE_API_KEY and retry, or reduce filters.")
    key = hashlib.md5(doc_text.encode()).hexdigest()[:10]
    path = f"corpus_{key}.json"
    with open(path, "w") as f:
        json.dump({"queries": pairs, "created_at": time.time()}, f, ensure_ascii=False, indent=2)
    return {"queries": pairs, "path": path}

    
def load_corpus(path: str) -> Dict[str, Any]:
    with open(path) as f:
        return json.load(f)


## 5) Evaluator (per-query + MIS/ISR/MIV)

In [9]:
import math
import numpy as np

def _nearest_anchor(x: float) -> float:
    # snap to anchor grid
    if x is None: return 0.0
    try: x = float(x)
    except: return 0.0
    return min(ANCHORS, key=lambda a: abs(a - x))

def evaluator_score(document: str, query: str, contexts: List[str]) -> Dict[str, Any]:
    if DEBUG:
        log_heading(f"Evaluator: Query & Contexts")
        print(f"Query: {query}")
        print(f"\nNumber of contexts: {len(contexts)}")
        for i, ctx in enumerate(contexts):
            print(f"\n[Context {i+1}]")
            print(ctx[:500] + ("..." if len(ctx) > 500 else ""))
    
    llm = make_llm(MODEL_EVAL, TEMPERATURE_EVAL)
    payload = call_llm_json(llm, PROMPT_EVAL_SYSTEM, prompt_eval_user(query, document, contexts))
    if DEBUG:
        log_heading("Evaluator Output")
        log_json("payload", payload)

    answer = payload.get("answer", "")
    raw_scores = (payload.get("scores") or {})
    why = payload.get("why") or {}
    # coerce to anchors & fill missing
    scores = {m: _nearest_anchor(raw_scores.get(m)) for m in METRICS}
    return {"query": query, "scores": scores, "why": why, "answer": answer}

def aggregate_scores(per_query_scores: List[Dict[str, Any]], tau: float = SUCCESS_TAU) -> Dict[str, Dict[str, float]]:
    arr = np.array([[pq["scores"][m] for m in METRICS] for pq in per_query_scores])  # shape Qx6
    mis = dict(zip(METRICS, arr.mean(axis=0).round(4).tolist()))
    isr = dict(zip(METRICS, (arr >= tau).mean(axis=0).round(4).tolist()))
    miv = dict(zip(METRICS, arr.var(axis=0, ddof=0).round(4).tolist()))
    return {"MIS": mis, "ISR": isr, "MIV": miv}


## 6) Analyst (edits) + tag detection

In [10]:
def analyst_propose_edits(doc: str, per_query: List[Dict[str, Any]], agg: Dict[str, Any]) -> Dict[str, Any]:
    llm = make_llm(MODEL_ANALYST, TEMPERATURE_ANALYST)
    payload = call_llm_json(llm, PROMPT_ANALYST_SYSTEM, prompt_analyst_user(doc, per_query, agg))
    if DEBUG:
        log_heading("Analyst: proposed edits")
        log_json("edits", payload)

    if "__SCHEMA_ERROR__" in payload:
        # conservative fallback: add benefits sentence (improves SC/KC)
        return {"edits": [{
            "target_metric": "SC",
            "reason": "Add explicit benefits to improve semantic contribution and sufficiency.",
            "location_hint": "After introduction",
            "operation": "insert_after",
            "patch": "Key benefits include clarity, coverage of essential concepts, and concrete examples that distinguish this article from generic sources."
        }]}
    # auto-tag the proposed patches
    for e in payload.get("edits", []):
        patch = e.get("patch", "")
        for tag, pat in TAG_PATTERNS:
            if re.search(pat, patch, flags=re.I):
                e["tag"] = tag
                break
    return payload


## 7) Editor (apply one edit)

In [11]:
def _apply_edit_locally(doc: str, edit: Dict[str, Any]) -> str:
    """Lightweight, deterministic local editor for simple ops before LLM."""
    op = edit.get("operation")
    hint = edit.get("location_hint","")
    patch = edit.get("patch","").strip()

    if not patch and op != "delete_span":
        return doc

    if op == "insert_after":
        idx = doc.find(hint) if hint else -1
        if idx >= 0:
            cut = idx + len(hint)
            return doc[:cut] + ("\n" if doc[cut:cut+1] != "\n" else "") + patch + "\n" + doc[cut:]
        else:
            # append near end
            return doc.rstrip() + "\n\n" + patch + "\n"

    if op == "replace_span":
        if hint and hint in doc:
            return doc.replace(hint, patch, 1)
        return doc  # fallback: no-op

    if op == "append_section":
        return doc.rstrip() + "\n\n" + patch + "\n"

    if op == "delete_span":
        if hint and hint in doc:
            return doc.replace(hint, "", 1)
        return doc

    if op == "merge_sections":
        # naive: remove duplicate consecutive blank lines (simplify structure)
        merged = re.sub(r"\n{3,}", "\n\n", doc)
        return merged

    return doc

def editor_apply_edit(doc: str, chosen_edit: Dict[str, Any]) -> str:
    """
    First try a deterministic local application; if the hint isn't found or
    the operation needs rewriting, fall back to the LLM editor.
    """
    # Try local
    new_doc = _apply_edit_locally(doc, chosen_edit)
    if new_doc != doc or chosen_edit.get("operation") in ("append_section","merge_sections","delete_span"):
        return new_doc

    # Fallback to LLM editor for tougher cases
    llm = make_llm(MODEL_EDITOR, TEMPERATURE_EDITOR)
    out = llm.invoke([("system", PROMPT_EDITOR_SYSTEM),
                      ("human", prompt_editor_user(doc, chosen_edit))])
    text = getattr(out, "content", "") or str(out)
    return text.strip()


## 8) Optimize loop + hybrid selector

In [12]:
def _history_summary_for_selector(history: List[Tuple[str, List[Dict[str,Any]], Dict[str,Any]]]) -> List[Dict[str, Any]]:
    summ = []
    for i, (doc, perq, agg) in enumerate(history):
        # a short snippet for context
        snippet = (doc[:220] + "…") if len(doc) > 220 else doc
        summ.append({"idx": i, "agg": agg, "snippet": snippet})
    return summ

def score_scalar(agg: Dict[str, Dict[str, float]], lam: float = 0.2) -> float:
    s = sum(agg["MIS"][m] for m in METRICS) - lam * sum(agg["MIV"][m] for m in METRICS)
    return round(float(s), 4)

def select_best_version(history: List[Tuple[str, List[Dict[str,Any]], Dict[str,Any]]]) -> Dict[str, Any]:
    # 1) rule-based ranking
    with_scores = [(i, score_scalar(agg)) for i, (_, _, agg) in enumerate(history)]
    with_scores.sort(key=lambda x: x[1], reverse=True)
    top = [i for i,_ in with_scores[:3]]

    # 2) LLM selector tie-breaker among top-3 (optional; safer)
    llm = make_llm(MODEL_EVAL, 0.0)
    summary = _history_summary_for_selector([history[i] for i in top])
    payload = call_llm_json(llm, PROMPT_SELECTOR_SYSTEM, prompt_selector_user(summary))
    if "__SCHEMA_ERROR__" in payload:
        # fallback to best scalar
        best_idx = top[0]
    else:
        k = payload.get("winner_index", 0)
        best_idx = top[min(max(int(k), 0), len(top)-1)]

    doc, perq, agg = history[best_idx]
    return {"index": best_idx, "doc": doc, "agg": agg, "score_scalar": score_scalar(agg)}

def optimize_doc(doc_text: str, corpus: Dict[str, Any], n_iters: int = N_ITERS):
    history = []
    D = doc_text
    for t in range(n_iters):
        # Evaluate on the frozen corpus
        per_query_scores = []
        for item in corpus["queries"]:
            scores = evaluator_score(D, item["q"], item["ctx"])
            per_query_scores.append(scores)
        agg = aggregate_scores(per_query_scores, tau=SUCCESS_TAU)
        history.append((D, per_query_scores, agg))

        # Analyze & choose an edit
        plan = analyst_propose_edits(D, per_query_scores, agg)
        edits = plan.get("edits", [])
        if not edits:
            # nothing to do -> early stop
            break
        # Choose the edit most aligned with weakest metric (by MIS)
        mis = agg["MIS"]
        weakest = sorted(METRICS, key=lambda m: mis[m])[0]
        chosen = next((e for e in edits if e.get("target_metric")==weakest), edits[0])
        if DEBUG:
            log_heading(f"ITER {t} — Chosen edit")
            log_json("chosen_edit", chosen)

        # Apply
        D = editor_apply_edit(D, chosen)

        if DEBUG:
            log_heading(f"ITER {t} — Editor Output (New Document)")
            print(D) 
            print("="*80)

    return history


## 9) Data Generation Experiment Setup

In [13]:
import pandas as pd
import csv

def load_articles_from_tsv(tsv_path: str) -> List[Dict[str, str]]:
    """
    Load articles from TSV file.
    Expected format: query, source_url, se_rank, ge_rank, clean_content (tab-separated).
    Maps: query -> article_id, clean_content -> article_text
    Returns list of dicts with 'article_id' and 'article_text'.
    """
    df = pd.read_csv(tsv_path, sep='\t')
    
    # Check required columns
    if 'query' not in df.columns or 'clean_content' not in df.columns:
        raise ValueError(f"TSV must have 'query' and 'clean_content' columns. Found: {df.columns.tolist()}")
    
    articles = []
    for _, row in df.iterrows():
        # Skip rows with empty or very short content
        content = str(row['clean_content']).strip()
        if len(content) < 50:
            continue
            
        articles.append({
            'article_id': str(row['query']),  # Use query as article_id
            'article_text': content            # Use clean_content as article_text
        })
    
    log_info(f"Loaded {len(articles)} valid articles from {tsv_path}")
    return articles

def run_single_article_experiment(article_id: str, article_text: str, n_iters: int = 5, n_generated_queries: int = 5) -> Dict[str, Any]:
    """
    Run optimization experiment for a single article.
    
    Args:
        article_id: Unique identifier for the article
        article_text: The target article text to optimize
        n_iters: Number of optimization iterations
        n_generated_queries: Number of queries to generate from target article (default: 5)
    
    Returns:
        dict with keys:
            - 'article_id': str
            - 'original_article': str
            - 'best_edited_article': str
            - 'best_iteration': int
            - 'iteration_metrics': List[Dict[str, float]] (6 metrics per iteration)
            - 'generated_queries': List[str] (queries generated from target article)
            - 'success': bool
    """
    log_heading(f"Processing Article: {article_id}")
    
    # Validate input
    if not article_text or len(article_text.strip()) < 50:
        return {'article_id': article_id, 'success': False, 'error': 'Article text too short or empty'}
    
    target_article = article_text.strip()
    
    # Log the original article
    log_heading("Original Target Article")
    print(f"Article ID: {article_id}")
    print(f"Article length: {len(target_article)} chars")
    print(f"Article preview (first 500 chars):\n{target_article[:500]}...\n")
    
    log_info(f"Target article length: {len(target_article)} chars")
    
    # Step 1: Generate queries from target article
    log_heading("Generating Queries from Target Article")
    try:
        generated_queries = generate_queries_from_doc(target_article, n_queries=n_generated_queries)
        log_info(f"✓ Generated {len(generated_queries)} queries:")
        for i, gq in enumerate(generated_queries):
            print(f"  {i+1}. {gq[:100]}{'...' if len(gq) > 100 else ''}")
    except Exception as e:
        log_info(f"✗ Failed to generate queries: {e}")
        return {'article_id': article_id, 'success': False, 'error': f'Query generation failed: {e}'}
    
    if len(generated_queries) < 2:
        return {'article_id': article_id, 'success': False, 'error': f'Too few queries generated: {len(generated_queries)}'}
    
    # Step 2: Retrieve contexts for each generated query
    log_heading("Retrieving Contexts for Generated Queries")
    corpus_queries = []
    
    for i, gq in enumerate(generated_queries):
        log_info(f"Retrieving contexts for generated query {i+1}/{len(generated_queries)}")
        try:
            # Retrieve 4 contexts per generated query
            gq_contexts = gensee_ai_retrieve(gq, max_results=4)
            
            # Basic cleaning
            cleaned_ctxs = []
            seen_c = set()
            for c in gq_contexts:
                if isinstance(c, str):
                    c_clean = c.strip()
                    if c_clean and c_clean not in seen_c:
                        cleaned_ctxs.append(c_clean)
                        seen_c.add(c_clean)
            
            log_info(f"  -> Retrieved {len(cleaned_ctxs)} contexts")
            corpus_queries.append({'q': gq, 'ctx': cleaned_ctxs})
            
        except Exception as e:
            log_info(f"  -> ✗ Error retrieving contexts: {e}")
            corpus_queries.append({'q': gq, 'ctx': []})
            
    corpus = {
        'queries': corpus_queries
    }
    
    log_info(f"Corpus created with {len(corpus['queries'])} query-context pairs")
    
    # Step 3: Run optimization iterations
    history = []
    D = target_article
    
    for t in range(n_iters):
        log_heading(f"Iteration {t+1}/{n_iters}")
        
        # Evaluate current document
        log_heading(f"Evaluating Article - Iteration {t}")
        print(f"Article length: {len(D)} chars")
        print(f"Article preview (first 300 chars):\n{D[:300]}...\n")
        
        # Evaluate on each generated query
        per_query_scores = []
        for i, item in enumerate(corpus["queries"]):
            scores = evaluator_score(D, item["q"], item["ctx"])
            per_query_scores.append(scores)
            # Log individual query metrics
            query_metrics = " ".join([f"{m}:{scores['scores'][m]:.2f}" for m in METRICS])
            log_info(f"  Query {i+1}/{len(corpus['queries'])} | {query_metrics}")
        
        # Aggregate metrics across all generated queries
        agg = aggregate_scores(per_query_scores, tau=SUCCESS_TAU)
        history.append((D, per_query_scores, agg))
        
        # Log aggregated metrics for this iteration
        log_heading(f"Iteration {t} - Aggregated Metrics")
        mis_line = " ".join([f"{m}:{agg['MIS'][m]:.2f}" for m in METRICS])
        isr_line = " ".join([f"{m}:{agg['ISR'][m]:.2f}" for m in METRICS])
        miv_line = " ".join([f"{m}:{agg['MIV'][m]:.4f}" for m in METRICS])
        log_info(f"MIS | {mis_line}")
        log_info(f"ISR | {isr_line}")
        log_info(f"MIV | {miv_line}")
        
        # Analyze and apply edit (skip on last iteration)
        if t < n_iters - 1:
            plan = analyst_propose_edits(D, per_query_scores, agg)
            edits = plan.get("edits", [])
            
            if not edits:
                log_info("No edits proposed, stopping early")
                break
            
            # Choose edit targeting weakest metric
            mis = agg["MIS"]
            weakest = sorted(METRICS, key=lambda m: mis[m])[0]
            chosen = next((e for e in edits if e.get("target_metric")==weakest), edits[0])
            
            log_info(f"Applying edit targeting: {chosen.get('target_metric')}")
            
            # Apply edit
            D = editor_apply_edit(D, chosen)
            
            # Log the edited article
            log_heading(f"Edited Article - After Iteration {t}")
            print(f"New length: {len(D)} chars")
            print(f"Preview (first 300 chars):\n{D[:300]}...\n")
    
    # Step 4: Select best version
    best = select_best_version(history)
    
    # Log the best selected article
    log_heading("Best Article Selected")
    print(f"Best iteration: {best['index']}")
    print(f"Best score: {best['score_scalar']:.3f}")
    print(f"Best article length: {len(best['doc'])} chars")
    print(f"Best article preview (first 500 chars):\n{best['doc'][:500]}...\n")
    
    # Step 5: Extract metrics for each iteration
    iteration_metrics = []
    for i, (_, _, agg) in enumerate(history):
        metrics_dict = {m: agg['MIS'][m] for m in METRICS}
        iteration_metrics.append(metrics_dict)
    
    log_info(f"Best iteration: {best['index']} with score: {best['score_scalar']:.3f}")
    log_info(f"Generated {len(generated_queries)} queries for evaluation")
    
    return {
        'article_id': article_id,
        'original_article': target_article,
        'best_edited_article': best['doc'],
        'best_iteration': best['index'],
        'iteration_metrics': iteration_metrics,
        'generated_queries': generated_queries,
        'success': True
    }

def save_experiment_results(results: List[Dict[str, Any]], output_dir: str = "data"):
    """
    Save experiment results to two CSV files.
    
    File 1: articles.csv - contains article_id, original_article, best_edited_article
    File 2: metrics.csv - contains article_id and metrics for each iteration
    """
    os.makedirs(output_dir, exist_ok=True)

    # File 1: Articles
    articles_file = os.path.join(output_dir, 'articles.csv')
    articles_exists = os.path.exists(articles_file)
    with open(articles_file, 'a', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        
        # Write header only if file is new
        if not articles_exists:
            writer.writerow(['article_id', 'original_article', 'best_edited_article', 'best_iteration'])
        
        for r in results:
            if r.get('success'):
                writer.writerow([
                    r['article_id'],
                    r['original_article'],
                    r['best_edited_article'],
                    r['best_iteration']
                ])
    
    log_info(f"{'Appended to' if articles_exists else 'Created'} articles file: {articles_file}")
    
    # File 2: Metrics
    metrics_file = os.path.join(output_dir, 'metrics.csv')
    metrics_exists = os.path.exists(metrics_file)
    with open(metrics_file, 'a', newline='', encoding='utf-8') as f:
        # Create header: article_id, iter0_CP, iter0_AA, ..., iter4_AD
        header = ['article_id']

        successes = [r for r in results if r.get("success")]
        if not successes:
            log_info("No successful results, skipping CSV write")
            return None, None
        max_iters = max(len(r.get("iteration_metrics", [])) for r in successes)
        
        for i in range(max_iters):
            for metric in METRICS:
                header.append(f'iter{i}_{metric}')
        
        writer = csv.writer(f)
        
        # Write header only if file is new
        if not metrics_exists:
            writer.writerow(header)
        
        for r in results:
            if r.get('success'):
                row = [r['article_id']]
                for iter_metrics in r['iteration_metrics']:
                    for metric in METRICS:
                        row.append(iter_metrics.get(metric, ''))
                writer.writerow(row)
    
    log_info(f"{'Appended to' if metrics_exists else 'Created'} metrics file: {metrics_file}")
    
    return articles_file, metrics_file

def run_full_experiment(tsv_path: str, n_articles: int = None, n_iters: int = 5):
    """
    Run the full data generation experiment.
    
    Args:
        tsv_path: Path to TSV file containing articles
        n_articles: Number of articles to process (None = all)
        n_iters: Number of optimization iterations per article
    """
    log_heading("Starting Data Generation Experiment")
    
    # Load articles
    all_articles = load_articles_from_tsv(tsv_path)
    
    if n_articles:
        articles_to_process = all_articles[:n_articles]
    else:
        articles_to_process = all_articles
    
    log_info(f"Processing {len(articles_to_process)} articles with {n_iters} iterations each")
    
    # Run experiments
    results = []
    for i, article in enumerate(articles_to_process):
        log_heading(f"Article {i+1}/{len(articles_to_process)}")
        try:
            result = run_single_article_experiment(
                article['article_id'], 
                article['article_text'], 
                n_iters=n_iters
            )
            results.append(result)
            
            if result.get('success'):
                log_info(f"✓ Successfully processed article {i+1}")
            else:
                log_info(f"✗ Failed to process article {i+1}: {result.get('error')}")
        except Exception as e:
            log_info(f"✗ Exception processing article {i+1}: {e}")
            results.append({'article_id': article['article_id'], 'success': False, 'error': str(e)})
    
    # Save results
    articles_file, metrics_file = save_experiment_results(results)
    
    # Summary
    successful = sum(1 for r in results if r.get('success'))
    log_heading("Experiment Complete")
    log_info(f"Successfully processed: {successful}/{len(results)} articles")
    log_info(f"Results saved to:")
    log_info(f"  - Articles: {articles_file}")
    log_info(f"  - Metrics: {metrics_file}")
    
    return results, articles_file, metrics_file


## 10) Run Experiment

Uncomment and run the cell below to start the experiment.

In [14]:
# ==================== EXPERIMENT CONFIGURATION ====================

# Input/Output Settings 
TSV_PATH = "../data/ariticles/target_articles.tsv"  # Path to target_articles CSV
OUTPUT_DIR = "data"                                      # Output directory

# Experiment Settings 
N_ARTICLES_TO_PROCESS = 30   # Number of articles to process (set to None for all)
                             
START_INDEX = 0              # Start from which article (useful for batching)
                             
N_ITERATIONS = 8             # Number of optimization iterations per article

N_GENERATED_QUERIES = 8      # Number of queries to generate from target article (3-10 recommended)

# Resume Settings - NEW!
RESUME_FROM_EXISTING = True   # Skip articles that are already in metrics file
EXISTING_METRICS_FILE = None  # Specify file path, or None to auto-detect latest
                             

# Retrieval Settings
# N_CONTEXTS_TO_RETRIEVE removed (not needed anymore)   # Number of contexts to retrieve per article
MIN_CONTEXTS_REQUIRED = 2    # Minimum contexts needed to process article

# Model Settings (optional, using defaults from config)
MODEL_EVAL = "gemini-2.5-flash"
MODEL_ANALYST = "gemini-2.5-flash"
MODEL_EDITOR = "gemini-2.5-flash"  # Fixed typo: flah -> flash

# Debug Settings 
VERBOSE = True              # Print detailed progress 
SAVE_INTERMEDIATE = True    # Save after each article (slower but safer)

# ==================================================================


# ==================== HELPER FUNCTIONS ====================
import glob
import pandas as pd

def find_latest_metrics_file(output_dir: str) -> str:
    """
    Check if metrics.csv exists and has data.
    Returns filepath if exists with data, None otherwise.
    """
    metrics_file = os.path.join(output_dir, "metrics.csv")
    
    # Check if file exists
    if not os.path.exists(metrics_file):
        return None
    
    # Check if file has actual data (not just header)
    try:
        df = pd.read_csv(metrics_file)
        if len(df) > 0:  # Has at least 1 data row
            return metrics_file
        else:  # Only header, treat as non-existent
            return None
    except Exception:
        return None

def load_processed_articles(metrics_file: str) -> set:
    """Load queries that have already been processed"""
    if not metrics_file or not os.path.exists(metrics_file):
        return set()
    
    try:
        df = pd.read_csv(metrics_file)
        processed = set(df['article_id'].tolist())
        log_info(f"📂 Loaded {len(processed)} processed articles from: {metrics_file}")
        return processed
    except Exception as e:
        log_info(f"⚠️  Could not load existing metrics file: {e}")
        return set()

def filter_unprocessed_articles(all_articles: list, processed_articles: set) -> list:
    """Filter out articles that have already been processed"""
    unprocessed = [a for a in all_articles if a["article_id"] not in processed_articles]
    skipped = len(all_articles) - len(unprocessed)
    
    if skipped > 0:
        log_info(f"⏭️  Skipping {skipped} already processed queries")
        log_info(f"📋 {len(unprocessed)} queries remaining to process")
    
    return unprocessed


# ==================== RUN EXPERIMENT ====================
import time
from datetime import datetime

def run_experiment_with_config():
    """Run experiment with above configuration"""
    
    print("="*80)
    print(f"🚀 Starting Experiment at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("="*80)
    print(f"\n📋 Configuration:")
    print(f"  - TSV Path: {TSV_PATH}")
    print(f"  - Articles to process: {N_ARTICLES_TO_PROCESS if N_ARTICLES_TO_PROCESS else 'ALL'}")
    print(f"  - Start index: {START_INDEX}")
    print(f"  - Iterations per article: {N_ITERATIONS}")
    print(f"  - Generated queries per article: {N_GENERATED_QUERIES}")
    print(f"  - Output directory: {OUTPUT_DIR}")
    print(f"  - Resume from existing: {RESUME_FROM_EXISTING}")
    print("\n" + "="*80 + "\n")
    
    start_time = time.time()
    
    # ✅ FIX 1: Define output file paths at START
    articles_file = os.path.join(OUTPUT_DIR, 'articles.csv')
    metrics_file = os.path.join(OUTPUT_DIR, 'metrics.csv')
    
    # Load queries
    all_articles = load_articles_from_tsv(TSV_PATH)
    log_info(f"📊 Total articles in TSV: {len(all_articles)}")
    
    # Check for existing processed queries (RESUME FEATURE)
    processed_articles = set()
    if RESUME_FROM_EXISTING:
        # Find existing metrics file
        metrics_file_to_check = EXISTING_METRICS_FILE
        if not metrics_file_to_check:
            metrics_file_to_check = find_latest_metrics_file(OUTPUT_DIR)
        
        if metrics_file_to_check:
            log_info(f"🔍 Checking for processed queries in: {metrics_file_to_check}")
            processed_articles = load_processed_articles(metrics_file_to_check)
        else:
            log_info(f"ℹ️  No existing metrics file found, starting fresh")
    
    # Select queries based on configuration
    # Filter out already processed queries FIRST
    all_unprocessed = filter_unprocessed_articles(all_articles, processed_articles)
    
    # Then select based on START_INDEX and N_ARTICLES_TO_PROCESS
    if N_ARTICLES_TO_PROCESS:
        end_index = min(START_INDEX + N_ARTICLES_TO_PROCESS, len(all_unprocessed))
        articles_to_process = all_unprocessed[START_INDEX:end_index]
    else:
        articles_to_process = all_unprocessed[START_INDEX:]
    
    
    if len(articles_to_process) == 0:
        print("\n" + "="*80)
        print("✅ All queries already processed! Nothing to do.")
        print("="*80)
        return [], articles_file, metrics_file
    
    log_info(f"📊 Processing queries: {len(articles_to_process)} remaining")
    print("\n" + "-"*80 + "\n")
    
    # Run experiments
    results = []
    successful = 0
    failed = 0
    
    for i, article in enumerate(articles_to_process):
        # Find actual index in original query list
        actual_idx = next((idx for idx, a in enumerate(all_articles) if a["article_id"] == article["article_id"]), i)
        
        log_heading(f"Query {actual_idx + 1}/{len(all_articles)} (Processing {i+1}/{len(articles_to_process)})")
        
        if VERBOSE:
            article_id = article["article_id"]
            article_text = article["article_text"]
            preview = article_text[:80]
            suffix = '...' if len(article_text) > 80 else ''
            print(f"📝 Article ID: {article_id} | Preview: {preview}{suffix}")
        
        try:
            # Run single query experiment
            result = run_single_article_experiment(
                article["article_id"],
                article["article_text"], 
                n_iters=N_ITERATIONS,
                n_generated_queries=N_GENERATED_QUERIES
            )
            results.append(result)
            
            if result.get('success'):
                successful += 1
                log_info(f"✅ Success | Total: {successful}/{i+1}")
            else:
                failed += 1
                log_info(f"❌ Failed: {result.get('error')} | Total failures: {failed}/{i+1}")
            
            # ✅ FIX 2: Save only the NEW result (not accumulated list)
            if results:
                log_info(f"💾 Saving intermediate results...")
                save_experiment_results([result], output_dir=OUTPUT_DIR)
                
        except Exception as e:
            failed += 1
            log_info(f"❌ Exception: {str(e)[:100]}")
            results.append({
                'article_id': article["article_id"], 
                'success': False, 
                'error': str(e)
            })
        
        # Progress summary every 10 queries
        if (i + 1) % 10 == 0:
            elapsed = time.time() - start_time
            avg_time = elapsed / (i + 1)
            remaining = avg_time * (len(articles_to_process) - i - 1)
            print("\n" + "="*80)
            print(f"📊 Progress: {i+1}/{len(articles_to_process)} articles processed")
            print(f"✅ Successful: {successful} | ❌ Failed: {failed}")
            print(f"⏱️  Elapsed: {elapsed/60:.1f} min | Estimated remaining: {remaining/60:.1f} min")
            print("="*80 + "\n")
    
    # Note: Results are already saved after each query
    # No need to save again at the end
    
    # Final summary
    total_time = time.time() - start_time
    print("\n" + "="*80)
    print("🎉 EXPERIMENT COMPLETE!")
    print("="*80)
    print(f"\n📊 Summary:")
    print(f"  - Total processed this run: {len(results)}")
    print(f"  - Successful: {successful} ({successful/len(results)*100:.1f}%)")
    print(f"  - Failed: {failed} ({failed/len(results)*100:.1f}%)")
    print(f"  - Skipped (already done): {len(processed_articles)}")
    print(f"  - Total time: {total_time/60:.1f} minutes ({total_time/3600:.2f} hours)")
    print(f"  - Average per article: {total_time/len(results):.1f} seconds")
    
    print(f"\n📁 Output Files:")
    print(f"  - Articles: {articles_file}")
    print(f"  - Metrics:  {metrics_file}")
    
    # ✅ FIX 3: Add return statement
    return results, articles_file, metrics_file


# ==================== RUN EXPERIMENT ====================

# Run the experiment
results, articles_file, metrics_file = run_experiment_with_config()