In [1]:
import os
import re
import pathlib
from datasets import load_dataset
import google.generativeai as genai

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL = "gemini-2.5-flash"
OUT_DIR = pathlib.Path("generated_solutions")
OUT_DIR.mkdir(parents=True, exist_ok=True)

In [3]:
def extract_docstring(prompt_text: str) -> list:
    """
    Extract all triple-quoted docstrings from the input text.
    Returns a list of docstring contents as strings.
    """
    matches = re.findall(r'("""|\'\'\')(.*?)(\1)', prompt_text, flags=re.DOTALL)
    # Join all docstrings into a single string (if needed)
    doc_text = "\n\n".join([m[1].strip() for m in matches])

    return doc_text

In [4]:
api_key = os.environ.get("GOOGLE_API_KEY")
if not api_key:
    raise RuntimeError("Please set GOOGLE_API_KEY in your environment.")
genai.configure(api_key=api_key)
model = genai.GenerativeModel(MODEL)

ds = load_dataset("openai/openai_humaneval")["test"]  # 164 items
print(f"Loaded HumanEval with {len(ds)} problems.")

Loaded HumanEval with 164 problems.


In [8]:
for idx, item in enumerate(ds):
    doc = extract_docstring(item["prompt"])
    prompt = f'"""\n{doc}\n"""\nPlease generate a plausible solution for this problem.'
    # if(idx == 32):
    #     print(prompt)
    try:
        resp = model.generate_content(prompt)
        text = (resp.text or "").strip()
        # If fences sneak in, strip them.
        text = re.sub(r"^```(?:python)?\s*", "", text)
        text = re.sub(r"\s*```$", "", text)

        out_path = OUT_DIR / f"Solution_{idx}.py"
        out_path.write_text(text, encoding="utf-8")
    except Exception as e:
        print(f"[WARN] Problem {idx} failed: {e}")
        
print("Generation Complete")

Generation Complete


In [20]:
# Remove everything above the first ```python fence (inclusive) in each .py file under DIR
# Usage: set DIR below and run this cell.

from pathlib import Path
import re

# <<< EDIT THIS >>>
DIR = Path("generated_solutions")  # e.g., Path("/home/jason/project")
# <<< ----------- >>>

FENCE_RE = re.compile(r"```[ \t]*python", re.IGNORECASE)

def clean_file(p: Path) -> bool:
    """
    Find the first ```python fence in file p and remove everything above it
    (including the fence itself). Returns True if the file was modified.
    """
    try:
        text = p.read_text(encoding="utf-8", errors="ignore")
    except Exception as e:
        print(f"[skip read error] {p}: {e}")
        return False

    m = FENCE_RE.search(text)
    if not m:
        return False  # no fence, leave file as-is

    new_text = text[m.end():].lstrip("\r\n")  # drop fence + any immediate blank lines
    if new_text == text:
        return False

    try:
        p.write_text(new_text, encoding="utf-8")
        return True
    except Exception as e:
        print(f"[skip write error] {p}: {e}")
        return False

def run(dir_path: Path):
    if not dir_path.exists():
        raise FileNotFoundError(f"Directory not found: {dir_path}")

    py_files = [p for p in dir_path.rglob("*.py") if p.is_file()]
    changed = 0
    for p in py_files:
        if clean_file(p):
            changed += 1
            print(f"[modified] {p}")

    print(f"\nDone. Scanned: {len(py_files)} .py files; Modified: {changed}")


Done. Scanned: 164 .py files; Modified: 0


# -------------------------------------------------------------------------------------

In [75]:
# === CONFIG ===
import os, re, io, sys, pathlib, types, importlib.util, builtins, json, textwrap
import unittest
import google.generativeai as genai
import pandas as pd
import textwrap

GEN_SOL_DIR = pathlib.Path("generated_solutions").resolve()     # Solution_<i>.py
TESTS_DIR   = pathlib.Path("generated_tests").resolve()         # HumanEval_<i>.py
CANON_DIR   = pathlib.Path("canonical_solutions").resolve()     # 000_<entry>.py  (to infer entry_point)
MAX_FIX_ATTEMPTS_PER_TEST = 3
MODEL = "gemini-2.5-flash"

assert GEN_SOL_DIR.is_dir(), f"Missing {GEN_SOL_DIR}"
assert TESTS_DIR.is_dir(),   f"Missing {TESTS_DIR}"
assert CANON_DIR.is_dir(),   f"Missing {CANON_DIR}"

# --- Gemini client ---
api_key = os.environ.get("GOOGLE_API_KEY")
if not api_key:
    raise RuntimeError("Set GOOGLE_API_KEY before running.")
genai.configure(api_key=api_key)
gmodel = genai.GenerativeModel(MODEL)

# --- Map id -> entry_point from canonical_solutions (000_funcname.py)
ID_TO_ENTRY = {}
for f in sorted(CANON_DIR.glob("*.py")):
    m = re.match(r"^(\d{3})_(.+)\.py$", f.name)
    if m:
        ID_TO_ENTRY[int(m.group(1))] = m.group(2)

# ========== Helpers ==========
def read_text(p: pathlib.Path) -> str:
    return p.read_text(encoding="utf-8")

def write_text(p: pathlib.Path, s: str):
    p.write_text(s, encoding="utf-8")

def load_solution_module(sol_path: pathlib.Path) -> types.ModuleType:
    """Load generated solution into a fresh module, also expose as 'candidate'."""
    code = read_text(sol_path)
    mod = types.ModuleType("candidate")
    exec(compile(code, str(sol_path), "exec"), mod.__dict__)
    sys.modules["candidate"] = mod
    return mod

def load_test_module_with_injection(test_path: pathlib.Path, entry_point: str, func_obj) -> types.ModuleType:
    """
    Load test file as a fresh module and inject the function:
    - in builtins (bare calls)
    - in test module globals
    - keep sys.modules clean between runs
    """
    mod_name = f"test_generated_{hash((str(test_path), os.urandom(4)))}"
    spec = importlib.util.spec_from_file_location(mod_name, str(test_path))
    if spec is None or spec.loader is None:
        raise RuntimeError(f"Cannot load tests from {test_path}")

    tmod = importlib.util.module_from_spec(spec)
    # injection
    setattr(builtins, entry_point, func_obj)
    tmod.__dict__[entry_point] = func_obj
    sys.modules[mod_name] = tmod
    spec.loader.exec_module(tmod)
    return tmod

def run_tests(test_module: types.ModuleType) -> unittest.TestResult:
    stream = io.StringIO()
    suite = unittest.defaultTestLoader.loadTestsFromModule(test_module)
    runner = unittest.TextTestRunner(stream=stream, verbosity=0)
    result = runner.run(suite)
    result._stdout = stream.getvalue()
    return result

def extract_failed_tests(result: unittest.TestResult):
    """Return list of dicts: {'id': id_str, 'trace': tb_str} for failures+errors."""
    fails = []
    for tc, tb in list(result.failures) + list(result.errors):
        fails.append({"id": tc.id(), "trace": tb})
    return fails

def parse_test_id(test_id: str):
    """
    unittest id() format: 'module.ClassName.test_method'
    Returns (class_name, method_name).
    """
    parts = test_id.split(".")
    method = parts[-1]
    cls = parts[-2] if len(parts) >= 2 else None
    return cls, method

def extract_test_method_code(test_path: pathlib.Path, class_name: str, method_name: str) -> tuple[str, tuple[int,int]]:
    """
    Extract the exact text of a test method by scanning the file.
    Returns (method_text, (start_idx, end_idx)) with indices into the file's lines list.
    """
    src = read_text(test_path)
    lines = src.splitlines(keepends=True)

    # Find "class ClassName(" line
    class_pat = re.compile(rf'^\s*class\s+{re.escape(class_name)}\s*\(', re.M)
    class_match = None
    for i, ln in enumerate(lines):
        if class_pat.match(ln):
            class_match = i
            break
    if class_match is None:
        raise ValueError(f"Class {class_name} not found in {test_path.name}")

    # Within class block, find 'def method_name(' at greater indent
    # Determine class indent
    class_indent = len(lines[class_match]) - len(lines[class_match].lstrip())
    method_pat = re.compile(rf'^\s*def\s+{re.escape(method_name)}\s*\(', re.M)

    start = end = None
    for i in range(class_match + 1, len(lines)):
        ln = lines[i]
        indent = len(ln) - len(ln.lstrip())
        if indent <= class_indent and ln.strip():  # class block ended
            break
        if start is None and method_pat.match(ln):
            start = i
            # capture until next def at same indent or class end
            for j in range(i + 1, len(lines)):
                ln2 = lines[j]
                ind2 = len(ln2) - len(ln2.lstrip())
                if re.match(r'^\s*def\s+\w+\s*\(', ln2) and ind2 == indent:
                    end = j
                    break
                # class end
                if ind2 <= class_indent and ln2.strip():
                    end = j
                    break
            if end is None:
                end = len(lines)
            break

    if start is None:
        raise ValueError(f"Method {method_name} not found in {test_path.name}")
    method_text = "".join(lines[start:end])
    return method_text, (start, end)

def replace_block_in_file(test_path: pathlib.Path, span: tuple[int,int], new_block: str):
    src = read_text(test_path)
    lines = src.splitlines(keepends=True)
    start, end = span
    # Ensure trailing newline
    if not new_block.endswith("\n"):
        new_block += "\n"
    lines[start:end] = [new_block]
    write_text(test_path, "".join(lines))

def delete_block_in_file(test_path: pathlib.Path, span: tuple[int,int]):
    replace_block_in_file(test_path, span, "")

import re, textwrap

def clean_and_align_method_code(raw_text: str, target_indent: str, method_name: str) -> str:
    """
    Normalize a test method snippet and make sure the *first def line* is indented
    to `target_indent`, with the body keeping its relative indentation.
    """
    # remove code fences / normalize whitespace
    text = (raw_text or "")
    text = re.sub(r"```(?:python)?", "", text)
    text = re.sub(r"```", "", text)
    text = text.replace("\r\n", "\n").replace("\r", "\n").expandtabs(4).strip()

    # keep only the requested method if extra prose is present
    m = re.search(rf'^\s*def\s+{re.escape(method_name)}\s*\(.*', text, flags=re.M)
    if m:
        text = text[m.start():]

    # dedent everything so relative body indentation is preserved
    text = textwrap.dedent(text).strip("\n")

    # split, then force-indent the first def line
    lines = text.split("\n")
    if lines:
        lines[0] = target_indent + lines[0].lstrip()  # <-- ensure the first def is indented

    # indent all remaining lines by target_indent too (keeps relative indents from dedent)
    for i in range(1, len(lines)):
        if lines[i].strip():
            lines[i] = target_indent + lines[i]

    out = "\n".join(lines)
    if not out.endswith("\n"):
        out += "\n"
    return out


def prompt_gemini_fix(function_src: str, failing_test_src: str, error_text: str, method_name: str) -> str:
    """
    Ask Gemini to fix the failing test method. Returns ONLY the method code text.
    """
    prompt = f"""
You are given a Python function under test and a single failing unittest method.
Fix ONLY the test method so that it correctly tests the intended behavior of the function.

Rules:
- Do NOT change the function under test.
- Keep it in Python's standard unittest style.
- Prefer adjusting inputs/expected values or using correct assertions.
- Keep the method name the same: {method_name}.
- Return ONLY the updated method code that starts with `def {method_name}(` and ends at the end of the method. No extra text, no backticks, no comment explainations.

Function under test:
```python
{function_src}
```
Failing test method:
```python
{failing_test_src}
```
Error message/trace:
```python
{error_text}
```
"""
    resp = gmodel.generate_content(prompt)
    text = (resp.text or "").strip()
    # strip possible fences
    text = re.sub(r"^\s*(?:python)?\s*", "", text) 
    text = re.sub(r"```", "", text)
    text = re.sub(r"\s*\s*$", "", text)
    return text.strip()

def gather_function_source(sol_path: pathlib.Path, entry_point: str) -> str:
    code = read_text(sol_path)
    # Try to extract the function block; fallback to whole file
    # This regex matches the function definition and its body (non-greedy, up to next def or end of file)
    pattern = rf"^def\s+{re.escape(entry_point)}\s*\([^\)]*\):(?:\n(?:[ \t]+.*\n?)*)*"
    m = re.search(pattern, code, flags=re.M)
    if m:
        return m.group(0)
    return code



In [None]:
# ========== Main loop ==========
rows = []
for sol_file in sorted(GEN_SOL_DIR.glob("Solution_*.py"), key=lambda p: int(re.search(r"(\d+)", p.stem).group(1))):

    idx = int(re.search(r"(\d+)", sol_file.stem).group(1))
    # if (idx != 147):
    #     continue # DEBUG
    test_file = TESTS_DIR / f"HumanEval_{idx}.py"
    if not test_file.exists():
        print(f"[SKIP] {idx:03d}: no test file {test_file.name}")
        continue
    if idx not in ID_TO_ENTRY:
        print(f"[SKIP] {idx:03d}: no entry_point mapping from canonical filenames.")
        continue
    entry_point = ID_TO_ENTRY[idx]
    print(f"\n=== [{idx:03d}] {entry_point} ===")

    original_test_src = read_text(test_file)  # keep original for span math
    func_mod = load_solution_module(sol_file)
    if not hasattr(func_mod, entry_point):
        print(f"  [WARN] Function '{entry_point}' not found in {sol_file.name}; skipping.")
        rows.append({"id": idx, "entry_point": entry_point, "initial_failures": None, "fixed": 0, "removed": 0, "final_passed": 0, "final_total": 0})
        continue
    
    # -- First run --
    tmod = load_test_module_with_injection(test_file, entry_point, getattr(func_mod, entry_point))
    result = run_tests(tmod)
    total = result.testsRun
    fails = extract_failed_tests(result)
    print(f"  First run: {total - len(fails)}/{total} passed ({len(fails)} failing)")

    fixed_count = 0
    removed_count = 0

    # Feedback loop: fix each failing test method up to 3 attempts; if still failing -> delete
    for fail in fails:
        cls_name, meth_name = parse_test_id(fail["id"])
        if not cls_name or not meth_name:
            continue

        try:
            failing_src, span = extract_test_method_code(test_file, cls_name, meth_name)
        except Exception as e:
            print(f"  [WARN] Cannot extract test {meth_name}: {e}")
            continue

        func_src = gather_function_source(sol_file, entry_point)
        attempt = 0
        success = False
        while attempt < MAX_FIX_ATTEMPTS_PER_TEST:
            attempt += 1
            print(f"    - Fixing {meth_name} (attempt {attempt})")
            try:
                new_method = prompt_gemini_fix(func_src, failing_src, fail["trace"], meth_name)
                if not new_method.strip().startswith(f"def {meth_name}("):
                    # If model changed the name or wrapped content, try to salvage by regex
                    m = re.search(rf'^\s*def\s+{re.escape(meth_name)}\s*\(.*', new_method, flags=re.M)
                    if m:
                        new_method = new_method[m.start():]

                original_src_lines = read_text(test_file).splitlines(keepends=True)
                target_indent = re.match(r'^(\s*)', original_src_lines[span[0]]).group(1)
                new_method_aligned = clean_and_align_method_code(new_method, target_indent, meth_name)
                replace_block_in_file(test_file, span, new_method_aligned)
            except Exception as e:
                print(f"      [LLM/Patch error] {e}")
                break

            # Re-run tests after patch
            func_mod = load_solution_module(sol_file)  # reload solution into 'candidate'
            tmod = load_test_module_with_injection(test_file, entry_point, getattr(func_mod, entry_point))
            res2 = run_tests(tmod)
            current_fails = extract_failed_tests(res2)
            still_failing_ids = {f["id"] for f in current_fails}
            if fail["id"] not in still_failing_ids:
                print("      ✓ fixed")
                fixed_count += 1
                success = True
                # refresh original_test_src and span indices for further patches in same file
                original_test_src = read_text(test_file)
                # recompute span for safety if another failure in same method appears (rare)
                break
            else:
                # update failing_src / span from the latest test file (may have changed formatting)
                try:
                    failing_src, span = extract_test_method_code(test_file, cls_name, meth_name)
                except Exception:
                    pass  # keep previous span

        if not success:
            print("      ✗ still failing after 3 attempts → deleting test")
            delete_block_in_file(test_file, span)
            removed_count += 1

    # Final run after all fixes/removals
    func_mod = load_solution_module(sol_file)
    tmod = load_test_module_with_injection(test_file, entry_point, getattr(func_mod, entry_point))
    final_res = run_tests(tmod)
    final_total = final_res.testsRun
    final_fails = len(extract_failed_tests(final_res))
    final_passed = final_total - final_fails
    print(f"  Final: {final_passed}/{final_total} passed (fixed {fixed_count}, removed {removed_count})")

    rows.append({
        "id": idx,
        "entry_point": entry_point,
        "initial_failures": len(fails),
        "fixed": fixed_count,
        "removed": removed_count,
        "final_passed": final_passed,
        "final_total": final_total,
    })
    # if idx == 7:
    #     break  # DEBUG

df_fix = pd.DataFrame(rows).sort_values("id").reset_index(drop=True)
df_fix


=== [147] get_max_triples ===
  First run: 0/8 passed (8 failing)
    - Fixing test_n_100_medium_large (attempt 1)
      ✓ fixed
    - Fixing test_n_2 (attempt 1)
      ✓ fixed
    - Fixing test_n_3 (attempt 1)
      ✓ fixed
    - Fixing test_n_4 (attempt 1)
      ✓ fixed
    - Fixing test_n_5_example (attempt 1)
      ✓ fixed
    - Fixing test_n_6 (attempt 1)
      ✓ fixed
    - Fixing test_n_8 (attempt 1)
      ✓ fixed
    - Fixing test_n_9 (attempt 1)
      ✓ fixed
  Final: 6/8 passed (fixed 8, removed 0)


In [None]:
# df_fix.to_csv("C:\\Users\\zhang\\Downloads\\fix.csv", index=False)

# ----------------------------------------------------------------------

In [99]:
import builtins, importlib.util, io, pathlib, re, sys, types, unittest, warnings, textwrap
import pandas as pd
from coverage import Coverage
from coverage.exceptions import CoverageWarning

# Paths
GEN_SOL_DIR = pathlib.Path("generated_solutions").resolve()
TESTS_DIR   = pathlib.Path("generated_tests").resolve()

warnings.filterwarnings("ignore", category=CoverageWarning)

def read_text(p: pathlib.Path) -> str:
    return p.read_text(encoding="utf-8")

def load_solution_as_module(sol_path: pathlib.Path) -> types.ModuleType:
    code = read_text(sol_path)
    mod = types.ModuleType("candidate")
    compiled = compile(code, filename=str(sol_path), mode="exec")
    exec(compiled, mod.__dict__)
    return mod

def infer_entry_point_from_solution(code: str) -> str:
    m = re.search(r'^\s*def\s+([A-Za-z_]\w*)\s*\(', code, flags=re.M)
    return m.group(1) if m else ""

def load_test_module(test_path: pathlib.Path, entry_point_name: str, func_obj) -> types.ModuleType:
    setattr(builtins, entry_point_name, func_obj)
    spec = importlib.util.spec_from_file_location("test_generated", str(test_path))
    tmod = importlib.util.module_from_spec(spec)
    tmod.__dict__[entry_point_name] = func_obj
    sys.modules["test_generated"] = tmod
    spec.loader.exec_module(tmod)
    return tmod

def run_suite_under_coverage(sol_path: pathlib.Path, suite: unittest.TestSuite):
    cov = Coverage(include=[str(sol_path)], branch=True)
    cov.erase()
    cov.start()
    result = unittest.TextTestRunner(stream=io.StringIO(), verbosity=0).run(suite)
    cov.stop()
    cov.save()
    return cov, result

def coverage_metrics_for_file(cov: Coverage, filename: str):
    try:
        data = cov.get_data()
        if not data:
            return 0.0, [], [], []
        measured = data.measured_files()
        if filename not in measured:
            for f in measured:
                if pathlib.Path(f).name == pathlib.Path(filename).name:
                    filename = f
                    break
            else:
                return 0.0, [], [], []
        _, statements, _, missing, partial = cov.analysis2(filename)
        pct = 0.0 if not statements else round(100.0 * (len(statements) - len(missing)) / len(statements), 2)
        return pct, statements, missing, partial
    except Exception:
        return 0.0, [], [], []

def make_augmentation_prompt(solution_code: str, missing_lines, partial_lines) -> str:
    feedback = []
    if missing_lines:
        feedback.append(f"Uncovered lines: {sorted(set(missing_lines))}")
    if partial_lines:
        feedback.append(f"Lines with partial (branch) coverage: {sorted(set(partial_lines))}")
    feedback_text = "\n".join(feedback) if feedback else "No uncovered lines detected. Try to add more edge-case tests."
    example = f"""
    def some_test_cases(self):
        self.assertFalse(methodname())
    """
    prompt = f"""
You are given a Python function (solution code) and coverage feedback from running our current unittest file.

Goal:
- Generate additional **Python unittest** test cases to improve code coverage, focusing on the uncovered or partially covered lines listed below.
- Append tests that are **compatible** with the existing test file's style (standard unittest) and **do not alter the solution code**.

Solution code:
```python
{solution_code}
```
Coverage feedback:
{feedback_text}

Instructions:
- Write only valid Python unittest test cases that can be appended to the existing test class. For example:
```python
{example}
```
- Target the uncovered/partial lines specifically (exercise missing branches, edge inputs, error paths).
- Do NOT include triple backticks in your output.
- Return only the additional test methods like the example.
"""
    return prompt

In [85]:
rows = []

for sol_file in sorted(GEN_SOL_DIR.glob("Solution_*.py"), key=lambda p: int(re.search(r"(\d+)", p.stem).group(1))):
    idx = int(re.search(r"(\d+)", sol_file.stem).group(1))
    test_file = TESTS_DIR / f"HumanEval_{idx}.py"
    if not test_file.exists():
        print(f"[SKIP] {idx:03d} missing test file")
        continue
    sol_code = read_text(sol_file)
    entry_point = infer_entry_point_from_solution(sol_code)
    if not entry_point:
        print(f"[WARN] {idx:03d} could not infer entry point in {sol_file.name}")
        continue

    sol_mod = load_solution_as_module(sol_file)
    if not hasattr(sol_mod, entry_point):
        print(f"[WARN] {idx:03d} entry point {entry_point} not found in {sol_file.name}")
        continue

    func_obj = getattr(sol_mod, entry_point)
    test_mod = load_test_module(test_file, entry_point, func_obj)
    suite = unittest.defaultTestLoader.loadTestsFromModule(test_mod)

    cov, result = run_suite_under_coverage(sol_file, suite)
    pct, statements, missing, partial = coverage_metrics_for_file(cov, str(sol_file))

    # print(f"\n[{idx:03d}] entry: {entry_point}, coverage: {pct:5.1f}%, statement: {len(statements)}, missing: {len(missing)}, partial: {len(partial)}")
    if missing or partial:
        prompt = make_augmentation_prompt(sol_code, missing, partial)
        # print("Generated augmentation prompt:\n", prompt)

    rows.append({
        "id": idx,
        "entry_point": entry_point,
        "coverage": pct,
        "statements": len(statements),
        "missing": missing,
        "partial": partial
    })

df_cov_aug = pd.DataFrame(rows).sort_values("id").reset_index(drop=True)
print(f"Average Coverage Rate: {df_cov_aug['coverage'].mean():.2f}%")
df_cov_aug

count_up_to(5) => [2, 3]
count_up_to(11) => [2, 3, 5, 7]
count_up_to(0) => []
count_up_to(20) => [2, 3, 5, 7, 11, 13, 17, 19]
count_up_to(1) => []
count_up_to(18) => [2, 3, 5, 7, 11, 13, 17]
count_up_to(2) => []
count_up_to(3) => [2]
s='abcde', c='ae' -> ('bcd', False)
s='abcdef', c='b' -> ('acdef', False)
s='abcdedcba', c='ab' -> ('cdedc', True)
s='racecar', c='c' -> ('raear', True)
s='madam', c='' -> ('madam', True)
s='hello', c='l' -> ('heo', False)
s='a', c='a' -> ('', True)
s='a', c='b' -> ('a', True)
s='', c='xyz' -> ('', True)
x_or_y(7, 34, 12) == 34
x_or_y(15, 8, 5) == 5
x_or_y(2, 100, 200) == 100
x_or_y(4, 'apple', 'banana') == banana
x_or_y(1, 'prime', 'not prime') == not prime
x_or_y(11, 'found', 'missed') == found
Average Coverage Rate: 73.35%


Unnamed: 0,id,entry_point,coverage,statements,missing,partial
0,0,has_close_elements,80.00,10,"[1, 3]",1-3
1,1,separate_paren_groups,93.33,15,[1],1
2,2,truncate_number,60.00,5,"[1, 3]",1-3
3,3,below_zero,85.71,7,[1],1
4,4,mean_absolute_deviation,77.78,9,"[1, 3]",1-3
...,...,...,...,...,...,...
159,159,eat,80.00,5,[1],1
160,160,do_algebra,87.50,8,[1],1
161,161,solve,93.75,16,[1],1
162,162,string_to_md5,57.14,7,"[1, 2, 4]",1-4


In [None]:
# df_cov_aug.to_csv("C:\\Users\\zhang\\Downloads\\cov_aug.csv", index=False)

In [102]:
# === Augment tests with Gemini based on coverage gaps, then re-measure coverage ===
import os, re, io, sys, pathlib, types, importlib.util, builtins, unittest, textwrap, warnings
import google.generativeai as genai
import pandas as pd
from coverage.exceptions import CoverageWarning

# Reuse your existing helpers & imports already defined in your notebook:
# - read_text, load_solution_as_module, infer_entry_point_from_solution,
#   load_test_module, run_suite_under_coverage, coverage_metrics_for_file,
#   make_augmentation_prompt, GEN_SOL_DIR, TESTS_DIR
# Assumes: warnings.filterwarnings("ignore", category=CoverageWarning) already set

# --- Configure Gemini ---
api_key = os.environ.get("GOOGLE_API_KEY")
if not api_key:
    raise RuntimeError("GOOGLE_API_KEY not set. `export GOOGLE_API_KEY=...`")
genai.configure(api_key=api_key)
gmodel = genai.GenerativeModel("gemini-2.5-flash")

# ---------- small utilities for patching tests ----------

def first_testcase_class_span(test_src: str):
    """
    Find the first `class XXX(unittest.TestCase):` block and return:
      (class_name, start_idx, end_idx, class_indent_str, inner_indent_str)
    inner_indent_str is the indentation used for a method inside the class
    (e.g., a tab or 4 spaces). If no method is found, defaults to 4 spaces.
    """
    lines = test_src.splitlines(keepends=True)
    class_pat = re.compile(r'^\s*class\s+([A-Za-z_]\w*)\s*\(\s*unittest\.TestCase\s*\)\s*:\s*$', re.M)

    cls_name = None
    start = end = None
    class_indent = ""
    for i, ln in enumerate(lines):
        m = class_pat.match(ln)
        if m:
            cls_name = m.group(1)
            class_indent = re.match(r'^(\s*)', ln).group(1)
            # find end of class (next non-blank line with indent <= class_indent)
            for j in range(i + 1, len(lines)):
                ln2 = lines[j]
                ind2 = len(ln2) - len(ln2.lstrip())
                if ln2.strip() and ind2 <= len(class_indent):
                    end = j
                    break
            if end is None:
                end = len(lines)
            start = i
            break

    if cls_name is None:
        return None, None, None, None, None

    # Detect inner indent by finding the first method def in the class
    inner_indent = "    "  # default 4 spaces
    method_pat = re.compile(r'^\s*def\s+\w+\s*\(', re.M)
    for k in range(start + 1, end):
        ln = lines[k]
        if method_pat.match(ln):
            line_indent = re.match(r'^(\s*)', ln).group(1)
            # inner indent is what's beyond the class indent
            inner_indent = line_indent[len(class_indent):]
            if not inner_indent:
                inner_indent = "    "
            break

    return cls_name, start, end, class_indent, inner_indent

def clean_llm_methods(raw_text: str, test_name_prefix: str = "test_") -> str:
    """
    Robustly sanitize LLM output and return only *valid* unittest methods:

    - remove code fences / markdown
    - normalize newlines and tabs
    - strip ANY junk before 'def test_...:' on a line (e.g., ': def ...', ', def ...')
    - extract complete method blocks from each 'def test_*' to the line before the next 'def ' at the same indent
    - dedent each method independently to fix body indentation
    - join the cleaned methods with a single blank line

    Returns an empty string if nothing usable is found.
    """
    if not raw_text:
        return ""

    # 1) strip markdown/code fences & normalize whitespace
    text = re.sub(r"```(?:python)?", "", raw_text)
    text = re.sub(r"```", "", text)
    text = text.replace("\r\n", "\n").replace("\r", "\n").expandtabs(4)

    # 2) line-level cleanup: if a line contains 'def test_', drop everything before the 'def'
    cleaned_lines = []
    for ln in text.split("\n"):
        i = ln.find("def ")
        j = ln.find(f"def {test_name_prefix}")
        if j != -1:
            ln = ln[j:]               # drop junk like ': ' or ', ' before def
        elif i != -1 and re.search(rf"def\s+{re.escape(test_name_prefix)}", ln[i:]):
            ln = ln[i:]
        cleaned_lines.append(ln)
    text = "\n".join(cleaned_lines)

    # 3) extract method blocks that start with 'def test_*(' and end before the next top-level 'def '
    methods = []
    lines = text.split("\n")
    n = len(lines)
    i = 0
    def_line_re = re.compile(rf'^\s*def\s+{re.escape(test_name_prefix)}[\w]*\s*\(')
    any_def_re  = re.compile(r'^\s*def\s+\w+\s*\(')

    while i < n:
        if def_line_re.match(lines[i]):
            # capture block
            start = i
            start_indent = len(lines[i]) - len(lines[i].lstrip())
            i += 1
            while i < n:
                # next method at same or smaller indent ends this block
                if any_def_re.match(lines[i]) and (len(lines[i]) - len(lines[i].lstrip())) <= start_indent:
                    break
                i += 1
            block = "\n".join(lines[start:i])
            # dedent and trim trailing whitespace to normalize indentation
            block = textwrap.dedent(block).strip("\n")
            # ensure a colon at end of signature line
            if not block.split("\n", 1)[0].rstrip().endswith(":"):
                head, *rest = block.split("\n")
                block = head.rstrip() + ":\n" + ("\n".join(rest) if rest else "    pass")
            methods.append(block)
        else:
            i += 1

    return ("\n\n".join(methods)).strip()

def indent_into_class(methods_block: str, class_indent: str, inner_indent: str) -> str:
    """
    Dedent the given test methods and re-indent so they sit inside a class
    at indentation: class_indent + inner_indent.
    Ensures trailing newline.
    """
    if not methods_block:
        return ""
    # Normalize + dedent (preserve relative indentation)
    ded = textwrap.dedent(methods_block.replace("\r\n", "\n").replace("\r", "\n")).strip("\n")
    lines = ded.split("\n")
    prefix = class_indent + inner_indent
    out_lines = []
    for ln in lines:
        if ln.strip():
            out_lines.append(prefix + ln)
        else:
            out_lines.append("")  # keep blank lines
    out = "\n".join(out_lines)
    if not out.endswith("\n"):
        out += "\n"
    return out

def append_methods_into_test_file(test_path: pathlib.Path, methods_block: str, *, force_spaces: bool = False) -> bool:
    """
    Append methods_block inside the first unittest.TestCase class.
    - If no class exists, scaffold one.
    - If force_spaces=True, class/inner indents are converted to spaces equivalents.

    Returns True if the file was modified (methods added), False if methods_block is empty.
    """
    methods_block = (methods_block or "").strip()
    if not methods_block:
        return False

    src = read_text(test_path)
    cls_name, start, end, class_indent, inner_indent = first_testcase_class_span(src)

    # If requested, convert indent tokens to pure spaces (fallback mode)
    if force_spaces:
        # Convert detected indents to width-based spaces
        class_spaces = " " * len(class_indent.expandtabs(4))
        inner_spaces = " " * max(4, len(inner_indent.expandtabs(4)) or 4)
        class_indent, inner_indent = class_spaces, inner_spaces

    if cls_name is None:
        # No TestCase class found: scaffold one and insert methods inside it
        body = indent_into_class(methods_block, class_indent="", inner_indent="    ")
        scaffold = (
            "\n\nimport unittest\n\n"
            "class GeneratedAugmentedTests(unittest.TestCase):\n"
            f"{body}"
        )
        test_path.write_text(src + scaffold, encoding="utf-8")
        return True

    # Insert inside the first class, right before its end
    lines = src.splitlines(keepends=True)
    insert_block = indent_into_class(methods_block, class_indent, inner_indent)

    # Ensure there's a blank line before insertion if needed
    if end > 0 and lines[end-1].strip():
        insert_block = "\n" + insert_block

    lines[end:end] = [insert_block]
    test_path.write_text("".join(lines), encoding="utf-8")
    return True

def try_load_tests(test_path: pathlib.Path, entry_point: str, func_obj):
    """
    Try importing the test file with function injection. Return (ok: bool, module_or_err: Any).
    """
    try:
        tmod = load_test_module(test_path, entry_point, func_obj)
        return True, tmod
    except Exception as e:
        return False, e
    


In [107]:
rows = []
for sol_file in sorted(GEN_SOL_DIR.glob("Solution_*.py"), key=lambda p: int(re.search(r"(\d+)", p.stem).group(1))):
    idx = int(re.search(r"(\d+)", sol_file.stem).group(1))
    test_file = TESTS_DIR / f"HumanEval_{idx}.py"
    if not test_file.exists():
        print(f"[SKIP] {idx:03d} missing test file: {test_file.name}")
        continue

    solution_code = read_text(sol_file)
    entry_point = infer_entry_point_from_solution(solution_code)
    if not entry_point:
        print(f"[WARN] {idx:03d} could not infer entry point in {sol_file.name}")
        rows.append({"id": idx, "entry_point": None, "coverage_before": 0.0, "coverage_after": 0.0, "added_methods": 0})
        continue

    # Load solution and baseline coverage
    sol_mod = load_solution_as_module(sol_file)
    if not hasattr(sol_mod, entry_point):
        print(f"[WARN] {idx:03d} entry point '{entry_point}' not found in solution.")
        rows.append({"id": idx, "entry_point": entry_point, "coverage_before": 0.0, "coverage_after": 0.0, "added_methods": 0})
        continue

    func = getattr(sol_mod, entry_point)
    test_mod = load_test_module(test_file, entry_point, func)
    suite = unittest.defaultTestLoader.loadTestsFromModule(test_mod)

    cov_before, _ = run_suite_under_coverage(sol_file, suite)
    pct_before, statements, missing, partial = coverage_metrics_for_file(cov_before, str(sol_file))

    if idx <= 1:
        print(pct_before, statements, missing, partial)
        continue
    # print(f"\n[{idx:03d}] entry={entry_point:25s} coverage_before={pct_before:5.1f}%  missing={len(missing)}  partial={len(partial)}")

    # Build prompt and print it (as requested)
    prompt_text = make_augmentation_prompt(solution_code, missing, partial)
    # print("\n--- Augmentation prompt ---\n")
    # print(prompt_text)
    # print("\n--- End prompt ---\n")

    # If nothing missing/partial, skip augmentation
    if not missing and not partial:
        rows.append({
            "id": idx, "entry_point": entry_point,
            "coverage_before": pct_before, "coverage_after": pct_before,
            "added_methods": 0
        })
        continue

    # Ask Gemini for additional test methods
    resp = gmodel.generate_content(prompt_text)
    add_methods_raw = (getattr(resp, "text", None) or "").strip()
    add_methods_clean = clean_llm_methods(add_methods_raw)
    # print("clean method:\n", add_methods_clean or "[no methods generated]")

    # Keep a backup
    original_test_src = test_file.read_text(encoding="utf-8")

    added = 0
    if add_methods_clean:
        modified = append_methods_into_test_file(test_file, add_methods_clean)
        if modified:
            # Validate by importing tests; if broken, revert
            ok, mod_or_err = try_load_tests(test_file, entry_point, func)
            if not ok:
                print(f"  [REVERT] Added methods caused import error: {mod_or_err}")
                test_file.write_text(original_test_src, encoding="utf-8")
            else:
                added = len(re.findall(r'^\s*def\s+test_', add_methods_clean, flags=re.M))

    # Re-run coverage after augmentation (or after revert)
    sol_mod = load_solution_as_module(sol_file)  # ensure fresh module
    func = getattr(sol_mod, entry_point, func)
    test_mod = load_test_module(test_file, entry_point, func)
    suite = unittest.defaultTestLoader.loadTestsFromModule(test_mod)
    cov_after, result_after = run_suite_under_coverage(sol_file, suite)
    pctNew, statementsNew, missingNew, partialNew = coverage_metrics_for_file(cov_after, str(sol_file))

    # print(f"  coverage_after={pctNew:5.1f}%  (+{pctNew - pct_before:+.1f} pp)  added_methods={added}")

    rows.append({
        "id": idx,
        "entry_point": entry_point,
        "statements:": len(statementsNew),
        "missing": len(missingNew),
        "partial": len(partialNew),
        "coverage_current": pctNew,
        "coverage_before": pct_before,
        "added_methods": added,
    })
    # break  # DEBUG

df_aug_cov2 = pd.DataFrame(rows).sort_values("id").reset_index(drop=True)
df_aug_cov2

80.0 [1, 3, 31, 34, 35, 40, 43, 44, 45, 48] [1, 3] 1-3
93.33 [1, 28, 30, 31, 32, 35, 37, 40, 41, 42, 43, 47, 49, 51, 61] [1] 1
  [REVERT] Added methods caused import error: invalid syntax (HumanEval_2.py, line 74)
count_up_to(5) => [2, 3]
count_up_to(11) => [2, 3, 5, 7]
count_up_to(0) => []
count_up_to(20) => [2, 3, 5, 7, 11, 13, 17, 19]
count_up_to(1) => []
count_up_to(18) => [2, 3, 5, 7, 11, 13, 17]
count_up_to(2) => []
count_up_to(3) => [2]
count_up_to(5) => [2, 3]
count_up_to(11) => [2, 3, 5, 7]
count_up_to(0) => []
count_up_to(20) => [2, 3, 5, 7, 11, 13, 17, 19]
count_up_to(1) => []
count_up_to(18) => [2, 3, 5, 7, 11, 13, 17]
count_up_to(2) => []
count_up_to(3) => [2]
  [REVERT] Added methods caused import error: invalid syntax (HumanEval_103.py, line 60)
s='abcde', c='ae' -> ('bcd', False)
s='abcdef', c='b' -> ('acdef', False)
s='abcdedcba', c='ab' -> ('cdedc', True)
s='racecar', c='c' -> ('raear', True)
s='madam', c='' -> ('madam', True)
s='hello', c='l' -> ('heo', False)
s='a',

Unnamed: 0,id,entry_point,statements:,missing,partial,coverage_current,coverage_before,added_methods,coverage_after
0,2,truncate_number,5.0,2.0,3.0,60.00,60.00,0,
1,3,below_zero,7.0,1.0,1.0,85.71,85.71,5,
2,4,mean_absolute_deviation,9.0,2.0,3.0,77.78,77.78,3,
3,5,intersperse,7.0,1.0,1.0,85.71,85.71,7,
4,6,parse_nested_parens,14.0,1.0,1.0,92.86,92.86,4,
...,...,...,...,...,...,...,...,...,...
157,159,eat,5.0,1.0,1.0,80.00,80.00,4,
158,160,do_algebra,8.0,1.0,1.0,87.50,87.50,7,
159,161,solve,16.0,1.0,1.0,93.75,93.75,15,
160,162,string_to_md5,7.0,3.0,3.0,57.14,57.14,6,
