diff --git a/CLAUDE.md b/CLAUDE.md index 085d16d2..ba1cc9d5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -2,6 +2,8 @@ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. +For information on how to work in the Python part of this project, see `python/CLAUDE.md`. + ## Project Overview Cairo Coder is an open-source Cairo language code generation service using Retrieval-Augmented Generation (RAG) to transform natural language requests into functional Cairo smart contracts and programs. It was adapted from the Starknet Agent project. diff --git a/python/CLAUDE.md b/python/CLAUDE.md new file mode 100644 index 00000000..c50231eb --- /dev/null +++ b/python/CLAUDE.md @@ -0,0 +1,187 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Cairo Coder is an open-source Cairo language code generation service using Retrieval-Augmented Generation (RAG) with the DSPy framework. It transforms natural language requests into functional Cairo smart contracts and programs. + +## Essential Commands + +### Installation and Setup + +- `curl -LsSf https://astral.sh/uv/install.sh | sh` - Install uv package manager +- `uv sync` - Install all dependencies +- `cp sample.config.toml config.toml` - Create configuration file +- `cp .env.example .env` - Set up environment variables (if .env.example exists) + +### Development + +- `uv run cairo-coder` - Start the FastAPI server +- `uv run pytest` - Run all tests +- `uv run pytest tests/unit/test_query_processor.py::test_specific` - Run specific test +- `uv run pytest -k "test_name"` - Run tests matching pattern +- `uv run pytest --cov=src/cairo_coder` - Run tests with coverage +- `trunk check --fix` - Run linting and auto-fix issues +- `uv run ty check` - Run type checking + +### Docker Operations + +- `docker compose up postgres` - Start PostgreSQL database +- `docker compose up backend` - Start the API server +- `docker compose run ingester` - Run documentation ingestion + +### Optimization and Evaluation + +- `marimo run optimizers/generation_optimizer.py` - Run generation optimizer notebook +- `marimo run optimizers/rag_pipeline_optimizer.py` - Run full pipeline optimizer +- `uv run starklings_evaluate` - Evaluate against Starklings dataset +- `uv run cairo-coder-summarize ` - Summarize documentation + +## High-Level Architecture + +### DSPy-Based RAG Pipeline + +Cairo Coder uses a three-stage RAG pipeline implemented with DSPy modules: + +1. **Query Processing** (`src/cairo_coder/dspy/query_processor.py`): + + - Uses `CairoQueryAnalysis` signature with ChainOfThought + - Extracts search terms and identifies relevant documentation sources + - Detects if query is contract/test related + +2. **Document Retrieval** (`src/cairo_coder/dspy/document_retriever.py`): + + - Custom `SourceFilteredPgVectorRM` extends DSPy's retriever + - Queries PostgreSQL with pgvector for similarity search + - Supports source filtering and metadata extraction + +3. **Answer Generation** (`src/cairo_coder/dspy/generation_program.py`): + - `CairoCodeGeneration` signature for code synthesis + - Streaming support via async generators + - MCP mode for raw documentation retrieval + +### Agent-Based Architecture + +- **Agent Factory** (`src/cairo_coder/core/agent_factory.py`): Creates specialized agents from TOML configs +- **Agents**: General, Scarb-specific, or custom agents with filtered sources +- **Pipeline Factory**: Creates optimized RAG pipelines loading from `optimizers/results/` + +### FastAPI Server + +- **OpenAI-Compatible API** (`src/cairo_coder/server/app.py`): + - `/v1/chat/completions` - Legacy endpoint + - `/v1/agents/{agent_id}/chat/completions` - Agent-specific + - Supports streaming (SSE) and MCP mode via headers +- **Lifecycle Management**: Connection pooling, resource cleanup +- **Error Handling**: OpenAI-compatible error responses + +### Optimization Framework + +- **DSPy Optimizers**: MIPROv2 for prompt tuning +- **Datasets**: Generated from Starklings exercises +- **Metrics**: Code compilation success, relevance scores +- **Marimo Notebooks**: Reactive optimization workflows with MLflow tracking + +## Development Guidelines + +### Code Organization + +- Follow DSPy patterns: Signatures → Modules → Programs +- Use dependency injection for testability (e.g., vector_db parameter) +- Prefer async/await for I/O operations +- Type hints required (enforced by mypy) + +### Adding New Features + +1. **New Agent**: Add configuration to `config.toml`, extend `AgentConfiguration` +2. **New DSPy Module**: Create signature, implement forward/aforward methods +3. **New Optimizer**: Create Marimo notebook, define metrics, use MIPROv2 + +### Configuration Management + +- `ConfigManager` loads from `config.toml` and environment +- Vector store config in `[VECTOR_DB]` section +- LLM providers in `[PROVIDERS]` section +- Agent definitions in `[[AGENTS]]` array + +## Important Notes + +- Always load optimized programs from `optimizers/results/` in production +- Use `uv` for all dependency management (not pip/poetry) +- Structlog for JSON logging (`get_logger(__name__)`) +- DSPy tracks token usage via `lm.get_usage()` +- MLflow experiments logged to `mlruns/` directory + +## Working with the test suite + +This document provides guidelines for interacting with the Python test suite. Adhering to these patterns is crucial for maintaining a clean, efficient, and scalable testing environment. + +### 1. Running Tests + +All test commands should be run from the `python/` directory. + +- **Run all tests:** + + ```bash + uv run pytest + ``` + +- **Run tests in a specific file:** + + ```bash + uv run pytest tests/unit/test_rag_pipeline.py + ``` + +- **Run a specific test by name (using `-k`):** + ```bash + uv run pytest -k "test_mcp_mode_pipeline_execution" + ``` + +### 2. Test Architecture + +The test suite is divided into two main categories: + +- `tests/unit/`: For testing individual classes or functions in isolation. These tests should be fast and rely heavily on mocks to prevent external dependencies (like databases or APIs). +- `tests/integration/`: For testing how multiple components work together. This is primarily for testing the FastAPI server endpoints using `fastapi.testclient.TestClient`. These tests are slower and verify the contracts between different parts of the application. + +### 3. The Golden Rule: `conftest.py` is King + +**`python/tests/conftest.py` is the single source of truth for all shared fixtures, mocks, and test data.** + +- **Before adding any new mock or helper, check `conftest.py` first.** It is highly likely a suitable fixture already exists. +- **NEVER define a reusable fixture in an individual test file.** All shared fixtures **must** reside in `conftest.py`. This is non-negotiable for maintainability. + +### 4. Key Fixtures to Leverage + +Familiarize yourself with these core fixtures defined in `conftest.py`. Use them whenever possible. + +- `client`: An instance of `TestClient` for making requests to the FastAPI app in **integration tests**. +- `mock_agent`: A powerful, pre-configured mock of a RAG pipeline agent. It has mock implementations for `forward`, `aforward`, and `forward_streaming`. +- `mock_agent_factory`: A mock of the `AgentFactory` used in server tests to control which agent is created. +- `mock_vector_db`: A mock of `SourceFilteredPgVectorRM` for testing the document retrieval layer without a real database. +- `mock_lm`: A mock of a `dspy` language model for testing DSPy programs (`QueryProcessorProgram`, `GenerationProgram`) without making real API calls. +- `sample_documents`, `sample_agent_configs`, `sample_processed_query`: Consistent, reusable data fixtures for your tests. +- `sample_config_file`: A temporary, valid `config.toml` file for integration testing the configuration manager. + +### 5. Guidelines for Adding & Modifying Tests + +- **Adding a New Test File:** + + - If you are testing a single class's methods or a utility function, create a new file in `tests/unit/`. + - If you are testing a new API endpoint or a flow that involves multiple components, add it to the appropriate file in `tests/integration/`. + +- **Avoiding Code Duplication (DRY):** + + - If you find yourself writing several tests that differ only by their input values, you **must** use parametrization. + - **Pattern:** Use `@pytest.mark.parametrize`. See `tests/unit/test_document_retriever.py` for a canonical example of how this is done effectively. + +- **Adding New Mocks or Test Data:** + + - If the mock or data will be used in more than one test function, add it to `conftest.py` as a new fixture. + - If it's truly single-use, you may define it within the test function itself, but be certain it won't be needed elsewhere. + +- **Things to Be Careful About:** + - **Fixture Dependencies:** Understand that some fixtures depend on others (e.g., `client` depends on `mock_agent_factory`). Modifying a base fixture can have cascading effects on tests that use dependent fixtures. + - **Unit vs. Integration Mocks:** Do not use `TestClient` (`client` fixture) in unit tests. Unit tests should mock the direct dependencies of the class they are testing, not the entire application. + - **Removing Tests:** Only remove tests for code that has been removed. If you are refactoring, ensure that the new tests provide equivalent or better coverage than the ones being replaced. The recent refactoring that merged `test_server.py` into `test_openai_server.py` and `test_server_integration.py` is a key example of this pattern. diff --git a/python/src/cairo_coder/dspy/document_retriever.py b/python/src/cairo_coder/dspy/document_retriever.py index e7a1e929..990cfb7f 100644 --- a/python/src/cairo_coder/dspy/document_retriever.py +++ b/python/src/cairo_coder/dspy/document_retriever.py @@ -574,7 +574,7 @@ async def aforward( return [] # Step 2: Enrich context with appropriate templates based on query type. - return self._enhance_context(processed_query.original, documents) + return self._enhance_context(processed_query, documents) def forward( self, processed_query: ProcessedQuery, sources: list[DocumentSource] | None = None @@ -670,7 +670,7 @@ async def _afetch_documents( logger.error(f"Error fetching documents: {traceback.format_exc()}") raise e - def _enhance_context(self, query: str, context: list[Document]) -> list[Document]: + def _enhance_context(self, processed_query: ProcessedQuery, context: list[Document]) -> list[Document]: """ Enhance context with appropriate templates based on query type. @@ -681,12 +681,12 @@ def _enhance_context(self, query: str, context: list[Document]) -> list[Document Returns: Enhanced context with relevant templates """ - query_lower = query.lower() + query_lower = processed_query.original.lower() # Add contract template for contract-related queries if any( keyword in query_lower for keyword in ["contract", "storage", "external", "interface"] - ): + ) or processed_query.is_contract_related: context.append( Document( page_content=CONTRACT_TEMPLATE, @@ -695,7 +695,7 @@ def _enhance_context(self, query: str, context: list[Document]) -> list[Document ) # Add test template for test-related queries - if any(keyword in query_lower for keyword in ["test", "testing", "assert", "mock"]): + if any(keyword in query_lower for keyword in ["test", "testing", "assert", "mock"]) or processed_query.is_test_related: context.append( Document( page_content=TEST_TEMPLATE, diff --git a/python/src/cairo_coder/server/app.py b/python/src/cairo_coder/server/app.py index 773d8530..b8e44682 100644 --- a/python/src/cairo_coder/server/app.py +++ b/python/src/cairo_coder/server/app.py @@ -166,9 +166,6 @@ def __init__( allow_headers=["*"], ) - # Token tracking for usage statistics - self.token_tracker = TokenTracker() - # Setup routes self._setup_routes() @@ -490,32 +487,6 @@ async def _generate_chat_completion( ) -class TokenTracker: - """Simple token tracker for usage statistics.""" - - def __init__(self): - self.sessions = {} - - def track_tokens(self, session_id: str, prompt_tokens: int, completion_tokens: int): - """Track token usage for a session.""" - if session_id not in self.sessions: - self.sessions[session_id] = { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, - } - - self.sessions[session_id]["prompt_tokens"] += prompt_tokens - self.sessions[session_id]["completion_tokens"] += completion_tokens - self.sessions[session_id]["total_tokens"] += prompt_tokens + completion_tokens - - def get_session_usage(self, session_id: str) -> dict[str, int]: - """Get session token usage.""" - return self.sessions.get( - session_id, {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} - ) - - def create_app( vector_store_config: VectorStoreConfig, config_manager: ConfigManager | None = None ) -> FastAPI: diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 9423ceb6..a3f3cf2e 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -6,10 +6,15 @@ """ import asyncio -from collections.abc import AsyncGenerator -from unittest.mock import AsyncMock, Mock +import os +import tempfile +from collections.abc import AsyncGenerator, Generator +from pathlib import Path +from unittest.mock import AsyncMock, Mock, patch +import dspy import pytest +import toml from fastapi.testclient import TestClient from cairo_coder.config.manager import ConfigManager @@ -19,6 +24,7 @@ Document, DocumentSource, Message, + ProcessedQuery, Role, StreamEvent, StreamEventType, @@ -30,9 +36,26 @@ # Common Mock Fixtures # ============================================================================= - @pytest.fixture(scope="session") -def mock_vector_db(): +def mock_returned_documents(): + """Create a mock vector database instance for dependency injection.""" + return [ + dspy.Example( + content="Cairo is a programming language for writing provable programs.", + metadata={"source": "cairo_book", "score": 0.9, "chapter": 1}, + ), + dspy.Example( + content="Starknet is a validity rollup (also known as a ZK rollup).", + metadata={"source": "starknet_docs", "score": 0.8, "section": "overview"}, + ), + dspy.Example( + content="OpenZeppelin provides secure smart contract libraries for Cairo.", + metadata={"source": "openzeppelin_docs", "score": 0.7}, + ), + ] + +@pytest.fixture(scope="function") +def mock_vector_db(mock_returned_documents): """Create a mock vector database for dependency injection.""" mock_db = Mock(spec=SourceFilteredPgVectorRM) @@ -41,16 +64,17 @@ def mock_vector_db(): mock_db._ensure_pool = AsyncMock() # Mock the forward method - mock_db.forward = Mock(return_value=[]) + mock_db.forward = Mock(return_value=mock_returned_documents) # Mock the async forward method - mock_db.aforward = AsyncMock(return_value=[]) + mock_db.aforward = AsyncMock(return_value=mock_returned_documents) # Mock sources attribute mock_db.sources = [] return mock_db + @pytest.fixture(scope="session") def mock_vector_store_config(): """ @@ -92,49 +116,66 @@ def mock_config_manager(): return manager -@pytest.fixture +@pytest.fixture(scope="function") def mock_lm(): """ Create a mock language model for DSPy programs. This fixture provides a mock LM that can be used with DSPy programs - for testing without making actual API calls. + for testing without making actual API calls. It patches `dspy.ChainOfThought` + and returns a configurable mock. """ - mock_lm = Mock() - mock_lm.generate = Mock(return_value=["Generated response"]) - mock_lm.__call__ = Mock(return_value=["Generated response"]) - return mock_lm + with patch("dspy.ChainOfThought") as mock_cot: + mock_program = Mock() + # Mock for sync calls + mock_program.forward.return_value = dspy.Prediction( + answer="Here's a Cairo contract example:\n\n```cairo\n#[starknet::contract]\nmod SimpleContract {\n // Contract implementation\n}\n```\n\nThis contract demonstrates basic Cairo syntax." + ) + mock_program.return_value = dspy.Prediction( + answer="Here's a Cairo contract example:\n\n```cairo\n#[starknet::contract]\nmod SimpleContract {\n // Contract implementation\n}\n```\n\nThis contract demonstrates basic Cairo syntax." + ) + # Mock for async calls - use AsyncMock for coroutine + mock_program.aforward = AsyncMock( + return_value=dspy.Prediction( + answer="Here's a Cairo contract example:\n\n```cairo\n#[starknet::contract]\nmod SimpleContract {\n // Contract implementation\n}\n```\n\nThis contract demonstrates basic Cairo syntax." + ) + ) + mock_cot.return_value = mock_program + yield mock_program @pytest.fixture -def mock_agent_factory(): +def mock_agent_factory(mock_agent: Mock, sample_agent_configs: dict[str, AgentConfiguration]): """ Create a mock agent factory with standard agent configurations. Returns a mock AgentFactory with common agent configurations. """ factory = Mock(spec=AgentFactory) - factory.get_available_agents.return_value = [ - "default", - "scarb-assistant", - "starknet_assistant", - "openzeppelin_assistant", - ] - factory.get_agent_info.return_value = { - "id": "default", - "name": "Cairo Coder", - "description": "General Cairo programming assistant", - "sources": ["cairo_book", "cairo_docs"], - "max_source_count": 10, - "similarity_threshold": 0.4, - } - factory.create_agent = Mock() - factory.get_or_create_agent = Mock() + factory.get_available_agents.return_value = list(sample_agent_configs.keys()) + + def get_agent_info(agent_id, **kwargs): + if agent_id in sample_agent_configs: + agent_config = sample_agent_configs[agent_id] + return { + "id": agent_config.id, + "name": agent_config.name, + "description": agent_config.description, + "sources": [s.value for s in agent_config.sources], + "max_source_count": agent_config.max_source_count, + "similarity_threshold": agent_config.similarity_threshold, + } + raise ValueError(f"Agent '{agent_id}' not found") + + factory.get_agent_info.side_effect = get_agent_info + + factory.create_agent.return_value = mock_agent + factory.get_or_create_agent.return_value = mock_agent factory.clear_cache = Mock() return factory -@pytest.fixture(autouse=True) +@pytest.fixture def mock_agent(): """Create a mock agent with OpenAI-specific forward method.""" mock_agent = AsyncMock() @@ -223,6 +264,7 @@ def server(mock_vector_store_config, mock_config_manager, mock_agent_factory): """Create a CairoCoderServer instance for testing.""" return CairoCoderServer(mock_vector_store_config, mock_config_manager) + @pytest.fixture def client(server, mock_agent_factory): """Create a test client for the server.""" @@ -242,12 +284,34 @@ async def mock_get_agent_factory(): server.app.dependency_overrides[get_agent_factory] = mock_get_agent_factory return TestClient(server.app) + +@pytest.fixture(scope="session") +def mock_embedder(): + """Mock the embedder.""" + with patch("cairo_coder.dspy.document_retriever.dspy.Embedder") as mock_embedder: + mock_embedder.return_value = Mock() + yield mock_embedder + + # ============================================================================= # Sample Data Fixtures # ============================================================================= -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") +def sample_processed_query(): + """Create a sample processed query.""" + return ProcessedQuery( + original="How do I create a Cairo contract?", + search_queries=["cairo", "contract", "create"], + reasoning="I need to create a Cairo contract", + is_contract_related=True, + is_test_related=False, + resources=[DocumentSource.CAIRO_BOOK, DocumentSource.STARKNET_DOCS], + ) + + +@pytest.fixture(scope="session") def sample_documents(): """ Create a collection of sample documents for testing. @@ -297,6 +361,7 @@ def sample_documents(): ), ] + @pytest.fixture def sample_messages(): """ @@ -307,7 +372,9 @@ def sample_messages(): return [ Message(role=Role.SYSTEM, content="You are a helpful Cairo programming assistant."), Message(role=Role.USER, content="How do I create a smart contract in Cairo?"), - Message(role=Role.ASSISTANT, content="To create a smart contract in Cairo, you need to..."), + Message( + role=Role.ASSISTANT, content="To create a smart contract in Cairo, you need to..." + ), Message(role=Role.USER, content="Can you show me an example?"), ] @@ -320,6 +387,14 @@ def sample_agent_configs(): Returns a dictionary of AgentConfiguration objects. """ return { + "cairo-coder": AgentConfiguration( + id="cairo-coder", + name="Cairo Coder", + description="Cairo programming assistant", + sources=[DocumentSource.CAIRO_BOOK, DocumentSource.STARKNET_DOCS], + max_source_count=10, + similarity_threshold=0.4, + ), "default": AgentConfiguration( id="default", name="Cairo Coder", @@ -370,50 +445,35 @@ def sample_agent_configs(): ), } + # ============================================================================= # Test Configuration Fixtures # ============================================================================= -@pytest.fixture -def temp_config_file(tmp_path): - """ - Create a temporary configuration file for testing. +@pytest.fixture(scope="session") +def sample_config_file() -> Generator[Path, None, None]: + """Create a temporary config file for testing.""" + config_data = { + "VECTOR_DB": { + "POSTGRES_HOST": "test-db.example.com", + "POSTGRES_PORT": 5433, + "POSTGRES_DB": "test_cairo", + "POSTGRES_USER": "test_user", + "POSTGRES_PASSWORD": "test_password", + "POSTGRES_TABLE_NAME": "test_documents", + "SIMILARITY_MEASURE": "cosine", + }, + } - Returns the path to a temporary TOML configuration file. - """ - config_content = """ -[providers.openai] -api_key = "test-openai-key" -model = "gpt-4" - -[providers.anthropic] -api_key = "test-anthropic-key" -model = "claude-3-sonnet" - -[providers] -default_provider = "openai" - -[vector_db] -host = "localhost" -port = 5432 -database = "cairo_coder_test" -user = "test_user" -password = "test_password" - -[agents.default] -sources = ["cairo_book", "starknet_docs"] -max_source_count = 10 -similarity_threshold = 0.4 - -[logging] -level = "INFO" -format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) as f: + toml.dump(config_data, f) + temp_path = Path(f.name) + + yield temp_path - config_file = tmp_path / "test_config.toml" - config_file.write_text(config_content) - return config_file + # Cleanup + os.unlink(temp_path) @pytest.fixture diff --git a/python/tests/integration/test_config_integration.py b/python/tests/integration/test_config_integration.py index 88e6bb56..4ddc50a5 100644 --- a/python/tests/integration/test_config_integration.py +++ b/python/tests/integration/test_config_integration.py @@ -1,79 +1,39 @@ """Integration tests for configuration management.""" -import os -import tempfile -from collections.abc import Generator from pathlib import Path import pytest -import toml from cairo_coder.config.manager import ConfigManager +@pytest.fixture(scope="function", autouse=True) +def clear_env_vars(monkeypatch: pytest.MonkeyPatch): + """Clear all environment variables before each test.""" + import os + + for var in [ + "POSTGRES_HOST", + "POSTGRES_PORT", + "POSTGRES_DB", + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "OPENAI_API_KEY", + "ANTHROPIC_API_KEY", + "GEMINI_API_KEY", + ]: + os.environ.pop(var, None) + monkeypatch.delenv(var, raising=False) + + yield + class TestConfigIntegration: """Test configuration integration with real files and environment.""" - @pytest.fixture - def sample_config_file(self) -> Generator[Path, None, None]: - """Create a temporary config file for testing.""" - config_data = { - "VECTOR_DB": { - "POSTGRES_HOST": "test-db.example.com", - "POSTGRES_PORT": 5433, - "POSTGRES_DB": "test_cairo", - "POSTGRES_USER": "test_user", - "POSTGRES_PASSWORD": "test_password", - "POSTGRES_TABLE_NAME": "test_documents", - "SIMILARITY_MEASURE": "cosine", - }, - "providers": { - "default": "openai", - "embedding_model": "text-embedding-3-large", - "openai": {"api_key": "test-openai-key", "model": "gpt-4"}, - "anthropic": {"api_key": "test-anthropic-key", "model": "claude-3-sonnet"}, - }, - "logging": {"level": "DEBUG", "format": "json"}, - "monitoring": {"enable_metrics": True, "metrics_port": 9191}, - "agents": { - "test-agent": { - "name": "Test Agent", - "description": "Integration test agent", - "sources": ["cairo_book", "starknet_docs"], - "max_source_count": 5, - "similarity_threshold": 0.5, - "contract_template": "Test contract template", - "test_template": "Test template", - } - }, - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) as f: - toml.dump(config_data, f) - temp_path = Path(f.name) - - yield temp_path - - # Cleanup - os.unlink(temp_path) - def test_load_full_configuration( - self, sample_config_file: Path, monkeypatch: pytest.MonkeyPatch + self, sample_config_file: Path, clear_env_vars ) -> None: """Test loading a complete configuration file.""" - # Clear any existing environment variables - for var in [ - "POSTGRES_HOST", - "POSTGRES_PORT", - "POSTGRES_DB", - "POSTGRES_USER", - "POSTGRES_PASSWORD", - "OPENAI_API_KEY", - "ANTHROPIC_API_KEY", - "GEMINI_API_KEY", - ]: - monkeypatch.delenv(var, raising=False) - config = ConfigManager.load_config(sample_config_file) # Verify database settings diff --git a/python/tests/integration/test_server_integration.py b/python/tests/integration/test_server_integration.py index 530bac9d..050ed5ce 100644 --- a/python/tests/integration/test_server_integration.py +++ b/python/tests/integration/test_server_integration.py @@ -1,77 +1,65 @@ """ Integration tests for OpenAI-compatible FastAPI server. -This module tests the FastAPI server with more realistic scenarios, -including actual vector store and config manager integration. +This module tests the FastAPI server with realistic scenarios, +including vector store and config manager integration, API contract +verification, and OpenAI compatibility checks. """ import concurrent.futures +import json +import uuid from unittest.mock import AsyncMock, Mock, patch -import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient from cairo_coder.config.manager import ConfigManager -from cairo_coder.core.agent_factory import AgentFactory from cairo_coder.core.config import VectorStoreConfig -from cairo_coder.server.app import create_app, get_vector_store_config +from cairo_coder.core.types import StreamEvent, StreamEventType +from cairo_coder.server.app import CairoCoderServer, create_app class TestServerIntegration: """Integration tests for the server.""" - @pytest.fixture(scope="function") - def mock_agent_factory(self, mock_agent): - """Patch create_agent_factory and return the mock factory.""" - with patch("cairo_coder.server.app.create_agent_factory") as mock_factory_creator: - factory = Mock(spec=AgentFactory) - agents_data = { - "default": { - "id": "default", - "name": "Cairo Coder", - "description": "General Cairo programming assistant", - "sources": ["cairo_book", "cairo_docs"], - }, - "scarb-assistant": { - "id": "scarb-assistant", - "name": "Scarb Assistant", - "description": "Starknet-specific programming help", - "sources": ["scarb_docs"], - }, - } - factory.get_available_agents.return_value = list(agents_data.keys()) - - def get_agent_info(agent_id, **kwargs): - if agent_id in agents_data: - return agents_data[agent_id] - raise ValueError(f"Agent {agent_id} not found") - - factory.get_agent_info.side_effect = get_agent_info - factory.create_agent.return_value = mock_agent - factory.get_or_create_agent = Mock(return_value=mock_agent) - mock_factory_creator.return_value = factory - yield factory - - @pytest.fixture(scope="function") - def app(self, mock_vector_store_config, mock_config_manager, mock_agent_factory): - """Create a test FastAPI application.""" - app = create_app(mock_vector_store_config, mock_config_manager) - app.dependency_overrides[get_vector_store_config] = lambda: mock_vector_store_config - return app - - def test_health_check_integration(self, client): + def test_health_check_integration(self, client: TestClient): """Test health check endpoint in integration context.""" response = client.get("/") assert response.status_code == 200 assert response.json() == {"status": "ok"} - def test_full_agent_workflow(self, client, mock_agent_factory): + def test_list_agents(self, client: TestClient, sample_agent_configs: dict): + """Test listing available agents.""" + response = client.get("/v1/agents") + assert response.status_code == 200 + + data = response.json() + assert len(data) == len(sample_agent_configs) + agent_ids = {agent["id"] for agent in data} + assert "cairo-coder" in agent_ids + assert "default" in agent_ids + assert "scarb-assistant" in agent_ids + + def test_list_agents_error_handling(self, client: TestClient, mock_agent_factory: Mock): + """Test error handling in list agents endpoint.""" + mock_agent_factory.get_available_agents.side_effect = Exception("Database error") + + response = client.get("/v1/agents") + assert response.status_code == 500 + + data = response.json() + assert "detail" in data + assert data["detail"]["error"]["message"] == "Failed to list agents" + assert data["detail"]["error"]["type"] == "server_error" + + def test_full_agent_workflow(self, client: TestClient, mock_agent: Mock): """Test complete agent workflow from listing to chat.""" # First, list available agents response = client.get("/v1/agents") assert response.status_code == 200 agents = response.json() - assert len(agents) == 2 assert any(agent["id"] == "default" for agent in agents) assert any(agent["id"] == "scarb-assistant" for agent in agents) @@ -85,9 +73,7 @@ def test_full_agent_workflow(self, client, mock_agent_factory): "total_tokens": 30, } } - mock_agent = Mock() mock_agent.aforward = AsyncMock(return_value=mock_response) - mock_agent_factory.create_agent.return_value = mock_agent # Test chat completion with default agent response = client.post( @@ -101,7 +87,7 @@ def test_full_agent_workflow(self, client, mock_agent_factory): data = response.json() assert data["choices"][0]["message"]["content"] == "Smart contract response." - def test_multiple_conversation_turns(self, client, mock_agent_factory, mock_agent): + def test_multiple_conversation_turns(self, client: TestClient, mock_agent: Mock): """Test handling multiple conversation turns.""" conversation_responses = [ "Hello! I'm Cairo Coder, ready to help with Cairo programming.", @@ -119,7 +105,6 @@ async def mock_aforward(query: str, chat_history=None, mcp_mode=False, **kwargs) return mock_response mock_agent.aforward = mock_aforward - mock_agent_factory.create_agent.return_value = mock_agent # Test conversation flow messages = [{"role": "user", "content": "Hello"}] @@ -135,7 +120,7 @@ async def mock_aforward(query: str, chat_history=None, mcp_mode=False, **kwargs) data = response.json() assert data["choices"][0]["message"]["content"] == conversation_responses[1] - def test_streaming_integration(self, client, mock_agent_factory, mock_agent): + def test_streaming_integration(self, client: TestClient, mock_agent: Mock): """Test streaming response integration.""" async def mock_forward_streaming(query: str, chat_history=None, mcp_mode=False, **kwargs): @@ -150,7 +135,6 @@ async def mock_forward_streaming(query: str, chat_history=None, mcp_mode=False, yield {"type": "end", "data": ""} mock_agent.forward_streaming = mock_forward_streaming - mock_agent_factory.create_agent.return_value = mock_agent response = client.post( "/v1/chat/completions", @@ -162,7 +146,7 @@ async def mock_forward_streaming(query: str, chat_history=None, mcp_mode=False, assert response.status_code == 200 assert "text/event-stream" in response.headers.get("content-type", "") - def test_error_handling_integration(self, client, mock_agent_factory): + def test_error_handling_integration(self, client: TestClient, mock_agent_factory: Mock): """Test error handling in integration context.""" mock_agent_factory.get_agent_info.side_effect = ValueError("Agent not found") response = client.post( @@ -181,13 +165,13 @@ def test_error_handling_integration(self, client, mock_agent_factory): ) assert response.status_code == 422 # Validation error - def test_cors_integration(self, client): + def test_cors_integration(self, client: TestClient): """Test CORS headers in integration context.""" response = client.get("/", headers={"Origin": "https://example.com"}) assert response.status_code == 200 assert "access-control-allow-origin" in response.headers - def test_mcp_mode_integration(self, client, mock_agent_factory, mock_agent): + def test_mcp_mode_integration(self, client: TestClient, mock_agent: Mock): """Test MCP mode in integration context.""" async def mock_forward_streaming(query: str, chat_history=None, mcp_mode=False, **kwargs): if mcp_mode: @@ -205,7 +189,6 @@ async def mock_forward_streaming(query: str, chat_history=None, mcp_mode=False, yield {"type": "end", "data": ""} mock_agent.forward_streaming = mock_forward_streaming - mock_agent_factory.create_agent.return_value = mock_agent response = client.post( "/v1/chat/completions", @@ -214,7 +197,7 @@ async def mock_forward_streaming(query: str, chat_history=None, mcp_mode=False, ) assert response.status_code == 200 - def test_concurrent_requests(self, client): + def test_concurrent_requests(self, client: TestClient): """Test handling concurrent requests.""" def make_request(request_id): @@ -236,7 +219,7 @@ def make_request(request_id): for status_code, _request_id in results: assert status_code == 200 - def test_large_request_handling(self, client): + def test_large_request_handling(self, client: TestClient): """Test handling of large requests.""" large_content = "How do I create a contract? " * 1000 # Large query @@ -246,11 +229,153 @@ def test_large_request_handling(self, client): ) assert response.status_code in [200, 413] + def test_chat_completions_validation_empty_messages(self, client: TestClient): + """Test validation of empty messages array.""" + response = client.post("/v1/chat/completions", json={"messages": []}) + assert response.status_code == 422 # Pydantic validation error + + def test_chat_completions_validation_last_message_not_user(self, client: TestClient): + """Test validation that last message must be from user.""" + response = client.post( + "/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + }, + ) + assert response.status_code == 422 # Pydantic validation error + + def test_agent_chat_completions_valid_agent(self, client: TestClient): + """Test agent-specific chat completions with valid agent.""" + response = client.post( + "/v1/agents/cairo-coder/chat/completions", + json={"messages": [{"role": "user", "content": "Hello"}], "stream": False}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["model"] == "cairo-coder" + assert len(data["choices"]) == 1 + + def test_agent_chat_completions_invalid_agent(self, client: TestClient, mock_agent_factory: Mock): + """Test agent-specific chat completions with invalid agent.""" + mock_agent_factory.get_agent_info.side_effect = ValueError("Agent not found") + + response = client.post( + "/v1/agents/unknown-agent/chat/completions", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + + assert response.status_code == 404 + data = response.json() + assert "detail" in data + assert "Agent 'unknown-agent' not found" in data["detail"]["error"]["message"] + assert data["detail"]["error"]["type"] == "invalid_request_error" + assert data["detail"]["error"]["code"] == "agent_not_found" + + def test_error_handling_agent_creation_failure(self, client: TestClient, mock_agent_factory: Mock): + """Test error handling when agent creation fails.""" + mock_agent_factory.create_agent.side_effect = Exception("Agent creation failed") + + response = client.post( + "/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}]} + ) + + assert response.status_code == 500 + data = response.json() + assert "detail" in data + assert data["detail"]["error"]["type"] == "server_error" + + def test_message_conversion(self, client: TestClient, mock_agent_factory: Mock, mock_agent: Mock): + """Test proper conversion of messages to internal format.""" + mock_agent_factory.create_agent.return_value = mock_agent + + client.post( + "/v1/chat/completions", + json={ + "messages": [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + }, + ) + + # Verify agent was called with proper message conversion + mock_agent_factory.create_agent.assert_called_once() + call_args, call_kwargs = mock_agent_factory.create_agent.call_args + + # Check that history excludes the last message + history = call_kwargs.get("history", []) + assert len(history) == 3 # Excludes last user message + + # Check query is the last user message + query = call_kwargs.get("query") + assert query == "How are you?" + + def test_streaming_error_handling(self, client: TestClient, mock_agent_factory: Mock, mock_agent: Mock): + """Test error handling during streaming.""" + + async def mock_forward_streaming_error(*args, **kwargs): + yield StreamEvent(type=StreamEventType.RESPONSE, data="Starting response...") + raise Exception("Stream error") + + mock_agent.forward_streaming = mock_forward_streaming_error + mock_agent_factory.create_agent.return_value = mock_agent + + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Hello"}], "stream": True}, + ) + + assert response.status_code == 200 + + # Parse streaming response to check error handling + lines = response.text.strip().split("\n") + chunks = [] + for line in lines: + if line.startswith("data: "): + data_str = line[6:] + if data_str != "[DONE]": + chunks.append(json.loads(data_str)) + + # Should have error chunk + error_found = False + for chunk in chunks: + if chunk["choices"][0]["finish_reason"] == "stop": + content = chunk["choices"][0]["delta"].get("content", "") + if "Error:" in content: + error_found = True + break + assert error_found + + def test_request_id_generation(self, client: TestClient): + """Test that unique request IDs are generated.""" + response1 = client.post( + "/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}]} + ) + response2 = client.post( + "/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}]} + ) + + assert response1.status_code == 200 + assert response2.status_code == 200 + + data1 = response1.json() + data2 = response2.json() + + assert data1["id"] != data2["id"] + uuid.UUID(data1["id"]) # Should not raise exception + uuid.UUID(data2["id"]) # Should not raise exception + class TestServerStartup: """Test server startup and configuration.""" - def test_server_startup_with_mocked_dependencies(self, mock_vector_store_config): + def test_server_startup_with_mocked_dependencies(self, mock_vector_store_config: Mock): """Test that server can start with mocked dependencies.""" mock_config_manager = Mock(spec=ConfigManager) @@ -262,15 +387,8 @@ def test_server_startup_with_mocked_dependencies(self, mock_vector_store_config) def test_server_main_function_configuration(self): """Test the server's main function configuration.""" - from cairo_coder.server.app import ( - CairoCoderServer, - TokenTracker, - create_app, - ) - assert create_app is not None assert CairoCoderServer is not None - assert TokenTracker is not None # Test that we can create an app instance with patch("cairo_coder.server.app.create_agent_factory"), patch( @@ -278,8 +396,187 @@ def test_server_main_function_configuration(self): ) as mock_get_config: mock_get_config.return_value = Mock(spec=VectorStoreConfig) app = create_app(mock_get_config()) + assert isinstance(app, FastAPI) - # Verify the app is a FastAPI instance - from fastapi import FastAPI + def test_create_app_with_defaults(self, mock_vector_store_config: Mock): + """Test create_app with default config manager.""" + with ( + patch("cairo_coder.server.app.create_agent_factory"), + patch("cairo_coder.config.manager.ConfigManager") as mock_config_class, + ): + mock_config_class.return_value = Mock() + app = create_app(mock_vector_store_config) - assert isinstance(app, FastAPI) + assert isinstance(app, FastAPI) + + def test_cors_configuration(self, mock_vector_store_config: Mock): + """Test CORS configuration.""" + with patch("cairo_coder.server.app.create_agent_factory"): + app = create_app(mock_vector_store_config) + client = TestClient(app) + + # Test CORS headers + response = client.options( + "/v1/chat/completions", + headers={"Origin": "https://example.com", "Access-Control-Request-Method": "POST"}, + ) + + assert response.status_code in [200, 204] + + def test_app_middleware(self, mock_vector_store_config: Mock): + """Test that app has proper middleware configuration.""" + with patch("cairo_coder.server.app.create_agent_factory"): + app = create_app(mock_vector_store_config) + assert hasattr(app, "middleware_stack") + assert hasattr(app, "middleware") + + def test_app_routes(self, mock_vector_store_config: Mock): + """Test that app has expected routes.""" + with patch("cairo_coder.server.app.create_agent_factory"): + app = create_app(mock_vector_store_config) + routes = [route.path for route in app.routes] # type: ignore + assert "/" in routes + assert "/v1/agents" in routes + assert "/v1/chat/completions" in routes + + +class TestOpenAICompatibility: + """Test suite for OpenAI API compatibility.""" + + def test_openai_chat_completion_response_structure(self, client: TestClient): + """Test that response structure matches OpenAI API.""" + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Hello"}], "stream": False}, + ) + assert response.status_code == 200 + data = response.json() + + required_fields = ["id", "object", "created", "model", "choices", "usage"] + for field in required_fields: + assert field in data + + choice = data["choices"][0] + choice_fields = ["index", "message", "finish_reason"] + for field in choice_fields: + assert field in choice + + message = choice["message"] + message_fields = ["role", "content"] + for field in message_fields: + assert field in message + + usage = data["usage"] + usage_fields = ["prompt_tokens", "completion_tokens", "total_tokens"] + for field in usage_fields: + assert field in usage + + def test_openai_streaming_response_structure(self, client: TestClient): + """Test that streaming response structure matches OpenAI API.""" + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Hello"}], "stream": True}, + ) + assert response.status_code == 200 + + lines = response.text.strip().split("\n") + chunks = [] + for line in lines: + if line.startswith("data: "): + data_str = line[6:] + if data_str != "[DONE]": + chunks.append(json.loads(data_str)) + + for chunk in chunks: + required_fields = ["id", "object", "created", "model", "choices"] + for field in required_fields: + assert field in chunk + assert chunk["object"] == "chat.completion.chunk" + choice = chunk["choices"][0] + choice_fields = ["index", "delta", "finish_reason"] + for field in choice_fields: + assert field in choice + + def test_openai_error_response_structure(self, client: TestClient, mock_agent_factory: Mock): + """Test that error response structure matches OpenAI API.""" + mock_agent_factory.get_agent_info.side_effect = ValueError("Agent not found") + response = client.post( + "/v1/agents/invalid/chat/completions", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + assert response.status_code == 404 + data = response.json() + assert "detail" in data + error = data["detail"]["error"] + error_fields = ["message", "type", "code"] + for field in error_fields: + assert field in error + assert error["type"] == "invalid_request_error" + assert error["code"] == "agent_not_found" + + +class TestMCPModeCompatibility: + """Test suite for MCP mode compatibility with TypeScript backend.""" + + def test_mcp_mode_non_streaming_response(self, client: TestClient): + """Test MCP mode returns sources in non-streaming response.""" + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Test"}], "stream": False}, + headers={"x-mcp-mode": "true"}, + ) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + assert data["choices"][0]["message"]["content"] == "Cairo is a programming language" + + def test_mcp_mode_streaming_response(self, client: TestClient): + """Test MCP mode with streaming response.""" + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Test"}], "stream": True}, + headers={"x-mcp-mode": "true"}, + ) + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get("content-type", "") + + lines = response.text.strip().split("\n") + chunks = [] + for line in lines: + if line.startswith("data: "): + data_str = line[6:] + if data_str != "[DONE]": + chunks.append(json.loads(data_str)) + + assert len(chunks) > 0 + content_found = any(chunk["choices"][0]["delta"].get("content") for chunk in chunks) + assert content_found + + def test_mcp_mode_header_variations(self, client: TestClient): + """Test different MCP mode header variations.""" + # Test x-mcp-mode header + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Test"}]}, + headers={"x-mcp-mode": "true"}, + ) + assert response.status_code == 200 + + # Test mcp header + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Test"}]}, + headers={"mcp": "true"}, + ) + assert response.status_code == 200 + + def test_mcp_mode_agent_specific_endpoint(self, client: TestClient): + """Test MCP mode with agent-specific endpoint.""" + response = client.post( + "/v1/agents/cairo-coder/chat/completions", + json={"messages": [{"role": "user", "content": "Cairo is a programming language"}]}, + headers={"x-mcp-mode": "true"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["choices"][0]["message"]["content"] == "Cairo is a programming language" diff --git a/python/tests/unit/test_agent_factory.py b/python/tests/unit/test_agent_factory.py index 1db22571..f02cad1f 100644 --- a/python/tests/unit/test_agent_factory.py +++ b/python/tests/unit/test_agent_factory.py @@ -96,8 +96,7 @@ def test_create_agent_with_custom_sources(self, mock_vector_store_config): similarity_threshold=0.6, ) - @pytest.mark.asyncio - async def test_create_agent_by_id(self, mock_vector_store_config, mock_config_manager): + def test_create_agent_by_id(self, mock_vector_store_config, mock_config_manager): """Test creating agent by ID.""" query = "How do I create a contract?" history = [Message(role=Role.USER, content="Hello")] @@ -123,8 +122,7 @@ async def test_create_agent_by_id(self, mock_vector_store_config, mock_config_ma mock_config_manager.get_agent_config.assert_called_once_with(config, agent_id) mock_create.assert_called_once() - @pytest.mark.asyncio - async def test_create_agent_by_id_not_found( + def test_create_agent_by_id_not_found( self, mock_vector_store_config, mock_config_manager ): """Test creating agent by ID when agent not found.""" @@ -143,7 +141,6 @@ async def test_create_agent_by_id_not_found( config_manager=mock_config_manager, ) - @pytest.mark.asyncio def test_get_or_create_agent_cache_miss(self, agent_factory): """Test get_or_create_agent with cache miss.""" query = "Test query" @@ -174,8 +171,7 @@ def test_get_or_create_agent_cache_miss(self, agent_factory): assert cache_key in agent_factory._agent_cache assert agent_factory._agent_cache[cache_key] == mock_pipeline - @pytest.mark.asyncio - async def test_get_or_create_agent_cache_hit(self, agent_factory): + def test_get_or_create_agent_cache_hit(self, agent_factory): """Test get_or_create_agent with cache hit.""" query = "Test query" history = [] @@ -229,51 +225,36 @@ def test_get_agent_info_not_found(self, agent_factory): with pytest.raises(ValueError, match="Agent not found"): agent_factory.get_agent_info("nonexistent_agent") - def test_infer_sources_from_query_scarb(self): - """Test inferring sources from Scarb-related query.""" - query = "How do I configure Scarb for my project?" - - sources = AgentFactory._infer_sources_from_query(query) - - assert DocumentSource.SCARB_DOCS in sources - - def test_infer_sources_from_query_foundry(self): - """Test inferring sources from Foundry-related query.""" - query = "How do I use forge test command?" - - sources = AgentFactory._infer_sources_from_query(query) - - assert DocumentSource.STARKNET_FOUNDRY in sources - - def test_infer_sources_from_query_openzeppelin(self): - """Test inferring sources from OpenZeppelin-related query.""" - query = "How do I implement ERC20 token with OpenZeppelin?" - + @pytest.mark.parametrize( + "query, expected_sources", + [ + ("How do I configure Scarb for my project?", [DocumentSource.SCARB_DOCS]), + ("How do I use forge test command?", [DocumentSource.STARKNET_FOUNDRY]), + ( + "How do I implement ERC20 token with OpenZeppelin?", + [DocumentSource.OPENZEPPELIN_DOCS], + ), + ( + "How do I create a function?", + [DocumentSource.CAIRO_BOOK, DocumentSource.STARKNET_DOCS], + ), + ( + "How do I test Cairo contracts with Foundry and OpenZeppelin?", + [ + DocumentSource.STARKNET_FOUNDRY, + DocumentSource.OPENZEPPELIN_DOCS, + DocumentSource.CAIRO_BOOK, + ], + ), + ], + ) + def test_infer_sources_from_query(self, query, expected_sources): + """Test inferring sources from various queries.""" sources = AgentFactory._infer_sources_from_query(query) + for expected in expected_sources: + assert expected in sources - assert DocumentSource.OPENZEPPELIN_DOCS in sources - - def test_infer_sources_from_query_default(self): - """Test inferring sources from generic query.""" - query = "How do I create a function?" - - sources = AgentFactory._infer_sources_from_query(query) - - assert DocumentSource.CAIRO_BOOK in sources - assert DocumentSource.STARKNET_DOCS in sources - - def test_infer_sources_from_query_multiple(self): - """Test inferring sources from query with multiple relevant sources.""" - query = "How do I test Cairo contracts with Foundry and OpenZeppelin?" - - sources = AgentFactory._infer_sources_from_query(query) - - assert DocumentSource.STARKNET_FOUNDRY in sources - assert DocumentSource.OPENZEPPELIN_DOCS in sources - assert DocumentSource.CAIRO_BOOK in sources - - @pytest.mark.asyncio - async def test_create_pipeline_from_config_general(self, mock_vector_store_config): + def test_create_pipeline_from_config_general(self, mock_vector_store_config): """Test creating pipeline from general agent configuration.""" agent_config = AgentConfiguration( id="general_agent", @@ -309,8 +290,7 @@ async def test_create_pipeline_from_config_general(self, mock_vector_store_confi vector_db=None, ) - @pytest.mark.asyncio - async def test_create_pipeline_from_config_scarb(self, mock_vector_store_config): + def test_create_pipeline_from_config_scarb(self, mock_vector_store_config): """Test creating pipeline from Scarb agent configuration.""" agent_config = AgentConfiguration( id="scarb-assistant", diff --git a/python/tests/unit/test_config.py b/python/tests/unit/test_config.py index e883e0eb..4ab6fff5 100644 --- a/python/tests/unit/test_config.py +++ b/python/tests/unit/test_config.py @@ -1,12 +1,8 @@ """Tests for configuration management.""" -import os -import tempfile -from collections.abc import Generator from pathlib import Path import pytest -import toml from cairo_coder.config.manager import ConfigManager from cairo_coder.core.config import AgentConfiguration @@ -16,57 +12,14 @@ class TestConfigManager: """Test configuration manager functionality.""" - @pytest.fixture(autouse=True) - def mock_config_file(self) -> Generator[Path, None, None]: - """Create a sample config file for testing.""" - config_data = { - "VECTOR_DB": { - "POSTGRES_HOST": "db.example.com", - "POSTGRES_PORT": 5433, - "POSTGRES_DB": "test_db", - "POSTGRES_USER": "test_user", - "POSTGRES_PASSWORD": "test_password", - "POSTGRES_TABLE_NAME": "test_table", - "SIMILARITY_MEASURE": "cosine", - }, - "providers": { - "default": "anthropic", - "anthropic": { - "api_key": "test-key", - "model": "claude-3-opus", - }, - }, - "agents": { - # "cairo-coder": { - # "id": "cairo-coder", - # "name": "Cairo Coder", - # "description": "General Cairo programming assistant", - # "sources": [ - # DocumentSource.CAIRO_BOOK.value, - # "starknet-docs", - # "cairo-by-example", - # "corelib-docs", - # ], - # "contract_template": "You are helping write a Cairo smart contract. Consider: - Contract structure with #[contract] attribute - Storage variables and access patterns - External/view functions and their signatures - Event definitions and emissions - Error handling and custom errors - Interface implementations", - # }, - }, - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) as f: - toml.dump(config_data, f) - temp_path = Path(f.name) - - yield temp_path - - # Cleanup - os.unlink(temp_path) - def test_load_config_fails_if_no_config_file(self) -> None: """Test loading configuration with no config file.""" with pytest.raises(FileNotFoundError, match="Configuration file not found at"): ConfigManager.load_config(Path("nonexistent.toml")) - def test_load_toml_config(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_load_toml_config( + self, monkeypatch: pytest.MonkeyPatch, sample_config_file: Path + ) -> None: """Test loading configuration from TOML file.""" # Clear environment variables that might interfere monkeypatch.delenv("POSTGRES_HOST", raising=False) @@ -78,40 +31,14 @@ def test_load_toml_config(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("OPENAI_API_KEY", raising=False) monkeypatch.delenv("GOOGLE_API_KEY", raising=False) - config_data = { - "VECTOR_DB": { - "POSTGRES_HOST": "db.example.com", - "POSTGRES_PORT": 5433, - "POSTGRES_DB": "test_db", - "POSTGRES_USER": "test_user", - "POSTGRES_PASSWORD": "test_password", - "POSTGRES_TABLE_NAME": "test_table", - "SIMILARITY_MEASURE": "cosine", - }, - "providers": { - "default": "anthropic", - "anthropic": { - "api_key": "test-key", - "model": "claude-3-opus", - }, - }, - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) as f: - toml.dump(config_data, f) - temp_path = Path(f.name) - - try: - config = ConfigManager.load_config(temp_path) - - assert config.vector_store.host == "db.example.com" - assert config.vector_store.port == 5433 - assert config.vector_store.database == "test_db" - finally: - temp_path.unlink() + config = ConfigManager.load_config(sample_config_file) + + assert config.vector_store.host == "test-db.example.com" + assert config.vector_store.port == 5433 + assert config.vector_store.database == "test_cairo" def test_environment_override( - self, monkeypatch: pytest.MonkeyPatch, mock_config_file: Path + self, monkeypatch: pytest.MonkeyPatch, sample_config_file: Path ) -> None: """Test environment variable overrides.""" # Set environment variables @@ -124,7 +51,7 @@ def test_environment_override( monkeypatch.setenv("ANTHROPIC_API_KEY", "env-anthropic-key") monkeypatch.setenv("GEMINI_API_KEY", "env-gemini-key") - config = ConfigManager.load_config(mock_config_file) + config = ConfigManager.load_config(sample_config_file) # Check environment overrides assert config.vector_store.host == "env-host" @@ -133,9 +60,9 @@ def test_environment_override( assert config.vector_store.user == "env-user" assert config.vector_store.password == "env-pass" - def test_get_agent_config(self, mock_config_file: Path) -> None: + def test_get_agent_config(self, sample_config_file: Path) -> None: """Test retrieving agent configuration.""" - config = ConfigManager.load_config(mock_config_file) + config = ConfigManager.load_config(sample_config_file) # Get default agent agent = ConfigManager.get_agent_config(config, "cairo-coder") @@ -153,21 +80,21 @@ def test_get_agent_config(self, mock_config_file: Path) -> None: with pytest.raises(ValueError, match="Agent 'unknown' not found"): ConfigManager.get_agent_config(config, "unknown") - def test_validate_config(self, mock_config_file: Path) -> None: + def test_validate_config(self, sample_config_file: Path) -> None: """Test configuration validation.""" # Valid config - config = ConfigManager.load_config(mock_config_file) + config = ConfigManager.load_config(sample_config_file) config.vector_store.password = "test-pass" ConfigManager.validate_config(config) # No database password - config = ConfigManager.load_config(mock_config_file) + config = ConfigManager.load_config(sample_config_file) config.vector_store.password = "" with pytest.raises(ValueError, match="Database password is required"): ConfigManager.validate_config(config) # Agent without sources - config = ConfigManager.load_config(mock_config_file) + config = ConfigManager.load_config(sample_config_file) config.vector_store.password = "test-pass" config.agents["test"] = AgentConfiguration( id="test", name="Test", description="Test agent", sources=[] @@ -176,16 +103,16 @@ def test_validate_config(self, mock_config_file: Path) -> None: ConfigManager.validate_config(config) # Invalid default agent - config = ConfigManager.load_config(mock_config_file) + config = ConfigManager.load_config(sample_config_file) config.vector_store.password = "test-pass" config.default_agent_id = "unknown" config.agents = {} # No agents with pytest.raises(ValueError, match="Default agent 'unknown' not found"): ConfigManager.validate_config(config) - def test_dsn_property(self, mock_config_file: Path) -> None: + def test_dsn_property(self, sample_config_file: Path) -> None: """Test PostgreSQL DSN generation.""" - config = ConfigManager.load_config(mock_config_file) + config = ConfigManager.load_config(sample_config_file) config.vector_store.user = "testuser" config.vector_store.password = "testpass" config.vector_store.host = "testhost" diff --git a/python/tests/unit/test_document_retriever.py b/python/tests/unit/test_document_retriever.py index a4ac747f..bbbd0cd8 100644 --- a/python/tests/unit/test_document_retriever.py +++ b/python/tests/unit/test_document_retriever.py @@ -14,67 +14,22 @@ from cairo_coder.dspy.document_retriever import DocumentRetrieverProgram -@pytest.fixture(scope='function') -def mock_pgvector_rm(mock_dspy_examples: list[dspy.Example]): - """Patch the vector database for the document retriever.""" - with patch("cairo_coder.dspy.document_retriever.SourceFilteredPgVectorRM") as mock_pgvector_rm: - mock_instance = Mock() - mock_instance.aforward = AsyncMock(return_value=mock_dspy_examples) - mock_instance.forward = Mock(return_value=mock_dspy_examples) - mock_pgvector_rm.return_value = mock_instance - yield mock_pgvector_rm - - -@pytest.fixture(scope='session') -def mock_embedder(): - """Mock the embedder.""" - with patch("cairo_coder.dspy.document_retriever.dspy.Embedder") as mock_embedder: - mock_embedder.return_value = Mock() - yield mock_embedder - class TestDocumentRetrieverProgram: """Test suite for DocumentRetrieverProgram.""" - @pytest.fixture(scope='session') - def enhanced_sample_documents(self): - """Create enhanced sample documents for testing with additional metadata.""" - return [ - Document( - page_content="Cairo is a programming language for writing provable programs.", - metadata={"source": "cairo_book", "score": 0.9, "chapter": 1}, - ), - Document( - page_content="Starknet is a validity rollup (also known as a ZK rollup).", - metadata={"source": "starknet_docs", "score": 0.8, "section": "overview"}, - ), - Document( - page_content="OpenZeppelin provides secure smart contract libraries for Cairo.", - metadata={"source": "openzeppelin_docs", "score": 0.7}, - ), - ] - - @pytest.fixture(scope='session') - def sample_processed_query(self): - """Create a sample processed query.""" - return ProcessedQuery( - original="How do I create a Cairo contract?", - search_queries=["cairo", "contract", "create"], - reasoning="I need to create a Cairo contract", - is_contract_related=True, - is_test_related=False, - resources=[DocumentSource.CAIRO_BOOK, DocumentSource.STARKNET_DOCS], - ) - - @pytest.fixture(scope='function') - def retriever(self, mock_vector_store_config: VectorStoreConfig, mock_pgvector_rm: Mock) -> DocumentRetrieverProgram: + @pytest.fixture(scope="function") + def retriever( + self, mock_vector_store_config: VectorStoreConfig, mock_vector_db: Mock + ) -> DocumentRetrieverProgram: """Create a DocumentRetrieverProgram instance.""" return DocumentRetrieverProgram( vector_store_config=mock_vector_store_config, + vector_db=mock_vector_db, max_source_count=5, similarity_threshold=0.4, ) - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def mock_dspy_examples(self, sample_documents: list[Document]) -> list[dspy.Example]: """Create mock DSPy Example objects from sample documents.""" examples = [] @@ -89,48 +44,31 @@ def mock_dspy_examples(self, sample_documents: list[Document]) -> list[dspy.Exam async def test_basic_document_retrieval( self, retriever: DocumentRetrieverProgram, - mock_vector_store_config: VectorStoreConfig, mock_dspy_examples: list[dspy.Example], sample_processed_query: ProcessedQuery, - mock_pgvector_rm: Mock, - mock_embedder: Mock, ): """Test basic document retrieval using DSPy PgVectorRM.""" + retriever.vector_db.aforward.return_value = mock_dspy_examples - # Mock dspy module - mock_dspy = Mock() - - with patch("cairo_coder.dspy.document_retriever.dspy", mock_dspy): - # Execute retrieval - use async version since we're in async test - result = await retriever.aforward(sample_processed_query) - - # Verify results - assert len(result) != 0 - assert all(isinstance(doc, Document) for doc in result) - - # Verify SourceFilteredPgVectorRM was instantiated correctly - mock_pgvector_rm.assert_called_once_with( - db_url=mock_vector_store_config.dsn, - pg_table_name=mock_vector_store_config.table_name, - embedding_func=mock_embedder.return_value, - content_field="content", - fields=["id", "content", "metadata"], - k=5, # max_source_count - embedding_model='text-embedding-3-large', - include_similarity=True, - ) + # Execute retrieval - use async version since we're in async test + result = await retriever.aforward(sample_processed_query) - # Verify retriever was called with proper query - # Since we're using async, check aforward was called - assert mock_pgvector_rm().aforward.call_count == len(sample_processed_query.search_queries) - # Check it was called with each search query - for query in sample_processed_query.search_queries: - mock_pgvector_rm().aforward.assert_any_call(query=query, sources=sample_processed_query.resources) + # Verify results + assert len(result) != 0 + assert all(isinstance(doc, Document) for doc in result) + + # Verify retriever was called with proper query + assert retriever.vector_db.aforward.call_count == len( + sample_processed_query.search_queries + ) + # Check it was called with each search query + for query in sample_processed_query.search_queries: + retriever.vector_db.aforward.assert_any_call( + query=query, sources=sample_processed_query.resources + ) @pytest.mark.asyncio - async def test_retrieval_with_empty_transformed_terms( - self, retriever: DocumentRetrieverProgram, mock_pgvector_rm: Mock - ): + async def test_retrieval_with_empty_transformed_terms(self, retriever: DocumentRetrieverProgram): """Test retrieval when transformed terms list is empty.""" query = ProcessedQuery( original="Simple query", @@ -141,66 +79,43 @@ async def test_retrieval_with_empty_transformed_terms( resources=[DocumentSource.CAIRO_BOOK], ) - # Mock dspy module - mock_dspy = Mock() - mock_settings = Mock() - mock_settings.configure = Mock() - mock_dspy.settings = mock_settings - - with patch("cairo_coder.dspy.document_retriever.dspy", mock_dspy): - result = await retriever.aforward(query) + result = await retriever.aforward(query) - # Should still work with empty transformed terms - assert len(result) != 0 + # Should still work with empty transformed terms + assert len(result) != 0 - # Query should just be the reasoning with empty tags - expected_query = "Simple reasoning" - mock_pgvector_rm().aforward.assert_called_once_with(query=expected_query, sources=query.resources) + # Query should just be the reasoning with empty tags + expected_query = query.reasoning + retriever.vector_db.aforward.assert_called_with( + query=expected_query, sources=query.resources + ) @pytest.mark.asyncio - async def test_retrieval_with_custom_sources( - self, retriever, sample_processed_query, mock_pgvector_rm: Mock - ): + async def test_retrieval_with_custom_sources(self, retriever, sample_processed_query): """Test retrieval with custom source filtering.""" # Override sources custom_sources = [DocumentSource.SCARB_DOCS, DocumentSource.OPENZEPPELIN_DOCS] - # Mock dspy module - mock_dspy = Mock() - mock_settings = Mock() - mock_settings.configure = Mock() - mock_dspy.settings = mock_settings + result = await retriever.aforward(sample_processed_query, sources=custom_sources) - with patch("cairo_coder.dspy.document_retriever.dspy", mock_dspy): - result = await retriever.aforward(sample_processed_query, sources=custom_sources) + # Verify result + assert len(result) != 0 - # Verify result - assert len(result) != 0 - - # Note: sources filtering is not currently implemented in PgVectorRM call - # This test ensures the method still works when sources are provided - mock_pgvector_rm().aforward.assert_called() + # Note: sources filtering is not currently implemented in PgVectorRM call + # This test ensures the method still works when sources are provided + retriever.vector_db.aforward.assert_called() @pytest.mark.asyncio - async def test_empty_document_handling(self, retriever, sample_processed_query, mock_pgvector_rm: Mock): + async def test_empty_document_handling(self, retriever, sample_processed_query): """Test handling of empty document results.""" retriever.vector_db.aforward = AsyncMock(return_value=[]) - # Mock dspy module - mock_dspy = Mock() - mock_settings = Mock() - mock_settings.configure = Mock() - mock_dspy.settings = mock_settings - - with patch("cairo_coder.dspy.document_retriever.dspy", mock_dspy): - result = await retriever.aforward(sample_processed_query) + result = await retriever.aforward(sample_processed_query) - assert result == [] + assert result == [] @pytest.mark.asyncio - async def test_pgvector_rm_error_handling( - self, retriever, sample_processed_query - ): + async def test_pgvector_rm_error_handling(self, retriever, sample_processed_query): """Test handling of PgVectorRM instantiation errors.""" # Mock PgVectorRM to raise an exception retriever.vector_db.aforward.side_effect = Exception("Database connection error") @@ -211,24 +126,15 @@ async def test_pgvector_rm_error_handling( assert "Database connection error" in str(exc_info.value) @pytest.mark.asyncio - async def test_retriever_call_error_handling( - self, retriever, sample_processed_query, mock_pgvector_rm: Mock - ): + async def test_retriever_call_error_handling(self, retriever, sample_processed_query): """Test handling of retriever call errors.""" retriever.vector_db.aforward.side_effect = Exception("Query execution error") - # Mock dspy module - mock_dspy = Mock() - mock_settings = Mock() - mock_settings.configure = Mock() - mock_dspy.settings = mock_settings - - with patch("cairo_coder.dspy.document_retriever.dspy", mock_dspy): - with pytest.raises(Exception) as exc_info: - await retriever.aforward(sample_processed_query) + with pytest.raises(Exception) as exc_info: + await retriever.aforward(sample_processed_query) - assert "Query execution error" in str(exc_info.value) + assert "Query execution error" in str(exc_info.value) @pytest.mark.asyncio async def test_max_source_count_configuration( @@ -250,7 +156,7 @@ async def test_document_conversion( self, retriever: DocumentRetrieverProgram, sample_processed_query: ProcessedQuery, - mock_pgvector_rm: Mock + mock_vector_db: Mock, ): """Test conversion from DSPy Examples to Document objects.""" @@ -269,238 +175,76 @@ async def test_document_conversion( retriever.vector_db.aforward = AsyncMock(return_value=mock_examples) - # Mock dspy module - mock_dspy = Mock() - mock_settings = Mock() - mock_settings.configure = Mock() - mock_dspy.settings = mock_settings - - with patch("cairo_coder.dspy.document_retriever.dspy", mock_dspy): - result = await retriever.aforward(sample_processed_query) + result = await retriever.aforward(sample_processed_query) - # Verify conversion to Document objects - # Ran 3 times the query, returned 2 docs each - but de-duped - mock_pgvector_rm().aforward.assert_has_calls( - [call(query=query, sources=sample_processed_query.resources) for query in sample_processed_query.search_queries], - any_order=True, - ) - - # Verify conversion to Document objects - assert len(result) == len(expected_docs) + 1 # (Contract template) - - # Convert result to (content, metadata) tuples for comparison - result_tuples = [(doc.page_content, doc.metadata) for doc in result] - - # Check that all expected documents are present (order doesn't matter) - for expected_content, expected_metadata in expected_docs: - assert (expected_content, expected_metadata) in result_tuples - - @pytest.mark.asyncio - async def test_contract_context_enhancement( - self, retriever, mock_vector_store_config, mock_dspy_examples - ): - """Test context enhancement for contract-related queries.""" - # Create a contract-related query - query = ProcessedQuery( - original="How do I create a contract with storage?", - search_queries=["contract", "storage"], - reasoning="I need to create a contract with storage", - is_contract_related=True, - is_test_related=False, - resources=[DocumentSource.CAIRO_BOOK], + # Verify conversion to Document objects + # Ran 3 times the query, returned 2 docs each - but de-duped + mock_vector_db.aforward.assert_has_calls( + [ + call(query=query, sources=sample_processed_query.resources) + for query in sample_processed_query.search_queries + ], + any_order=True, ) - # Mock Embedder - with patch("cairo_coder.dspy.document_retriever.dspy.Embedder") as mock_embedder: - mock_embedder.return_value = Mock() - - with patch( - "cairo_coder.dspy.document_retriever.SourceFilteredPgVectorRM" - ) as mock_pgvector_rm: - mock_retriever_instance = Mock(return_value=mock_dspy_examples) - mock_retriever_instance.forward = Mock(return_value=mock_dspy_examples) - mock_retriever_instance.aforward = AsyncMock(return_value=mock_dspy_examples) - mock_pgvector_rm.return_value = mock_retriever_instance - - # Mock dspy module - mock_dspy = Mock() - mock_settings = Mock() - mock_settings.configure = Mock() - mock_dspy.settings = mock_settings - - with patch("cairo_coder.dspy.document_retriever.dspy", mock_dspy): - result = await retriever.aforward(query) - - # Verify contract template was added to context - contract_template_found = False - for doc in result: - if doc.metadata.get("source") == "contract_template": - contract_template_found = True - # Verify it contains the contract template content - assert "The content inside the tag" in doc.page_content - assert "#[starknet::contract]" in doc.page_content - assert "#[storage]" in doc.page_content - break - - assert contract_template_found, ( - "Contract template should be added for contract-related queries" - ) + # Verify conversion to Document objects + assert len(result) == len(expected_docs) + 1 # (Contract template) - @pytest.mark.asyncio - async def test_test_context_enhancement( - self, retriever, mock_vector_store_config, mock_dspy_examples - ): - """Test context enhancement for test-related queries.""" - # Create a test-related query - query = ProcessedQuery( - original="How do I write tests for Cairo contracts?", - search_queries=["test", "cairo"], - reasoning="I need to write tests for a Cairo contract", - is_contract_related=False, - is_test_related=True, - resources=[DocumentSource.CAIRO_BOOK], - ) + # Convert result to (content, metadata) tuples for comparison + result_tuples = [(doc.page_content, doc.metadata) for doc in result] - # Mock Embedder - with patch("cairo_coder.dspy.document_retriever.dspy.Embedder") as mock_embedder: - mock_embedder.return_value = Mock() - - with patch( - "cairo_coder.dspy.document_retriever.SourceFilteredPgVectorRM" - ) as mock_pgvector_rm: - mock_retriever_instance = Mock(return_value=mock_dspy_examples) - mock_retriever_instance.forward = Mock(return_value=mock_dspy_examples) - mock_retriever_instance.aforward = AsyncMock(return_value=mock_dspy_examples) - mock_pgvector_rm.return_value = mock_retriever_instance - - # Mock dspy module - mock_dspy = Mock() - mock_settings = Mock() - mock_settings.configure = Mock() - mock_dspy.settings = mock_settings - - with patch("cairo_coder.dspy.document_retriever.dspy", mock_dspy): - result = await retriever.aforward(query) - - # Verify test template was added to context - test_template_found = False - for doc in result: - if doc.metadata.get("source") == "test_template": - test_template_found = True - # Verify it contains the test template content - assert ( - "The content inside the tag is the test code for the 'Registry' contract. It is assumed" - in doc.page_content - ) - assert ( - "that the contract is part of a package named 'registry'. When writing tests, follow the important rules." - in doc.page_content - ) - assert "#[test]" in doc.page_content - assert "assert(" in doc.page_content - break - - assert test_template_found, ( - "Test template should be added for test-related queries" - ) + # Check that all expected documents are present (order doesn't matter) + for expected_content, expected_metadata in expected_docs: + assert (expected_content, expected_metadata) in result_tuples + @pytest.mark.parametrize( + "query_str, query_details, expected_templates", + [ + ( + "Some query", + {"is_contract_related": True, "is_test_related": False}, + ["contract_template"], + ), + ( + "Some query", + {"is_contract_related": False, "is_test_related": True}, + ["test_template"], + ), + ( + "Some query", + {"is_contract_related": True, "is_test_related": True}, + ["contract_template", "test_template"], + ), + ( + "Some other query", + {"is_contract_related": False, "is_test_related": False}, + [], + ), + ("Query with contract and test in string", {"is_contract_related": False, "is_test_related": False}, ["contract_template", "test_template"]), + ], + ) @pytest.mark.asyncio - async def test_both_templates_enhancement( - self, retriever, mock_vector_store_config, mock_dspy_examples + async def test_context_enhancement( + self, retriever, mock_vector_db, mock_dspy_examples, query_str, query_details, expected_templates ): - """Test context enhancement when query relates to both contracts and tests.""" - # Create a query that mentions both contracts and tests + """Test context enhancement for contract-related and test-related queries.""" query = ProcessedQuery( - original="How do I create a contract and write tests for it?", - search_queries=["contract", "test"], - reasoning="I need to create a contract and write tests for it", - is_contract_related=True, - is_test_related=True, + original=query_str, + search_queries=["None"], + reasoning="Some reasoning", resources=[DocumentSource.CAIRO_BOOK], + **query_details, ) + mock_vector_db.aforward.return_value = mock_dspy_examples - # Mock Embedder - with patch("cairo_coder.dspy.document_retriever.dspy.Embedder") as mock_embedder: - mock_embedder.return_value = Mock() - - with patch( - "cairo_coder.dspy.document_retriever.SourceFilteredPgVectorRM" - ) as mock_pgvector_rm: - mock_retriever_instance = Mock(return_value=mock_dspy_examples) - mock_retriever_instance.forward = Mock(return_value=mock_dspy_examples) - mock_retriever_instance.aforward = AsyncMock(return_value=mock_dspy_examples) - mock_pgvector_rm.return_value = mock_retriever_instance - - # Mock dspy module - mock_dspy = Mock() - mock_settings = Mock() - mock_settings.configure = Mock() - mock_dspy.settings = mock_settings - - with patch("cairo_coder.dspy.document_retriever.dspy", mock_dspy): - result = await retriever.aforward(query) - - # Verify both templates were added - contract_template_found = False - test_template_found = False - - for doc in result: - if doc.metadata.get("source") == "contract_template": - contract_template_found = True - elif doc.metadata.get("source") == "test_template": - test_template_found = True - - assert contract_template_found, ( - "Contract template should be added for contract-related queries" - ) - assert test_template_found, ( - "Test template should be added for test-related queries" - ) - - @pytest.mark.asyncio - async def test_no_template_enhancement( - self, retriever, mock_vector_store_config, mock_dspy_examples - ): - """Test that no templates are added for unrelated queries.""" - # Create a query that's not related to contracts or tests - query = ProcessedQuery( - original="What is Cairo programming language?", - search_queries=["cairo", "programming"], - reasoning="I need to know what Cairo is", - is_contract_related=False, - is_test_related=False, - resources=[DocumentSource.CAIRO_BOOK], - ) + result = await retriever.aforward(query) - # Mock Embedder - with patch("cairo_coder.dspy.document_retriever.dspy.Embedder") as mock_embedder: - mock_embedder.return_value = Mock() - - with patch( - "cairo_coder.dspy.document_retriever.SourceFilteredPgVectorRM" - ) as mock_pgvector_rm: - mock_retriever_instance = Mock(return_value=mock_dspy_examples) - mock_retriever_instance.forward = Mock(return_value=mock_dspy_examples) - mock_retriever_instance.aforward = AsyncMock(return_value=mock_dspy_examples) - mock_pgvector_rm.return_value = mock_retriever_instance - - # Mock dspy module - mock_dspy = Mock() - mock_settings = Mock() - mock_settings.configure = Mock() - mock_dspy.settings = mock_settings - - with patch("cairo_coder.dspy.document_retriever.dspy", mock_dspy): - result = await retriever.aforward(query) - - # Verify no templates were added - template_sources = [doc.metadata.get("source") for doc in result] - assert "contract_template" not in template_sources, ( - "Contract template should not be added for non-contract queries" - ) - assert "test_template" not in template_sources, ( - "Test template should not be added for non-test queries" - ) + found_templates = { + doc.metadata.get("source") + for doc in result + if "template" in doc.metadata.get("source", "") + } + assert set(expected_templates) == found_templates class TestDocumentRetrieverFactory: diff --git a/python/tests/unit/test_generation_program.py b/python/tests/unit/test_generation_program.py index c7f2d08a..43784c26 100644 --- a/python/tests/unit/test_generation_program.py +++ b/python/tests/unit/test_generation_program.py @@ -5,7 +5,7 @@ Scarb configuration, and MCP mode document formatting. """ -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, patch import dspy import pytest @@ -22,27 +22,6 @@ ) -@pytest.fixture(scope="function") -def mock_lm(): - """Configure DSPy with a mock language model for testing.""" - mock = Mock() - # Mock for sync calls - mock.forward.return_value = dspy.Prediction( - answer="Here's a Cairo contract example:\n\n```cairo\n#[starknet::contract]\nmod SimpleContract {\n // Contract implementation\n}\n```\n\nThis contract demonstrates basic Cairo syntax." - ) - mock.return_value = dspy.Prediction( - answer="Here's a Cairo contract example:\n\n```cairo\n#[starknet::contract]\nmod SimpleContract {\n // Contract implementation\n}\n```\n\nThis contract demonstrates basic Cairo syntax." - ) - # Mock for async calls - use AsyncMock for coroutine - mock.aforward = AsyncMock(return_value=dspy.Prediction( - answer="Here's a Cairo contract example:\n\n```cairo\n#[starknet::contract]\nmod SimpleContract {\n // Contract implementation\n}\n```\n\nThis contract demonstrates basic Cairo syntax." - )) - - with patch("dspy.ChainOfThought") as mock_cot: - mock_cot.return_value = mock - yield mock - - async def call_program(program, method, *args, **kwargs): """Helper to call sync or async method on a program.""" if method == "aforward": @@ -68,30 +47,6 @@ def mcp_generation_program(self): """Create an MCP GenerationProgram instance.""" return McpGenerationProgram() - @pytest.fixture - def sample_documents(self): - """Create sample documents for testing.""" - return [ - Document( - page_content="Cairo contracts are defined using #[starknet::contract] attribute.", - metadata={ - "source": "cairo_book", - "title": "Cairo Contracts", - "url": "https://book.cairo-lang.org/contracts", - "source_display": "Cairo Book", - }, - ), - Document( - page_content="Storage variables are defined with #[storage] attribute.", - metadata={ - "source": "starknet_docs", - "title": "Storage Variables", - "url": "https://docs.starknet.io/storage", - "source_display": "Starknet Documentation", - }, - ), - ] - @pytest.mark.parametrize("call_method", ["forward", "aforward"]) @pytest.mark.asyncio async def test_general_code_generation(self, generation_program, call_method): @@ -205,30 +160,6 @@ def mcp_program(self): """Create an MCP GenerationProgram instance.""" return McpGenerationProgram() - @pytest.fixture - def sample_documents(self): - """Create sample documents for testing.""" - return [ - Document( - page_content="Cairo contracts are defined using #[starknet::contract] attribute.", - metadata={ - "source": "cairo_book", - "title": "Cairo Contracts", - "url": "https://book.cairo-lang.org/contracts", - "source_display": "Cairo Book", - }, - ), - Document( - page_content="Storage variables are defined with #[storage] attribute.", - metadata={ - "source": "starknet_docs", - "title": "Storage Variables", - "url": "https://docs.starknet.io/storage", - "source_display": "Starknet Documentation", - }, - ), - ] - def test_mcp_document_formatting(self, mcp_program, sample_documents): """Test MCP mode document formatting.""" answer = mcp_program.forward(sample_documents).answer @@ -236,17 +167,20 @@ def test_mcp_document_formatting(self, mcp_program, sample_documents): assert isinstance(answer, str) assert len(answer) > 0 - # Verify document structure - assert "## 1. Cairo Contracts" in answer - assert "## 2. Storage Variables" in answer - assert "**Source:** Cairo Book" in answer - assert "**Source:** Starknet Documentation" in answer - assert "**URL:** https://book.cairo-lang.org/contracts" in answer - assert "**URL:** https://docs.starknet.io/storage" in answer - - # Verify content is included - assert "starknet::contract" in answer - assert "#[storage]" in answer + # Verify document structure is present + for i, doc in enumerate(sample_documents, 1): + assert f"## {i}." in answer + + # Check source display + source_display = doc.metadata.get("source_display", "Unknown Source") + assert f"**Source:** {source_display}" in answer + + # Check URL + url = doc.metadata.get("url", "#") + assert f"**URL:** {url}" in answer + + # Check content is included + assert doc.page_content in answer def test_mcp_empty_documents(self, mcp_program): """Test MCP mode with empty documents.""" diff --git a/python/tests/unit/test_openai_server.py b/python/tests/unit/test_openai_server.py index 3c117830..9e515c3d 100644 --- a/python/tests/unit/test_openai_server.py +++ b/python/tests/unit/test_openai_server.py @@ -9,10 +9,9 @@ import uuid from unittest.mock import Mock, patch -import pytest from fastapi import FastAPI +from fastapi.testclient import TestClient -from cairo_coder.core.agent_factory import AgentFactory from cairo_coder.core.types import StreamEvent, StreamEventType from cairo_coder.server.app import create_app @@ -20,42 +19,21 @@ class TestCairoCoderServer: """Test suite for CairoCoderServer class.""" - @pytest.fixture - def mock_agent_factory(self, mock_agent): - """Patch create_agent_factory and return the mock factory.""" - with patch("cairo_coder.server.app.create_agent_factory") as mock_factory_creator: - factory = Mock(spec=AgentFactory) - factory.get_available_agents.return_value = ["cairo-coder"] - factory.get_agent_info.return_value = { - "id": "cairo-coder", - "name": "Cairo Coder", - "description": "Cairo programming assistant", - "sources": ["cairo-docs"], - } - factory.get_or_create_agent.return_value = mock_agent - factory.create_agent.return_value = mock_agent - factory.get_or_create_agent = Mock(return_value=mock_agent) - mock_factory_creator.return_value = factory - - yield factory - def test_health_check(self, client): """Test health check endpoint.""" response = client.get("/") assert response.status_code == 200 assert response.json() == {"status": "ok"} - def test_list_agents(self, client): + def test_list_agents(self, client, sample_agent_configs): """Test listing available agents.""" response = client.get("/v1/agents") assert response.status_code == 200 data = response.json() - assert len(data) == 1 - assert data[0]["id"] == "cairo-coder" - assert data[0]["name"] == "Cairo Coder" - assert data[0]["description"] == "Cairo programming assistant" - assert data[0]["sources"] == ["cairo-docs"] + assert len(data) == len(sample_agent_configs) + agent_ids = {agent["id"] for agent in data} + assert "cairo-coder" in agent_ids def test_list_agents_error_handling(self, client, mock_agent_factory): """Test error handling in list agents endpoint.""" @@ -337,67 +315,48 @@ def test_create_app_with_defaults(self, mock_vector_store_config): assert isinstance(app, FastAPI) + def test_cors_configuration(self, mock_vector_store_config): + """Test CORS configuration.""" + with patch("cairo_coder.server.app.create_agent_factory"): + app = create_app(mock_vector_store_config) + client = TestClient(app) -class TestTokenTracker: - """Test suite for TokenTracker class.""" - - def test_track_tokens_new_session(self): - """Test tracking tokens for a new session.""" - from cairo_coder.server.app import TokenTracker - - tracker = TokenTracker() - tracker.track_tokens("session1", 10, 20) - - usage = tracker.get_session_usage("session1") - assert usage["prompt_tokens"] == 10 - assert usage["completion_tokens"] == 20 - assert usage["total_tokens"] == 30 + # Test CORS headers + response = client.options( + "/v1/chat/completions", + headers={"Origin": "https://example.com", "Access-Control-Request-Method": "POST"}, + ) - def test_track_tokens_existing_session(self): - """Test tracking tokens for an existing session.""" - from cairo_coder.server.app import TokenTracker + assert response.status_code in [200, 204] - tracker = TokenTracker() - tracker.track_tokens("session1", 10, 20) - tracker.track_tokens("session1", 5, 15) + def test_app_middleware(self, mock_vector_store_config): + """Test that app has proper middleware configuration.""" + with patch("cairo_coder.server.app.create_agent_factory"): + app = create_app(mock_vector_store_config) - usage = tracker.get_session_usage("session1") - assert usage["prompt_tokens"] == 15 - assert usage["completion_tokens"] == 35 - assert usage["total_tokens"] == 50 + # Check that middleware is properly configured + # FastAPI apps have middleware, but middleware_stack might be None until build + assert hasattr(app, "middleware_stack") + # Check that CORS middleware was added by verifying the middleware property exists + assert hasattr(app, "middleware") - def test_get_session_usage_nonexistent(self): - """Test getting usage for non-existent session.""" - from cairo_coder.server.app import TokenTracker + def test_app_routes(self, mock_vector_store_config): + """Test that app has expected routes.""" + with patch("cairo_coder.server.app.create_agent_factory"): + app = create_app(mock_vector_store_config) - tracker = TokenTracker() - usage = tracker.get_session_usage("nonexistent") + # Get all routes + routes = [route.path for route in app.routes] # type: ignore - assert usage["prompt_tokens"] == 0 - assert usage["completion_tokens"] == 0 - assert usage["total_tokens"] == 0 + # Check expected routes exist + assert "/" in routes + assert "/v1/agents" in routes + assert "/v1/chat/completions" in routes class TestOpenAICompatibility: """Test suite for OpenAI API compatibility.""" - @pytest.fixture - def mock_agent_factory(self, mock_agent): - """Patch create_agent_factory and return the mock factory.""" - with patch("cairo_coder.server.app.create_agent_factory") as mock_factory_creator: - factory = Mock(spec=AgentFactory) - factory.get_available_agents.return_value = ["cairo-coder"] - factory.get_agent_info.return_value = { - "id": "cairo-coder", - "name": "Cairo Coder", - "description": "Cairo programming assistant", - "sources": ["cairo-docs"], - } - factory.create_agent.return_value = mock_agent - factory.get_or_create_agent = Mock(return_value=mock_agent) - mock_factory_creator.return_value = factory - yield factory - def test_openai_chat_completion_response_structure(self, client): """Test that response structure matches OpenAI API.""" response = client.post( @@ -490,26 +449,6 @@ def test_openai_error_response_structure(self, client, mock_agent_factory): class TestMCPModeCompatibility: """Test suite for MCP mode compatibility with TypeScript backend.""" - @pytest.fixture - def mock_agent_factory(self, mock_agent): - """Setup mocks for MCP mode tests.""" - with patch("cairo_coder.server.app.create_agent_factory") as mock_factory_creator: - factory = Mock(spec=AgentFactory) - factory.get_available_agents = Mock(return_value=["cairo-coder"]) - factory.get_agent_info = Mock( - return_value={ - "id": "cairo-coder", - "name": "Cairo Coder", - "description": "Cairo programming assistant", - "sources": ["cairo-docs"], - } - ) - factory.create_agent.return_value = mock_agent - factory.get_or_create_agent = Mock(return_value=mock_agent) - mock_factory_creator.return_value = factory - yield factory - - def test_mcp_mode_non_streaming_response(self, client): """Test MCP mode returns sources in non-streaming response.""" response = client.post( diff --git a/python/tests/unit/test_query_processor.py b/python/tests/unit/test_query_processor.py index 44ded750..366c903d 100644 --- a/python/tests/unit/test_query_processor.py +++ b/python/tests/unit/test_query_processor.py @@ -5,7 +5,7 @@ resource identification, and query categorization. """ -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, patch import dspy import pytest @@ -17,32 +17,21 @@ class TestQueryProcessorProgram: """Test suite for QueryProcessorProgram.""" - @pytest.fixture - def mock_lm(self): - """Configure DSPy with a mock language model for testing.""" - mock = Mock() - mock.forward.return_value = dspy.Prediction( - search_queries=["cairo, contract, storage, variable"], - resources=["cairo_book", "starknet_docs"], - reasoning="I need to create a Cairo contract", - ) - mock.aforward = AsyncMock(return_value=dspy.Prediction( - search_queries=["cairo, contract, storage, variable"], - resources=["cairo_book", "starknet_docs"], - reasoning="I need to create a Cairo contract", - )) - - with patch("dspy.ChainOfThought") as mock_cot: - mock_cot.return_value = mock - yield mock - - @pytest.fixture + @pytest.fixture(scope="function") def processor(self, mock_lm): """Create a QueryProcessorProgram instance with mocked LM.""" return QueryProcessorProgram() - def test_contract_query_processing(self, processor): + def test_contract_query_processing(self, mock_lm, processor): """Test processing of contract-related queries.""" + prediction = dspy.Prediction( + search_queries=["cairo, contract, storage, variable"], + resources=["cairo_book", "starknet_docs"], + reasoning="I need to create a Cairo contract", + ) + mock_lm.forward.return_value = prediction + mock_lm.aforward.return_value = prediction + query = "How do I define storage variables in a Cairo contract?" result = processor.forward(query) @@ -80,36 +69,31 @@ def test_resource_validation(self, processor: QueryProcessorProgram): assert DocumentSource.STARKNET_DOCS in validated assert len(validated) == 2 - def test_test_detection(self, processor): + @pytest.mark.parametrize( + "query, expected", + [ + ("How do I write tests for Cairo?", True), + ("Unit testing best practices", True), + ("How to assert in Cairo tests?", True), + ("Mock setup for integration tests", True), + ("Test fixture configuration", True), + ("How to create a contract?", False), + ("What are Cairo data types?", False), + ("StarkNet deployment guide", False), + ], + ) + def test_test_detection(self, processor, query, expected): """Test detection of test-related queries.""" - test_queries = [ - "How do I write tests for Cairo?", - "Unit testing best practices", - "How to assert in Cairo tests?", - "Mock setup for integration tests", - "Test fixture configuration", - ] - - for query in test_queries: - assert processor._is_test_query(query) is True - - non_test_queries = [ - "How to create a contract?", - "What are Cairo data types?", - "StarkNet deployment guide", - ] - - for query in non_test_queries: - assert processor._is_test_query(query) is False + assert processor._is_test_query(query) is expected def test_empty_query_handling(self, processor): """Test handling of empty or whitespace queries.""" with patch.object(processor, "retrieval_program") as mock_program: - mock_program.aforward = AsyncMock(return_value=dspy.Prediction( - search_queries=[], - resources=[], - reasoning="Empty query" - )) + mock_program.aforward = AsyncMock( + return_value=dspy.Prediction( + search_queries=[], resources=[], reasoning="Empty query" + ) + ) result = processor.forward("") diff --git a/python/tests/unit/test_rag_pipeline.py b/python/tests/unit/test_rag_pipeline.py index 58e3d34c..0b49f0ad 100644 --- a/python/tests/unit/test_rag_pipeline.py +++ b/python/tests/unit/test_rag_pipeline.py @@ -34,24 +34,6 @@ def merge_usage_dict(sources: list[dict]) -> dict: merged_usage[model_name][metric_name] = merged_usage[model_name].get(metric_name, 0) + value return merged_usage -@pytest.fixture(scope='function') -def mock_pgvector_rm(): - """Patch the vector database for the document retriever.""" - with patch("cairo_coder.dspy.document_retriever.SourceFilteredPgVectorRM") as mock_pgvector_rm: - mock_instance = Mock() - mock_instance.aforward = AsyncMock(return_value=[]) - mock_instance.forward = Mock(return_value=[]) - mock_pgvector_rm.return_value = mock_instance - yield mock_pgvector_rm - - -@pytest.fixture(scope='session') -def mock_embedder(): - """Mock the embedder.""" - with patch("cairo_coder.dspy.document_retriever.dspy.Embedder") as mock_embedder: - mock_embedder.return_value = Mock() - yield mock_embedder - class TestRagPipeline: """Test suite for RagPipeline.""" @@ -582,9 +564,11 @@ def test_create_pipeline_with_custom_components(self, mock_vector_store_config): assert pipeline.config.contract_template == "Custom contract template" assert pipeline.config.test_template == "Custom test template" - def test_create_scarb_pipeline(self, mock_vector_store_config, mock_pgvector_rm: Mock): + def test_create_scarb_pipeline(self, mock_vector_store_config): """Test creating Scarb-specific pipeline.""" - with patch("cairo_coder.dspy.create_generation_program") as mock_create_gp: + with patch("cairo_coder.dspy.create_generation_program") as mock_create_gp, patch( + "cairo_coder.dspy.document_retriever.SourceFilteredPgVectorRM" + ): mock_scarb_program = Mock() mock_create_gp.return_value = mock_scarb_program diff --git a/python/tests/unit/test_server.py b/python/tests/unit/test_server.py deleted file mode 100644 index 8d849d87..00000000 --- a/python/tests/unit/test_server.py +++ /dev/null @@ -1,318 +0,0 @@ -""" -Unit tests for FastAPI server. - -Tests the FastAPI application endpoints and server functionality. -This test file is for the OpenAI-compatible server implementation. -""" - -from unittest.mock import Mock, patch - -import pytest -from fastapi.testclient import TestClient - -from cairo_coder.config.manager import ConfigManager -from cairo_coder.core.agent_factory import AgentFactory -from cairo_coder.server.app import CairoCoderServer, TokenTracker - - -class TestCairoCoderServer: - """Test suite for CairoCoderServer.""" - - @pytest.fixture - def mock_agent_factory(self, mock_agent): - """Patch create_agent_factory and return the mock factory.""" - with patch("cairo_coder.server.app.create_agent_factory") as mock_create_factory: - factory = Mock(spec=AgentFactory) - factory.get_available_agents.return_value = ["default"] - factory.get_agent_info.return_value = { - "id": "default", - "name": "Default Agent", - "description": "Default Cairo assistant", - "sources": ["cairo_book"], - } - factory.get_or_create_agent.return_value = mock_agent - factory.create_agent.return_value = mock_agent - mock_create_factory.return_value = factory - yield factory - - def test_health_check(self, client): - """Test health check endpoint.""" - response = client.get("/") - - assert response.status_code == 200 - data = response.json() - assert data["status"] == "ok" - - def test_list_agents(self, client): - """Test list agents endpoint.""" - response = client.get("/v1/agents") - - assert response.status_code == 200 - data = response.json() - assert isinstance(data, list) - assert len(data) >= 1 - - def test_chat_completions_basic(self, client): - """Test basic chat completions endpoint.""" - response = client.post( - "/v1/chat/completions", - json={"messages": [{"role": "user", "content": "Hello"}], "stream": False}, - ) - - assert response.status_code == 200 - data = response.json() - assert "choices" in data - assert "usage" in data - assert data["model"] == "cairo-coder" - - def test_chat_completions_validation(self, client): - """Test chat completions validation.""" - # Test empty messages - response = client.post("/v1/chat/completions", json={"messages": []}) - assert response.status_code == 422 - - # Test last message not from user - response = client.post( - "/v1/chat/completions", - json={ - "messages": [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi"}, - ] - }, - ) - assert response.status_code == 422 - - def test_agent_specific_completions(self, client, mock_agent_factory, mock_agent): - """Test agent-specific chat completions.""" - mock_agent_factory.get_agent_info.return_value = { - "id": "default", - "name": "Default Agent", - "description": "Default Cairo assistant", - "sources": ["cairo_book"], - } - mock_agent_factory.get_or_create_agent = Mock(return_value=mock_agent) - - response = client.post( - "/v1/agents/default/chat/completions", - json={"messages": [{"role": "user", "content": "Hello"}], "stream": False}, - ) - - assert response.status_code == 200 - data = response.json() - assert "choices" in data - - def test_agent_not_found(self, client, mock_agent_factory): - """Test agent not found error.""" - mock_agent_factory.get_agent_info.side_effect = ValueError("Agent not found") - - response = client.post( - "/v1/agents/nonexistent/chat/completions", - json={"messages": [{"role": "user", "content": "Hello"}]}, - ) - - assert response.status_code == 404 - data = response.json() - assert "detail" in data - assert "error" in data["detail"] - - def test_streaming_response(self, client): - """Test streaming chat completions.""" - response = client.post( - "/v1/chat/completions", - json={"messages": [{"role": "user", "content": "Hello"}], "stream": True}, - ) - - assert response.status_code == 200 - assert "text/event-stream" in response.headers["content-type"] - - def test_mcp_mode(self, client): - """Test MCP mode functionality.""" - response = client.post( - "/v1/chat/completions", - json={"messages": [{"role": "user", "content": "Test"}]}, - headers={"x-mcp-mode": "true"}, - ) - - assert response.status_code == 200 - - def test_error_handling(self, client, mock_agent_factory): - """Test error handling in chat completions.""" - mock_agent_factory.create_agent.side_effect = Exception("Agent creation failed") - - response = client.post( - "/v1/chat/completions", json={"messages": [{"role": "user", "content": "Hello"}]} - ) - - assert response.status_code == 500 - data = response.json() - assert "detail" in data - assert "error" in data["detail"] - - -class TestTokenTracker: - """Test suite for TokenTracker.""" - - def test_track_tokens(self): - """Test token tracking functionality.""" - tracker = TokenTracker() - - tracker.track_tokens("session1", 10, 20) - usage = tracker.get_session_usage("session1") - - assert usage["prompt_tokens"] == 10 - assert usage["completion_tokens"] == 20 - assert usage["total_tokens"] == 30 - - def test_multiple_sessions(self): - """Test tracking multiple sessions.""" - tracker = TokenTracker() - - tracker.track_tokens("session1", 10, 20) - tracker.track_tokens("session2", 15, 25) - - usage1 = tracker.get_session_usage("session1") - usage2 = tracker.get_session_usage("session2") - - assert usage1["total_tokens"] == 30 - assert usage2["total_tokens"] == 40 - - def test_session_accumulation(self): - """Test token accumulation within a session.""" - tracker = TokenTracker() - - tracker.track_tokens("session1", 10, 20) - tracker.track_tokens("session1", 5, 15) - - usage = tracker.get_session_usage("session1") - - assert usage["prompt_tokens"] == 15 - assert usage["completion_tokens"] == 35 - assert usage["total_tokens"] == 50 - - def test_nonexistent_session(self): - """Test getting usage for nonexistent session.""" - tracker = TokenTracker() - - usage = tracker.get_session_usage("nonexistent") - - assert usage["prompt_tokens"] == 0 - assert usage["completion_tokens"] == 0 - assert usage["total_tokens"] == 0 - - -class TestCreateApp: - """Test suite for create_app function.""" - - def test_create_app_basic(self, mock_vector_store_config): - """Test basic app creation.""" - from cairo_coder.server.app import create_app - - mock_config_manager = Mock(spec=ConfigManager) - - with patch("cairo_coder.server.app.create_agent_factory"): - app = create_app(mock_vector_store_config, mock_config_manager) - - assert app is not None - assert app.title == "Cairo Coder" - assert app.version == "1.0.0" - - def test_create_app_with_defaults(self, mock_vector_store_config): - """Test app creation with default config manager.""" - from cairo_coder.server.app import create_app - - with ( - patch("cairo_coder.server.app.create_agent_factory"), - patch("cairo_coder.config.manager.ConfigManager"), - ): - app = create_app(mock_vector_store_config) - - assert app is not None - - def test_cors_configuration(self, mock_vector_store_config): - """Test CORS configuration.""" - from cairo_coder.server.app import create_app - - with patch("cairo_coder.server.app.create_agent_factory"): - app = create_app(mock_vector_store_config) - client = TestClient(app) - - # Test CORS headers - response = client.options( - "/v1/chat/completions", - headers={"Origin": "https://example.com", "Access-Control-Request-Method": "POST"}, - ) - - assert response.status_code in [200, 204] - - def test_app_middleware(self, mock_vector_store_config): - """Test that app has proper middleware configuration.""" - from cairo_coder.server.app import create_app - - with patch("cairo_coder.server.app.create_agent_factory"): - app = create_app(mock_vector_store_config) - - # Check that middleware is properly configured - # FastAPI apps have middleware, but middleware_stack might be None until build - assert hasattr(app, "middleware_stack") - # Check that CORS middleware was added by verifying the middleware property exists - assert hasattr(app, "middleware") - - def test_app_routes(self, mock_vector_store_config): - """Test that app has expected routes.""" - from cairo_coder.server.app import create_app - - with patch("cairo_coder.server.app.create_agent_factory"): - app = create_app(mock_vector_store_config) - - # Get all routes - routes = [route.path for route in app.routes] # type: ignore - - # Check expected routes exist - assert "/" in routes - assert "/v1/agents" in routes - assert "/v1/chat/completions" in routes - - -class TestServerConfiguration: - """Test suite for server configuration.""" - - def test_server_initialization(self, mock_vector_store_config): - """Test server initialization.""" - mock_config_manager = Mock(spec=ConfigManager) - - with patch("cairo_coder.server.app.create_agent_factory"): - server = CairoCoderServer(mock_vector_store_config, mock_config_manager) - - assert server.vector_store_config == mock_vector_store_config - assert server.config_manager == mock_config_manager - assert server.app is not None - assert server.token_tracker is not None - - def test_server_dependencies(self, mock_vector_store_config): - """Test server dependency injection.""" - mock_config_manager = Mock(spec=ConfigManager) - - with patch("cairo_coder.server.app.create_agent_factory") as mock_create_factory: - mock_factory = Mock() - mock_create_factory.return_value = mock_factory - - CairoCoderServer(mock_vector_store_config, mock_config_manager) - - # This test now verifies that the factory is not a member of the server, - # but is created inside the handlers. - pass - - def test_server_app_configuration(self, mock_vector_store_config): - """Test server app configuration.""" - mock_config_manager = Mock(spec=ConfigManager) - - with patch("cairo_coder.server.app.create_agent_factory"): - server = CairoCoderServer(mock_vector_store_config, mock_config_manager) - - # Check FastAPI app configuration - assert server.app.title == "Cairo Coder" - assert server.app.version == "1.0.0" - assert ( - server.app.description == "OpenAI-compatible API for Cairo programming assistance" - )