In [13]:
# User provides: 
    ## a Python function (possibly buggy), 
    ## a short description of intended behavior, 
    ## and optionally an example unit test to use as a style/reference.

# Planner: LLM (Gemini) generates unit tests for the function (using the example test if provided).

# Executor: Run the tests locally with pytest and capture results (pass/fail, tracebacks).

# Reflector: LLM reads test failures and suggests a corrected function.

# Loop: Repeat (generate new tests / run / fix) until tests pass or max iterations reached.

# Return: final (hopefully fixed) function and the test results history.

In [14]:
from langgraph.graph import StateGraph, END
from langchain_community.llms import Ollama
from typing import TypedDict
import subprocess, tempfile, os
import textwrap
import re

# Initialize LLM
llm = Ollama(model="codellama:latest")

In [15]:

# def ask_ollama(prompt, model="codellama:latest"):
#     """
#     Sends a prompt to Ollama and returns the model's text response.
#     Works with any installed Ollama model (e.g., codellama, mistral, llama3, etc.)
#     """
#     url = "http://localhost:11434/api/generate"
#     payload = {"model": model, "prompt": prompt, "stream": False}
#     response = requests.post(url, json=payload)

#     if response.status_code != 200:
#         raise Exception(f"Ollama error: {response.text}")
    
#     data = response.json()
#     return data.get("response", "").strip()


In [16]:
class AgentState(TypedDict):
    func_code: str
    description: str
    example_test: str
    tests: str
    results: str
    iteration: int


def run_tests_and_capture(test_code, func_code):
    """Run pytest and capture results, safely dedented."""
    import textwrap
    import re

    # Remove Markdown-style ``` and any language hint
    test_code = re.sub(r"^```.*\n|```$", "", test_code.strip(), flags=re.MULTILINE)
    
    func_code = textwrap.dedent(func_code)
    test_code = textwrap.dedent(test_code)

    with tempfile.TemporaryDirectory() as tmpdir:
        func_path = os.path.join(tmpdir, "target.py")
        test_path = os.path.join(tmpdir, "test_target.py")

        with open(func_path, "w") as f:
            f.write(func_code)
        with open(test_path, "w") as f:
            f.write(test_code)

        result = subprocess.run(
            ["pytest", "-q", test_path],
            cwd=tmpdir,
            text=True,
            capture_output=True,
            timeout=10
        )

        return result.stdout + "\n" + result.stderr

In [17]:
def planner_node(state: AgentState):
    prompt = f"""
    You are a helpful Python unit test generator.
    Generate **additional pytest unit tests** for the function.
    Follow the style and format of the example test.
    Take into account the Description to understand the intent of the function.
    Follow the style, structure, and tone of this example test.
    
    --- Example Test ---
    {state['example_test']}
    ---------------------
    
    Description:
    {state['description']}
    
    Function:
    {state['func_code']}
    
    Return **ONLY** valid pytest code, **NO EXPLANATION, NO MARKDOWN, NO TEXT EXPLANATION**
    """
    raw_generated = llm.invoke(prompt)

    # Remove any lines that are not valid Python (skip everything before first "def ")
    lines = raw_generated.splitlines()
    start_idx = 0
    for i, line in enumerate(lines):
        if line.strip().startswith("def "):
            start_idx = i
            break
    cleaned_generated = "\n".join(lines[start_idx:])
    cleaned_code = re.sub(r"```.*$", "", cleaned_generated.strip(), flags=re.MULTILINE)

    # Combine example tests + cleaned generated tests
    state["tests"] = state["example_test"] + "\n\n" + cleaned_generated

    print(f"\n--- Iteration {state['iteration']} Planner Generated Tests ---\n")
    print(cleaned_generated)
    return state


In [18]:
def executor_node(state: AgentState):
    results = run_tests_and_capture(state["tests"], state["func_code"])
    state["results"] = results
    print(f"\n--- Iteration {state['iteration']} Test Results ---\n")
    print(state["results"])
    return state

In [19]:
def reflector_node(state: AgentState):
    prompt = f"""
    The following function failed some tests.
    Please correct it based on the pytest output below.
    Preserve the function name and signature.
    Import or Install packages if corre
    
    Function:
    {state['func_code']}
    
    Pytest Results:
    {state['results']}
    
    Return **ONLY** corrected Python code, **NO EXPLANATION, NO MARKDOWN, NO TEXT EXPLANATION**.
    """
    raw_code = llm.invoke(prompt)

    # Remove Markdown or explanation lines before the function
    # Keep only the first "def ..." and everything after
    lines = raw_code.splitlines()
    start_idx = 0
    for i, line in enumerate(lines):
        if line.strip().startswith("def "):
            start_idx = i
            break

    # Optionally, stop at the next 'def ' to remove duplicates
    end_idx = len(lines)
    for i in range(start_idx + 1, len(lines)):
        if lines[i].strip().startswith("def "):
            end_idx = i
            break
            
    cleaned_code = "\n".join(lines[start_idx:end_idx])
    cleaned_code = re.sub(r"```.*$", "", cleaned_code.strip(), flags=re.MULTILINE)
    
    state["func_code"] = cleaned_code
    state["iteration"] += 1

    print(f"\n--- Iteration {state['iteration']} Corrected Function ---\n")
    print(state["func_code"])
    return state


In [20]:
def should_continue(state):
    results = state.get("results") or ""
    if "failed" in results.lower() and state.get("iteration", 0) < 10:
        return "planner"
    return END

In [21]:
graph = StateGraph(AgentState)
graph.add_node("planner", planner_node)
graph.add_node("executor", executor_node)
graph.add_node("reflector", reflector_node)

graph.add_edge("planner", "executor")
graph.add_edge("executor", "reflector")
graph.add_conditional_edges("reflector", should_continue)
graph.set_entry_point("planner")
compiled_graph = graph.compile()


In [22]:
input_function_code_ = """


def is_prime(n):
    
    if n <= 1:
        return True # Correct for 0, negative, and 1
    
    
    for i in range(2, n):
        if n % i == 0:
            return False
    return True
"""

input_function_description = "To return True if n is a prime number, and False otherwise."

example_test = """
import pytest
from target import is_prime

def test_simple_is_prime():
    assert is_prime(17) == True


def test_edge_cases_is_prime():
    # 0, 1 and negative numbers are not prime
    assert is_prime(0) == False


def test_composite_numbers():
    assert is_prime(4) == False


def test_strings_and_invalid_inputs():
    # For strings, function should ideally raise an exception
    with pytest.raises(TypeError):
        is_prime("abc")
    
"""


            
initial_state = {
    "func_code": input_function_code_,
    "description": input_function_description,
    "example_test": example_test,
    "tests": "",
    "results": "",
    "iteration": 0
}

In [None]:
final_state = compiled_graph.invoke(initial_state)


--- Iteration 0 Planner Generated Tests ---

def test_simple_is_prime():
    assert is_prime(17) == True


def test_edge_cases_is_prime():
    # 0, 1 and negative numbers are not prime
    assert is_prime(0) == False
    

def test_composite_numbers():
    assert is_prime(4) == False
    

def test_strings_and_invalid_inputs():
    # For strings, function should ideally raise an exception
    with pytest.raises(TypeError):
        is_prime("abc")

--- Iteration 0 Test Results ---

[32m.[0m[31mF[0m[32m.[0m[32m.[0m[31m                                                                     [100%][0m
[31m[1m__________________________ test_edge_cases_is_prime ___________________________[0m

    [0m[94mdef[39;49;00m[90m [39;49;00m[92mtest_edge_cases_is_prime[39;49;00m():[90m[39;49;00m
        [90m# 0, 1 and negative numbers are not prime[39;49;00m[90m[39;49;00m
>       [94massert[39;49;00m is_prime([94m0[39;49;00m) == [94mFalse[39;49;00m[90m[39;49;00m
[1m[31

In [None]:
print("-----------------------------------------------------------------------------------------------")
print(" Input Function:\n")
print("-----------------------------------------------------------------------------------------------")
print(input_function_code_)
print("\n")

print("-----------------------------------------------------------------------------------------------")
print(" Final Function:\n")
print("-----------------------------------------------------------------------------------------------")
print(final_state["func_code"])
print("\n")

print("-----------------------------------------------------------------------------------------------")
print("\n Test Results:\n")
print("-----------------------------------------------------------------------------------------------")
print(final_state["results"])

