diff --git a/python/src/cairo_coder/core/rag_pipeline.py b/python/src/cairo_coder/core/rag_pipeline.py index 8bfd11d..ce498de 100644 --- a/python/src/cairo_coder/core/rag_pipeline.py +++ b/python/src/cairo_coder/core/rag_pipeline.py @@ -210,7 +210,9 @@ async def aforward_streaming( ) mcp_prediction = self.mcp_generation_program(documents) + # Emit single response plus a final response event for clients that rely on it yield StreamEvent(type=StreamEventType.RESPONSE, data=mcp_prediction.answer) + yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data=mcp_prediction.answer) else: # Normal mode: Generate response yield StreamEvent(type=StreamEventType.PROCESSING, data="Generating response...") @@ -223,12 +225,19 @@ async def aforward_streaming( adapter=dspy.adapters.ChatAdapter() ), ls.trace(name="GenerationProgramStreaming", run_type="llm", inputs={"query": query, "chat_history": chat_history_str, "context": context}) as rt: chunk_accumulator = "" + final_text: str | None = None async for chunk in self.generation_program.aforward_streaming( query=query, context=context, chat_history=chat_history_str ): - chunk_accumulator += chunk - yield StreamEvent(type=StreamEventType.RESPONSE, data=chunk) - rt.end(outputs={"output": chunk_accumulator}) + if isinstance(chunk, dspy.streaming.StreamResponse): + # Incremental token + chunk_accumulator += chunk.chunk + yield StreamEvent(type=StreamEventType.RESPONSE, data=chunk.chunk) + elif isinstance(chunk, dspy.Prediction): + # Final complete answer + final_text = getattr(chunk, "answer", None) or chunk_accumulator + yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data=final_text) + rt.end(outputs={"output": final_text}) # Pipeline completed yield StreamEvent(type=StreamEventType.END, data=None) diff --git a/python/src/cairo_coder/core/types.py b/python/src/cairo_coder/core/types.py index 6099da6..d76c7cb 100644 --- a/python/src/cairo_coder/core/types.py +++ b/python/src/cairo_coder/core/types.py @@ -113,6 +113,7 @@ class StreamEventType(str, Enum): SOURCES = "sources" PROCESSING = "processing" RESPONSE = "response" + FINAL_RESPONSE = "final_response" END = "end" ERROR = "error" diff --git a/python/src/cairo_coder/dspy/generation_program.py b/python/src/cairo_coder/dspy/generation_program.py index a9f2e08..ea79876 100644 --- a/python/src/cairo_coder/dspy/generation_program.py +++ b/python/src/cairo_coder/dspy/generation_program.py @@ -227,7 +227,7 @@ async def aforward(self, query: str, context: str, chat_history: Optional[str] = async def aforward_streaming( self, query: str, context: str, chat_history: Optional[str] = None - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[object, None]: """ Generate Cairo code response with streaming support using DSPy's native streaming. @@ -255,18 +255,8 @@ async def aforward_streaming( query=query, context=context, chat_history=chat_history ) - # Process the stream and yield tokens - is_cached = True async for chunk in output_stream: - if isinstance(chunk, dspy.streaming.StreamResponse): - # No streaming if cached - is_cached = False - # Yield the actual token content - yield chunk.chunk - elif isinstance(chunk, dspy.Prediction): - if is_cached: - yield chunk.answer - # Final output received - streaming is complete + yield chunk def _format_chat_history(self, chat_history: list[Message]) -> str: """ diff --git a/python/src/cairo_coder/server/app.py b/python/src/cairo_coder/server/app.py index 8d202e8..98f464b 100644 --- a/python/src/cairo_coder/server/app.py +++ b/python/src/cairo_coder/server/app.py @@ -30,7 +30,7 @@ AgentLoggingCallback, RagPipeline, ) -from cairo_coder.core.types import Message, Role +from cairo_coder.core.types import Message, Role, StreamEventType from cairo_coder.dspy.document_retriever import SourceFilteredPgVectorRM from cairo_coder.dspy.suggestion_program import SuggestionGeneration from cairo_coder.utils.logging import setup_logging @@ -401,7 +401,7 @@ async def _handle_chat_completion( ) from e async def _stream_chat_completion( - self, agent, query: str, history: list[Message], mcp_mode: bool + self, agent: RagPipeline, query: str, history: list[Message], mcp_mode: bool ) -> AsyncGenerator[str, None]: """Stream chat completion response - replicates TypeScript streaming.""" response_id = str(uuid.uuid4()) @@ -425,14 +425,14 @@ async def _stream_chat_completion( async for event in agent.aforward_streaming( query=query, chat_history=history, mcp_mode=mcp_mode ): - if event.type == "sources": + if event.type == StreamEventType.SOURCES: # Emit sources event for clients to display sources_chunk = { "type": "sources", "data": event.data, } yield f"data: {json.dumps(sources_chunk)}\n\n" - elif event.type == "response": + elif event.type == StreamEventType.RESPONSE: content_buffer += event.data # Send content chunk @@ -446,7 +446,14 @@ async def _stream_chat_completion( ], } yield f"data: {json.dumps(chunk)}\n\n" - elif event.type == "error": + elif event.type == StreamEventType.FINAL_RESPONSE: + # Emit an explicit final response event for clients + final_event = { + "type": "final_response", + "data": event.data, + } + yield f"data: {json.dumps(final_event)}\n\n" + elif event.type == StreamEventType.ERROR: # Emit an error as a final delta and stop error_chunk = { "id": response_id, @@ -463,7 +470,7 @@ async def _stream_chat_completion( } yield f"data: {json.dumps(error_chunk)}\n\n" break - elif event.type == "end": + elif event.type == StreamEventType.END: break rt.end(outputs={"output": content_buffer}) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index bca5d69..8207169 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -168,10 +168,12 @@ async def mock_aforward_streaming( ], ) yield StreamEvent(type=StreamEventType.RESPONSE, data="Cairo is a programming language") + yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data="Cairo is a programming language") else: # Normal mode returns response yield StreamEvent(type=StreamEventType.RESPONSE, data="Hello! I'm Cairo Coder.") yield StreamEvent(type=StreamEventType.RESPONSE, data=" How can I help you?") + yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data="Hello! I'm Cairo Coder. How can I help you?") yield StreamEvent(type=StreamEventType.END, data="") def mock_forward(query: str, chat_history: list[Message] | None = None, mcp_mode: bool = False): @@ -369,8 +371,9 @@ def mock_generation_program(): program.get_lm_usage = Mock(return_value={}) async def mock_streaming(*args, **kwargs): - yield "Here's how to write " - yield "Cairo contracts..." + yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Here's how to write ", is_last_chunk=False) + yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Cairo contracts...", is_last_chunk=True) + yield dspy.Prediction(answer=answer) program.aforward_streaming = mock_streaming return program diff --git a/python/tests/integration/conftest.py b/python/tests/integration/conftest.py index d60d4af..cba2057 100644 --- a/python/tests/integration/conftest.py +++ b/python/tests/integration/conftest.py @@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, Mock +import dspy import pytest from fastapi.testclient import TestClient @@ -14,41 +15,6 @@ from cairo_coder.server.app import get_agent_factory, get_vector_db -@pytest.fixture -def patch_dspy_streaming_success(monkeypatch): - """Patch dspy.streamify to emit token-like chunks and provide StreamListener. - - Yields two chunks: "Hello " and "world". - """ - import dspy - - class FakeStreamResponse: - def __init__(self, chunk: str): - self.chunk = chunk - - class FakeStreamListener: - def __init__(self, signature_field_name: str): # noqa: ARG002 - pass - - monkeypatch.setattr( - dspy, - "streaming", - type("S", (), {"StreamResponse": FakeStreamResponse, "StreamListener": FakeStreamListener}), - ) - - def fake_streamify(_program, stream_listeners=None): # noqa: ARG001 - def runner(**kwargs): # noqa: ARG001 - async def gen(): - yield FakeStreamResponse("Hello ") - yield FakeStreamResponse("world") - - return gen() - - return runner - - monkeypatch.setattr(dspy, "streamify", fake_streamify) - - @pytest.fixture def patch_dspy_streaming_error(monkeypatch, real_pipeline): """Patch dspy.streamify to raise an error mid-stream and provide StreamListener. @@ -140,8 +106,9 @@ async def _fake_gen_aforward(query: str, context: str, chat_history: str | None return _dspy.Prediction(answer=responses[idx]) async def _fake_gen_aforward_streaming(query: str, context: str, chat_history: str | None = None): - yield "Hello! I'm Cairo Coder, " - yield "ready to help with Cairo programming." + yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Hello! I'm Cairo Coder, ", is_last_chunk=False) + yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="ready to help with Cairo programming.", is_last_chunk=True) + yield dspy.Prediction(answer="Hello! I'm Cairo Coder, ready to help with Cairo programming.") pipeline.generation_program.aforward = AsyncMock(side_effect=_fake_gen_aforward) pipeline.generation_program.aforward_streaming =_fake_gen_aforward_streaming diff --git a/python/tests/integration/test_server_integration.py b/python/tests/integration/test_server_integration.py index aebbcb7..bf691ce 100644 --- a/python/tests/integration/test_server_integration.py +++ b/python/tests/integration/test_server_integration.py @@ -115,7 +115,6 @@ async def mock_aforward(query: str, chat_history=None, mcp_mode=False, **kwargs) def test_streaming_integration( self, client: TestClient, - patch_dspy_streaming_success, ): """Test streaming response end-to-end using a real pipeline with low-level patches.""" @@ -457,8 +456,10 @@ def test_openai_streaming_response_structure(self, client: TestClient): if data_str != "[DONE]": chunks.append(json.loads(data_str)) - # Filter out sources events (custom event type for frontend) - openai_chunks = [chunk for chunk in chunks if chunk.get("type") != "sources"] + # Filter out custom frontend events (sources, final_response) + openai_chunks = [ + chunk for chunk in chunks if chunk.get("type") not in ("sources", "final_response") + ] for chunk in openai_chunks: required_fields = ["id", "object", "created", "model", "choices"] diff --git a/python/tests/unit/test_rag_pipeline.py b/python/tests/unit/test_rag_pipeline.py index b4f4d94..e425aab 100644 --- a/python/tests/unit/test_rag_pipeline.py +++ b/python/tests/unit/test_rag_pipeline.py @@ -14,7 +14,7 @@ RagPipelineConfig, RagPipelineFactory, ) -from cairo_coder.core.types import Document, DocumentSource, Message, Role +from cairo_coder.core.types import Document, DocumentSource, Message, Role, StreamEventType from cairo_coder.dspy.retrieval_judge import RetrievalJudge @@ -148,10 +148,11 @@ async def test_streaming_pipeline_execution(self, pipeline): # Verify event sequence event_types = [e.type for e in events] - assert "processing" in event_types - assert "sources" in event_types - assert "response" in event_types - assert "end" in event_types + assert StreamEventType.PROCESSING in event_types + assert StreamEventType.SOURCES in event_types + assert StreamEventType.RESPONSE in event_types + assert StreamEventType.FINAL_RESPONSE in event_types + assert StreamEventType.END in event_types @pytest.mark.asyncio async def test_mcp_mode_execution(self, pipeline):