In [1]:
import os
import re
import pathlib
from datasets import load_dataset
import google.generativeai as genai

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL = "gemini-2.5-flash"
OUT_DIR = pathlib.Path("generated_solutions")
OUT_DIR.mkdir(parents=True, exist_ok=True)

In [3]:
def extract_docstring(prompt_text: str) -> list:
    """
    Extract all triple-quoted docstrings from the input text.
    Returns a list of docstring contents as strings.
    """
    matches = re.findall(r'("""|\'\'\')(.*?)(\1)', prompt_text, flags=re.DOTALL)
    # Join all docstrings into a single string (if needed)
    doc_text = "\n\n".join([m[1].strip() for m in matches])

    return doc_text

In [4]:
api_key = os.environ.get("GOOGLE_API_KEY")
if not api_key:
    raise RuntimeError("Please set GOOGLE_API_KEY in your environment.")
genai.configure(api_key=api_key)
model = genai.GenerativeModel(MODEL)

ds = load_dataset("openai/openai_humaneval")["test"]  # 164 items
print(f"Loaded HumanEval with {len(ds)} problems.")

Loaded HumanEval with 164 problems.


In [8]:
for idx, item in enumerate(ds):
    doc = extract_docstring(item["prompt"])
    prompt = f'"""\n{doc}\n"""\nPlease generate a plausible solution for this problem.'
    # if(idx == 32):
    #     print(prompt)
    try:
        resp = model.generate_content(prompt)
        text = (resp.text or "").strip()
        # If fences sneak in, strip them.
        text = re.sub(r"^```(?:python)?\s*", "", text)
        text = re.sub(r"\s*```$", "", text)

        out_path = OUT_DIR / f"Solution_{idx}.py"
        out_path.write_text(text, encoding="utf-8")
    except Exception as e:
        print(f"[WARN] Problem {idx} failed: {e}")
        
print("Generation Complete")

Generation Complete


In [20]:
# Remove everything above the first ```python fence (inclusive) in each .py file under DIR
# Usage: set DIR below and run this cell.

from pathlib import Path
import re

# <<< EDIT THIS >>>
DIR = Path("generated_solutions")  # e.g., Path("/home/jason/project")
# <<< ----------- >>>

FENCE_RE = re.compile(r"```[ \t]*python", re.IGNORECASE)

def clean_file(p: Path) -> bool:
    """
    Find the first ```python fence in file p and remove everything above it
    (including the fence itself). Returns True if the file was modified.
    """
    try:
        text = p.read_text(encoding="utf-8", errors="ignore")
    except Exception as e:
        print(f"[skip read error] {p}: {e}")
        return False

    m = FENCE_RE.search(text)
    if not m:
        return False  # no fence, leave file as-is

    new_text = text[m.end():].lstrip("\r\n")  # drop fence + any immediate blank lines
    if new_text == text:
        return False

    try:
        p.write_text(new_text, encoding="utf-8")
        return True
    except Exception as e:
        print(f"[skip write error] {p}: {e}")
        return False

def run(dir_path: Path):
    if not dir_path.exists():
        raise FileNotFoundError(f"Directory not found: {dir_path}")

    py_files = [p for p in dir_path.rglob("*.py") if p.is_file()]
    changed = 0
    for p in py_files:
        if clean_file(p):
            changed += 1
            print(f"[modified] {p}")

    print(f"\nDone. Scanned: {len(py_files)} .py files; Modified: {changed}")


Done. Scanned: 164 .py files; Modified: 0


# -------------------------------------------------------------------------------------

In [None]:
# === CONFIG ===
import os, re, io, sys, pathlib, types, importlib.util, builtins, json, textwrap
import unittest
import google.generativeai as genai
import pandas as pd

GEN_SOL_DIR = pathlib.Path("generated_solutions").resolve()     # Solution_<i>.py
TESTS_DIR   = pathlib.Path("generated_tests").resolve()         # HumanEval_<i>.py
CANON_DIR   = pathlib.Path("canonical_solutions").resolve()     # 000_<entry>.py  (to infer entry_point)
MAX_FIX_ATTEMPTS_PER_TEST = 3
MODEL = "gemini-2.5-flash"

assert GEN_SOL_DIR.is_dir(), f"Missing {GEN_SOL_DIR}"
assert TESTS_DIR.is_dir(),   f"Missing {TESTS_DIR}"
assert CANON_DIR.is_dir(),   f"Missing {CANON_DIR}"

# --- Gemini client ---
api_key = os.environ.get("GOOGLE_API_KEY")
if not api_key:
    raise RuntimeError("Set GOOGLE_API_KEY before running.")
genai.configure(api_key=api_key)
gmodel = genai.GenerativeModel(MODEL)

# --- Map id -> entry_point from canonical_solutions (000_funcname.py)
ID_TO_ENTRY = {}
for f in sorted(CANON_DIR.glob("*.py")):
    m = re.match(r"^(\d{3})_(.+)\.py$", f.name)
    if m:
        ID_TO_ENTRY[int(m.group(1))] = m.group(2)

# ========== Helpers ==========
def read_text(p: pathlib.Path) -> str:
    return p.read_text(encoding="utf-8")

def write_text(p: pathlib.Path, s: str):
    p.write_text(s, encoding="utf-8")

def load_solution_module(sol_path: pathlib.Path) -> types.ModuleType:
    """Load generated solution into a fresh module, also expose as 'candidate'."""
    code = read_text(sol_path)
    mod = types.ModuleType("candidate")
    exec(compile(code, str(sol_path), "exec"), mod.__dict__)
    sys.modules["candidate"] = mod
    return mod

def load_test_module_with_injection(test_path: pathlib.Path, entry_point: str, func_obj) -> types.ModuleType:
    """
    Load test file as a fresh module and inject the function:
    - in builtins (bare calls)
    - in test module globals
    - keep sys.modules clean between runs
    """
    mod_name = f"test_generated_{hash((str(test_path), os.urandom(4)))}"
    spec = importlib.util.spec_from_file_location(mod_name, str(test_path))
    if spec is None or spec.loader is None:
        raise RuntimeError(f"Cannot load tests from {test_path}")

    tmod = importlib.util.module_from_spec(spec)
    # injection
    setattr(builtins, entry_point, func_obj)
    tmod.__dict__[entry_point] = func_obj
    sys.modules[mod_name] = tmod
    spec.loader.exec_module(tmod)
    return tmod

def run_tests(test_module: types.ModuleType) -> unittest.TestResult:
    stream = io.StringIO()
    suite = unittest.defaultTestLoader.loadTestsFromModule(test_module)
    runner = unittest.TextTestRunner(stream=stream, verbosity=0)
    result = runner.run(suite)
    result._stdout = stream.getvalue()
    return result

def extract_failed_tests(result: unittest.TestResult):
    """Return list of dicts: {'id': id_str, 'trace': tb_str} for failures+errors."""
    fails = []
    for tc, tb in list(result.failures) + list(result.errors):
        fails.append({"id": tc.id(), "trace": tb})
    return fails

def parse_test_id(test_id: str):
    """
    unittest id() format: 'module.ClassName.test_method'
    Returns (class_name, method_name).
    """
    parts = test_id.split(".")
    method = parts[-1]
    cls = parts[-2] if len(parts) >= 2 else None
    return cls, method

def extract_test_method_code(test_path: pathlib.Path, class_name: str, method_name: str) -> tuple[str, tuple[int,int]]:
    """
    Extract the exact text of a test method by scanning the file.
    Returns (method_text, (start_idx, end_idx)) with indices into the file's lines list.
    """
    src = read_text(test_path)
    lines = src.splitlines(keepends=True)

    # Find "class ClassName(" line
    class_pat = re.compile(rf'^\s*class\s+{re.escape(class_name)}\s*\(', re.M)
    class_match = None
    for i, ln in enumerate(lines):
        if class_pat.match(ln):
            class_match = i
            break
    if class_match is None:
        raise ValueError(f"Class {class_name} not found in {test_path.name}")

    # Within class block, find 'def method_name(' at greater indent
    # Determine class indent
    class_indent = len(lines[class_match]) - len(lines[class_match].lstrip())
    method_pat = re.compile(rf'^\s*def\s+{re.escape(method_name)}\s*\(', re.M)

    start = end = None
    for i in range(class_match + 1, len(lines)):
        ln = lines[i]
        indent = len(ln) - len(ln.lstrip())
        if indent <= class_indent and ln.strip():  # class block ended
            break
        if start is None and method_pat.match(ln):
            start = i
            # capture until next def at same indent or class end
            for j in range(i + 1, len(lines)):
                ln2 = lines[j]
                ind2 = len(ln2) - len(ln2.lstrip())
                if re.match(r'^\s*def\s+\w+\s*\(', ln2) and ind2 == indent:
                    end = j
                    break
                # class end
                if ind2 <= class_indent and ln2.strip():
                    end = j
                    break
            if end is None:
                end = len(lines)
            break

    if start is None:
        raise ValueError(f"Method {method_name} not found in {test_path.name}")
    method_text = "".join(lines[start:end])
    return method_text, (start, end)

def replace_block_in_file(test_path: pathlib.Path, span: tuple[int,int], new_block: str):
    src = read_text(test_path)
    lines = src.splitlines(keepends=True)
    start, end = span
    # Ensure trailing newline
    if not new_block.endswith("\n"):
        new_block += "\n"
    lines[start:end] = [new_block]
    write_text(test_path, "".join(lines))

def delete_block_in_file(test_path: pathlib.Path, span: tuple[int,int]):
    replace_block_in_file(test_path, span, "")

def prompt_gemini_fix(function_src: str, failing_test_src: str, error_text: str, method_name: str) -> str:
    """
    Ask Gemini to fix the failing test method. Returns ONLY the method code text.
    """
    prompt = f"""
You are given a Python function under test and a single failing unittest method.
Fix ONLY the test method so that it correctly tests the intended behavior of the function.

Rules:
- Do NOT change the function under test.
- Keep it in Python's standard unittest style.
- Prefer adjusting inputs/expected values or using correct assertions.
- Keep the method name the same: {method_name}.
- Return ONLY the updated method code that starts with `def {method_name}(` and ends at the end of the method. No extra text, no backticks, no comment explainations.

Function under test:
```python
{function_src}
```
Failing test method:
```python
{failing_test_src}
```
Error message/trace:
```python
{error_text}
```
"""
    resp = gmodel.generate_content(prompt)
    text = (resp.text or "").strip()
    # strip possible fences
    text = re.sub(r"^\s*(?:python)?\s*", "", text) 
    text = re.sub(r"```", "", text)
    text = re.sub(r"\s*\s*$", "", text)
    print(text.strip())
    return text.strip()

def gather_function_source(sol_path: pathlib.Path, entry_point: str) -> str:
    code = read_text(sol_path)
    # Try to extract the function block; fallback to whole file
    # This regex matches the function definition and its body (non-greedy, up to next def or end of file)
    pattern = rf"^def\s+{re.escape(entry_point)}\s*\([^\)]*\):(?:\n(?:[ \t]+.*\n?)*)*"
    m = re.search(pattern, code, flags=re.M)
    if m:
        return m.group(0)
    return code



In [33]:
# ========== Main loop ==========
rows = []
for sol_file in sorted(GEN_SOL_DIR.glob("Solution_*.py"), key=lambda p: int(re.search(r"(\d+)", p.stem).group(1))):
    idx = int(re.search(r"(\d+)", sol_file.stem).group(1))
    test_file = TESTS_DIR / f"HumanEval_{idx}.py"
    if not test_file.exists():
        print(f"[SKIP] {idx:03d}: no test file {test_file.name}")
        continue
    if idx not in ID_TO_ENTRY:
        print(f"[SKIP] {idx:03d}: no entry_point mapping from canonical filenames.")
        continue
    entry_point = ID_TO_ENTRY[idx]
    print(f"\n=== [{idx:03d}] {entry_point} ===")

    original_test_src = read_text(test_file)  # keep original for span math
    func_mod = load_solution_module(sol_file)
    if not hasattr(func_mod, entry_point):
        print(f"  [WARN] Function '{entry_point}' not found in {sol_file.name}; skipping.")
        rows.append({"id": idx, "entry_point": entry_point, "initial_failures": None, "fixed": 0, "removed": 0, "final_passed": 0, "final_total": 0})
        continue

    # -- First run --
    tmod = load_test_module_with_injection(test_file, entry_point, getattr(func_mod, entry_point))
    result = run_tests(tmod)
    total = result.testsRun
    fails = extract_failed_tests(result)
    print(f"  First run: {total - len(fails)}/{total} passed ({len(fails)} failing)")

    fixed_count = 0
    removed_count = 0

    # Feedback loop: fix each failing test method up to 3 attempts; if still failing -> delete
    for fail in fails:
        cls_name, meth_name = parse_test_id(fail["id"])
        if not cls_name or not meth_name:
            continue

        try:
            failing_src, span = extract_test_method_code(test_file, cls_name, meth_name)
        except Exception as e:
            print(f"  [WARN] Cannot extract test {meth_name}: {e}")
            continue

        func_src = gather_function_source(sol_file, entry_point)
        attempt = 0
        success = False
        while attempt < MAX_FIX_ATTEMPTS_PER_TEST:
            attempt += 1
            print(f"    - Fixing {meth_name} (attempt {attempt})")
            try:
                new_method = prompt_gemini_fix(func_src, failing_src, fail["trace"], meth_name)
                if not new_method.strip().startswith(f"def {meth_name}("):
                    # If model changed the name or wrapped content, try to salvage by regex
                    m = re.search(rf'^\s*def\s+{re.escape(meth_name)}\s*\(.*', new_method, flags=re.M)
                    if m:
                        new_method = new_method[m.start():]
                replace_block_in_file(test_file, span, new_method)
            except Exception as e:
                print(f"      [LLM/Patch error] {e}")
                break

            # Re-run tests after patch
            func_mod = load_solution_module(sol_file)  # reload solution into 'candidate'
            tmod = load_test_module_with_injection(test_file, entry_point, getattr(func_mod, entry_point))
            res2 = run_tests(tmod)
            current_fails = extract_failed_tests(res2)
            still_failing_ids = {f["id"] for f in current_fails}
            if fail["id"] not in still_failing_ids:
                print("      ✓ fixed")
                fixed_count += 1
                success = True
                # refresh original_test_src and span indices for further patches in same file
                original_test_src = read_text(test_file)
                # recompute span for safety if another failure in same method appears (rare)
                break
            else:
                # update failing_src / span from the latest test file (may have changed formatting)
                try:
                    failing_src, span = extract_test_method_code(test_file, cls_name, meth_name)
                except Exception:
                    pass  # keep previous span

        if not success:
            print("      ✗ still failing after 3 attempts → deleting test")
            delete_block_in_file(test_file, span)
            removed_count += 1

    # Final run after all fixes/removals
    func_mod = load_solution_module(sol_file)
    tmod = load_test_module_with_injection(test_file, entry_point, getattr(func_mod, entry_point))
    final_res = run_tests(tmod)
    final_total = final_res.testsRun
    final_fails = len(extract_failed_tests(final_res))
    final_passed = final_total - final_fails
    print(f"  Final: {final_passed}/{final_total} passed (fixed {fixed_count}, removed {removed_count})")

    rows.append({
        "id": idx,
        "entry_point": entry_point,
        "initial_failures": len(fails),
        "fixed": fixed_count,
        "removed": removed_count,
        "final_passed": final_passed,
        "final_total": final_total,
    })
    if idx == 6:
        break  # DEBUG

df_fix = pd.DataFrame(rows).sort_values("id").reset_index(drop=True)
df_fix



=== [000] has_close_elements ===
  First run: 10/10 passed (0 failing)
  Final: 10/10 passed (fixed 0, removed 0)

=== [001] separate_paren_groups ===
  First run: 10/10 passed (0 failing)
  Final: 10/10 passed (fixed 0, removed 0)

=== [002] truncate_number ===
  First run: 10/10 passed (0 failing)
  Final: 10/10 passed (fixed 0, removed 0)

=== [003] below_zero ===
  First run: 10/10 passed (0 failing)
  Final: 10/10 passed (fixed 0, removed 0)

=== [004] mean_absolute_deviation ===
  First run: 10/10 passed (0 failing)
  Final: 10/10 passed (fixed 0, removed 0)

=== [005] intersperse ===
  First run: 10/10 passed (0 failing)
  Final: 10/10 passed (fixed 0, removed 0)

=== [006] parse_nested_parens ===
  First run: 9/10 passed (1 failing)
    - Fixing test_very_deep_nesting_in_one_group (attempt 1)
prompt:
 
You are given a Python function under test and a single failing unittest method.
Fix ONLY the test method so that it correctly tests the intended behavior of the function.

Rule

SyntaxError: invalid syntax (HumanEval_6.py, line 63)