In [None]:
# ==============================
# HumanEval Multi-Stage Orchestration (with Syntax-Only Debugger Agent + AST check)
# Planner: microsoft/phi-1_5
# Coder: deepseek-ai/deepseek-coder-1.3b-instruct
# Debugger: google/codegemma-2b (accept fix ONLY if AST unchanged)
# ==============================

import re, json, time, tempfile, subprocess, os, ast, py_compile
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

MAX_TASKS = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def load_model_and_tokenizer(model_name):
    tok = AutoTokenizer.from_pretrained(model_name)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto" if DEVICE == "cuda" else None
    )
    model.eval()
    return tok, model

print("Loading Planner (microsoft/phi-1_5)...")
planner_tokenizer, planner_model = load_model_and_tokenizer("microsoft/phi-1_5")

print("Loading Coder (deepseek-ai/deepseek-coder-1.3b-instruct)...")
coder_tokenizer, coder_model = load_model_and_tokenizer("deepseek-ai/deepseek-coder-1.3b-instruct")

print("Loading Debugger (google/codegemma-2b)...")
debugger_tokenizer, debugger_model = load_model_and_tokenizer("google/codegemma-2b")

def generate_text(model, tokenizer, prompt, max_new_tokens=512):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new_tokens,
                             temperature=0.0, do_sample=False,
                             pad_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(out[0], skip_special_tokens=True)

def extract_code(text):
    m = re.search(r"```(?:python)?\s*(.*?)```", text, re.DOTALL)
    return m.group(1).strip() if m else text.strip()

def run_candidate(code, test_code):
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
        f.write(code + "\n" + test_code)
        path = f.name
    try:
        r = subprocess.run(["python3", path],
                           stdout=subprocess.PIPE,
                           stderr=subprocess.PIPE,
                           text=True)
        if r.returncode == 0:
            return "PASS"
        err = r.stderr.strip().splitlines()
        return "FAIL: " + (err[-1] if err else r.stderr.strip())
    finally:
        try: os.remove(path)
        except: pass

def ast_equivalent(src1, src2):
    """
    Return True if ASTs of src1 and src2 are equivalent modulo lineno/col_offset attributes.
    """
    try:
        a1 = ast.parse(src1)
        a2 = ast.parse(src2)
    except Exception:
        return False
    # remove attributes that vary (lineno, col_offset, end_lineno, end_col_offset)
    def _normalize(node):
        for field, value in list(ast.iter_fields(node)):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AST):
                        _normalize(item)
            elif isinstance(value, ast.AST):
                _normalize(value)
        # delete attributes if present
        for attr in ("lineno", "col_offset", "end_lineno", "end_col_offset"):
            if hasattr(node, attr):
                try:
                    delattr(node, attr)
                except Exception:
                    pass
    _normalize(a1)
    _normalize(a2)
    # Compare dumps without attributes
    d1 = ast.dump(a1, include_attributes=False)
    d2 = ast.dump(a2, include_attributes=False)
    return d1 == d2

def debug_code_with_gemma(buggy_code):
    """
    Ask CodeGemma-2b to fix syntax-only issues.
    Accept the fix ONLY if:
      - fixed code compiles, AND
      - AST(fixed) == AST(original) (i.e., no logic/structure changes)
    Otherwise, return None to indicate rejection.
    """
    prompt = f"""Fix ONLY syntax errors in the following Python code.
Do NOT change logic or function signatures. Return ONLY corrected code inside one ```python``` block.

```python
{buggy_code}
```"""
    fixed_raw = generate_text(debugger_model, debugger_tokenizer, prompt, max_new_tokens=512)
    fixed_code = extract_code(fixed_raw)

    # Quick sanity: must compile
    try:
        # write to temp file and py_compile to ensure syntax validity
        with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tf:
            tf.write(fixed_code)
            tmp = tf.name
        py_compile.compile(tmp, doraise=True)
    except Exception:
        # fixed code does not even compile
        try: os.remove(tmp)
        except: pass
        return None
    finally:
        try: os.remove(tmp)
        except: pass

    # AST equivalence check
    if ast_equivalent(buggy_code, fixed_code):
        return fixed_code
    else:
        return None

dataset = load_dataset("openai_humaneval", split="test")
results = []
passed = 0
start_time = time.time()

for idx, task in enumerate(dataset):
    if MAX_TASKS and idx >= MAX_TASKS:
        break

    prompt = task["prompt"]
    test_code = task["test"]
    task_id = task["task_id"]

    print("="*80)
    print("Task:", task_id)
    print(prompt)

    # Planner
    plan_prompt = f"Provide a clear plan to implement the following:\n\n{prompt}"
    planner_output = generate_text(planner_model, planner_tokenizer, plan_prompt, max_new_tokens=256)

    # Coder
    coder_prompt = (
        f"Using this plan, write correct Python.\nPLAN:\n{planner_output}\nPROMPT:\n{prompt}\n"
        "Return only code inside ```python``` block."
    )
    coder_output = generate_text(coder_model, coder_tokenizer, coder_prompt, max_new_tokens=512)
    candidate_code = extract_code(coder_output)

    # Evaluate
    verdict = run_candidate(candidate_code, test_code)

    # Debugger if syntax-like failure
    if verdict.lower().startswith("fail") and any(
        x in verdict.lower() for x in [
            "syntaxerror","invalid syntax","indentationerror","taberror",
            "unexpected eof","eof while parsing","unterminated","invalid token"
        ]
    ):
        print("\nüêû Syntax-like failure detected ‚Üí attempting syntax-only fix with Debugger...\n")
        fixed_code = debug_code_with_gemma(candidate_code)
        if fixed_code is not None:
            # accepted because it compiled and AST-equivalent
            print("üîß Debugger provided a syntax-only fix. Re-running tests...")
            candidate_code = fixed_code
            verdict = run_candidate(candidate_code, test_code)
        else:
            print("‚ö†Ô∏è Debugger fix rejected (either didn't compile or changed AST). Keeping original code.")

    if verdict.startswith("PASS"):
        passed += 1

    results.append({
        "task_id": task_id,
        "planner_output": planner_output,
        "coder_output": coder_output,
        "candidate_code": candidate_code,
        "verdict": verdict
    })

    print("Verdict:", verdict)
    print(f"Pass@1: {passed}/{len(results)} = {passed/len(results)*100:.2f}%")

elapsed = time.time() - start_time
print("="*80)
print(f"Finished {len(results)} tasks | Passed {passed} | Pass@1 = {passed/len(results)*100:.2f}%")
print(f"Time: {elapsed/60:.2f} min")

with open("humaneval_orchestration_results.json", "w") as f:
    json.dump(results, f, indent=2)


Loading Planner (microsoft/phi-1_5)...
Loading Coder (deepseek-ai/deepseek-coder-1.3b-instruct)...
Loading Debugger (google/codegemma-2b)...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Task: HumanEval/0
from typing import List


def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any two numbers closer to each other than
    given threshold.
    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
    False
    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
    True
    """

Verdict: PASS
Pass@1: 1/1 = 100.00%
Task: HumanEval/1
from typing import List


def separate_paren_groups(paren_string: str) -> List[str]:
    """ Input to this function is a string containing multiple groups of nested parentheses. Your goal is to
    separate those group into separate strings and return the list of those.
    Separate groups are balanced (each open brace is properly closed) and not nested within each other
    Ignore any spaces in the input string.
    >>> separate_paren_groups('( ) (( )) (( )( ))')
    ['()', '(())', '(()())']
    """

Verdict: PASS
Pass@1: 2/2 = 100.00%
Task: HumanEval/2


def truncate_num