In [None]:
import os
from typing import List
from typing import TypedDict

import colorama
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.graph import END
from langgraph.graph import StateGraph
from langgraph.pregel import GraphRecursionError

In [None]:
# Define the paths.
search_path = os.path.join(os.getcwd(), "app")
code_file = os.path.join(search_path, "src/crud.py")
test_file = os.path.join(search_path, "test/test_crud.py")

# Create the folders and files if necessary.
if not os.path.exists(search_path):
    os.mkdir(search_path)
    os.mkdir(os.path.join(search_path, "src"))
    os.mkdir(os.path.join(search_path, "test"))

In [None]:
# crud.py
code = """class Item:
    def __init__(self, id, name, description=None):
        self.id = id
        self.name = name
        self.description = description

    def __repr__(self):
        return f"Item(id={self.id}, name={self.name}, description={self.description})"

class CRUDApp:
    def __init__(self):
        self.items = []

    def create_item(self, id, name, description=None):
        item = Item(id, name, description)
        self.items.append(item)
        return item

    def read_item(self, id):
        for item in self.items:
            if item.id == id:
                return item
        return None

    def update_item(self, id, name=None, description=None):
        for item in self.items:
            if item.id == id:
                if name:
                    item.name = name
                if description:
                    item.description = description
                return item
        return None

    def delete_item(self, id):
        for index, item in enumerate(self.items):
            if item.id == id:
                return self.items.pop(index)
        return None

    def list_items(self):
        return self.items"""

with open(code_file, 'w') as f:
    f.write(code)

In [None]:
from langchain.chat_models import init_chat_model

llm = init_chat_model("gpt-4o-mini", model_provider="openai")

In [None]:
class AgentState(TypedDict):
    class_source: str
    class_methods: List[str]
    tests_source: str

In [None]:
# Create the graph.
workflow = StateGraph(AgentState)

In [None]:
def extract_code_from_message(message):
    lines = message.split("\n")
    code = ""
    in_code = False
    for line in lines:
        if "```" in line:
            in_code = not in_code
        elif in_code:
            code += line + "\n"
    return code

In [None]:
import_prompt_template = """Here is a path of a file with code: {code_file}.
Here is the path of a file with tests: {test_file}.
Write a proper import statement for the class in the file.
"""


# Discover the class and its methods.
def discover_function(state: AgentState):
    assert os.path.exists(code_file)
    with open(code_file, "r") as f:
        source = f.read()
    state["class_source"] = source

    # Get the methods.
    methods = []
    for line in source.split("\n"):
        if "def " in line:
            methods.append(line.split("def ")[1].split("(")[0])
    state["class_methods"] = methods

    # Generate the import statement and start the code.
    import_prompt = import_prompt_template.format(code_file=code_file, test_file=test_file)
    message = llm.invoke([HumanMessage(content=import_prompt)]).content
    code = extract_code_from_message(message)
    state["tests_source"] = code + "\n\n"

    return state


# Add a node to for discovery.
workflow.add_node("discover", discover_function)

<langgraph.graph.state.StateGraph at 0x10dd2e450>

In [None]:
# System message template.

system_message_template = """You are a smart developer. You can do this! You will write unit 
tests that have a high quality. Use pytest.

Reply with the source code for the test only. 
Do not include the class in your response. I will add the imports myself.
If there is no test to write, reply with "# No test to write" and 
nothing more. Do not include the class in your response.

Example:

```
def test_function():
    ...
```

I will give you 200 EUR if you adhere to the instructions and write a high quality test. 
Do not write test classes, only methods.
"""

# Write the tests template.
write_test_template = """Here is a class:
'''
{class_source}
'''

Implement a test for the method \"{class_method}\".
"""

In [None]:
# This method will write a test.
def write_tests_function(state: AgentState):

    # Get the next method to write a test for.
    class_method = state["class_methods"].pop(0)
    print(f"Writing test for {class_method}.")

    # Get the source code.
    class_source = state["class_source"]

    # Create the prompt.
    write_test_prompt = write_test_template.format(class_source=class_source, class_method=class_method)
    print(colorama.Fore.CYAN + write_test_prompt + colorama.Style.RESET_ALL)

    # Get the test source code.
    system_message = SystemMessage(system_message_template)
    human_message = HumanMessage(write_test_prompt)
    test_source = llm.invoke([system_message, human_message]).content
    test_source = extract_code_from_message(test_source)
    print(colorama.Fore.GREEN + test_source + colorama.Style.RESET_ALL)
    state["tests_source"] += test_source + "\n\n"

    return state


# Add the node.
workflow.add_node("write_tests", write_tests_function)

<langgraph.graph.state.StateGraph at 0x10dd2e450>

In [None]:
# Define the entry point. This is where the flow will start.
workflow.set_entry_point("discover")

# Always go from discover to write_tests.
workflow.add_edge("discover", "write_tests")

<langgraph.graph.state.StateGraph at 0x10dd2e450>

In [None]:
# Write the file.
def write_file(state: AgentState):
    with open(test_file, "w") as f:
        f.write(state["tests_source"])
    return state


# Add a node to write the file.
workflow.add_node("write_file", write_file)

<langgraph.graph.state.StateGraph at 0x10dd2e450>

In [None]:
# Find out if we are done.
def should_continue(state: AgentState):
    if len(state["class_methods"]) == 0:
        return "end"
    else:
        return "continue"


# Add the conditional edge.
workflow.add_conditional_edges("write_tests", should_continue, {"continue": "write_tests", "end": "write_file"})

<langgraph.graph.state.StateGraph at 0x10dd2e450>

In [None]:
# Always go from write_file to end.
workflow.add_edge("write_file", END)

<langgraph.graph.state.StateGraph at 0x10dd2e450>

In [None]:
# Create the app and run it
app = workflow.compile()
inputs = {'class_soure': None, 'class_methods': None, 'tests_source': None}
config = RunnableConfig(recursion_limit=100)
try:
    result = app.invoke(inputs, config)
    print(result)
except GraphRecursionError:
    print("Graph recursion limit reached.")

Writing test for __init__.
[36mHere is a class:
'''
class Item:
    def __init__(self, id, name, description=None):
        self.id = id
        self.name = name
        self.description = description

    def __repr__(self):
        return f"Item(id={self.id}, name={self.name}, description={self.description})"

class CRUDApp:
    def __init__(self):
        self.items = []

    def create_item(self, id, name, description=None):
        item = Item(id, name, description)
        self.items.append(item)
        return item

    def read_item(self, id):
        for item in self.items:
            if item.id == id:
                return item
        return None

    def update_item(self, id, name=None, description=None):
        for item in self.items:
            if item.id == id:
                if name:
                    item.name = name
                if description:
                    item.description = description
                return item
        return None

    def delete

In [None]:
print(result['tests_source'])



def test_item_init():
    item = Item(1, "Test Item", "This is a test item")
    assert item.id == 1
    assert item.name == "Test Item"
    assert item.description == "This is a test item"


def test_item_repr():
    item = Item(1, "Test Item", "Description")
    expected_output = "Item(id=1, name=Test Item, description=Description)"
    assert repr(item) == expected_output


def test_item_init():
    item = Item(1, "Book", "A great book")
    assert item.id == 1
    assert item.name == "Book"
    assert item.description == "A great book"


def test_create_item():
    app = CRUDApp()
    
    item = app.create_item(1, "Item 1", "Description for Item 1")
    
    assert item.id == 1
    assert item.name == "Item 1"
    assert item.description == "Description for Item 1"
    assert item in app.items


def test_read_item():
    app = CRUDApp()
    item1 = app.create_item(1, 'Item 1', 'Description 1')
    item2 = app.create_item(2, 'Item 2', 'Description 2')

    assert app.read_item(1)