Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions python/src/cairo_coder/core/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any

import dspy
import langsmith as ls
import structlog
from dspy.adapters import XMLAdapter
from dspy.utils.callback import BaseCallback
Expand Down Expand Up @@ -164,6 +165,7 @@ async def aforward(
query=query, context=context, chat_history=chat_history_str
)


async def aforward_streaming(
self,
query: str,
Expand Down Expand Up @@ -218,13 +220,15 @@ async def aforward_streaming(

# Stream response generation. Use ChatAdapter for streaming, which performs better.
with dspy.context(
lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000),
adapter=dspy.adapters.XMLAdapter(),
):
async for chunk in self.generation_program.aforward_streaming(
query=query, context=context, chat_history=chat_history_str
):
yield StreamEvent(type=StreamEventType.RESPONSE, data=chunk)
adapter=dspy.adapters.ChatAdapter()
), ls.trace(name="GenerationProgramStreaming", run_type="llm", inputs={"query": query, "chat_history": chat_history_str, "context": context}) as rt:
chunk_accumulator = ""
async for chunk in self.generation_program.aforward_streaming(
query=query, context=context, chat_history=chat_history_str
):
chunk_accumulator += chunk
yield StreamEvent(type=StreamEventType.RESPONSE, data=chunk)
rt.end(outputs={"output": chunk_accumulator})

# Pipeline completed
yield StreamEvent(type=StreamEventType.END, data=None)
Expand Down
2 changes: 1 addition & 1 deletion python/src/cairo_coder/dspy/generation_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def get_lm_usage(self) -> dict[str, int]:
"""
return self.generation_program.get_lm_usage()

@traceable(name="GenerationProgram", run_type="llm")
@traceable(name="GenerationProgram", run_type="llm", metadata={"llm_provider": dspy.settings.lm})
async def aforward(self, query: str, context: str, chat_history: Optional[str] = None) -> dspy.Prediction | None :
"""
Generate Cairo code response based on query and context - async
Expand Down
2 changes: 1 addition & 1 deletion python/src/cairo_coder/dspy/query_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(self):
"foundry",
}

@traceable(name="QueryProcessorProgram", run_type="llm")
@traceable(name="QueryProcessorProgram", run_type="llm", metadata={"llm_provider": dspy.settings.lm})
async def aforward(self, query: str, chat_history: Optional[str] = None) -> ProcessedQuery:
"""
Process a user query into a structured format for document retrieval.
Expand Down
24 changes: 19 additions & 5 deletions python/src/cairo_coder/dspy/retrieval_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,21 @@ class RetrievalRecallPrecision(dspy.Signature):
"""

query: str = dspy.InputField()
system_resource: str = dspy.InputField(desc="Single resource text (content + minimal metadata/title)")
system_resource: str = dspy.InputField(
desc="Single resource text (content + minimal metadata/title)"
)
reasoning: str = dspy.OutputField(
desc="A short sentence, on why a selected resource will be useful. If it's not selected, reason about why it's not going to be useful. Start by Resource <resource_title>..."
)
resource_note: float = dspy.OutputField(
desc="A note between 0 and 1.0 on how useful the resource is to directly answer the query. 0 being completely unrelated, 1.0 being very relevant, 0.5 being 'not directly related but still informative and can be useful for context'."
)


DEFAULT_THRESHOLD = 0.4
DEFAULT_PARALLEL_THREADS = 5


class RetrievalJudge(dspy.Module):
"""
LLM-based judge that scores retrieved documents for relevance to a query.
Expand Down Expand Up @@ -88,13 +92,17 @@ def __init__(self):
raise FileNotFoundError(f"{compiled_program_path} not found")
self.rater.load(compiled_program_path)

@traceable(name="RetrievalJudge", run_type="llm")
@traceable(
name="RetrievalJudge", run_type="llm", metadata={"llm_provider": dspy.settings.lm}
)
async def aforward(self, query: str, documents: list[Document]) -> list[Document]:
"""Async judge."""
if not documents:
return documents

keep_docs, judged_indices, judged_payloads = self._split_templates_and_prepare_docs(documents)
keep_docs, judged_indices, judged_payloads = self._split_templates_and_prepare_docs(
documents
)

# TODO: can we use dspy.Parallel here instead of asyncio gather?
if judged_payloads:
Expand All @@ -114,7 +122,11 @@ async def judge_one(doc_string: str):
keep_docs=keep_docs,
)
except Exception as e:
logger.error("Retrieval judge failed (async), returning all docs", error=str(e), exc_info=True)
logger.error(
"Retrieval judge failed (async), returning all docs",
error=str(e),
exc_info=True,
)
return documents

return keep_docs
Expand Down Expand Up @@ -155,7 +167,9 @@ def _split_templates_and_prepare_docs(
return keep_docs, judged_indices, judged_payloads

@staticmethod
def _document_to_string(title: str, content: str, max_len: int = JUDGE_DOCUMENT_PREVIEW_MAX_LEN) -> str:
def _document_to_string(
title: str, content: str, max_len: int = JUDGE_DOCUMENT_PREVIEW_MAX_LEN
) -> str:
"""Build the string seen by the judge for one doc."""
preview = content[:max_len]
if len(content) > max_len:
Expand Down
93 changes: 48 additions & 45 deletions python/src/cairo_coder/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from contextlib import asynccontextmanager

import dspy
import langsmith as ls
import structlog
import uvicorn
from dspy.adapters import XMLAdapter
from dspy.adapters import ChatAdapter, XMLAdapter
from fastapi import Depends, FastAPI, Header, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -185,7 +186,7 @@ def __init__(
embedder = dspy.Embedder("gemini/gemini-embedding-001", dimensions=3072, batch_size=512)
dspy.configure(
lm=dspy.LM("gemini/gemini-flash-latest", max_tokens=30000, cache=False),
adapter=XMLAdapter(),
adapter=ChatAdapter(),
embedder=embedder,
callbacks=[AgentLoggingCallback()],
track_usage=True,
Expand Down Expand Up @@ -420,49 +421,51 @@ async def _stream_chat_completion(
content_buffer = ""

try:
async for event in agent.aforward_streaming(
query=query, chat_history=history, mcp_mode=mcp_mode
):
if event.type == "sources":
# Emit sources event for clients to display
sources_chunk = {
"type": "sources",
"data": event.data,
}
yield f"data: {json.dumps(sources_chunk)}\n\n"
elif event.type == "response":
content_buffer += event.data

# Send content chunk
chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created,
"model": "cairo-coder",
"choices": [
{"index": 0, "delta": {"content": event.data}, "finish_reason": None}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
elif event.type == "error":
# Emit an error as a final delta and stop
error_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created,
"model": "cairo-coder",
"choices": [
{
"index": 0,
"delta": {"content": f"\n\nError: {event.data}"},
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(error_chunk)}\n\n"
break
elif event.type == "end":
break
with ls.trace(name="RagPipelineStreaming", run_type="chain", inputs={"query": query, "chat_history": history, "mcp_mode": mcp_mode}) as rt:
async for event in agent.aforward_streaming(
query=query, chat_history=history, mcp_mode=mcp_mode
):
if event.type == "sources":
# Emit sources event for clients to display
sources_chunk = {
"type": "sources",
"data": event.data,
}
yield f"data: {json.dumps(sources_chunk)}\n\n"
elif event.type == "response":
content_buffer += event.data

# Send content chunk
chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created,
"model": "cairo-coder",
"choices": [
{"index": 0, "delta": {"content": event.data}, "finish_reason": None}
],
}
yield f"data: {json.dumps(chunk)}\n\n"
elif event.type == "error":
# Emit an error as a final delta and stop
error_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created,
"model": "cairo-coder",
"choices": [
{
"index": 0,
"delta": {"content": f"\n\nError: {event.data}"},
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(error_chunk)}\n\n"
break
elif event.type == "end":
break
rt.end(outputs={"output": content_buffer})

except Exception as e:
logger.error("Error during agent streaming", error=str(e), exc_info=True)
Expand Down