Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions python/src/cairo_coder/core/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ async def aforward_streaming(
)

mcp_prediction = self.mcp_generation_program(documents)
# Emit single response plus a final response event for clients that rely on it
yield StreamEvent(type=StreamEventType.RESPONSE, data=mcp_prediction.answer)
yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data=mcp_prediction.answer)
else:
# Normal mode: Generate response
yield StreamEvent(type=StreamEventType.PROCESSING, data="Generating response...")
Expand All @@ -223,12 +225,19 @@ async def aforward_streaming(
adapter=dspy.adapters.ChatAdapter()
), ls.trace(name="GenerationProgramStreaming", run_type="llm", inputs={"query": query, "chat_history": chat_history_str, "context": context}) as rt:
chunk_accumulator = ""
final_text: str | None = None
async for chunk in self.generation_program.aforward_streaming(
query=query, context=context, chat_history=chat_history_str
):
chunk_accumulator += chunk
yield StreamEvent(type=StreamEventType.RESPONSE, data=chunk)
rt.end(outputs={"output": chunk_accumulator})
if isinstance(chunk, dspy.streaming.StreamResponse):
# Incremental token
chunk_accumulator += chunk.chunk
yield StreamEvent(type=StreamEventType.RESPONSE, data=chunk.chunk)
elif isinstance(chunk, dspy.Prediction):
# Final complete answer
final_text = getattr(chunk, "answer", None) or chunk_accumulator
yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data=final_text)
rt.end(outputs={"output": final_text})

# Pipeline completed
yield StreamEvent(type=StreamEventType.END, data=None)
Expand Down
1 change: 1 addition & 0 deletions python/src/cairo_coder/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class StreamEventType(str, Enum):
SOURCES = "sources"
PROCESSING = "processing"
RESPONSE = "response"
FINAL_RESPONSE = "final_response"
END = "end"
ERROR = "error"

Expand Down
14 changes: 2 additions & 12 deletions python/src/cairo_coder/dspy/generation_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ async def aforward(self, query: str, context: str, chat_history: Optional[str] =

async def aforward_streaming(
self, query: str, context: str, chat_history: Optional[str] = None
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[object, None]:
"""
Generate Cairo code response with streaming support using DSPy's native streaming.

Expand Down Expand Up @@ -255,18 +255,8 @@ async def aforward_streaming(
query=query, context=context, chat_history=chat_history
)

# Process the stream and yield tokens
is_cached = True
async for chunk in output_stream:
if isinstance(chunk, dspy.streaming.StreamResponse):
# No streaming if cached
is_cached = False
# Yield the actual token content
yield chunk.chunk
elif isinstance(chunk, dspy.Prediction):
if is_cached:
yield chunk.answer
# Final output received - streaming is complete
yield chunk

def _format_chat_history(self, chat_history: list[Message]) -> str:
"""
Expand Down
19 changes: 13 additions & 6 deletions python/src/cairo_coder/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
AgentLoggingCallback,
RagPipeline,
)
from cairo_coder.core.types import Message, Role
from cairo_coder.core.types import Message, Role, StreamEventType
from cairo_coder.dspy.document_retriever import SourceFilteredPgVectorRM
from cairo_coder.dspy.suggestion_program import SuggestionGeneration
from cairo_coder.utils.logging import setup_logging
Expand Down Expand Up @@ -401,7 +401,7 @@ async def _handle_chat_completion(
) from e

async def _stream_chat_completion(
self, agent, query: str, history: list[Message], mcp_mode: bool
self, agent: RagPipeline, query: str, history: list[Message], mcp_mode: bool
) -> AsyncGenerator[str, None]:
"""Stream chat completion response - replicates TypeScript streaming."""
response_id = str(uuid.uuid4())
Expand All @@ -425,14 +425,14 @@ async def _stream_chat_completion(
async for event in agent.aforward_streaming(
query=query, chat_history=history, mcp_mode=mcp_mode
):
if event.type == "sources":
if event.type == StreamEventType.SOURCES:
# Emit sources event for clients to display
sources_chunk = {
"type": "sources",
"data": event.data,
}
yield f"data: {json.dumps(sources_chunk)}\n\n"
elif event.type == "response":
elif event.type == StreamEventType.RESPONSE:
content_buffer += event.data

# Send content chunk
Expand All @@ -446,7 +446,14 @@ async def _stream_chat_completion(
],
}
yield f"data: {json.dumps(chunk)}\n\n"
elif event.type == "error":
elif event.type == StreamEventType.FINAL_RESPONSE:
# Emit an explicit final response event for clients
final_event = {
"type": "final_response",
"data": event.data,
}
yield f"data: {json.dumps(final_event)}\n\n"
elif event.type == StreamEventType.ERROR:
# Emit an error as a final delta and stop
error_chunk = {
"id": response_id,
Expand All @@ -463,7 +470,7 @@ async def _stream_chat_completion(
}
yield f"data: {json.dumps(error_chunk)}\n\n"
break
elif event.type == "end":
elif event.type == StreamEventType.END:
break
rt.end(outputs={"output": content_buffer})

Expand Down
7 changes: 5 additions & 2 deletions python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ async def mock_aforward_streaming(
],
)
yield StreamEvent(type=StreamEventType.RESPONSE, data="Cairo is a programming language")
yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data="Cairo is a programming language")
else:
# Normal mode returns response
yield StreamEvent(type=StreamEventType.RESPONSE, data="Hello! I'm Cairo Coder.")
yield StreamEvent(type=StreamEventType.RESPONSE, data=" How can I help you?")
yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data="Hello! I'm Cairo Coder. How can I help you?")
yield StreamEvent(type=StreamEventType.END, data="")

def mock_forward(query: str, chat_history: list[Message] | None = None, mcp_mode: bool = False):
Expand Down Expand Up @@ -369,8 +371,9 @@ def mock_generation_program():
program.get_lm_usage = Mock(return_value={})

async def mock_streaming(*args, **kwargs):
yield "Here's how to write "
yield "Cairo contracts..."
yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Here's how to write ", is_last_chunk=False)
yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Cairo contracts...", is_last_chunk=True)
yield dspy.Prediction(answer=answer)

program.aforward_streaming = mock_streaming
return program
Expand Down
41 changes: 4 additions & 37 deletions python/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,14 @@

from unittest.mock import AsyncMock, Mock

import dspy
import pytest
from fastapi.testclient import TestClient

from cairo_coder.agents.registry import AgentId
from cairo_coder.server.app import get_agent_factory, get_vector_db


@pytest.fixture
def patch_dspy_streaming_success(monkeypatch):
"""Patch dspy.streamify to emit token-like chunks and provide StreamListener.

Yields two chunks: "Hello " and "world".
"""
import dspy

class FakeStreamResponse:
def __init__(self, chunk: str):
self.chunk = chunk

class FakeStreamListener:
def __init__(self, signature_field_name: str): # noqa: ARG002
pass

monkeypatch.setattr(
dspy,
"streaming",
type("S", (), {"StreamResponse": FakeStreamResponse, "StreamListener": FakeStreamListener}),
)

def fake_streamify(_program, stream_listeners=None): # noqa: ARG001
def runner(**kwargs): # noqa: ARG001
async def gen():
yield FakeStreamResponse("Hello ")
yield FakeStreamResponse("world")

return gen()

return runner

monkeypatch.setattr(dspy, "streamify", fake_streamify)


@pytest.fixture
def patch_dspy_streaming_error(monkeypatch, real_pipeline):
"""Patch dspy.streamify to raise an error mid-stream and provide StreamListener.
Expand Down Expand Up @@ -140,8 +106,9 @@ async def _fake_gen_aforward(query: str, context: str, chat_history: str | None
return _dspy.Prediction(answer=responses[idx])

async def _fake_gen_aforward_streaming(query: str, context: str, chat_history: str | None = None):
yield "Hello! I'm Cairo Coder, "
yield "ready to help with Cairo programming."
yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="Hello! I'm Cairo Coder, ", is_last_chunk=False)
yield dspy.streaming.StreamResponse(predict_name="GenerationProgram", signature_field_name="answer", chunk="ready to help with Cairo programming.", is_last_chunk=True)
yield dspy.Prediction(answer="Hello! I'm Cairo Coder, ready to help with Cairo programming.")

pipeline.generation_program.aforward = AsyncMock(side_effect=_fake_gen_aforward)
pipeline.generation_program.aforward_streaming =_fake_gen_aforward_streaming
Expand Down
7 changes: 4 additions & 3 deletions python/tests/integration/test_server_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ async def mock_aforward(query: str, chat_history=None, mcp_mode=False, **kwargs)
def test_streaming_integration(
self,
client: TestClient,
patch_dspy_streaming_success,
):
"""Test streaming response end-to-end using a real pipeline with low-level patches."""

Expand Down Expand Up @@ -457,8 +456,10 @@ def test_openai_streaming_response_structure(self, client: TestClient):
if data_str != "[DONE]":
chunks.append(json.loads(data_str))

# Filter out sources events (custom event type for frontend)
openai_chunks = [chunk for chunk in chunks if chunk.get("type") != "sources"]
# Filter out custom frontend events (sources, final_response)
openai_chunks = [
chunk for chunk in chunks if chunk.get("type") not in ("sources", "final_response")
]

for chunk in openai_chunks:
required_fields = ["id", "object", "created", "model", "choices"]
Expand Down
11 changes: 6 additions & 5 deletions python/tests/unit/test_rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
RagPipelineConfig,
RagPipelineFactory,
)
from cairo_coder.core.types import Document, DocumentSource, Message, Role
from cairo_coder.core.types import Document, DocumentSource, Message, Role, StreamEventType
from cairo_coder.dspy.retrieval_judge import RetrievalJudge


Expand Down Expand Up @@ -148,10 +148,11 @@ async def test_streaming_pipeline_execution(self, pipeline):

# Verify event sequence
event_types = [e.type for e in events]
assert "processing" in event_types
assert "sources" in event_types
assert "response" in event_types
assert "end" in event_types
assert StreamEventType.PROCESSING in event_types
assert StreamEventType.SOURCES in event_types
assert StreamEventType.RESPONSE in event_types
assert StreamEventType.FINAL_RESPONSE in event_types
assert StreamEventType.END in event_types

@pytest.mark.asyncio
async def test_mcp_mode_execution(self, pipeline):
Expand Down