diff --git a/python/pyproject.toml b/python/pyproject.toml index fb60673f..efdbef5d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "tenacity>=8.0.0", "prometheus-client>=0.20.0", "python-multipart>=0.0.6", - "dspy>=2.6.27", + "dspy>=3.0.0", "psycopg2>=2.9.10", "pgvector>=0.4.1", "marimo>=0.14.11", diff --git a/python/src/cairo_coder/core/rag_pipeline.py b/python/src/cairo_coder/core/rag_pipeline.py index 1f7882d4..8f27bd20 100644 --- a/python/src/cairo_coder/core/rag_pipeline.py +++ b/python/src/cairo_coder/core/rag_pipeline.py @@ -11,6 +11,7 @@ from typing import Any import dspy +from dspy.adapters.baml_adapter import BAMLAdapter from dspy.utils.callback import BaseCallback from langsmith import traceable @@ -110,38 +111,6 @@ def __init__(self, config: RagPipelineConfig): self._current_processed_query: ProcessedQuery | None = None self._current_documents: list[Document] = [] - def _process_query_and_retrieve_docs( - self, - query: str, - chat_history_str: str, - sources: list[DocumentSource] | None = None, - ) -> tuple[ProcessedQuery, list[Document]]: - processed_query = self.query_processor.forward(query=query, chat_history=chat_history_str) - self._current_processed_query = processed_query - - # Use provided sources or fall back to processed query sources - retrieval_sources = sources or processed_query.resources - documents = self.document_retriever.forward( - processed_query=processed_query, sources=retrieval_sources - ) - - # Apply LLM judge if enabled - try: - with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash-lite", max_tokens=10000)): - documents = self.retrieval_judge.forward(query=query, documents=documents) - except Exception as e: - logger.warning( - "Retrieval judge failed (sync), using all documents", - error=str(e), - exc_info=True, - ) - # documents already contains all retrieved docs, no action needed - - self._current_documents = documents - - return processed_query, documents - - async def _aprocess_query_and_retrieve_docs( self, query: str, @@ -159,7 +128,7 @@ async def _aprocess_query_and_retrieve_docs( ) try: - with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash-lite", max_tokens=10000)): + with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash-lite", max_tokens=10000), adapter=BAMLAdapter()): documents = await self.retrieval_judge.aforward(query=query, documents=documents) except Exception as e: logger.warning( @@ -173,30 +142,6 @@ async def _aprocess_query_and_retrieve_docs( return processed_query, documents - # Waits for streaming to finish before returning the response - @traceable(name="RagPipeline", run_type="chain") - def forward( - self, - query: str, - chat_history: list[Message] | None = None, - mcp_mode: bool = False, - sources: list[DocumentSource] | None = None, - ) -> dspy.Prediction: - chat_history_str = self._format_chat_history(chat_history or []) - processed_query, documents = self._process_query_and_retrieve_docs( - query, chat_history_str, sources - ) - logger.info(f"Processed query: {processed_query.original} and retrieved {len(documents)} doc titles: {[doc.metadata.get('title') for doc in documents]}") - - if mcp_mode: - return self.mcp_generation_program.forward(documents) - - context = self._prepare_context(documents, processed_query) - - return self.generation_program.forward( - query=query, context=context, chat_history=chat_history_str - ) - # Waits for streaming to finish before returning the response @traceable(name="RagPipeline", run_type="chain") async def aforward( @@ -213,7 +158,7 @@ async def aforward( logger.info(f"Processed query: {processed_query.original[:100]}... and retrieved {len(documents)} doc titles: {[doc.metadata.get('title') for doc in documents]}") if mcp_mode: - return self.mcp_generation_program.forward(documents) + return await self.mcp_generation_program.aforward(documents) context = self._prepare_context(documents, processed_query) @@ -221,7 +166,7 @@ async def aforward( query=query, context=context, chat_history=chat_history_str ) - async def forward_streaming( + async def aforward_streaming( self, query: str, chat_history: list[Message] | None = None, diff --git a/python/src/cairo_coder/dspy/document_retriever.py b/python/src/cairo_coder/dspy/document_retriever.py index 7d74bf10..3825cf95 100644 --- a/python/src/cairo_coder/dspy/document_retriever.py +++ b/python/src/cairo_coder/dspy/document_retriever.py @@ -14,7 +14,7 @@ import dspy from cairo_coder.core.config import VectorStoreConfig from cairo_coder.core.types import Document, DocumentSource, ProcessedQuery -from dspy.retrieve.pgvector_rm import PgVectorRM +from cairo_coder.dspy.pgvector_rm import PgVectorRM logger = structlog.get_logger() @@ -134,6 +134,9 @@ - Always import strictly the required types in the module the interface is implemented in. - Always import the required types of the contract inside the contract module. - Always make the interface and the contract module 'pub' +- In assert! macros, the string is using double \" quotes, not \'; e.g.: assert!(caller == owner, +"Caller is not owner"). You can also not use any string literals in assert! macros. +- Always match the generated code against context-provided code to reduce hallucination risk. The content inside the tag is the contract code for a 'Registry' contract, demonstrating diff --git a/python/src/cairo_coder/dspy/generation_program.py b/python/src/cairo_coder/dspy/generation_program.py index 481d9b78..434f31be 100644 --- a/python/src/cairo_coder/dspy/generation_program.py +++ b/python/src/cairo_coder/dspy/generation_program.py @@ -98,36 +98,6 @@ def get_lm_usage(self) -> dict[str, int]: """ return self.generation_program.get_lm_usage() - @traceable(name="GenerationProgram", run_type="llm") - def forward(self, query: str, context: str, chat_history: Optional[str] = None) -> dspy.Prediction | None : - """ - Generate Cairo code response based on query and context. - - Args: - query: User's Cairo programming question - context: Retrieved documentation and examples - chat_history: Previous conversation context (optional) - - Returns: - Generated Cairo code response with explanations - """ - if chat_history is None: - chat_history = "" - - # Execute the generation program - max_retries = 3 - for attempt in range(max_retries): - try: - return self.generation_program.forward(query=query, context=context, chat_history=chat_history) - except AdapterParseError as e: - if attempt < max_retries - 1: - continue - code = self._try_extract_code_from_response(e.lm_response) - if code: - return dspy.Prediction(answer=code) - raise e - return None - @traceable(name="GenerationProgram", run_type="llm") async def aforward(self, query: str, context: str, chat_history: Optional[str] = None) -> dspy.Prediction | None : """ @@ -269,6 +239,12 @@ def forward(self, documents: list[Document]) -> dspy.Prediction: return dspy.Prediction(answer='\n'.join(formatted_docs)) + async def aforward(self, documents: list[Document]) -> dspy.Prediction: + """ + Format documents for MCP mode response. + """ + return self.forward(documents) + def get_lm_usage(self) -> dict[str, int]: """ Get the total number of tokens used by the LLM. diff --git a/python/src/cairo_coder/dspy/pgvector_rm.py b/python/src/cairo_coder/dspy/pgvector_rm.py new file mode 100644 index 00000000..747c4f84 --- /dev/null +++ b/python/src/cairo_coder/dspy/pgvector_rm.py @@ -0,0 +1,157 @@ +import warnings +from collections.abc import Callable +from typing import Optional + +import dspy + +try: + import psycopg2 + from pgvector.psycopg2 import register_vector + from psycopg2 import sql +except ImportError as e: + raise ImportError( + "The 'pgvector' extra is required to use PgVectorRM. Install it with `pip install dspy-ai[pgvector]`. Also, try `pip install pgvector psycopg2`.", + ) from e +try: + import openai +except ImportError: + warnings.warn( + "`openai` is not installed. Install it with `pip install openai` to use OpenAI embedding models.", + stacklevel=2, category=ImportWarning, + ) + + +class PgVectorRM(dspy.Retrieve): + """ + Implements a retriever that (as the name suggests) uses pgvector to retrieve passages, + using a raw SQL query and a postgresql connection managed by psycopg2. + + It needs to register the pgvector extension with the psycopg2 connection + + Returns a list of dspy.Example objects + + Args: + db_url (str): A PostgreSQL database URL in psycopg2's DSN format + pg_table_name (Optional[str]): name of the table containing passages + openai_client (openai.OpenAI): OpenAI client to use for computing query embeddings. Either openai_client or embedding_func must be provided. + embedding_func (Callable): A function to use for computing query embeddings. Either openai_client or embedding_func must be provided. + content_field (str = "text"): Field containing the passage text. Defaults to "text" + k (Optional[int]): Default number of top passages to retrieve. Defaults to 20 + embedding_field (str = "embedding"): Field containing passage embeddings. Defaults to "embedding" + fields (List[str] = ['text']): Fields to retrieve from the table. Defaults to "text" + embedding_model (str = "text-embedding-ada-002"): Field containing the OpenAI embedding model to use. Defaults to "text-embedding-ada-002" + + Examples: + Below is a code snippet that shows how to use PgVector as the default retriever + + ```python + import dspy + import openai + import psycopg2 + + openai.api_key = os.environ.get("OPENAI_API_KEY", None) + openai_client = openai.OpenAI() + + llm = dspy.OpenAI(model="gpt-3.5-turbo") + + DATABASE_URL should be in the format postgresql://user:password@host/database + db_url=os.getenv("DATABASE_URL") + + retriever_model = PgVectorRM(conn, openai_client=openai_client, "paragraphs", fields=["text", "document_id"], k=20) + dspy.settings.configure(lm=llm, rm=retriever_model) + ``` + + Below is a code snippet that shows how to use PgVector in the forward() function of a module + ```python + self.retrieve = PgVectorRM(db_url, openai_client=openai_client, "paragraphs", fields=["text", "document_id"], k=20) + ``` + """ + + def __init__( + self, + db_url: str, + pg_table_name: str, + openai_client: Optional[openai.OpenAI] = None, + embedding_func: Optional[Callable] = None, + k: int = 20, + embedding_field: str = "embedding", + fields: Optional[list[str]] = None, + content_field: str = "text", + embedding_model: str = "text-embedding-ada-002", + include_similarity: bool = False, + ): + """ + k = 20 is the number of paragraphs to retrieve + """ + assert ( + openai_client or embedding_func + ), "Either openai_client or embedding_func must be provided." + self.openai_client = openai_client + self.embedding_func = embedding_func + + self.conn = psycopg2.connect(db_url) + register_vector(self.conn) + self.pg_table_name = pg_table_name + self.fields = fields or ["text"] + self.content_field = content_field + self.embedding_field = embedding_field + self.embedding_model = embedding_model + self.include_similarity = include_similarity + + super().__init__(k=k) + + def forward(self, query: str, k: int = None): + """Search with PgVector for k top passages for query using cosine similarity + + Args: + query (str): The query to search for + k (int): The number of top passages to retrieve. Defaults to the value set in the constructor. + Returns: + dspy.Prediction: an object containing the retrieved passages. + """ + # Embed query + query_embedding = self._get_embeddings(query) + + retrieved_docs = [] + + fields = sql.SQL(",").join([sql.Identifier(f) for f in self.fields]) + if self.include_similarity: + similarity_field = sql.SQL(",") + sql.SQL( + "1 - ({embedding_field} <=> %s::vector) AS similarity", + ).format(embedding_field=sql.Identifier(self.embedding_field)) + fields += similarity_field + args = (query_embedding, query_embedding, k if k else self.k) + else: + args = (query_embedding, k if k else self.k) + + sql_query = sql.SQL( + "select {fields} from {table} order by {embedding_field} <=> %s::vector limit %s", + ).format( + fields=fields, + table=sql.Identifier(self.pg_table_name), + embedding_field=sql.Identifier(self.embedding_field), + ) + + with self.conn as conn, conn.cursor() as cur: + cur.execute(sql_query, args) + rows = cur.fetchall() + columns = [descrip[0] for descrip in cur.description] + for row in rows: + data = dict(zip(columns, row, strict=False)) + data["long_text"] = data[self.content_field] + retrieved_docs.append(dspy.Example(**data)) + # Return Prediction + return retrieved_docs + + def _get_embeddings(self, query: str) -> list[float]: + if self.openai_client is not None: + return ( + self.openai_client.embeddings.create( + model=self.embedding_model, + input=query, + encoding_format="float", + ) + .data[0] + .embedding + ) + return self.embedding_func(query) diff --git a/python/src/cairo_coder/dspy/query_processor.py b/python/src/cairo_coder/dspy/query_processor.py index 17016b39..fd4eb171 100644 --- a/python/src/cairo_coder/dspy/query_processor.py +++ b/python/src/cairo_coder/dspy/query_processor.py @@ -109,35 +109,6 @@ def __init__(self): "foundry", } - @traceable(name="QueryProcessorProgram", run_type="llm") - def forward(self, query: str, chat_history: Optional[str] = None) -> ProcessedQuery: - """ - Process a user query into a structured format for document retrieval. - - Args: - query: The user's Cairo/Starknet programming question - chat_history: Previous conversation context (optional) - - Returns: - ProcessedQuery with search terms, resource identification, and categorization - """ - # Execute the DSPy retrieval program - result = self.retrieval_program.forward(query=query, chat_history=chat_history) - - # Parse and validate the results - search_queries = result.search_queries - resources = self._validate_resources(result.resources) - - # Build structured query result - return ProcessedQuery( - original=query, - search_queries=search_queries, - reasoning=result.reasoning, - is_contract_related=self._is_contract_query(query), - is_test_related=self._is_test_query(query), - resources=resources, - ) - @traceable(name="QueryProcessorProgram", run_type="llm") async def aforward(self, query: str, chat_history: Optional[str] = None) -> ProcessedQuery: """ diff --git a/python/src/cairo_coder/dspy/retrieval_judge.py b/python/src/cairo_coder/dspy/retrieval_judge.py index a9defd2b..5396a2d9 100644 --- a/python/src/cairo_coder/dspy/retrieval_judge.py +++ b/python/src/cairo_coder/dspy/retrieval_judge.py @@ -70,40 +70,6 @@ def __init__(self): self.parallel_threads = DEFAULT_PARALLEL_THREADS self.threshold = DEFAULT_THRESHOLD - # ========================= - # Public API - # ========================= - @traceable(name="RetrievalJudge", run_type="llm") - def forward(self, query: str, documents: list[Document]) -> list[Document]: - """Sync judge.""" - if not documents: - return documents - - keep_docs, judged_indices, judged_payloads = self._split_templates_and_prepare_docs(documents) - - if judged_payloads: - try: - # Build batches for dspy.Parallel exactly as tests expect - parallel = dspy.Parallel(num_threads=self.parallel_threads) - batches = [] - for doc_string in judged_payloads: - example = dspy.Example(query=query, system_resource=doc_string).with_inputs("query", "system_resource") - batches.append((self.rater, example)) - - results = parallel(batches) - self._attach_scores_and_filter( - query=query, - documents=documents, - judged_indices=judged_indices, - results=results, - keep_docs=keep_docs, - ) - except Exception as e: - logger.error("Retrieval judge failed (sync), returning all docs", error=str(e), exc_info=True) - return documents - - return keep_docs - @traceable(name="RetrievalJudge", run_type="llm") async def aforward(self, query: str, documents: list[Document]) -> list[Document]: """Async judge.""" @@ -112,6 +78,7 @@ async def aforward(self, query: str, documents: list[Document]) -> list[Document keep_docs, judged_indices, judged_payloads = self._split_templates_and_prepare_docs(documents) + # TODO: can we use dspy.Parallel here instead of asyncio gather? if judged_payloads: try: # Judge concurrently @@ -218,7 +185,8 @@ def _attach_scores_and_filter_async( ) doc.metadata[LLM_JUDGE_SCORE_KEY] = 1.0 doc.metadata[LLM_JUDGE_REASON_KEY] = "Could not judge document. Keeping it." - # Do not append to keep_docs + # Actually keep it, as the log claims. + keep_docs.append(doc) continue self._process_single_result(doc, result, keep_docs) diff --git a/python/src/cairo_coder/optimizers/generation/generate_starklings_dataset.py b/python/src/cairo_coder/optimizers/generation/generate_starklings_dataset.py index 7a1c21b9..27c95ecf 100644 --- a/python/src/cairo_coder/optimizers/generation/generate_starklings_dataset.py +++ b/python/src/cairo_coder/optimizers/generation/generate_starklings_dataset.py @@ -39,7 +39,7 @@ def get_context_for_query(full_query: str, config) -> str: try: # Create instances per task to avoid shared state issues document_retriever = DocumentRetrieverProgram(vector_store_config=config.vector_store) - query_processor = QueryProcessorProgram() + query_processor = dspy.syncify(QueryProcessorProgram()) context_summarizer = dspy.ChainOfThought(CairoContextSummarization) processed_query = query_processor.forward(query=full_query) diff --git a/python/src/cairo_coder/optimizers/mcp_optimizer.py b/python/src/cairo_coder/optimizers/mcp_optimizer.py index c689d91c..9a54cb2d 100644 --- a/python/src/cairo_coder/optimizers/mcp_optimizer.py +++ b/python/src/cairo_coder/optimizers/mcp_optimizer.py @@ -14,7 +14,6 @@ def _(): import dspy import psycopg2 import structlog - from dspy import MIPROv2 from psycopg2 import OperationalError from cairo_coder.config.manager import ConfigManager @@ -57,7 +56,7 @@ def _(): dspy.settings.configure(lm=lm) logger.info("Configured DSPy with Gemini 2.5 Flash") - return ConfigManager, MIPROv2, Path, dspy, json, lm, logger, time + return ConfigManager, Path, dspy, json, lm, logger, time @app.cell @@ -104,10 +103,11 @@ def load_dataset(dataset_path: str) -> list[dspy.Example]: @app.cell -def _(Path, ConfigManager, dspy): +def _(ConfigManager, Path, dspy): """Initialize the generation program.""" # Initialize program + from cairo_coder.core.types import DocumentSource, Message from cairo_coder.dspy.document_retriever import DocumentRetrieverProgram from cairo_coder.dspy.query_processor import QueryProcessorProgram @@ -116,35 +116,35 @@ class QueryAndRetrieval(dspy.Module): def __init__(self): try: config = ConfigManager.load_config() + db_config = config.vector_store except FileNotFoundError: # Running in test environment without config.toml - config = None + db_config = None - self.processor = QueryProcessorProgram() + self.processor = dspy.syncify(QueryProcessorProgram()) if Path("optimizers/results/optimized_mcp_program.json").exists(): self.processor.load("optimizers/results/optimized_mcp_program.json") - self.document_retriever = DocumentRetrieverProgram(vector_store_config=config.vector_store if config else None) + self.document_retriever = DocumentRetrieverProgram(vector_store_config=db_config) - def forward( + async def aforward( self, query: str, chat_history: list[Message] | None = None, sources: list[DocumentSource] | None = None, ) -> dspy.Prediction: - processed_query = self.processor.forward(query=query, chat_history=chat_history) - document_list = self.document_retriever.forward(processed_query=processed_query) + processed_query = await self.processor.aforward(query=query, chat_history=chat_history) + document_list = self.document_retriever(processed_query=processed_query) return dspy.Prediction(answer=document_list) - query_retrieval_program = QueryAndRetrieval() + query_retrieval_program = dspy.syncify(QueryAndRetrieval()) return (query_retrieval_program,) @app.cell def _(dspy): # Defining our metrics here. - class RetrievalRecallPrecision(dspy.Signature): """ Compare a system's retrieval response to the query and to compute recall and precision. @@ -161,7 +161,7 @@ def __init__(self, threshold=0.33, decompositional=False): self.threshold = threshold self.rater = dspy.Predict(RetrievalRecallPrecision) - def forward(self, example, pred, trace=None): + def forward(self, example, pred, trace=None, pred_name=None, pred_trace=None): parallel = dspy.Parallel(num_threads=10) batches = [] for resource in pred.answer: @@ -189,34 +189,37 @@ def _(): metric = RetrievalF1() # You can use this cell to run more comprehensive evaluation - evaluator__ = Evaluate(devset=valset, num_threads=12, display_progress=True) + evaluator__ = Evaluate(devset=valset, num_threads=12, display_progress=True, provide_traceback=True) return evaluator__(query_retrieval_program, metric=metric) baseline_score = _() return (baseline_score,) + @app.cell def test_notebook(query_retrieval_program): assert query_retrieval_program is not None + return + @app.cell def _( - MIPROv2, RetrievalF1, + dspy, logger, query_retrieval_program, time, trainset, valset, ): - """Run optimization using MIPROv2.""" - + """Run optimization.""" + from dspy import MIPROv2 metric = RetrievalF1() def run_optimization(trainset, valset): - """Run the optimization process using MIPROv2.""" + """Run the optimization process.""" logger.info("Starting optimization process") # Configure optimizer @@ -224,7 +227,6 @@ def run_optimization(trainset, valset): metric=metric, auto="light", num_threads=12, - ) # Run optimization diff --git a/python/src/cairo_coder/server/app.py b/python/src/cairo_coder/server/app.py index 788589fe..2b54e18b 100644 --- a/python/src/cairo_coder/server/app.py +++ b/python/src/cairo_coder/server/app.py @@ -15,6 +15,7 @@ import dspy import uvicorn +from dspy.adapters.baml_adapter import BAMLAdapter from fastapi import Depends, FastAPI, Header, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse @@ -170,7 +171,7 @@ def __init__( # TODO: This is the place where we should select the proper LLM configuration. # TODO: For now we just Hard-code DSPY - GEMINI - dspy.configure(lm=dspy.LM("gemini/gemini-2.5-flash", max_tokens=30000)) + dspy.configure(lm=dspy.LM("gemini/gemini-2.5-flash", max_tokens=30000), adapter=BAMLAdapter()) dspy.configure(callbacks=[AgentLoggingCallback(), LangsmithTracingCallback()]) dspy.configure(track_usage=True) @@ -562,7 +563,7 @@ async def lifespan(app: FastAPI): # Load config once config = ConfigManager.load_config() vector_store_config = config.vector_store - + # TODO: These should not be literal constants like this. embedder = dspy.Embedder("openai/text-embedding-3-large", dimensions=1536, batch_size=512) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index c1ae9399..0ef9cb22 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -95,10 +95,6 @@ def mock_lm(): """ with patch("dspy.ChainOfThought") as mock_cot: mock_program = Mock() - # Mock for sync calls - mock_program.forward.return_value = dspy.Prediction( - answer="Here's a Cairo contract example:\n\n```cairo\n#[starknet::contract]\nmod SimpleContract {\n // Contract implementation\n}\n```\n\nThis contract demonstrates basic Cairo syntax." - ) mock_program.return_value = dspy.Prediction( answer="Here's a Cairo contract example:\n\n```cairo\n#[starknet::contract]\nmod SimpleContract {\n // Contract implementation\n}\n```\n\nThis contract demonstrates basic Cairo syntax." ) @@ -383,15 +379,15 @@ def clean_config_env_vars(monkeypatch): "LANGSMITH_ENDPOINT", "LANGSMITH_OTEL_ENABLED", ] - + # Store original values original_values = {} for var in env_vars_to_clean: original_values[var] = os.environ.get(var) monkeypatch.delenv(var, raising=False) - + yield - + # Restore original values after test for var, value in original_values.items(): if value is not None: @@ -588,7 +584,7 @@ def mock_mcp_generation_program(): Storage variables use #[storage] attribute. """ - program.forward = Mock(return_value=dspy.Prediction(answer=mcp_answer)) + program.aforward = AsyncMock(return_value=dspy.Prediction(answer=mcp_answer)) program.get_lm_usage = Mock(return_value={}) return program diff --git a/python/tests/unit/test_generation_program.py b/python/tests/unit/test_generation_program.py index 43784c26..c50f7e3f 100644 --- a/python/tests/unit/test_generation_program.py +++ b/python/tests/unit/test_generation_program.py @@ -22,13 +22,6 @@ ) -async def call_program(program, method, *args, **kwargs): - """Helper to call sync or async method on a program.""" - if method == "aforward": - return await program.aforward(*args, **kwargs) - return getattr(program, method)(*args, **kwargs) - - @pytest.fixture(scope="function") def generation_program(mock_lm): """Create a GenerationProgram instance.""" @@ -47,14 +40,13 @@ def mcp_generation_program(self): """Create an MCP GenerationProgram instance.""" return McpGenerationProgram() - @pytest.mark.parametrize("call_method", ["forward", "aforward"]) @pytest.mark.asyncio - async def test_general_code_generation(self, generation_program, call_method): + async def test_general_code_generation(self, generation_program): """Test general Cairo code generation for both sync and async.""" query = "How do I create a simple Cairo contract?" context = "Cairo contracts use #[starknet::contract] attribute..." - result = await call_program(generation_program, call_method, query, context) + result = await generation_program.aforward(query, context) # Result should be a dspy.Predict object with an answer attribute assert hasattr(result, "answer") @@ -62,39 +54,31 @@ async def test_general_code_generation(self, generation_program, call_method): assert len(result.answer) > 0 assert "cairo" in result.answer.lower() - # Verify the generation program was called with correct parameters - mocked_method = getattr(generation_program.generation_program, call_method) - mocked_method.assert_called_once() - call_args = mocked_method.call_args[1] + generation_program.generation_program.aforward.assert_called_once() + call_args = generation_program.generation_program.aforward.call_args[1] assert call_args["query"] == query assert "cairo" in call_args["context"].lower() assert call_args["chat_history"] == "" - @pytest.mark.parametrize("call_method", ["forward", "aforward"]) @pytest.mark.asyncio - async def test_generation_with_chat_history(self, generation_program, call_method): + async def test_generation_with_chat_history(self, generation_program): """Test code generation with chat history for both sync and async.""" query = "How do I add storage to that contract?" context = "Storage variables are defined with #[storage]..." chat_history = "Previous conversation about contracts" - result = await call_program( - generation_program, call_method, query, context, chat_history - ) + result = await generation_program.aforward(query, context, chat_history) # Result should be a dspy.Predict object with an answer attribute assert hasattr(result, "answer") assert isinstance(result.answer, str) assert len(result.answer) > 0 - # Verify chat history was passed - mocked_method = getattr(generation_program.generation_program, call_method) - call_args = mocked_method.call_args[1] - assert call_args["chat_history"] == chat_history + # Verify chat history was passed to the mocked inner program + assert generation_program.generation_program.aforward.call_args[1]["chat_history"] == chat_history - @pytest.mark.parametrize("call_method", ["forward", "aforward"]) @pytest.mark.asyncio - async def test_scarb_generation_program(self, scarb_generation_program, call_method): + async def test_scarb_generation_program(self, scarb_generation_program): """Test Scarb-specific code generation for both sync and async.""" with patch.object( scarb_generation_program, "generation_program" @@ -102,16 +86,10 @@ async def test_scarb_generation_program(self, scarb_generation_program, call_met mock_program.aforward = AsyncMock(return_value=dspy.Prediction( answer='Here\'s your Scarb configuration:\n\n```toml\n[package]\nname = "my-project"\nversion = "0.1.0"\n```' )) - mock_program.forward.return_value = dspy.Prediction( - answer='Here\'s your Scarb configuration:\n\n```toml\n[package]\nname = "my-project"\nversion = "0.1.0"\n```' - ) - query = "How do I configure Scarb for my project?" context = "Scarb configuration documentation..." - result = await call_program( - scarb_generation_program, call_method, query, context - ) + result = await scarb_generation_program.aforward(query, context) # Result should be a dspy.Predict object with an answer attribute assert hasattr(result, "answer") @@ -119,7 +97,7 @@ async def test_scarb_generation_program(self, scarb_generation_program, call_met assert ( "scarb" in result.answer.lower() or "toml" in result.answer.lower() ) - getattr(mock_program, call_method).assert_called_once() + mock_program.aforward.assert_called_once() def test_format_chat_history(self, generation_program): """Test chat history formatting.""" @@ -160,9 +138,10 @@ def mcp_program(self): """Create an MCP GenerationProgram instance.""" return McpGenerationProgram() - def test_mcp_document_formatting(self, mcp_program, sample_documents): + @pytest.mark.asyncio + async def test_mcp_document_formatting(self, mcp_program, sample_documents): """Test MCP mode document formatting.""" - answer = mcp_program.forward(sample_documents).answer + answer = (await mcp_program.aforward(sample_documents)).answer assert isinstance(answer, str) assert len(answer) > 0 @@ -182,17 +161,19 @@ def test_mcp_document_formatting(self, mcp_program, sample_documents): # Check content is included assert doc.page_content in answer - def test_mcp_empty_documents(self, mcp_program): + @pytest.mark.asyncio + async def test_mcp_empty_documents(self, mcp_program): """Test MCP mode with empty documents.""" - result = mcp_program.forward([]) + result = await mcp_program.aforward([]) assert result.answer == "No relevant documentation found." - def test_mcp_documents_with_missing_metadata(self, mcp_program): + @pytest.mark.asyncio + async def test_mcp_documents_with_missing_metadata(self, mcp_program): """Test MCP mode with documents missing metadata.""" documents = [Document(page_content="Some Cairo content", metadata={})] # Missing metadata - answer = mcp_program.forward(documents).answer + answer = (await mcp_program.aforward(documents)).answer assert isinstance(answer, str) assert "Some Cairo content" in answer @@ -301,9 +282,8 @@ def test_create_mcp_generation_program(self): class TestForwardRetries: """Test suite for forward retry logic.""" - @pytest.mark.parametrize("call_method", ["forward", "aforward"]) @pytest.mark.asyncio - async def test_forward_retry_logic(self, call_method, generation_program): + async def test_forward_retry_logic(self, generation_program): """Test that forward retries AdapterParseError up to 3 times.""" # Mock the generation_program to raise AdapterParseError side_effect = [ @@ -315,19 +295,18 @@ async def test_forward_retry_logic(self, call_method, generation_program): ), dspy.Prediction(answer="Success"), ] - getattr(generation_program.generation_program, call_method).side_effect = side_effect + generation_program.generation_program.aforward = AsyncMock(side_effect=side_effect) # Should succeed after 2 retries - result = await call_program(generation_program, call_method, "test query", "test context") + result = await generation_program.aforward("test query", "test context") # Verify forward was called 3 times (2 failures + 1 success) - assert getattr(generation_program.generation_program, call_method).call_count == 3 + assert generation_program.generation_program.aforward.call_count == 3 assert result is not None assert result.answer == "Success" - @pytest.mark.parametrize("call_method", ["forward", "aforward"]) @pytest.mark.asyncio - async def test_forward_max_retries_exceeded(self, call_method, generation_program): + async def test_forward_max_retries_exceeded(self, generation_program): """Test that forward raises AdapterParseError after max retries.""" # Mock the generation_program to always raise AdapterParseError @@ -345,37 +324,35 @@ async def test_forward_max_retries_exceeded(self, call_method, generation_progra "Parse error", CairoCodeGeneration, "", "test response", None ), ] - getattr(generation_program.generation_program, call_method).side_effect = side_effect + generation_program.generation_program.aforward = AsyncMock(side_effect=side_effect) # Should raise after 3 attempts with pytest.raises(AdapterParseError): - await call_program(generation_program, call_method, "test query", "test context") + await generation_program.aforward("test query", "test context") # Verify forward was called exactly 3 times - assert getattr(generation_program.generation_program, call_method).call_count == 3 + assert generation_program.generation_program.aforward.call_count == 3 - @pytest.mark.parametrize("call_method", ["forward", "aforward"]) @pytest.mark.asyncio - async def test_forward_other_exceptions_not_retried(self, call_method, generation_program): + async def test_forward_other_exceptions_not_retried(self, generation_program): """Test that forward doesn't retry non-AdapterParseError exceptions.""" # Mock the generation_program to raise a different exception side_effect = [ ValueError("Some other error"), ] - getattr(generation_program.generation_program, call_method).side_effect = side_effect + generation_program.generation_program.aforward = AsyncMock(side_effect=side_effect) # Should raise immediately without retries with pytest.raises(ValueError): - await call_program(generation_program, call_method, "test query", "test context") + await generation_program.aforward("test query", "test context") # Verify forward was called only once - assert getattr(generation_program.generation_program, call_method).call_count == 1 + assert generation_program.generation_program.aforward.call_count == 1 - @pytest.mark.parametrize("call_method", ["forward", "aforward"]) @pytest.mark.asyncio - async def test_should_extract_code_before_raising(self, generation_program, call_method): + async def test_should_extract_code_before_raising(self, generation_program): """Test that code is extracted before raising AdapterParseError.""" # Mock the generation_program to raise AdapterParseError side_effect = [ @@ -389,7 +366,7 @@ async def test_should_extract_code_before_raising(self, generation_program, call "Parse error", CairoCodeGeneration, "```cairo\nfn main() {}\n```", "test response", None ), ] - generation_program.generation_program.aforward.side_effect = side_effect + generation_program.generation_program.aforward = AsyncMock(side_effect=side_effect) - response = await call_program(generation_program, "aforward", "test query", "test context") + response = await generation_program.aforward("test query", "test context") assert response.answer == "\nfn main() {}\n" diff --git a/python/tests/unit/test_query_processor.py b/python/tests/unit/test_query_processor.py index 366c903d..26d9b79f 100644 --- a/python/tests/unit/test_query_processor.py +++ b/python/tests/unit/test_query_processor.py @@ -22,19 +22,19 @@ def processor(self, mock_lm): """Create a QueryProcessorProgram instance with mocked LM.""" return QueryProcessorProgram() - def test_contract_query_processing(self, mock_lm, processor): + @pytest.mark.asyncio + async def test_contract_query_processing(self, mock_lm, processor): """Test processing of contract-related queries.""" prediction = dspy.Prediction( search_queries=["cairo, contract, storage, variable"], resources=["cairo_book", "starknet_docs"], reasoning="I need to create a Cairo contract", ) - mock_lm.forward.return_value = prediction mock_lm.aforward.return_value = prediction query = "How do I define storage variables in a Cairo contract?" - result = processor.forward(query) + result = await processor.aforward(query) assert isinstance(result, ProcessedQuery) assert result.original == query @@ -86,7 +86,8 @@ def test_test_detection(self, processor, query, expected): """Test detection of test-related queries.""" assert processor._is_test_query(query) is expected - def test_empty_query_handling(self, processor): + @pytest.mark.asyncio + async def test_empty_query_handling(self, processor): """Test handling of empty or whitespace queries.""" with patch.object(processor, "retrieval_program") as mock_program: mock_program.aforward = AsyncMock( @@ -95,7 +96,7 @@ def test_empty_query_handling(self, processor): ) ) - result = processor.forward("") + result = await processor.aforward("") assert result.original == "" assert result.resources == [DocumentSource.CAIRO_BOOK] # Default fallback diff --git a/python/tests/unit/test_rag_pipeline.py b/python/tests/unit/test_rag_pipeline.py index d8230a8b..1cc2e056 100644 --- a/python/tests/unit/test_rag_pipeline.py +++ b/python/tests/unit/test_rag_pipeline.py @@ -7,7 +7,6 @@ from unittest.mock import AsyncMock, Mock, patch -import dspy import pytest from cairo_coder.core.rag_pipeline import ( @@ -47,8 +46,7 @@ def pipeline(pipeline_config): with patch("cairo_coder.core.rag_pipeline.RetrievalJudge") as mock_judge_class: mock_judge = Mock() mock_judge.get_lm_usage.return_value = {} - mock_judge.forward.return_value = dspy.Prediction() - mock_judge.aforward = AsyncMock(return_value=dspy.Prediction()) + mock_judge.aforward = AsyncMock(side_effect=lambda query, documents: documents) mock_judge_class.return_value = mock_judge return RagPipeline(pipeline_config) @@ -143,7 +141,7 @@ async def test_async_pipeline_execution(self, pipeline): async def test_streaming_pipeline_execution(self, pipeline): """Test streaming pipeline execution.""" events = [] - async for event in pipeline.forward_streaming("How to write Cairo contracts?"): + async for event in pipeline.aforward_streaming("How to write Cairo contracts?"): events.append(event) # Verify event sequence @@ -153,47 +151,50 @@ async def test_streaming_pipeline_execution(self, pipeline): assert "response" in event_types assert "end" in event_types - def test_mcp_mode_execution(self, pipeline): + @pytest.mark.asyncio + async def test_mcp_mode_execution(self, pipeline): """Test MCP mode pipeline execution.""" - result = pipeline.forward("How to write Cairo contracts?", mcp_mode=True) + result = await pipeline.aforward("How to write Cairo contracts?", mcp_mode=True) # Verify MCP program was used - pipeline.mcp_generation_program.forward.assert_called_once() + pipeline.mcp_generation_program.aforward.assert_called_once() assert "Cairo contracts are defined using #[starknet::contract]" in result.answer - def test_pipeline_with_chat_history(self, pipeline): + @pytest.mark.asyncio + async def test_pipeline_with_chat_history(self, pipeline): """Test pipeline with chat history.""" chat_history = [ Message(role=Role.USER, content="Previous question"), Message(role=Role.ASSISTANT, content="Previous answer"), ] - pipeline.forward("Follow-up question", chat_history=chat_history) + await pipeline.aforward("Follow-up question", chat_history=chat_history) # Verify chat history was formatted and passed - call_args = pipeline.query_processor.forward.call_args + call_args = pipeline.query_processor.aforward.call_args assert "User: Previous question" in call_args[1]["chat_history"] assert "Assistant: Previous answer" in call_args[1]["chat_history"] - def test_pipeline_with_custom_sources(self, pipeline): + @pytest.mark.asyncio + async def test_pipeline_with_custom_sources(self, pipeline): """Test pipeline with custom sources.""" sources = [DocumentSource.SCARB_DOCS] - pipeline.forward("Scarb question", sources=sources) + await pipeline.aforward("Scarb question", sources=sources) # Verify sources were passed to retriever - call_args = pipeline.document_retriever.forward.call_args[1] + call_args = pipeline.document_retriever.aforward.call_args[1] assert call_args["sources"] == sources - def test_empty_documents_handling(self, pipeline, mock_document_retriever): + @pytest.mark.asyncio + async def test_empty_documents_handling(self, pipeline, mock_document_retriever): """Test pipeline handling of empty document list.""" # Configure retriever to return empty list - mock_document_retriever.forward.return_value = [] mock_document_retriever.aforward.return_value = [] - pipeline.forward("test query") + await pipeline.aforward("test query") # Verify generation was called with "No relevant documentation found" - call_args = pipeline.generation_program.forward.call_args + call_args = pipeline.generation_program.aforward.call_args assert "No relevant documentation found" in call_args[1]["context"] @pytest.mark.asyncio @@ -203,7 +204,7 @@ async def test_pipeline_error_handling(self, pipeline, mock_document_retriever): mock_document_retriever.aforward.side_effect = Exception("Retrieval error") events = [] - async for event in pipeline.forward_streaming("test query"): + async for event in pipeline.aforward_streaming("test query"): events.append(event) # Should have an error event @@ -216,7 +217,8 @@ class TestRagPipelineWithJudge: """Tests for RAG Pipeline with Retrieval Judge feature.""" @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") - def test_judge_enabled_filters_documents( + @pytest.mark.asyncio + async def test_judge_enabled_filters_documents( self, mock_judge_class, dspy_env_patched, patch_dspy_parallel, pipeline, mock_document_retriever ): @@ -227,7 +229,7 @@ def test_judge_enabled_filters_documents( ("Python Guide", "Python content", "python_docs"), ("Cairo Storage", "Cairo storage content", "cairo_book"), ]) - mock_document_retriever.forward.return_value = docs + mock_document_retriever.aforward.return_value = docs # Setup judge with specific scores judge = create_custom_retrieval_judge({ @@ -236,50 +238,50 @@ def test_judge_enabled_filters_documents( "Cairo Storage": 0.7, }) # Configure the mock instance that the pipeline will use - pipeline.retrieval_judge.forward.side_effect = judge.forward pipeline.retrieval_judge.aforward.side_effect = judge.aforward pipeline.retrieval_judge.threshold = judge.threshold - pipeline.forward("Cairo question") + await pipeline.aforward("Cairo question") # Verify judge was called - pipeline.retrieval_judge.forward.assert_called_once() + pipeline.retrieval_judge.aforward.assert_called_once() # Verify context only contains high-scoring docs - call_args = pipeline.generation_program.forward.call_args + call_args = pipeline.generation_program.aforward.call_args context = call_args[1]["context"] assert "Cairo contract content" in context assert "Cairo storage content" in context assert "Python content" not in context @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") - def test_judge_disabled_passes_all_documents( + @pytest.mark.asyncio + async def test_judge_disabled_passes_all_documents( self, mock_judge_class, dspy_env_patched, sample_documents, pipeline ): """Test that when judge fails, all documents are passed through.""" # Mock the judge to fail - pipeline.retrieval_judge.forward.side_effect = Exception("Judge failed") pipeline.retrieval_judge.aforward.side_effect = Exception("Judge failed") - pipeline.forward("test query") + await pipeline.aforward("test query") # Verify judge exists assert pipeline.retrieval_judge is not None # All documents should be in context (because judge failed) - call_args = pipeline.generation_program.forward.call_args + call_args = pipeline.generation_program.aforward.call_args context = call_args[1]["context"] for doc in sample_documents: assert doc.page_content in context @pytest.mark.parametrize("threshold", [0.0, 0.4, 0.6, 0.9]) @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") - def test_judge_threshold_parameterization( + @pytest.mark.asyncio + async def test_judge_threshold_parameterization( self, mock_judge_class, dspy_env_patched, patch_dspy_parallel, threshold, sample_documents, pipeline, mock_document_retriever ): """Test different judge thresholds.""" - mock_document_retriever.forward.return_value = sample_documents + mock_document_retriever.aforward.return_value = sample_documents # Judge with scores: 0.9, 0.8, 0.7, 0.6 (based on sample_documents) score_map = { @@ -290,17 +292,17 @@ def test_judge_threshold_parameterization( } judge = create_custom_retrieval_judge(score_map, threshold=threshold) - pipeline.retrieval_judge.forward.side_effect = judge.forward + pipeline.retrieval_judge.aforward.side_effect = judge.aforward pipeline.retrieval_judge.threshold = judge.threshold - pipeline.forward("test query") + await pipeline.aforward("test query") # Count filtered docs based on threshold scores = [0.9, 0.8, 0.7, 0.6] expected_count = sum(1 for score in scores if score >= threshold) # Verify judge was called - pipeline.retrieval_judge.forward.assert_called_once() + pipeline.retrieval_judge.aforward.assert_called_once() # Check that the pipeline stored the correct number of filtered documents assert hasattr(pipeline, "_current_documents") @@ -312,25 +314,26 @@ def test_judge_threshold_parameterization( assert doc.metadata.get("llm_judge_score", 0) >= threshold @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") - def test_judge_failure_fallback( + @pytest.mark.asyncio + async def test_judge_failure_fallback( self, mock_judge_class, dspy_env_patched, sample_documents, pipeline ): """Test fallback when judge fails.""" # Create failing judge - pipeline.retrieval_judge.forward.side_effect = Exception("Judge failed") pipeline.retrieval_judge.aforward.side_effect = Exception("Judge failed") # Should not raise, should use all docs - pipeline.forward("test query") + await pipeline.aforward("test query") # All documents should be passed through - call_args = pipeline.generation_program.forward.call_args + call_args = pipeline.generation_program.aforward.call_args context = call_args[1]["context"] for doc in sample_documents: assert doc.page_content in context @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") - def test_judge_parse_error_handling( + @pytest.mark.asyncio + async def test_judge_parse_error_handling( self, mock_judge_class, dspy_env_patched, patch_dspy_parallel, pipeline, mock_document_retriever ): @@ -339,7 +342,7 @@ def test_judge_parse_error_handling( ("Doc1", "Content1", "source1"), ("Doc2", "Content2", "source2"), ]) - mock_document_retriever.forward.return_value = docs + mock_document_retriever.aforward.return_value = docs # Create judge that returns invalid score judge = Mock(spec=RetrievalJudge) @@ -356,16 +359,16 @@ def filter_with_parse_error(query, documents): # The mock's side effect must replicate the real judge's behavior. return [documents[1]] - judge.forward = Mock(side_effect=filter_with_parse_error) + judge.aforward = AsyncMock(side_effect=filter_with_parse_error) judge.threshold = 0.5 - pipeline.retrieval_judge.forward.side_effect = judge.forward + pipeline.retrieval_judge.aforward.side_effect = judge.aforward pipeline.retrieval_judge.threshold = judge.threshold - pipeline.forward("test query") + await pipeline.aforward("test query") # The doc with the parse error ("Content1") should be dropped and not in the context. - call_args = pipeline.generation_program.forward.call_args + call_args = pipeline.generation_program.aforward.call_args context = call_args[1]["context"] assert "Content1" not in context assert "Content2" in context @@ -385,8 +388,8 @@ async def test_async_judge_execution( pipeline.retrieval_judge.aforward.assert_called_once() assert result.answer == "Here's how to write Cairo contracts..." - @pytest.mark.asyncio @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") + @pytest.mark.asyncio async def test_streaming_with_judge( self, mock_judge_class, dspy_env_patched, patch_dspy_parallel, pipeline, mock_retrieval_judge @@ -395,7 +398,7 @@ async def test_streaming_with_judge( pipeline.retrieval_judge.aforward.side_effect = mock_retrieval_judge.aforward events = [] - async for event in pipeline.forward_streaming("test query"): + async for event in pipeline.aforward_streaming("test query"): events.append(event) # Verify judge was called @@ -408,24 +411,25 @@ async def test_streaming_with_judge( assert sources_event.data[0]["title"] == "Introduction to Cairo" @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") - def test_judge_metadata_enrichment( + @pytest.mark.asyncio + async def test_judge_metadata_enrichment( self, mock_judge_class, dspy_env_patched, patch_dspy_parallel, pipeline, mock_document_retriever ): """Test that judge adds metadata to documents.""" docs = create_custom_documents([("Test Doc", "Test content", "test_source")]) - mock_document_retriever.forward.return_value = docs + mock_document_retriever.aforward.return_value = docs judge = create_custom_retrieval_judge({"Test Doc": 0.75}) - pipeline.retrieval_judge.forward.side_effect = judge.forward + pipeline.retrieval_judge.aforward.side_effect = judge.aforward - pipeline.forward("test query") + await pipeline.aforward("test query") # Check that judge was called and documents have metadata - pipeline.retrieval_judge.forward.assert_called_once() + pipeline.retrieval_judge.aforward.assert_called_once() # Verify that generation received the filtered document with metadata - gen_call_args = pipeline.generation_program.forward.call_args[1] + gen_call_args = pipeline.generation_program.aforward.call_args[1] context = gen_call_args["context"] # The document should be in the context (score 0.75 is above threshold) @@ -633,7 +637,7 @@ async def test_get_lm_usage_after_streaming( pipeline = RagPipeline(pipeline_config) # Execute the pipeline to ensure the full flow is invoked. - async for _ in pipeline.forward_streaming( + async for _ in pipeline.aforward_streaming( query="How do I create a Cairo contract?", mcp_mode=mcp_mode ): pass diff --git a/python/tests/unit/test_retrieval_judge.py b/python/tests/unit/test_retrieval_judge.py index fa01dab7..f2d885d5 100644 --- a/python/tests/unit/test_retrieval_judge.py +++ b/python/tests/unit/test_retrieval_judge.py @@ -1,6 +1,6 @@ """Unit tests for RetrievalJudge module.""" -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock import dspy import pytest @@ -31,107 +31,81 @@ def sample_documents(self): ), ] - def test_retrieval_judge_initialization(self): + @pytest.mark.asyncio + async def test_retrieval_judge_initialization(self): """Test RetrievalJudge initialization.""" judge = RetrievalJudge() assert judge.threshold == 0.4 assert judge.parallel_threads == 5 assert isinstance(judge.rater, dspy.Predict) - def test_forward_empty_documents(self): + @pytest.mark.asyncio + async def test_aforward_empty_documents(self): """Test forward with empty document list.""" judge = RetrievalJudge() - result = judge.forward("test query", []) + result = await judge.aforward("test query", []) assert result == [] - @patch("dspy.Parallel") - def test_forward_with_mocked_parallel(self, mock_parallel_class, sample_documents): - """Test forward method with mocked parallel execution.""" - # Setup mock - mock_parallel_instance = MagicMock() - mock_parallel_class.return_value = mock_parallel_instance - - # Create mock results - mock_results = [ - MagicMock( - resource_note=0.8, reasoning="Resource Cairo Introduction is highly relevant" - ), - MagicMock( - resource_note=0.3, reasoning="Resource Starknet Overview is somewhat relevant" - ), + @pytest.mark.asyncio + async def test_aforward_with_mocked_rater(self, sample_documents): + """Test forward method with mocked rater execution.""" + # Mock rater.acall to return per-document results + judge = RetrievalJudge() + judge.rater.acall = AsyncMock(side_effect=[ + MagicMock(resource_note=0.8, reasoning="Resource Cairo Introduction is highly relevant"), + MagicMock(resource_note=0.3, reasoning="Resource Starknet Overview is somewhat relevant"), MagicMock(resource_note=0.1, reasoning="Resource Python Guide is not relevant"), - ] - mock_parallel_instance.return_value = mock_results + ]) - # Test - judge = RetrievalJudge() documents = sample_documents - filtered_docs = judge.forward("How to write Cairo programs?", documents) + filtered_docs = await judge.aforward("How to write Cairo programs?", documents) # Assertions assert len(filtered_docs) == 1 # Only first doc passes threshold assert filtered_docs[0].metadata["llm_judge_score"] == 0.8 assert "highly relevant" in filtered_docs[0].metadata["llm_judge_reason"] - # Verify parallel was called correctly - mock_parallel_class.assert_called_once_with(num_threads=5) - mock_parallel_instance.assert_called_once() + # rater.acall was invoked for each doc + assert judge.rater.acall.await_count == 3 - @patch("dspy.Parallel") - def test_forward_with_parse_error(self, mock_parallel_class, sample_documents): + @pytest.mark.asyncio + async def test_forward_with_parse_error(self, sample_documents): """Test forward handling parse errors gracefully by dropping the invalid doc.""" - # Setup mock - mock_parallel_instance = MagicMock() - mock_parallel_class.return_value = mock_parallel_instance - - # Create results with parse error - mock_results = [ + judge = RetrievalJudge() + judge.rater.acall = AsyncMock(side_effect=[ MagicMock(resource_note="invalid", reasoning="Some reasoning"), # Invalid score MagicMock(resource_note=0.7, reasoning="Valid result"), - ] - mock_parallel_instance.return_value = mock_results - - # Test - judge = RetrievalJudge() + ]) documents = sample_documents[:2] - filtered_docs = judge.forward("test query", documents) + filtered_docs = await judge.aforward("test query", documents) # Should only keep the doc that was successfully parsed and scored above threshold. assert len(filtered_docs) == 1 assert filtered_docs[0] is documents[1] # Check it's the second document assert filtered_docs[0].metadata["llm_judge_score"] == 0.7 - # The document that failed parsing should have error metadata but not be in the final list. + # The doc that failed parsing should have error metadata but not be in the final list. assert "llm_judge_score" in documents[0].metadata assert documents[0].metadata["llm_judge_score"] == 0.0 assert documents[0].metadata["llm_judge_reason"] == "Parse error" - @patch("dspy.Parallel") - def test_forward_with_exception(self, mock_parallel_class, sample_documents): + @pytest.mark.asyncio + async def test_aforward_with_exception(self, sample_documents): """Test forward handling exceptions by returning all documents.""" - # Setup mock to raise exception - mock_parallel_class.side_effect = Exception("Parallel execution failed") - - # Test judge = RetrievalJudge() + judge.rater.acall = AsyncMock(side_effect=Exception("Parallel execution failed")) documents = sample_documents - filtered_docs = judge.forward("test query", documents) + filtered_docs = await judge.aforward("test query", documents) # Should return all documents on failure assert len(filtered_docs) == len(documents) assert filtered_docs == documents @pytest.mark.asyncio - async def test_aforward_empty_documents(self): - """Test async forward with empty document list.""" - judge = RetrievalJudge() - result = await judge.aforward("test query", []) - assert result == [] - - def test_forward_with_contract_and_test_templates(self, sample_documents): + async def test_aforward_with_contract_and_test_templates(self, sample_documents): """Test forward with contract template.""" judge = RetrievalJudge() - result = judge.forward( + result = await judge.aforward( "test query", [ Document( @@ -183,101 +157,29 @@ async def test_aforward_with_contract_template(self, sample_documents): ), ] - @pytest.mark.asyncio - async def test_aforward_with_mocked_rater(self, sample_documents): - """Test async forward method with mocked rater.""" - judge = Mock() - judge.threshold = 0.5 - - # Mock the aforward method directly - async def mock_aforward(query, documents): - filtered = [] - for doc in documents: - # Simulate scoring based on content - if "Cairo" in doc.page_content: - doc.metadata["llm_judge_score"] = 0.9 - doc.metadata["llm_judge_reason"] = "Highly relevant to Cairo" - filtered.append(doc) - elif "Starknet" in doc.page_content: - doc.metadata["llm_judge_score"] = 0.6 - doc.metadata["llm_judge_reason"] = "Somewhat relevant" - filtered.append(doc) - else: - doc.metadata["llm_judge_score"] = 0.2 - doc.metadata["llm_judge_reason"] = "Not relevant" - return filtered - - judge.aforward = mock_aforward - - # Test - documents = sample_documents - filtered_docs = await judge.aforward("Cairo programming query", documents) - - # Should filter out Python doc (score 0.2 < threshold 0.5) - assert len(filtered_docs) == 2 - assert all(doc.metadata["llm_judge_score"] >= 0.5 for doc in filtered_docs) @pytest.mark.asyncio - async def test_aforward_with_exception(self, sample_documents): - """Test async forward handling exceptions during judgment of a single document.""" - judge = Mock() - judge.threshold = 0.5 - - # Mock aforward to simulate exception for one document - async def mock_aforward(query, documents): - filtered = [] - for doc in documents: - if "Starknet" in doc.page_content: - # Simulate failure - document is dropped - doc.metadata["llm_judge_score"] = 0.0 - doc.metadata["llm_judge_reason"] = "Error during judgment" - else: - doc.metadata["llm_judge_score"] = 0.9 - doc.metadata["llm_judge_reason"] = "Highly relevant" - filtered.append(doc) - return filtered - - judge.aforward = mock_aforward - - # Test - documents = sample_documents[:2] # Using two docs for the test - filtered_docs = await judge.aforward("test query", documents) - - # Should keep the document that was judged successfully and drop the one with an error. - assert len(filtered_docs) == 1 - assert filtered_docs[0] is documents[0] - assert "llm_judge_reason" in documents[1].metadata - assert documents[1].metadata["llm_judge_reason"] == "Error during judgment" - - def test_score_clamping(self, sample_documents): + async def test_score_clamping(self, sample_documents): """Test that scores are properly clamped to [0,1] range.""" judge = RetrievalJudge() + judge.rater.acall = AsyncMock(side_effect=[ + MagicMock(resource_note=1.5, reasoning="Score too high"), + MagicMock(resource_note=-0.3, reasoning="Score too low"), + MagicMock(resource_note=0.5, reasoning="Valid score"), + ]) + documents = sample_documents + filtered_docs = await judge.aforward("test", documents) - # Mock parallel execution with out-of-range scores - with patch("dspy.Parallel") as mock_parallel_class: - mock_parallel_instance = MagicMock() - mock_parallel_class.return_value = mock_parallel_instance - - mock_results = [ - MagicMock(resource_note=1.5, reasoning="Score too high"), - MagicMock(resource_note=-0.3, reasoning="Score too low"), - MagicMock(resource_note=0.5, reasoning="Valid score"), - ] - mock_parallel_instance.return_value = mock_results - - documents = sample_documents - filtered_docs = judge.forward("test", documents) - - # Check scores are clamped and filtering works - assert len(filtered_docs) == 2 # Only 2 docs pass threshold of 0.4 + # Check scores are clamped and filtering works + assert len(filtered_docs) == 2 # Only 2 docs pass threshold of 0.4 - # Check the scores of all documents (including filtered out ones) - assert documents[0].metadata["llm_judge_score"] == 1.0 # Clamped from 1.5 - assert documents[1].metadata["llm_judge_score"] == 0.0 # Clamped from -0.3 - assert documents[2].metadata["llm_judge_score"] == 0.5 # Valid score + # Check the scores of all documents (including filtered out ones) + assert documents[0].metadata["llm_judge_score"] == 1.0 # Clamped from 1.5 + assert documents[1].metadata["llm_judge_score"] == 0.0 # Clamped from -0.3 + assert documents[2].metadata["llm_judge_score"] == 0.5 # Valid score - # Check filtered docs only contain those above threshold - assert all(doc.metadata["llm_judge_score"] >= 0.4 for doc in filtered_docs) + # Check filtered docs only contain those above threshold + assert all(doc.metadata["llm_judge_score"] >= 0.4 for doc in filtered_docs) def test_document_string_preparation(self): """Test document string preparation includes title and truncates content.""" @@ -285,21 +187,9 @@ def test_document_string_preparation(self): # Create document with long content long_content = "x" * 2000 - doc = Document(page_content=long_content, metadata={"title": "Test Doc"}) - - with patch("dspy.Parallel") as mock_parallel_class: - mock_parallel_instance = MagicMock() - mock_parallel_class.return_value = mock_parallel_instance - mock_parallel_instance.return_value = [MagicMock(resource_note=0.5, reasoning="test")] - - judge.forward("test", [doc]) - - # Get the batches that were created - call_args = mock_parallel_instance.call_args[0][0] - example = call_args[0][1] - # Check document string format - doc_string = example.system_resource - assert "Title: Test Doc" in doc_string - assert len(doc_string) < 1200 # Should be truncated - assert doc_string.endswith("...") + # Directly test the string builder + doc_string = judge._document_to_string("Test Doc", long_content) + assert "Title: Test Doc" in doc_string + assert len(doc_string) < 1200 # Should be truncated + assert doc_string.endswith("...") diff --git a/python/uv.lock b/python/uv.lock index 07c2c45a..cfde6089 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -398,7 +398,7 @@ requires-dist = [ { name = "asyncpg", specifier = ">=0.30.0" }, { name = "beautifulsoup4", specifier = ">=4.12.0" }, { name = "black", marker = "extra == 'dev'", specifier = ">=24.0.0" }, - { name = "dspy", specifier = ">=2.6.27" }, + { name = "dspy", specifier = ">=3.0.0" }, { name = "dspy-ai", specifier = ">=2.5.0" }, { name = "fastapi", specifier = ">=0.115.0" }, { name = "google-generativeai", specifier = ">=0.8.0" }, @@ -783,40 +783,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/ac/1d97e438f86c26314227f7b2f0711476db79522a137b60533c5181ae481b/databricks_sdk-0.59.0-py3-none-any.whl", hash = "sha256:2ae4baefd1f7360c8314e2ebdc0a0a6d7e76a88805a65d0415ff73631c1e4c0d", size = 676213, upload-time = "2025-07-17T11:13:56.088Z" }, ] -[[package]] -name = "datasets" -version = "4.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "dill" }, - { name = "filelock" }, - { name = "fsspec", extra = ["http"] }, - { name = "huggingface-hub" }, - { name = "multiprocess" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "packaging" }, - { name = "pandas" }, - { name = "pyarrow" }, - { name = "pyyaml" }, - { name = "requests" }, - { name = "tqdm" }, - { name = "xxhash" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e3/9d/348ed92110ba5f9b70b51ca1078d4809767a835aa2b7ce7e74ad2b98323d/datasets-4.0.0.tar.gz", hash = "sha256:9657e7140a9050db13443ba21cb5de185af8af944479b00e7ff1e00a61c8dbf1", size = 569566, upload-time = "2025-07-09T14:35:52.431Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/62/eb8157afb21bd229c864521c1ab4fa8e9b4f1b06bafdd8c4668a7a31b5dd/datasets-4.0.0-py3-none-any.whl", hash = "sha256:7ef95e62025fd122882dbce6cb904c8cd3fbc829de6669a5eb939c77d50e203d", size = 494825, upload-time = "2025-07-09T14:35:50.658Z" }, -] - -[[package]] -name = "dill" -version = "0.3.8" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/17/4d/ac7ffa80c69ea1df30a8aa11b3578692a5118e7cd1aa157e3ef73b092d15/dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca", size = 184847, upload-time = "2024-01-27T23:42:16.145Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252, upload-time = "2024-01-27T23:42:14.239Z" }, -] - [[package]] name = "diskcache" version = "5.6.3" @@ -869,7 +835,7 @@ wheels = [ [[package]] name = "dspy" -version = "2.6.27" +version = "3.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -877,8 +843,8 @@ dependencies = [ { name = "backoff" }, { name = "cachetools" }, { name = "cloudpickle" }, - { name = "datasets" }, { name = "diskcache" }, + { name = "gepa" }, { name = "joblib" }, { name = "json-repair" }, { name = "litellm" }, @@ -887,7 +853,6 @@ dependencies = [ { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "openai" }, { name = "optuna" }, - { name = "pandas" }, { name = "pydantic" }, { name = "regex" }, { name = "requests" }, @@ -895,10 +860,11 @@ dependencies = [ { name = "tenacity" }, { name = "tqdm" }, { name = "ujson" }, + { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/38/8a/f7ff1a6d3b5294678f13d17ecfc596f49a59e494b190e4e30f7dea7df1dc/dspy-2.6.27.tar.gz", hash = "sha256:de1c4f6f6d127e0efed894e1915dac40f5d5623e7f1cf3d749c98d790066477a", size = 234604, upload-time = "2025-06-03T17:47:13.411Z" } +sdist = { url = "https://files.pythonhosted.org/packages/70/cb/4bfb5345e230e33b0fa4f18c16fe646395a081a48c6feb314e6993a86bb1/dspy-3.0.1.tar.gz", hash = "sha256:92220584eb7c3587746cac76209f7f167dbf6f38f5f05a7019d610ededc1eede", size = 213285, upload-time = "2025-08-14T17:39:32.415Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5a/bb/8a75d44bc1b54dea0fa0428eb52b13e7ee533b85841d2c53a53dfc360646/dspy-2.6.27-py3-none-any.whl", hash = "sha256:54e55fd6999b6a46e09b0e49e8c4b71be7dd56a881e66f7a60b8d657650c1a74", size = 297296, upload-time = "2025-06-03T17:47:11.526Z" }, + { url = "https://files.pythonhosted.org/packages/c1/b4/ef2706be57daf78562b8aa811cdfe184616becb6659522ace85919202b21/dspy-3.0.1-py3-none-any.whl", hash = "sha256:a9afb6eedaab063e9ca6d46840fad85b97ab45e79b4bf9371e6bf3a5666ef5c6", size = 259011, upload-time = "2025-08-14T17:39:30.901Z" }, ] [[package]] @@ -1109,9 +1075,13 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/53/eb690efa8513166adef3e0669afd31e95ffde69fb3c52ec2ac7223ed6018/fsspec-2025.3.0-py3-none-any.whl", hash = "sha256:efb87af3efa9103f94ca91a7f8cb7a4df91af9f74fc106c9c7ea0efd7277c1b3", size = 193615, upload-time = "2025-03-07T21:47:54.809Z" }, ] -[package.optional-dependencies] -http = [ - { name = "aiohttp" }, +[[package]] +name = "gepa" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/0d/aa6065d7d59b3f10ff6818d527dada5a7179ac5643b666b6b6b71d11dab4/gepa-0.0.4.tar.gz", hash = "sha256:b3e020124c7d8a80c07595aca3b73647ec9151203d7166915ad62492b8459bd6", size = 32957, upload-time = "2025-08-14T05:08:36.792Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/c0/836c79f05113c96155e8de1bb8bf3631a9e7b3b75238c592d39460141ea8/gepa-0.0.4-py3-none-any.whl", hash = "sha256:53d275490d644855e90adf4eba1e3ace5c414c76ba0c0f22760b99a0e43984f9", size = 35191, upload-time = "2025-08-14T05:08:35.558Z" }, ] [[package]] @@ -2369,24 +2339,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d8/30/9aec301e9772b098c1f5c0ca0279237c9766d94b97802e9888010c64b0ed/multidict-6.6.3-py3-none-any.whl", hash = "sha256:8db10f29c7541fc5da4defd8cd697e1ca429db743fa716325f236079b96f775a", size = 12313, upload-time = "2025-06-30T15:53:45.437Z" }, ] -[[package]] -name = "multiprocess" -version = "0.70.16" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "dill" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b5/ae/04f39c5d0d0def03247c2893d6f2b83c136bf3320a2154d7b8858f2ba72d/multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1", size = 1772603, upload-time = "2024-01-28T18:52:34.85Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ef/76/6e712a2623d146d314f17598df5de7224c85c0060ef63fd95cc15a25b3fa/multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee", size = 134980, upload-time = "2024-01-28T18:52:15.731Z" }, - { url = "https://files.pythonhosted.org/packages/0f/ab/1e6e8009e380e22254ff539ebe117861e5bdb3bff1fc977920972237c6c7/multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec", size = 134982, upload-time = "2024-01-28T18:52:17.783Z" }, - { url = "https://files.pythonhosted.org/packages/bc/f7/7ec7fddc92e50714ea3745631f79bd9c96424cb2702632521028e57d3a36/multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02", size = 134824, upload-time = "2024-01-28T18:52:26.062Z" }, - { url = "https://files.pythonhosted.org/packages/50/15/b56e50e8debaf439f44befec5b2af11db85f6e0f344c3113ae0be0593a91/multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a", size = 143519, upload-time = "2024-01-28T18:52:28.115Z" }, - { url = "https://files.pythonhosted.org/packages/0a/7d/a988f258104dcd2ccf1ed40fdc97e26c4ac351eeaf81d76e266c52d84e2f/multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e", size = 146741, upload-time = "2024-01-28T18:52:29.395Z" }, - { url = "https://files.pythonhosted.org/packages/ea/89/38df130f2c799090c978b366cfdf5b96d08de5b29a4a293df7f7429fa50b/multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435", size = 132628, upload-time = "2024-01-28T18:52:30.853Z" }, - { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, -] - [[package]] name = "mypy" version = "1.17.0"