In [3]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [13]:
import os
import time
import tempfile
import subprocess
import textwrap
import re
from typing import Tuple, Dict, Optional

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    AutoModelForSeq2SeqLM,
)

In [14]:
# --------------------------
# 0. USER CONFIG
# --------------------------
GENERATOR_MODEL = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
CRITIC_MODEL = "google/flan-t5-small"
GENERATOR_LOCAL_PATH = None
CRITIC_LOCAL_PATH = None
PYTHON_RUN_TIMEOUT = 6
MAX_ITERS = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEBUG_HARMONIZE = True

# --------------------------
# 1. Load models
# --------------------------
print("Loading generator model:", GENERATOR_MODEL)
if GENERATOR_LOCAL_PATH:
    gen_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_LOCAL_PATH, local_files_only=True)
    gen_model = AutoModelForCausalLM.from_pretrained(GENERATOR_LOCAL_PATH, local_files_only=True)
else:
    gen_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
    gen_model = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL, device_map="auto")

gen_pipe = pipeline("text-generation", model=gen_model, tokenizer=gen_tokenizer)

print("Loading critic model:", CRITIC_MODEL)
if CRITIC_LOCAL_PATH:
    critic_tokenizer = AutoTokenizer.from_pretrained(CRITIC_LOCAL_PATH, local_files_only=True)
    critic_model = AutoModelForSeq2SeqLM.from_pretrained(CRITIC_LOCAL_PATH, local_files_only=True).to(DEVICE)
else:
    critic_tokenizer = AutoTokenizer.from_pretrained(CRITIC_MODEL)
    critic_model = AutoModelForSeq2SeqLM.from_pretrained(CRITIC_MODEL).to(DEVICE)

critic_pipe = pipeline("text2text-generation", model=critic_model, tokenizer=critic_tokenizer)

print("Environment setup complete and models loaded.")

Loading generator model: Qwen/Qwen2.5-Coder-1.5B-Instruct


Device set to use cuda:0


Loading critic model: google/flan-t5-small


Device set to use cuda:0


Environment setup complete and models loaded.


In [15]:
# --------------------------
# 2.1 Utility: Harmonize function names
# --------------------------
def harmonize_function_name(candidate_code: str, tests_code: str) -> str:
    """
    If the tests call a function with a specific name but the candidate
    defines a different top-level function name, rename the candidate's
    function to the expected name. This addresses the common MBPP case
    where a single top-level function is expected.
    """
    def_match = re.search(r'^\s*def\s+([A-Za-z_]\w*)\s*\(', candidate_code, flags=re.MULTILINE)
    current_name = def_match.group(1) if def_match else None

    builtins_ignore = set([
        "assert", "print", "len", "range", "int", "str", "float", "list", "dict", "set", "tuple",
        "max", "min", "sum", "any", "all", "zip", "map", "filter", "open", "enumerate", "sorted",
        "reversed", "isinstance", "type", "input", "True", "False", "None", "pytest", "unittest"
    ])

    call_candidates = re.findall(r'\b([A-Za-z_]\w*)\s*\(', tests_code)
    desired_name = None
    for cname in call_candidates:
        if cname not in builtins_ignore:
            desired_name = cname
            break

    if not current_name or not desired_name or current_name == desired_name:
        return candidate_code

    candidate_code_modified = re.sub(
        r'(^\s*def\s+)' + re.escape(current_name) + r'(\s*\([^)]*\)\s*:)',
        r'\1' + desired_name + r'\2',
        candidate_code,
        flags=re.MULTILINE
    )

    candidate_code_modified = re.sub(
        r'(?<!def\s)' + r'\b' + re.escape(current_name) + r'\b',
        desired_name,
        candidate_code_modified
    )

    if DEBUG_HARMONIZE:
        print(f"[harmonize] Renamed function '{current_name}' -> '{desired_name}'")

    return candidate_code_modified

In [16]:
# ------------------------------------------
# 2.2 Function: Run code + tests safely (adapted for MBPP test_list)
# ------------------------------------------
def run_python_tests(candidate_code: str, test_imports: str, test_list: list[str], timeout: int = PYTHON_RUN_TIMEOUT) -> dict:
    """Execute candidate code + tests in isolated subprocess, with function name harmonization."""
    tests = "\n".join(test_list)
    full_test_code = f"{test_imports}\n\n{tests}"

    # Harmonize function name before writing to file
    harmonized_code = harmonize_function_name(candidate_code, full_test_code)

    full_code_to_run = f"{harmonized_code}\n\n{full_test_code}"

    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
        fname = f.name
        f.write(full_code_to_run)
        f.flush()

    try:
        result = subprocess.run(
            ["python3", fname],
            capture_output=True,
            text=True,
            timeout=timeout
        )
        stdout = result.stdout.strip()
        stderr = result.stderr.strip()
        rc = result.returncode

        return {
            "returncode": rc,
            "stdout": stdout,
            "stderr": stderr,
            "passed": (rc == 0 and not stderr)
        }
    except subprocess.TimeoutExpired:
        return {"returncode": 1, "stdout": "", "stderr": f"TIMEOUT after {timeout}s", "passed": False}
    finally:
        try:
            os.unlink(fname)
        except Exception:
            pass

In [17]:
# --------------------------
# 3. Heuristic fallback critic (quick) — parse stderr and create reflection
# --------------------------
def heuristic_reflection_from_error(code: str, stderr: str, stdout: str) -> str:
    """Simple parsing of error traces to produce short reflection text."""
    if not stderr:
        msg = stdout.strip().splitlines()[-1] if stdout.strip() else "Test failed with no stderr."
        return f"Test failure: {msg}"
    if "AssertionError" in stderr:
        m = re.search(r'AssertionError(?:\: (.*))?', stderr, re.S)
        detail = m.group(1).strip() if m and m.group(1) else "Assertion failed"
        return f"Assertion failed: {detail}"
    if "TypeError" in stderr:
        return "TypeError encountered. Check argument types or function signature."
    if "IndexError" in stderr:
        return "IndexError encountered (likely out-of-range index).";
    if "NameError" in stderr:
        return "NameError: undefined variable or function name used. (Consider function name harmonization if problem persists)"
    if "SyntaxError" in stderr:
        m = re.search(r'SyntaxError: (.*)', stderr)
        return f"Syntax error: {m.group(1) if m else 'Syntax error detected.'}"
    last_line = stderr.strip().splitlines()[-1]
    return f"Runtime error: {last_line}"

In [18]:
# --------------------------
# 4. Critic function (uses seq2seq critic model)
# --------------------------
def critic_reflection(code: str, test_code: str, stderr: str, stdout: str, use_heuristic_if_needed=True) -> str:
    """
    Generate a short natural-language reflection explaining the failure.
    By default uses the critic seq2seq model (flan-t5-small). If model fails to produce useful content,
    fall back to a heuristic parser.
    """
    prompt = f"""You are a python debugging assistant. Read the code and the failing test output and explain in 1-3 short sentences what likely went wrong and what to fix.
Code:
```python
{code}
```

Test:
```python
{test_code}
```

Error/Traceback:
```
{stderr}
{stdout}
```

Explain briefly (1-3 sentences):"""
    try:
        out = critic_pipe(prompt, max_new_tokens=128, do_sample=False)[0]["generated_text"]
        out = out.strip()
        if use_heuristic_if_needed and (len(out) < 10 or "error" not in out.lower() and "fix" not in out.lower()):
            return heuristic_reflection_from_error(code, stderr, stdout)
        return out
    except Exception as e:
        return heuristic_reflection_from_error(code, stderr, stdout)

In [19]:
# ------------------------------------------
# 5. Function: Generate code from model (adapted for MBPP prompt format)
# ------------------------------------------
def generate_candidate_code(task_prompt, reflection=None, max_new_tokens=256):
    """Generate code given the task and optional reflection feedback for MBPP."""
    gen_input = f"# Task:\n{task_prompt.strip()}\n"
    if reflection:
        gen_input += f"# Reflection hint:\n{reflection.strip()}\n"
    gen_input += "\n# Write only Python code below:\n"

    outputs = gen_pipe(
        gen_input,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
        repetition_penalty=1.05,
        pad_token_id=gen_tokenizer.eos_token_id,
    )[0]["generated_text"]

    raw_generated = outputs.split("# Write only Python code below:")[-1].strip()

    code_block_match = re.search(r'```python\n(.*?)\n```', raw_generated, re.DOTALL)
    if code_block_match:
        generated = code_block_match.group(1).strip()
    else:
        generated = raw_generated.replace("```python", "").replace("```", "").strip()

        filtered_lines = []
        in_function_body = False
        for line in generated.splitlines():
            stripped_line = line.strip()
            if stripped_line.startswith('def ') or stripped_line.startswith('class '):
                filtered_lines.append(line)
                in_function_body = True
            elif in_function_body and (line.startswith(' ') or line.startswith('\t') or not stripped_line or stripped_line.startswith('#')):
                filtered_lines.append(line)
            elif not in_function_body and (not stripped_line or stripped_line.startswith('#')):
                filtered_lines.append(line)
            elif stripped_line:
                break
        generated = '\n'.join(filtered_lines).strip()

    return generated

In [20]:
# ------------------------------------------
# 6. Function: Clean reflection text
# ------------------------------------------
def clean_reflection_text(reflection: str) -> str:
    """Remove code/test junk from critic reflection."""
    reflection = re.sub(r"```.*?```", "", reflection, flags=re.DOTALL)
    reflection = re.sub(r"def\s+\w+\(.*?\):.*", "", reflection)
    reflection = re.sub(r"assert .*", "", reflection)
    reflection = reflection.strip().replace('\n', ' ')
    return reflection[:250]

# ------------------------------------------
# 7. Function: Run Reflexion loop (MBPP style)
# ------------------------------------------
def reflexion_loop(task_prompt, tests, test_imports, max_iters=3):
    """Iteratively refine code with critic feedback."""
    reflection = ""
    reflections = []
    for i in range(max_iters):
        print(f"\n===============================")
        print(f"ITERATION {i+1} — reflection: '{reflection}'")

        code = generate_candidate_code(task_prompt, reflection)
        result = run_python_tests(code, test_imports, tests)

        if result["passed"]:
            print("✅ Tests passed!")
            return code, reflections

        print(f"❌ Tests failed. stderr:\n{textwrap.indent(result['stderr'], '   ')}")

        # Prepare critic input with raw test code from MBPP ex['test_list']
        full_test_code_for_critic = f"{test_imports}\n\n{'\n'.join(tests)}"
        reflection_raw = critic_reflection(code, full_test_code_for_critic, stderr=result['stderr'], stdout=result['stdout'])
        reflection = clean_reflection_text(reflection_raw)
        print(f"\n--- Critic reflection: {reflection}")
        reflections.append(reflection)

    return code, reflections

print("Helper functions defined.")

Helper functions defined.


In [21]:
import pandas as pd
from datasets import load_dataset

# MBPP-specific config
MBPP_TEST_SLICE = "test[:25]" # Evaluate on first 25 tasks
MBPP_MAX_REFLEXION_ITERS = 3 # 3 iterations per task

# Load dataset subset
dataset = load_dataset("mbpp", "sanitized", split=MBPP_TEST_SLICE)
print(f"✅ Loaded {len(dataset)} MBPP tasks for evaluation.")

# ------------------------------------------
# Run benchmark
# ------------------------------------------
results = []
for i, ex in enumerate(dataset):
    print(f"\n============================================================")
    print(f"Task {i+1}/{len(dataset)}: {ex['prompt']}")
    print("============================================================")

    # The reflexion_loop function is already defined in the previous step
    code, reflections = reflexion_loop(
        task_prompt=ex["prompt"],
        tests=ex["test_list"],
        test_imports=ex["test_imports"],
        max_iters=MBPP_MAX_REFLEXION_ITERS # Use MBPP specific max_iters
    )

    # Run final test on the best candidate code (or last generated if none passed)
    test_result = run_python_tests(code, ex["test_imports"], ex["test_list"])
    results.append({
        "task_id": ex["task_id"],
        "prompt": ex["prompt"],
        "passed": test_result["passed"],
        "iterations": len(reflections),
        "reflections": reflections,
        "stderr": test_result["stderr"]
    })

# ------------------------------------------
# Analyze results
# ------------------------------------------
df_results = pd.DataFrame(results)
pass_rate = df_results["passed"].mean() * 100
print(f"\n✅ Reflexion Pass@1: {pass_rate:.2f}% ({df_results['passed'].sum()}/{len(df_results)})")

display(df_results[["task_id", "passed", "iterations", "prompt"]])


README.md: 0.00B [00:00, ?B/s]

sanitized/train-00000-of-00001.parquet:   0%|          | 0.00/33.9k [00:00<?, ?B/s]

sanitized/test-00000-of-00001.parquet:   0%|          | 0.00/60.9k [00:00<?, ?B/s]

sanitized/validation-00000-of-00001.parq(…):   0%|          | 0.00/14.0k [00:00<?, ?B/s]

sanitized/prompt-00000-of-00001.parquet:   0%|          | 0.00/6.72k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/257 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/43 [00:00<?, ? examples/s]

Generating prompt split:   0%|          | 0/7 [00:00<?, ? examples/s]

✅ Loaded 25 MBPP tasks for evaluation.

Task 1/25: Write a python function to remove first and last occurrence of a given character from the string.

ITERATION 1 — reflection: ''
[harmonize] Renamed function 'remove_char' -> 'remove_Occ'
✅ Tests passed!
[harmonize] Renamed function 'remove_char' -> 'remove_Occ'

Task 2/25: Write a function to sort a given matrix in ascending order according to the sum of its rows.

ITERATION 1 — reflection: ''
✅ Tests passed!

Task 3/25: Write a python function to find the volume of a triangular prism.

ITERATION 1 — reflection: ''
[harmonize] Renamed function 'calculate_prism_volume' -> 'find_Volume'
✅ Tests passed!
[harmonize] Renamed function 'calculate_prism_volume' -> 'find_Volume'

Task 4/25: Write a function to that returns true if the input string contains sequences of lowercase letters joined with an underscore and false otherwise.

ITERATION 1 — reflection: ''
[harmonize] Renamed function 'check_lowercase_underscore_sequence' -> 'text_lowerca

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


[harmonize] Renamed function 'is_one_less_than_twice_reverse' -> 'check'
✅ Tests passed!
[harmonize] Renamed function 'is_one_less_than_twice_reverse' -> 'check'

Task 10/25: Write a python function to find the largest number that can be formed with the given list of digits.

ITERATION 1 — reflection: ''
[harmonize] Renamed function 'largest_number' -> 'find_Max_Num'
❌ Tests failed. stderr:
   Traceback (most recent call last):
     File "/tmp/tmpaxow_apw.py", line 12, in <module>
       assert find_Max_Num([1,2,3]) == 321
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   AssertionError

--- Critic reflection: Assertion failed: Assertion failed

ITERATION 2 — reflection: 'Assertion failed: Assertion failed'
[harmonize] Renamed function 'largest_number' -> 'find_Max_Num'
❌ Tests failed. stderr:
   Traceback (most recent call last):
     File "/tmp/tmpzn0bs3kn.py", line 18, in <module>
       assert find_Max_Num([1,2,3]) == 321
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   AssertionError

--

Token indices sequence length is longer than the specified maximum sequence length for this model (658 > 512). Running this sequence through the model will result in indexing errors


❌ Tests failed. stderr:
   File "/tmp/tmpy8muzp67.py", line 2
       The `sort_tuples` function takes a list of tuples as input. It uses the built-in `sorted` function with a lambda function as the key argument. The lambda function extracts the second element (index 1) from each tuple, which is used for sorting. The `sorted` function returns a new list of tuples sorted based on the second value of each tuple. This approach is efficient and concise for sorting lists of tuples. ```python
           ^
   SyntaxError: invalid syntax

--- Critic reflection: Syntax error: invalid syntax

ITERATION 2 — reflection: 'Syntax error: invalid syntax'
[harmonize] Renamed function 'sort_tuples' -> 'subject_marks'
❌ Tests failed. stderr:
   Traceback (most recent call last):
     File "/tmp/tmpx4mmoa8x.py", line 8, in <module>
       assert subject_marks([('English', 88), ('Science', 90), ('Maths', 97), ('Social sciences', 82)])==[('Social sciences', 82), ('English', 88), ('Science', 90), ('Maths', 97

Unnamed: 0,task_id,passed,iterations,prompt
0,11,True,0,Write a python function to remove first and la...
1,12,True,0,Write a function to sort a given matrix in asc...
2,14,True,0,Write a python function to find the volume of ...
3,16,True,0,Write a function to that returns true if the i...
4,17,True,0,Write a function that returns the perimeter of...
5,18,True,0,Write a function to remove characters from the...
6,19,True,0,Write a function to find whether a given array...
7,20,True,1,Write a function to check if the given number ...
8,56,True,0,Write a python function to check if a given nu...
9,57,False,3,Write a python function to find the largest nu...
