In [1]:
from dotenv import load_dotenv
load_dotenv("../.env")

True

In [2]:
from pprint import pprint
from datasets import load_dataset
from tqdm.notebook import tqdm

from graph import create_debug_agent_graph

In [3]:
dataset = load_dataset("bigcode/humanevalpack", "python")
problems = list(dataset["test"])

#take only first 50 problems for quicker testing
problems = problems[:50]

print("problems len: " , len(problems))
print(problems[0].keys())

problems len:  50
dict_keys(['task_id', 'prompt', 'declaration', 'canonical_solution', 'buggy_solution', 'bug_type', 'failure_symptoms', 'entry_point', 'import', 'test_setup', 'test', 'example_test', 'signature', 'docstring', 'instruction'])


In [4]:
experiment_problem = problems[0]
print("Problem ID:", experiment_problem["task_id"])
print("Problem buggy_solution: ", experiment_problem["buggy_solution"])
print("Problem tests: ", experiment_problem["test"])

Problem ID: Python/0
Problem buggy_solution:      for idx, elem in enumerate(numbers):
        for idx2, elem2 in enumerate(numbers):
            if idx != idx2:
                distance = elem - elem2
                if distance < threshold:
                    return True

    return False

Problem tests:  




def check(has_close_elements):
    assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True
    assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False
    assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True
    assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False
    assert has_close_elements([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True
    assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True
    assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False

check(has_close_elements)


In [5]:
graph = create_debug_agent_graph()

In [6]:
from agent.state import DebugAgentState
from langchain_core.messages import HumanMessage

from typing import Any, Dict

def fix_code(buggy_code, test_code = "", max_iterations = 7) -> Dict[str, Any]:
    """
    Uses the debug agent graph to attempt to fix the provided buggy code.
    """

    user_message = f"Fix the following Python code:\n{buggy_code} and for testing use this code: \n{test_code}"

    initial_state: DebugAgentState = {
        "messages": [HumanMessage(content=user_message)],
        "iterations": 0,
        "max_iterations": max_iterations,
        "original_buggy_code": buggy_code,
        "test_code": test_code,
        "fixed_code": "",
        "is_fixed": False,
        "submit_idx": -1,
        "submissions": [],
        "first_pass": None
    }

    final_state = graph.invoke(initial_state)

    return {
        "fixed_code": final_state["fixed_code"],
        "is_fixed": final_state["is_fixed"],
        "iterations": final_state["iterations"],
        "messages": final_state["messages"],
        "submissions": final_state["submissions"],
        "first_pass": final_state["first_pass"],
    }


In [7]:
results = []

for problem in tqdm(problems):
    buggy_solution = problem.get("buggy_solution", "")
    test_code = problem.get("test", "")
    results.append(fix_code(buggy_solution, test_code=test_code))

  0%|          | 0/50 [00:00<?, ?it/s]

In [8]:
pprint(results[0])

{'first_pass': True,
 'fixed_code': 'def has_close_elements(numbers, threshold):\n'
               '    for idx, elem in enumerate(numbers):\n'
               '        for idx2, elem2 in enumerate(numbers):\n'
               '            if idx != idx2:\n'
               '                distance = abs(elem - elem2)\n'
               '                if distance < threshold:\n'
               '                    return True\n'
               '\n'
               '    return False',
 'is_fixed': True,
 'iterations': 2,
 'messages': [HumanMessage(content='Fix the following Python code:\n    for idx, elem in enumerate(numbers):\n        for idx2, elem2 in enumerate(numbers):\n            if idx != idx2:\n                distance = elem - elem2\n                if distance < threshold:\n                    return True\n\n    return False\n and for testing use this code: \n\n\n\n\n\ndef check(has_close_elements):\n    assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\n 

In [9]:
from metrics.pass_at_k import estimate_pass_at_1, estimate_first_submission_accuracy

# count metric based on results where each result is a dict with the following keys:
# {
#     "fixed_code": "str",
#     "is_fixed": "bool",
#     "iterations": "int",
#     "messages": "List[Message]",
#     "submissions": "List[Submission]",
#     "first_pass": "bool" <first pass means the first submission passed the tests>
# }
passed_at_k = estimate_pass_at_1(results)
first_pass_acc = estimate_first_submission_accuracy(results)
print(f"Pass@1: {passed_at_k:.2f}")
print(f"First Pass Accuracy: {first_pass_acc:.2f}")

Pass@1: 0.42
First Pass Accuracy: 0.38
