# Spring Boot Contract-Driven TDD Agent.

## Requirements
- Microsoft C++ Build Tools
- Anaconda (Python 3.12)
- [OpenAPI Key](https://platform.openai.com/)
- [Anthropic Key](https://console.anthropic.com/)
- [Gemini Key](https://ai.google.dev/gemini-api/)

## Getting Started
- Open **Anaconda Prompt** from **Anaconda Navigator**
- Navigate to the project's root directory
- Create the Anaconda environment:
  - `conda env create -f anaconda.yml`
- Activate the Anaconda environment:
  - `conda activate tdd-agent`
- Start Jupyter Lab:
  - `jupyter lab`
- Create the `.env` file containing the OpenAPI key:
  - `OPENAI_API_KEY=sk-proj-???`
  - `CLAUDE_API_KEY=sk-proj-???`
  - `GEMINI_API_KEY=sk-proj-???`
 
## Updating Anaconda Environment
- In case new dependencies have been added to `anaconda.yml`:
  - `conda env update -f anaconda.yml`
 
## Goal
- Create a TDD Agent that is able to generate the code based on `RestAssured` integration tests.

In [1]:
import os
import json
import platform
import subprocess
from pathlib import Path
from typing import Optional
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv(override=True)

openai_api_key = os.getenv('OPENAI_API_KEY')
gemini_api_key = os.getenv('GEMINI_API_KEY')

openai = OpenAI()
gemini = OpenAI(
    api_key=gemini_api_key, 
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)

- The `tree()` function will be used to inject the directory structure in tree format like Linux `tree` command to the LLM's context window.

In [2]:
def tree(directory=".", prefix="", max_depth=None, current_depth=0, show_hidden=False, exclude=['target']) -> str:
    """
    Display directory structure in tree format like Linux tree command.
    
    Args:
        directory: Path to the directory to display (default: current directory)
        prefix: Internal parameter for formatting (don't modify)
        max_depth: Maximum depth to traverse (None for unlimited)
        current_depth: Internal parameter for tracking depth (don't modify)
        show_hidden: Whether to show hidden files/directories (default: False)
        exclude: List of directories to exclude from the output (default: ['target'])

    Returns:
        String containing the tree structure
    """
    if max_depth is not None and current_depth >= max_depth:
        return ""

    path = Path(directory)
    if not path.exists() or not path.is_dir():
        return f'Error: "{directory}" does not exist or it is not a directory.'

    items = list(path.iterdir())
    if not show_hidden:
        items = [item for item in items if not item.name.startswith('.') and item.name not in exclude]

    items.sort(key=lambda x: (x.is_file(), x.name.lower()))

    result = []
    for i, item in enumerate(items):
        is_last = i == len(items) - 1
        if is_last:
            current_prefix = "└── "
            next_prefix = "    "
        else:
            current_prefix = "├── "
            next_prefix = "│   "

        result.append(f"{prefix}{current_prefix}{item.name}")

        if item.is_dir():
            subtree = tree(item, prefix + next_prefix, max_depth, current_depth + 1, show_hidden, exclude)
            if subtree:
                result.append(subtree)
    return "\n".join(result)

- The `find_file()` will be used as a helper function to find files by name in a directory tree.

In [3]:
def find_file(filename: str, search_dir: str = ".") -> Optional[Path]:
    """
    Find a file by name in a directory tree.

    Args:
        filename: Name of the file to find
        search_dir: Directory to search in (default: current directory)

    Returns:
        Path object of the first matching file, or None if not found
    """
    path = Path(search_dir)
    if not path.exists() or not path.is_dir():
        return f'Error: "{search_dir}" does not exist or it is not a directory.'

    for item in path.rglob(filename):
        if item.is_file():
            return item

    return None

- Tool that gives the LLM the capability to load a file's contents into its context window.

In [4]:
def cat_by_filename(filename: str, search_dir: str = ".") -> str:
    """
    Find a file by name and return its contents.
    
    Args:
        filename: Name of the file to read
        search_dir: Directory to search in (default: current directory)
        
    Returns:
        String containing the file contents
    """
    file_path = find_file(filename, search_dir)
    if file_path is None:
        return f'File "{filename}" not found in "{search_dir}".'

    return file_path.read_text(encoding='utf-8', errors='replace')

In [5]:
cat_by_filename_json = {
    "name": "cat_by_filename",
    "description": "Find a file by name and get its contents",
    "parameters": {
        "type": "object",
        "properties": {
            "filename": {
                "type": "string",
                "description": "Name of the file to read"
            },
            "search_dir": {
                "type": "string",
                "description": "Directory to search in",
                "default": "."
            }
        },
        "required": ["filename"],
        "additionalProperties": False
    }
}

- Tool that gives the LLM the capability to delete a file under `src/main/java`.

In [6]:
def rm_by_filename(filename: str, search_dir: str = ".") -> None:
    """
    Find a file by name and remove it.
    
    Args:
        filename: Name of the file to remove
        search_dir: Directory to search in (default: current directory)
    """
    file_path = find_file(filename, search_dir)
    if file_path is None:
        return f'File "{filename}" not found in "{search_dir}".'

    parts = file_path.parts
    if not ('src' in parts and 'main' in parts and 'java' in parts):
        return f'File "{filename}" found but not under the expected "src/main/java" directory.'

    file_path.unlink()
    return f'Removed file: {file_path}'

In [7]:
rm_by_filename_json = {
    "name": "rm_by_filename",
    "description": "Find a file by name and remove it (only if under src/main/java)",
    "parameters": {
        "type": "object",
        "properties": {
            "filename": {
                "type": "string",
                "description": "Name of the file to remove"
            },
            "search_dir": {
                "type": "string",
                "description": "Directory to search in",
                "default": "."
            }
        },
        "required": ["filename"],
        "additionalProperties": False
    }
}

- Tool that gives the LLM the capability to create or edit a file.

In [8]:
def create_file(file_path: str, content: str, overwrite: bool = False) -> str:
    """
    Create or edit a file with the given content. Only allows creation/editing of files in:
    - src/main/* (source code and resources including DB migrations)
    - src/test/resources/* (test resources including test data)
    - pom.xml (Maven configuration)

    Args:
        file_path: Full path where to create the file
        content: Content to write to the file
        overwrite: Whether to overwrite if file exists (default: False)
    """
    path = Path(file_path)
    if path.exists() and not overwrite:
        return f'File "{file_path}" already exists. Use overwrite=True to replace it.'

    allowed = False
    parts = path.parts
    if 'src' in parts and 'main' in parts:
        allowed = True
    elif 'src' in parts and 'test' in parts and 'resources' in parts:
        allowed = True
    elif path.name == 'pom.xml':
        allowed = True

    if not allowed:
        return f'Error: Cannot create/edit file "{file_path}". Only allowed in src/main/*, src/test/resources/*, or pom.xml'

    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(content, encoding='utf-8')
    return f'Created: {path}'

In [9]:
create_file_json = {
    "name": "create_file",
    "description": "Create a new file or edit an existing file by overwriting it with new content. Can only create/edit files in: src/main/* (source code and resources including DB migrations), src/test/resources/* (test resources including test data), or pom.xml (Maven configuration).",
    "parameters": {
        "type": "object",
        "properties": {
            "file_path": {
                "type": "string",
                "description": "Full path where to create the file or path of existing file to edit. Must be under src/main/*, src/test/resources/*, or be pom.xml"
            },
            "content": {
                "type": "string",
                "description": "Complete content to write to the file (will replace all existing content if file exists)"
            },
            "overwrite": {
                "type": "boolean",
                "description": "Whether to overwrite/edit if file already exists. Set to true to edit existing files.",
                "default": False
            }
        },
        "required": ["file_path", "content"],
        "additionalProperties": False
    }
}

- Tools available to the LLM.

In [10]:
tools = [
    {"type": "function", "function": cat_by_filename_json},
    {"type": "function", "function": rm_by_filename_json},
    {"type": "function", "function": create_file_json}
]

- The `handle_tool_calls()` will be used to execute the tool functions requested by the LLM and provide the results in the format expected by `openai` client to be added to the context window.

In [11]:
def handle_tool_calls(tool_calls):
    results = []
    for tool_call in tool_calls:
        tool_name = tool_call.function.name
        arguments = json.loads(tool_call.function.arguments)
        print(f"Tool called: {tool_name}", flush=True)

        if tool_name == "cat_by_filename":
            result = cat_by_filename(**arguments)
        elif tool_name == "rm_by_filename":
            result = rm_by_filename(**arguments)
        elif tool_name == "create_file":
            result = create_file(**arguments)

        results.append({"role": "tool", "content": json.dumps(result), "tool_call_id": tool_call.id})
    return results

In [12]:
system_prompt = f"""
You are a Spring Boot Contract-Driven TDD Agent. Your role is to implement features 
by making failing integration tests pass through minimal, incremental code changes.

## ENVIRONMENT CONTEXT
You are working in a Spring Boot project with:
- **Architecture**: Classic layered architecture (Controller → Service → Repository)
- **Database**: Real database with Testcontainers
- **Migrations**: Flyway for schema changes
- **Testing**: RestAssured integration tests
- **Test Data**: Located in src/test/resources/sql/seed-data.sql

## YOUR WORKFLOW
When you receive failing test output, follow this TDD cycle:

### 1. ANALYZE THE FAILURES
- Read the test failure messages carefully
- Identify what the test expects (endpoints, data, behavior)
- Determine what's missing (database schema, test data, or code)

### 2. CHECK DATABASE SCHEMA
- Look for schema-related changes (missing tables, columns, constraints)
- If schema changes are needed, create Flyway migration files in `src/main/resources/db/migration/`
- Follow naming convention: `V{{number}}__{{description}}.sql`

### 3. CHECK TEST DATA FIRST
Before writing any code, ensure the test data exists:
- Use `cat_by_filename` to read `seed-data.sql`
- Verify that the data the test expects is present
- If missing, update `seed-data.sql` with the required test data

### 4. IMPLEMENT MINIMAL CODE
Follow the layered architecture and implement only what's needed:

**Repository Layer** (`src/main/java/.../persistence/`):
- Spring Data JPA repositories
- Custom query methods if needed

**Domain Layer** (`src/main/java/.../domain/`):
- Entity classes with JPA annotations
- Enums and value objects

**Service Layer** (`src/main/java/.../application/`):
- Business logic
- Transaction management with `@Transactional`

**Controller Layer** (`src/main/java/.../api/`):
- REST endpoints with proper HTTP methods and status codes
- Request/Response DTOs in `schemas/request/` and `schemas/response/`
- Validation annotations

**Exception Handling** (`src/main/java/.../exception/`):
- Custom exceptions in `custom/`
- Global exception handler

## IMPLEMENTATION PRINCIPLES
1. **Clarity Over Cleverness**: Write code that clearly expresses intent - prefer simple,
readable solutions over complex optimizations
2. **YAGNI**: Implement only what the current test requires - avoid adding unnecessary abstractions, 
or patterns for hypothetical future needs
3. **Meaningful Names**: Use descriptive, searchable names that reveal intent - 
`calculateTotalOrderAmount()` not `calc()`, `OrderStatus` not `Status`
4. **High Cohesion, Loose Coupling**: Keep related functionality together in the same class/package, 
aim for locality of behavior, and minimize dependencies between conceptually unrelated concerns

## DEPENDENCY MANAGEMENT
When tests require new capabilities:
- Always check if new dependencies are needed BEFORE implementing
- For example, if Authentication/JWT are required → Add Spring Security dependencies to pom.xml

## AVAILABLE TOOLS
- `cat_by_filename()`: Load file contents into your context window
- `create_file()`: Create/edit source code and migrations
- `rm_by_filename()`: Remove files (only under src/main/java)

## TEST RESULTS
If you need to see updated test results after making changes, respond without any tool calls.

## KEY REMINDERS
- Create database migrations for schema changes
- Always check test data BEFORE implementing code
- Follow Spring Boot best practices and conventions
- Use proper HTTP status codes and error handling
- Keep implementations simple and concise - only what's needed to pass the test
- Use `overwrite=True` when editing existing files

The current project structure tree is:
{tree()}
"""

In [13]:
def generate_user_prompt(test_output: str) -> str:
    return (
        "## TEST RESULTS\n"
        "Here are the results from running `mvn test`:\n"
        f"{test_output}"
    )

- The `run_maven_tests()` will be used to execute the integration tests and provide the results to be added to the LLM's context window.

In [14]:
def run_maven_tests(project_dir: str = ".") -> str:
    """
    Execute Maven tests using Maven wrapper and return the output as a string.

    Args:
        project_dir: Directory containing the pom.xml (default: current directory)

    Returns:
        String containing the complete Maven test output (stdout and stderr combined)
    """
    try:
        mvn_cmd = "mvnw.cmd" if platform.system() == "Windows" else "./mvnw"

        project_path = Path(project_dir)
        if not project_path.exists():
            return f'Error: Directory "{project_dir}" does not exist'

        result = subprocess.run(
            [mvn_cmd, "test"],
            cwd=str(project_path),
            capture_output=True,
            text=True,
            timeout=300, # 5 minute timeout
            shell=True if platform.system() == "Windows" else False
        )

        output = ""
        if result.stdout:
            output += result.stdout
        if result.stderr:
            output += "\n" + result.stderr

        return output
    except subprocess.TimeoutExpired:
        return "Error: Maven test execution timed out after 5 minutes."
    except FileNotFoundError:
        return f'Error: Maven wrapper not found in "{project_dir}". Make sure ./mvnw and mvnw.cmd exists.'
    except Exception as e:
        return f'Error executing Maven tests: {str(e)}'

- The agent loop should continue to run until all integration tests pass.

In [None]:
test_output = run_maven_tests()

messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": generate_user_prompt(test_output)}
]

completed = False
while not completed:
    response = gemini.chat.completions.create(
        model="gemini-2.0-flash",
        messages=messages, 
        tools=tools
    )
    finish_reason = response.choices[0].finish_reason
    if finish_reason == "tool_calls":
        message = response.choices[0].message
        print(message.content)
        messages.append(message)
        tool_calls = message.tool_calls
        results = handle_tool_calls(tool_calls)
        for result in results:
            print(result["content"])
        messages.extend(results)
    else:
        test_output = run_maven_tests()
        if 'ERROR' not in test_output:
            completed = True
        messages.append({"role": "user", "content": generate_user_prompt(test_output)})