Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
"tenacity>=8.0.0",
"prometheus-client>=0.20.0",
"python-multipart>=0.0.6",
"dspy>=2.6.27",
"dspy>=3.0.0",
"psycopg2>=2.9.10",
"pgvector>=0.4.1",
"marimo>=0.14.11",
Expand Down
63 changes: 4 additions & 59 deletions python/src/cairo_coder/core/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any

import dspy
from dspy.adapters.baml_adapter import BAMLAdapter
from dspy.utils.callback import BaseCallback
from langsmith import traceable

Expand Down Expand Up @@ -110,38 +111,6 @@ def __init__(self, config: RagPipelineConfig):
self._current_processed_query: ProcessedQuery | None = None
self._current_documents: list[Document] = []

def _process_query_and_retrieve_docs(
self,
query: str,
chat_history_str: str,
sources: list[DocumentSource] | None = None,
) -> tuple[ProcessedQuery, list[Document]]:
processed_query = self.query_processor.forward(query=query, chat_history=chat_history_str)
self._current_processed_query = processed_query

# Use provided sources or fall back to processed query sources
retrieval_sources = sources or processed_query.resources
documents = self.document_retriever.forward(
processed_query=processed_query, sources=retrieval_sources
)

# Apply LLM judge if enabled
try:
with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash-lite", max_tokens=10000)):
documents = self.retrieval_judge.forward(query=query, documents=documents)
except Exception as e:
logger.warning(
"Retrieval judge failed (sync), using all documents",
error=str(e),
exc_info=True,
)
# documents already contains all retrieved docs, no action needed

self._current_documents = documents

return processed_query, documents


async def _aprocess_query_and_retrieve_docs(
self,
query: str,
Expand All @@ -159,7 +128,7 @@ async def _aprocess_query_and_retrieve_docs(
)

try:
with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash-lite", max_tokens=10000)):
with dspy.context(lm=dspy.LM("gemini/gemini-2.5-flash-lite", max_tokens=10000), adapter=BAMLAdapter()):
documents = await self.retrieval_judge.aforward(query=query, documents=documents)
except Exception as e:
logger.warning(
Expand All @@ -173,30 +142,6 @@ async def _aprocess_query_and_retrieve_docs(

return processed_query, documents

# Waits for streaming to finish before returning the response
@traceable(name="RagPipeline", run_type="chain")
def forward(
self,
query: str,
chat_history: list[Message] | None = None,
mcp_mode: bool = False,
sources: list[DocumentSource] | None = None,
) -> dspy.Prediction:
chat_history_str = self._format_chat_history(chat_history or [])
processed_query, documents = self._process_query_and_retrieve_docs(
query, chat_history_str, sources
)
logger.info(f"Processed query: {processed_query.original} and retrieved {len(documents)} doc titles: {[doc.metadata.get('title') for doc in documents]}")

if mcp_mode:
return self.mcp_generation_program.forward(documents)

context = self._prepare_context(documents, processed_query)

return self.generation_program.forward(
query=query, context=context, chat_history=chat_history_str
)

# Waits for streaming to finish before returning the response
@traceable(name="RagPipeline", run_type="chain")
async def aforward(
Expand All @@ -213,15 +158,15 @@ async def aforward(
logger.info(f"Processed query: {processed_query.original[:100]}... and retrieved {len(documents)} doc titles: {[doc.metadata.get('title') for doc in documents]}")

if mcp_mode:
return self.mcp_generation_program.forward(documents)
return await self.mcp_generation_program.aforward(documents)

context = self._prepare_context(documents, processed_query)

return await self.generation_program.aforward(
query=query, context=context, chat_history=chat_history_str
)

async def forward_streaming(
async def aforward_streaming(
self,
query: str,
chat_history: list[Message] | None = None,
Expand Down
5 changes: 4 additions & 1 deletion python/src/cairo_coder/dspy/document_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import dspy
from cairo_coder.core.config import VectorStoreConfig
from cairo_coder.core.types import Document, DocumentSource, ProcessedQuery
from dspy.retrieve.pgvector_rm import PgVectorRM
from cairo_coder.dspy.pgvector_rm import PgVectorRM

logger = structlog.get_logger()

Expand Down Expand Up @@ -134,6 +134,9 @@
- Always import strictly the required types in the module the interface is implemented in.
- Always import the required types of the contract inside the contract module.
- Always make the interface and the contract module 'pub'
- In assert! macros, the string is using double \" quotes, not \'; e.g.: assert!(caller == owner,
"Caller is not owner"). You can also not use any string literals in assert! macros.
- Always match the generated code against context-provided code to reduce hallucination risk.
</important_rules>

The content inside the <contract> tag is the contract code for a 'Registry' contract, demonstrating
Expand Down
36 changes: 6 additions & 30 deletions python/src/cairo_coder/dspy/generation_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,36 +98,6 @@ def get_lm_usage(self) -> dict[str, int]:
"""
return self.generation_program.get_lm_usage()

@traceable(name="GenerationProgram", run_type="llm")
def forward(self, query: str, context: str, chat_history: Optional[str] = None) -> dspy.Prediction | None :
"""
Generate Cairo code response based on query and context.

Args:
query: User's Cairo programming question
context: Retrieved documentation and examples
chat_history: Previous conversation context (optional)

Returns:
Generated Cairo code response with explanations
"""
if chat_history is None:
chat_history = ""

# Execute the generation program
max_retries = 3
for attempt in range(max_retries):
try:
return self.generation_program.forward(query=query, context=context, chat_history=chat_history)
except AdapterParseError as e:
if attempt < max_retries - 1:
continue
code = self._try_extract_code_from_response(e.lm_response)
if code:
return dspy.Prediction(answer=code)
raise e
return None

@traceable(name="GenerationProgram", run_type="llm")
async def aforward(self, query: str, context: str, chat_history: Optional[str] = None) -> dspy.Prediction | None :
"""
Expand Down Expand Up @@ -269,6 +239,12 @@ def forward(self, documents: list[Document]) -> dspy.Prediction:

return dspy.Prediction(answer='\n'.join(formatted_docs))

async def aforward(self, documents: list[Document]) -> dspy.Prediction:
"""
Format documents for MCP mode response.
"""
return self.forward(documents)

def get_lm_usage(self) -> dict[str, int]:
"""
Get the total number of tokens used by the LLM.
Expand Down
157 changes: 157 additions & 0 deletions python/src/cairo_coder/dspy/pgvector_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import warnings
from collections.abc import Callable
from typing import Optional

import dspy

try:
import psycopg2
from pgvector.psycopg2 import register_vector
from psycopg2 import sql
except ImportError as e:
raise ImportError(
"The 'pgvector' extra is required to use PgVectorRM. Install it with `pip install dspy-ai[pgvector]`. Also, try `pip install pgvector psycopg2`.",
) from e
try:
import openai
except ImportError:
warnings.warn(
"`openai` is not installed. Install it with `pip install openai` to use OpenAI embedding models.",
stacklevel=2, category=ImportWarning,
)


class PgVectorRM(dspy.Retrieve):
"""
Implements a retriever that (as the name suggests) uses pgvector to retrieve passages,
using a raw SQL query and a postgresql connection managed by psycopg2.

It needs to register the pgvector extension with the psycopg2 connection

Returns a list of dspy.Example objects

Args:
db_url (str): A PostgreSQL database URL in psycopg2's DSN format
pg_table_name (Optional[str]): name of the table containing passages
openai_client (openai.OpenAI): OpenAI client to use for computing query embeddings. Either openai_client or embedding_func must be provided.
embedding_func (Callable): A function to use for computing query embeddings. Either openai_client or embedding_func must be provided.
content_field (str = "text"): Field containing the passage text. Defaults to "text"
k (Optional[int]): Default number of top passages to retrieve. Defaults to 20
embedding_field (str = "embedding"): Field containing passage embeddings. Defaults to "embedding"
fields (List[str] = ['text']): Fields to retrieve from the table. Defaults to "text"
embedding_model (str = "text-embedding-ada-002"): Field containing the OpenAI embedding model to use. Defaults to "text-embedding-ada-002"

Examples:
Below is a code snippet that shows how to use PgVector as the default retriever

```python
import dspy
import openai
import psycopg2

openai.api_key = os.environ.get("OPENAI_API_KEY", None)
openai_client = openai.OpenAI()

llm = dspy.OpenAI(model="gpt-3.5-turbo")

DATABASE_URL should be in the format postgresql://user:password@host/database
db_url=os.getenv("DATABASE_URL")

retriever_model = PgVectorRM(conn, openai_client=openai_client, "paragraphs", fields=["text", "document_id"], k=20)
dspy.settings.configure(lm=llm, rm=retriever_model)
```

Below is a code snippet that shows how to use PgVector in the forward() function of a module
```python
self.retrieve = PgVectorRM(db_url, openai_client=openai_client, "paragraphs", fields=["text", "document_id"], k=20)
```
"""

def __init__(
self,
db_url: str,
pg_table_name: str,
openai_client: Optional[openai.OpenAI] = None,
embedding_func: Optional[Callable] = None,
k: int = 20,
embedding_field: str = "embedding",
fields: Optional[list[str]] = None,
content_field: str = "text",
embedding_model: str = "text-embedding-ada-002",
include_similarity: bool = False,
):
"""
k = 20 is the number of paragraphs to retrieve
"""
assert (
openai_client or embedding_func
), "Either openai_client or embedding_func must be provided."
self.openai_client = openai_client
self.embedding_func = embedding_func

self.conn = psycopg2.connect(db_url)
register_vector(self.conn)
self.pg_table_name = pg_table_name
self.fields = fields or ["text"]
self.content_field = content_field
self.embedding_field = embedding_field
self.embedding_model = embedding_model
self.include_similarity = include_similarity

super().__init__(k=k)

def forward(self, query: str, k: int = None):
"""Search with PgVector for k top passages for query using cosine similarity

Args:
query (str): The query to search for
k (int): The number of top passages to retrieve. Defaults to the value set in the constructor.
Returns:
dspy.Prediction: an object containing the retrieved passages.
"""
# Embed query
query_embedding = self._get_embeddings(query)

retrieved_docs = []

fields = sql.SQL(",").join([sql.Identifier(f) for f in self.fields])
if self.include_similarity:
similarity_field = sql.SQL(",") + sql.SQL(
"1 - ({embedding_field} <=> %s::vector) AS similarity",
).format(embedding_field=sql.Identifier(self.embedding_field))
fields += similarity_field
args = (query_embedding, query_embedding, k if k else self.k)
else:
args = (query_embedding, k if k else self.k)

sql_query = sql.SQL(
"select {fields} from {table} order by {embedding_field} <=> %s::vector limit %s",
).format(
fields=fields,
table=sql.Identifier(self.pg_table_name),
embedding_field=sql.Identifier(self.embedding_field),
)

with self.conn as conn, conn.cursor() as cur:
cur.execute(sql_query, args)
rows = cur.fetchall()
columns = [descrip[0] for descrip in cur.description]
for row in rows:
data = dict(zip(columns, row, strict=False))
data["long_text"] = data[self.content_field]
retrieved_docs.append(dspy.Example(**data))
# Return Prediction
return retrieved_docs

def _get_embeddings(self, query: str) -> list[float]:
if self.openai_client is not None:
return (
self.openai_client.embeddings.create(
model=self.embedding_model,
input=query,
encoding_format="float",
)
.data[0]
.embedding
)
return self.embedding_func(query)
29 changes: 0 additions & 29 deletions python/src/cairo_coder/dspy/query_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,35 +109,6 @@ def __init__(self):
"foundry",
}

@traceable(name="QueryProcessorProgram", run_type="llm")
def forward(self, query: str, chat_history: Optional[str] = None) -> ProcessedQuery:
"""
Process a user query into a structured format for document retrieval.

Args:
query: The user's Cairo/Starknet programming question
chat_history: Previous conversation context (optional)

Returns:
ProcessedQuery with search terms, resource identification, and categorization
"""
# Execute the DSPy retrieval program
result = self.retrieval_program.forward(query=query, chat_history=chat_history)

# Parse and validate the results
search_queries = result.search_queries
resources = self._validate_resources(result.resources)

# Build structured query result
return ProcessedQuery(
original=query,
search_queries=search_queries,
reasoning=result.reasoning,
is_contract_related=self._is_contract_query(query),
is_test_related=self._is_test_query(query),
resources=resources,
)

@traceable(name="QueryProcessorProgram", run_type="llm")
async def aforward(self, query: str, chat_history: Optional[str] = None) -> ProcessedQuery:
"""
Expand Down
Loading