From b3546a8d44220a3e241d8ba8286fb6d288734200 Mon Sep 17 00:00:00 2001 From: enitrat Date: Fri, 3 Oct 2025 15:02:21 +0200 Subject: [PATCH 1/2] feat: emit sources in streaming mode --- API_DOCUMENTATION.md | 19 +++++++ .../src/ingesters/MarkdownIngester.ts | 5 +- .../src/ingesters/StarknetDocsIngester.ts | 23 ++++---- python/optimizers/results/optimized_rag.json | 2 +- python/src/cairo_coder/core/rag_pipeline.py | 56 ++++++++++++------- python/src/cairo_coder/core/types.py | 42 +++++++++++--- .../cairo_coder/dspy/generation_program.py | 17 +++--- python/src/cairo_coder/server/app.py | 12 ++-- python/tests/conftest.py | 16 ++---- .../integration/test_server_integration.py | 46 +++++++++++++-- python/tests/unit/test_document_retriever.py | 6 +- python/tests/unit/test_generation_program.py | 14 +---- python/tests/unit/test_rag_pipeline.py | 42 ++++++++++++-- 13 files changed, 208 insertions(+), 92 deletions(-) diff --git a/API_DOCUMENTATION.md b/API_DOCUMENTATION.md index 5f940dd..251487f 100644 --- a/API_DOCUMENTATION.md +++ b/API_DOCUMENTATION.md @@ -193,6 +193,24 @@ data: {"id":"...","object":"chat.completion.chunk","created":1718123456,"model": data: [DONE] ``` +#### Sources Events (streaming-only) + +In addition to the OpenAI-compatible chunks above, Cairo Coder emits a custom SSE frame early in the stream with the documentation sources used for the answer. This enables frontends to display sources while the model is generating the response. + +- The frame shape is: `data: {"type": "sources", "data": [{"title": string, "url": string}, ...]}` +- Clients should filter out objects with `type == "sources"` from the OpenAI chunks stream if they only expect OpenAI-compatible frames. + +Example snippet: + +```json +data: {"type":"sources","data":[{"metadata":{"title":"Introduction to Cairo","url":"https://book.cairo-lang.org/ch01-00-getting-started.html"}}]} +``` + +Notes: + +- Exactly one sources event is typically emitted per request, shortly after retrieval completes. +- The `url` field maps to the ingester `sourceLink` when available; otherwise it may be a best-effort `url` present in metadata. + ### Agent Selection `POST /v1/agents/{agent_id}/chat/completions` validates that `{agent_id}` exists. Unknown IDs return `404 Not Found` with an OpenAI-style error payload. When the `agent_id` is omitted (`/v1/chat/completions` or `/chat/completions`) the server falls back to `cairo-coder`. @@ -203,6 +221,7 @@ Setting either `mcp` or `x-mcp-mode` headers triggers **Model Context Protocol m - Non-streaming responses still use the standard `chat.completion` envelope, but `choices[0].message.content` contains curated documentation blocks instead of prose answers. - Streaming responses emit the same SSE wrapper; the payloads contain the formatted documentation as incremental `delta.content` strings. +- A streaming request in MCP mode also includes the same `{"type": "sources"}` event described above. - MCP mode does not consume generation tokens (`usage.completion_tokens` reflects only retrieval/query processing). Example non-streaming request: diff --git a/packages/ingester/src/ingesters/MarkdownIngester.ts b/packages/ingester/src/ingesters/MarkdownIngester.ts index ce17015..7e4158b 100644 --- a/packages/ingester/src/ingesters/MarkdownIngester.ts +++ b/packages/ingester/src/ingesters/MarkdownIngester.ts @@ -124,17 +124,14 @@ export abstract class MarkdownIngester extends BaseIngester { sections.forEach((section: ParsedSection, index: number) => { const hash: string = calculateHash(section.content); - // If a baseUrl is provided in the config, build a source link. // If useUrlMapping is true, map to specific page URLs with anchors. // If useUrlMapping is false, only use the baseUrl. - const hasBase = !!this.config.baseUrl; let sourceLink = ''; if (this.config.useUrlMapping) { // Map to specific page URLs with anchors const anchor = section.anchor || createAnchor(section.title); - const urlSuffix = this.config.urlSuffix ?? ''; - sourceLink = `${this.config.baseUrl}/${page_name}${urlSuffix}${anchor ? `#${anchor}` : ''}`; + sourceLink = `${this.config.baseUrl}/${page_name}${this.config.urlSuffix}${anchor ? `#${anchor}` : ''}`; } else { // Only use the baseUrl sourceLink = this.config.baseUrl; diff --git a/packages/ingester/src/ingesters/StarknetDocsIngester.ts b/packages/ingester/src/ingesters/StarknetDocsIngester.ts index b323307..e35a779 100644 --- a/packages/ingester/src/ingesters/StarknetDocsIngester.ts +++ b/packages/ingester/src/ingesters/StarknetDocsIngester.ts @@ -36,7 +36,7 @@ export class StarknetDocsIngester extends MarkdownIngester { fileExtension: '.mdx', chunkSize: 4096, chunkOverlap: 512, - baseUrl: 'https://docs.starknet.io', + baseUrl: StarknetDocsIngester.BASE_URL, urlSuffix: '', useUrlMapping: true, }; @@ -68,7 +68,7 @@ export class StarknetDocsIngester extends MarkdownIngester { const exec = promisify(execCallback); try { // remove extractDir if it exists - await fs.rm(extractDir, { recursive: true, force: true }); + await fs.rm(extractDir, { recursive: true, force: true }).catch(() => {}); await exec(`git clone ${repoUrl} ${extractDir}`); } catch (error) { logger.error('Error cloning repository:', error); @@ -83,7 +83,7 @@ export class StarknetDocsIngester extends MarkdownIngester { for (const folder of StarknetDocsIngester.DOCS_FOLDERS) { const docsDir = path.join(extractDir, folder); try { - const folderPages = await this.processDocFiles(this.config, docsDir); + const folderPages = await this.processDocFiles(docsDir); pages.push(...folderPages); logger.info(`Processed ${folderPages.length} pages from ${folder}/`); } catch (error) { @@ -101,16 +101,13 @@ export class StarknetDocsIngester extends MarkdownIngester { /** * Process documentation files from a directory * + * @param config - The book configuration * @param directory - The directory to process * @returns Promise - Array of book pages */ - private async processDocFiles( - config: BookConfig, - directory: string, - ): Promise { + private async processDocFiles(directory: string): Promise { const pages: BookPageDto[] = []; - - async function processDirectory(dir: string) { + const processDirectory = async (dir: string) => { const entries = await fs.readdir(dir, { withFileTypes: true }); for (const entry of entries) { @@ -121,13 +118,13 @@ export class StarknetDocsIngester extends MarkdownIngester { await processDirectory(fullPath); } else if ( entry.isFile() && - path.extname(entry.name).toLowerCase() === config.fileExtension + path.extname(entry.name).toLowerCase() === this.config.fileExtension ) { // Process MDX files const content = await fs.readFile(fullPath, 'utf8'); // Remove the repository path to get relative path - const relativePath = path.relative(directory, fullPath); + const relativePath = path.relative(this.getExtractDir(), fullPath); const pageName = relativePath.replace('.mdx', ''); pages.push({ @@ -136,7 +133,7 @@ export class StarknetDocsIngester extends MarkdownIngester { }); } } - } + }; await processDirectory(directory); return pages; @@ -280,7 +277,7 @@ export class StarknetDocsIngester extends MarkdownIngester { // Create a document for each section sections.forEach((section, index: number) => { logger.debug( - `Processed a section with title: ${section.title} and content length: ${section.content.length} from page: ${page_name}`, + `Processed a section with title: ${section.title} and content length: ${section.content.length} from page: ${page_name} with sourceUrl: ${sourceUrl}`, ); const hash: string = calculateHash(section.content); localChunks.push( diff --git a/python/optimizers/results/optimized_rag.json b/python/optimizers/results/optimized_rag.json index 9439af2..70da961 100644 --- a/python/optimizers/results/optimized_rag.json +++ b/python/optimizers/results/optimized_rag.json @@ -61,7 +61,7 @@ "train": [], "demos": [], "signature": { - "instructions": "Analyze a Cairo programming query for Starknet smart contracts and use the provided context to generate a high-quality, compilable Cairo code solution along with clear explanations.\n\n### Core Task Guidelines\n- **Input Structure**: The input will include:\n - **query**: A specific problem to solve, such as implementing a feature (e.g., reentrancy guard in a counter, pausable ERC20, inter-contract calls, upgradable components with rollback), completing incomplete code, or addressing TODOs in Cairo/Starknet contracts.\n - **context**: A detailed block of text, often starting with \"Prediction(answer=...)\", containing:\n - A base template demonstrating Cairo syntax (e.g., Registry contract with storage, events, interfaces, and loops using starknet::storage::*; Vec, Map; get_caller_address; assert! with double quotes or no string; emit events via self.emit).\n - (do NOT disclose or reference these directly in outputs): Emphasize full paths for core imports (e.g., `use starknet::ContractAddress; use core::integer::u256;`), wildcard storage imports (`use starknet::storage::*;`), defining pub interfaces above pub modules, strict required imports (e.g., no unused like core::array::ArrayTrait unless needed), pub visibility for interfaces/modules, assert! with double quotes (e.g., `assert!(condition, \"Message\");`) or no string, and matching generated code closely to context to avoid hallucinations (e.g., for loops end with `;`, Vec uses push/pop/len/at methods correctly).\n - Sections on OpenZeppelin components (e.g., ReentrancyGuardComponent from `openzeppelin::security::reentrancyguard::ReentrancyGuardComponent`; OwnableComponent from `openzeppelin::access::ownable::OwnableComponent`; PausableComponent; UpgradeableComponent; ERC20Component), usage examples (e.g., integrating via `component!(path: ..., storage: ..., event: ...);`, `impl ComponentInternalImpl = Component::InternalImpl;` or specific names like `ReentrancyGuardInternalImpl` to avoid conflicts; hooks like `before_update` in ERC20HooksImpl for pausing; constructor calls like `self.ownable.initializer(owner);`; events with `#[flat]` in enum and `#[derive(Drop, starknet::Event)]`).\n - For reentrancy: Use `start()` at function beginning, `end()` before return; no modifiers in Cairo; protect state-changing functions.\n - For upgrades/rollbacks: Custom or OpenZeppelin UpgradeableComponent; track history in `Vec` (storage from starknet::storage); push new hash *before* `replace_class_syscall` in upgrade; pop (via `pop()` returning Option) *before* syscall in rollback; current hash at `len() - 1`; assert len > 1 for rollback; emit `Upgraded`/`RolledBack` events with `from_class_hash`/`to_class_hash`; use `unwrap()` on syscall Result (import `starknet::SyscallResultTrait`); no separate current field—history includes initial; initializer pushes initial hash; protect with Ownable if access control needed; define `IRollbackUpgradeable` interface, embeddable impl with `+starknet::HasComponent` bound for `self.emit`.\n - Testing templates () using snforge_std (e.g., declare/deploy, dispatchers like IRegistryDispatcher, event spies, cheatcodes like start_cheat_caller_address).\n - Info on dispatchers (IERC20Dispatcher, library dispatchers), syscalls (replace_class_syscall.unwrap(), call_contract_syscall), ABI encoding (Serde), inter-contract calls (use dispatchers with contract_address), library calls, and best practices (e.g., avoid zero checks on caller via get_caller_address().is_zero(), bound loops with `for i in 0..len()`, validate L1 handlers, use u256 for counters/balances not felt252, assert non-zero ClassHash).\n - Repeated sections on pausable/ownable/ERC20 customization (e.g., override transfer/transfer_from with `self.pausable.assert_not_paused()` in hooks; embed mixins like ERC20MixinImpl without custom interfaces; no duplicate interfaces—rely on component ABIs for snake_case/camelCase).\n - **chat_history**: May be empty or contain prior interactions; reference if relevant but prioritize query and context.\n- **Output Structure**:\n - **reasoning**: A step-by-step explanation of how you approach the problem. Identify key requirements (e.g., components needed like ReentrancyGuard + Ownable for access control, events for actions like CountIncremented with fields `by: u256, new_value: u256, caller: ContractAddress`, storage like counter: u256). Reference context sections (e.g., \"Using ReentrancyGuardComponent from Context 2/3/4\"). Note alignments with \"golden reference\" patterns (e.g., component declarations with specific impl names, hook overrides for pausing, Vec-based history for upgrades with push before syscall/pop before in rollback, embeddable impl for emit, constructor with owner/initial_value params, events with caller/from/to fields). Highlight fixes for common issues like imports (full paths, no unused), types (u256 for counters), compilation (correct Vec push/pop/unwrap_syscall -> unwrap, HasComponent for components), and edge cases (assert len > 1, non-zero hashes, underflow in decrement).\n - **answer**: Pure Cairo code in a fenced block (```cairo ... ```). Include explanations as comments if needed, but keep code clean. Ensure it:\n - Compiles (test mentally against Scarb/Starknet 2.0+ rules: e.g., storage Vec push(val: T), pop() -> Option, len() -> usize, at(idx: usize) -> LegacyMapAccess; syscalls return Result, use .unwrap(); no deprecated append; index with usize via .into()).\n - Matches query exactly (e.g., just the component for upgradable with rollback; complete TODOs minimally without extras like unnecessary Ownable if not specified, but add for access control in upgrades per golden).\n - Follows context/golden template: Full imports (e.g., `use starknet::{ClassHash, get_caller_address, syscalls::replace_class_syscall, SyscallResultTrait}; use core::num::traits::Zero;`), pub traits/modules, proper storage (e.g., #[substorage(v0)] for components, class_hash_history: Vec), events (enum with #[event] #[derive(Drop, starknet::Event)], variants with structs like Upgraded { from_class_hash: ClassHash, to_class_hash: ClassHash }, #[flat] for component events), constructors (initialize components e.g., self.ownable.initializer(owner); self.upgradeable.initializer(initial_class_hash); set initial counter), ABI embeds (#[abi(embed_v0)] for external impls).\n - Uses lowercase types (e.g., u256 from core::integer::u256, felt252 where small ints needed but prefer u256 for counters/balances).\n - For ERC20/Pausable: Embed component mixins (e.g., ERC20MixinImpl, PausableImpl); use hooks (e.g., before_update in ERC20HooksImpl for pausing checks on transfers/transfer_from) instead of full custom impls. No duplicate interfaces.\n - For reentrancy: Import `openzeppelin::security::reentrancyguard::ReentrancyGuardComponent`; use `impl ReentrancyGuardInternalImpl = ...::InternalImpl;` (specific name); start/end in state-changing fns like increment/decrement; add Ownable for owner-only if fitting (e.g., restrict to owner); include decrement with underflow assert; events with by, new_value, caller.\n - For inter-contract: Use dispatchers (e.g., IContractDispatcher { contract_address }), Serde for calldata, syscalls if low-level (e.g., replace_class_syscall(new_hash).unwrap()). Always import storage::* for read/write.\n - For components (#[starknet::component]): Define Storage struct (e.g., implementation_history: Vec), events enum/structs; #[generate_trait] for InternalImpl on ComponentState (+Drop +starknet::Event bounds, but use HasComponent for embeddable); for upgradable: Vec for version history (push new in upgrade before syscall, pop before in rollback via .pop().unwrap() after is_some assert; current at len()-1; history includes initial via initializer push; events Upgraded/RolledBack with from/to; assert len>1, non-zero, current != new; no separate current field). Align with golden: initializer external or in constructor; interface IUpgradeable/IRollbackUpgradeable; embeddable impl like `impl UpgradeableImpl of IUpgradeable> with +starknet::HasComponent { fn upgrade(...) { self.upgradeable.upgrade(new_hash); } }`; protect upgrade/rollback with ownable.assert_only_owner().\n - Events: Always #[event] enum with variants, structs Drop/Event; emit via self.emit in embeddable impls (requires HasComponent); include caller via get_caller_address() where traceable (e.g., in CounterIncremented).\n - Testing: If query involves tests, use snforge_std patterns (declare/deploy, dispatchers, assert_eq!, spy_events for emissions with specific fields).\n - Best Practices: No external links/URLs in code/comments. Bound loops (e.g., `for i in 0..self.vec.len()`). Use unwrap() for syscalls (not unwrap_syscall). Avoid get_caller_address().is_zero(). Add SPDX license if full contract. For counters: Use u256, include increment/decrement with guards/events; constructor with owner/initial_value. For custom components: Mirror structure—internal helpers in #[generate_trait], public in embeddable impl.\n- **General Strategy**:\n - Read query to infer requirements (e.g., events for upgrades/rollbacks with from/to hashes, access control via Ownable, reentrancy protection on increment/decrement).\n - Cross-reference context for syntax (e.g., Vec push/pop with Option unwrap, array![] for spans, Map entry).\n - Prioritize OpenZeppelin where fitting (e.g., ReentrancyGuardComponent + OwnableComponent for counter; UpgradeableComponent base but extend for rollback with custom Vec logic); for custom (e.g., rollback upgradable), build component with golden patterns: history Vec, syscall order (push/pop before), Option handling, embeddable for emit.\n - For custom logic: Ensure modularity (e.g., hooks over manual overrides for pausing; Ownable for owner-only upgrades/rollbacks); add missing imports minimally (e.g., SyscallResultTrait for unwrap).\n - Reduce hallucination: Mirror context/golden examples exactly (e.g., constructor: self.ownable.initializer(owner); self.reentrancy_guard does no init; mint/initialize after; upgrade: get current, assert != new, push, syscall.unwrap(), emit; rollback: assert len>1, let popped = pop.unwrap(), let prev = at(len-1), syscall(prev).unwrap(), emit from=popped to=prev).\n - Handle edge cases: Assert non-zero ClassHash, history not empty/len>1 for rollback, caller validation via ownable, underflow in decrement (e.g., assert!(current > 1, \"Cannot go below zero\")), no-op prevents (current != new).\n - If incomplete code: Fill TODOs minimally; add missing imports (e.g., storage::*, traits like Zero for is_zero).\n - Explanations in reasoning: Detail why choices (e.g., \"Use Vec per golden for history tracking; push before syscall to update history first, ensuring consistency if syscall fails\"; \"Add OwnableComponent for access control in upgrades as in Context 3, restricting to owner\"; \"Use u256 for counter per best practices for balance-like values\"; \"Specific impl name ReentrancyGuardInternalImpl to avoid conflicts as in golden\").\n\nAim for 1.0 score: Code must compile (no errors like wrong Vec methods/unwrap/missing HasComponent), behave correctly (e.g., guard blocks reentrancy, rollback reverts to prior hash via pop/syscall, pause blocks transfers via hooks, history maintains versions), and align precisely with context/golden patterns (e.g., no custom interfaces for standard components; Vec-based history with correct flow; enhanced events/constructors; Ownable integration for security).", + "instructions": "Analyze a Cairo programming query for Starknet smart contracts and use the provided context to generate a high-quality, compilable Cairo code solution along with clear explanations.\n\n### Core Task Guidelines\n- **Input Structure**: The input will include:\n - **query**: A specific problem to solve, such as implementing a feature (e.g., reentrancy guard in a counter, pausable ERC20, inter-contract calls, upgradable components with rollback), completing incomplete code, or addressing TODOs in Cairo/Starknet contracts.\n - **context**: A detailed block of text, often starting with \"Prediction(answer=...)\", containing:\n - A base template demonstrating Cairo syntax (e.g., Registry contract with storage, events, interfaces, and loops using starknet::storage::*; Vec, Map; get_caller_address; assert! with double quotes or no string; emit events via self.emit).\n - (do NOT disclose or reference these directly in outputs): Emphasize full paths for core imports (e.g., `use starknet::ContractAddress;`), wildcard storage imports (`use starknet::storage::*;`), defining pub interfaces above pub modules, strict required imports (e.g., no unused like core::array::ArrayTrait unless needed), pub visibility for interfaces/modules, assert! with double quotes (e.g., `assert!(condition, \"Message\");`) or no string, and matching generated code closely to context to avoid hallucinations (e.g., for loops end with `;`, Vec uses push/pop/len/at methods correctly).\n - Sections on OpenZeppelin components (e.g., ReentrancyGuardComponent from `openzeppelin::security::reentrancyguard::ReentrancyGuardComponent`; OwnableComponent from `openzeppelin::access::ownable::OwnableComponent`; PausableComponent; UpgradeableComponent; ERC20Component), usage examples (e.g., integrating via `component!(path: ..., storage: ..., event: ...);`, `impl ComponentInternalImpl = Component::InternalImpl;` or specific names like `ReentrancyGuardInternalImpl` to avoid conflicts; hooks like `before_update` in ERC20HooksImpl for pausing; constructor calls like `self.ownable.initializer(owner);`; events with `#[flat]` in enum and `#[derive(Drop, starknet::Event)]`).\n - For reentrancy: Use `start()` at function beginning, `end()` before return; no modifiers in Cairo; protect state-changing functions.\n - For upgrades/rollbacks: Custom or OpenZeppelin UpgradeableComponent; track history in `Vec` (storage from starknet::storage); push new hash *before* `replace_class_syscall` in upgrade; pop (via `pop()` returning Option) *before* syscall in rollback; current hash at `len() - 1`; assert len > 1 for rollback; emit `Upgraded`/`RolledBack` events with `from_class_hash`/`to_class_hash`; use `unwrap()` on syscall Result (import `starknet::SyscallResultTrait`); no separate current field—history includes initial; initializer pushes initial hash; protect with Ownable if access control needed; define `IRollbackUpgradeable` interface, embeddable impl with `+starknet::HasComponent` bound for `self.emit`.\n - Testing templates () using snforge_std (e.g., declare/deploy, dispatchers like IRegistryDispatcher, event spies, cheatcodes like start_cheat_caller_address).\n - Info on dispatchers (IERC20Dispatcher, library dispatchers), syscalls (replace_class_syscall.unwrap(), call_contract_syscall), ABI encoding (Serde), inter-contract calls (use dispatchers with contract_address), library calls, and best practices (e.g., avoid zero checks on caller via get_caller_address().is_zero(), bound loops with `for i in 0..len()`, validate L1 handlers, use u256 for counters/balances not felt252, assert non-zero ClassHash).\n - Repeated sections on pausable/ownable/ERC20 customization (e.g., override transfer/transfer_from with `self.pausable.assert_not_paused()` in hooks; embed mixins like ERC20MixinImpl without custom interfaces; no duplicate interfaces—rely on component ABIs for snake_case/camelCase).\n - **chat_history**: May be empty or contain prior interactions; reference if relevant but prioritize query and context.\n- **Output Structure**:\n - **reasoning**: A step-by-step explanation of how you approach the problem. Identify key requirements (e.g., components needed like ReentrancyGuard + Ownable for access control, events for actions like CountIncremented with fields `by: u256, new_value: u256, caller: ContractAddress`, storage like counter: u256). Note alignments with \"golden reference\" patterns (e.g., component declarations with specific impl names, hook overrides for pausing, Vec-based history for upgrades with push before syscall/pop before in rollback, embeddable impl for emit, constructor with owner/initial_value params, events with caller/from/to fields). Highlight fixes for common issues like imports (full paths, no unused), types (u256 for counters), compilation (correct Vec push/pop/unwrap_syscall -> unwrap, HasComponent for components), and edge cases (assert len > 1, non-zero hashes, underflow in decrement).\n - **answer**: Pure Cairo code in a fenced block (```cairo ... ```). Include explanations as comments if needed, but keep code clean. Ensure it:\n - Compiles (test mentally against Scarb/Starknet 2.0+ rules: e.g., storage Vec push(val: T), pop() -> Option, len() -> usize, at(idx: usize) -> LegacyMapAccess; syscalls return Result, use .unwrap(); no deprecated append; index with usize via .into()).\n - Matches query exactly (e.g., just the component for upgradable with rollback; complete TODOs minimally without extras like unnecessary Ownable if not specified, but add for access control in upgrades per golden).\n - Follows context/golden template: Full imports (e.g., `use starknet::{ClassHash, get_caller_address, syscalls::replace_class_syscall, SyscallResultTrait}; use core::num::traits::Zero;`), pub traits/modules, proper storage (e.g., #[substorage(v0)] for components, class_hash_history: Vec), events (enum with #[event] #[derive(Drop, starknet::Event)], variants with structs like Upgraded { from_class_hash: ClassHash, to_class_hash: ClassHash }, #[flat] for component events), constructors (initialize components e.g., self.ownable.initializer(owner); self.upgradeable.initializer(initial_class_hash); set initial counter), ABI embeds (#[abi(embed_v0)] for external impls).\n - Uses lowercase types (e.g., u256 from core::integer::u256, felt252 where small ints needed but prefer u256 for counters/balances).\n - For ERC20/Pausable: Embed component mixins (e.g., ERC20MixinImpl, PausableImpl); use hooks (e.g., before_update in ERC20HooksImpl for pausing checks on transfers/transfer_from) instead of full custom impls. No duplicate interfaces.\n - For reentrancy: Import `openzeppelin::security::reentrancyguard::ReentrancyGuardComponent`; use `impl ReentrancyGuardInternalImpl = ...::InternalImpl;` (specific name); start/end in state-changing fns like increment/decrement; add Ownable for owner-only if fitting (e.g., restrict to owner); include decrement with underflow assert; events with by, new_value, caller.\n - For inter-contract: Use dispatchers (e.g., IContractDispatcher { contract_address }), Serde for calldata, syscalls if low-level (e.g., replace_class_syscall(new_hash).unwrap()). Always import storage::* for read/write.\n - For components (#[starknet::component]): Define Storage struct (e.g., implementation_history: Vec), events enum/structs; #[generate_trait] for InternalImpl on ComponentState (+Drop +starknet::Event bounds, but use HasComponent for embeddable); for upgradable: Vec for version history (push new in upgrade before syscall, pop before in rollback via .pop().unwrap() after is_some assert; current at len()-1; history includes initial via initializer push; events Upgraded/RolledBack with from/to; assert len>1, non-zero, current != new; no separate current field). Align with golden: initializer external or in constructor; interface IUpgradeable/IRollbackUpgradeable; embeddable impl like `impl UpgradeableImpl of IUpgradeable> with +starknet::HasComponent { fn upgrade(...) { self.upgradeable.upgrade(new_hash); } }`; protect upgrade/rollback with ownable.assert_only_owner().\n - Events: Always #[event] enum with variants, structs Drop/Event; emit via self.emit in embeddable impls (requires HasComponent); include caller via get_caller_address() where traceable (e.g., in CounterIncremented).\n - Testing: If query involves tests, use snforge_std patterns (declare/deploy, dispatchers, assert_eq!, spy_events for emissions with specific fields).\n - Best Practices: No external links/URLs in code/comments. Bound loops (e.g., `for i in 0..self.vec.len()`). Use unwrap() for syscalls (not unwrap_syscall). Avoid get_caller_address().is_zero(). Add SPDX license if full contract. For counters: Use u256, include increment/decrement with guards/events; constructor with owner/initial_value. For custom components: Mirror structure—internal helpers in #[generate_trait], public in embeddable impl.\n- **General Strategy**:\n - Read query to infer requirements (e.g., events for upgrades/rollbacks with from/to hashes, access control via Ownable, reentrancy protection on increment/decrement).\n - Cross-reference context for syntax (e.g., Vec push/pop with Option unwrap, array![] for spans, Map entry).\n - Prioritize OpenZeppelin where fitting (e.g., ReentrancyGuardComponent + OwnableComponent for counter; UpgradeableComponent base but extend for rollback with custom Vec logic); for custom (e.g., rollback upgradable), build component with golden patterns: history Vec, syscall order (push/pop before), Option handling, embeddable for emit.\n - For custom logic: Ensure modularity (e.g., hooks over manual overrides for pausing; Ownable for owner-only upgrades/rollbacks); add missing imports minimally (e.g., SyscallResultTrait for unwrap).\n - Reduce hallucination: Mirror context/golden examples exactly (e.g., constructor: self.ownable.initializer(owner); self.reentrancy_guard does no init; mint/initialize after; upgrade: get current, assert != new, push, syscall.unwrap(), emit; rollback: assert len>1, let popped = pop.unwrap(), let prev = at(len-1), syscall(prev).unwrap(), emit from=popped to=prev).\n - Handle edge cases: Assert non-zero ClassHash, history not empty/len>1 for rollback, caller validation via ownable, underflow in decrement (e.g., assert!(current > 1, \"Cannot go below zero\")), no-op prevents (current != new).\n - If incomplete code: Fill TODOs minimally; add missing imports (e.g., storage::*, traits like Zero for is_zero).\n - Explanations in reasoning: Detail why choices (e.g., \"Use Vec per golden for history tracking; push before syscall to update history first, ensuring consistency if syscall fails\"; \"Add OwnableComponent for access control in upgrades, restricting to owner\"; \"Use u256 for counter per best practices for balance-like values\"; \"Specific impl name ReentrancyGuardInternalImpl to avoid conflicts as in golden\").\n\nAim for 1.0 score: Code must compile (no errors like wrong Vec methods/unwrap/missing HasComponent), behave correctly (e.g., guard blocks reentrancy, rollback reverts to prior hash via pop/syscall, pause blocks transfers via hooks, history maintains versions), and align precisely with context/golden patterns (e.g., no custom interfaces for standard components; Vec-based history with correct flow; enhanced events/constructors; Ownable integration for security).", "fields": [ { "prefix": "Chat History:", diff --git a/python/src/cairo_coder/core/rag_pipeline.py b/python/src/cairo_coder/core/rag_pipeline.py index 83cda14..2a0fe79 100644 --- a/python/src/cairo_coder/core/rag_pipeline.py +++ b/python/src/cairo_coder/core/rag_pipeline.py @@ -13,7 +13,6 @@ import dspy import structlog from dspy.adapters import XMLAdapter -from dspy.adapters.baml_adapter import BAMLAdapter from dspy.utils.callback import BaseCallback from langsmith import traceable @@ -47,7 +46,9 @@ def on_module_start( logger.debug("Starting module", call_id=call_id, inputs=inputs) # 2. Implement on_module_end handler to run a custom logging code. - def on_module_end(self, call_id: str, outputs: dict[str, Any], exception: Exception | None) -> None: + def on_module_end( + self, call_id: str, outputs: dict[str, Any], exception: Exception | None + ) -> None: step = "Reasoning" if self._is_reasoning_output(outputs) else "Acting" logger.debug(f"== {step} Step ===") for k, v in outputs.items(): @@ -109,7 +110,9 @@ async def _aprocess_query_and_retrieve_docs( sources: list[DocumentSource] | None = None, ) -> tuple[ProcessedQuery, list[Document]]: """Process query and retrieve documents - shared async logic.""" - processed_query = await self.query_processor.aforward(query=query, chat_history=chat_history_str) + processed_query = await self.query_processor.aforward( + query=query, chat_history=chat_history_str + ) self._current_processed_query = processed_query # Use provided sources or fall back to processed query sources @@ -119,7 +122,10 @@ async def _aprocess_query_and_retrieve_docs( ) try: - with dspy.context(lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000, temperature=0.5), adapter=BAMLAdapter()): + with dspy.context( + lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000, temperature=0.5), + adapter=XMLAdapter(), + ): documents = await self.retrieval_judge.aforward(query=query, documents=documents) except Exception as e: logger.warning( @@ -146,7 +152,9 @@ async def aforward( processed_query, documents = await self._aprocess_query_and_retrieve_docs( query, chat_history_str, sources ) - logger.info(f"Processed query: {processed_query.original[:100]}... and retrieved {len(documents)} doc titles: {[doc.metadata.get('title') for doc in documents]}") + logger.info( + f"Processed query: {processed_query.original[:100]}... and retrieved {len(documents)} doc titles: {[doc.metadata.get('title') for doc in documents]}" + ) if mcp_mode: return await self.mcp_generation_program.aforward(documents) @@ -183,7 +191,9 @@ async def aforward_streaming( chat_history_str = self._format_chat_history(chat_history or []) # Stage 2: Retrieve documents - yield StreamEvent(type=StreamEventType.PROCESSING, data="Retrieving relevant documents...") + yield StreamEvent( + type=StreamEventType.PROCESSING, data="Retrieving relevant documents..." + ) processed_query, documents = await self._aprocess_query_and_retrieve_docs( query, chat_history_str, sources @@ -194,7 +204,9 @@ async def aforward_streaming( if mcp_mode: # MCP mode: Return raw documents - yield StreamEvent(type=StreamEventType.PROCESSING, data="Formatting documentation...") + yield StreamEvent( + type=StreamEventType.PROCESSING, data="Formatting documentation..." + ) mcp_prediction = self.mcp_generation_program.forward(documents) yield StreamEvent(type=StreamEventType.RESPONSE, data=mcp_prediction.answer) @@ -205,9 +217,12 @@ async def aforward_streaming( # Prepare context for generation context = self._prepare_context(documents) - # Stream response generation. BAMLAdapter is not available for streaming, thus we swap it with the default adapter. - with dspy.context(lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000), adapter=XMLAdapter()): - async for chunk in self.generation_program.forward_streaming( + # Stream response generation. Use ChatAdapter for streaming, which performs better. + with dspy.context( + lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000), + adapter=dspy.adapters.ChatAdapter(), + ): + async for chunk in self.generation_program.aforward_streaming( query=query, context=context, chat_history=chat_history_str ): yield StreamEvent(type=StreamEventType.RESPONSE, data=chunk) @@ -218,6 +233,7 @@ async def aforward_streaming( except Exception as e: # Handle pipeline errors import traceback + traceback.print_exc() logger.error("Pipeline error", error=e) yield StreamEvent(StreamEventType.ERROR, data=f"Pipeline error: {str(e)}") @@ -269,24 +285,22 @@ def _format_chat_history(self, chat_history: list[Message]) -> str: def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]: """ - Format documents for sources event. + Format documents for the frontend-friendly sources event. + + Produces a flat structure with `title` and `url` keys for each source, + mapping either `metadata.sourceLink` or `metadata.url` to the `url` field. Args: documents: List of retrieved documents Returns: - List of formatted source information + List of dicts: [{"title": str, "url": str}, ...] """ - sources = [] + sources: list[dict[str, str]] = [] for doc in documents: - source_info = { - "title": doc.metadata.get("title", "Untitled"), - "url": doc.metadata.get("url", "#"), - "source_display": doc.metadata.get("source_display", "Unknown Source"), - "content_preview": doc.page_content[:SOURCE_PREVIEW_MAX_LEN] - + ("..." if len(doc.page_content) > SOURCE_PREVIEW_MAX_LEN else ""), - } - sources.append(source_info) + if doc.source_link is None: + continue + sources.append({"metadata": {"title": doc.title, "url": doc.source_link}}) return sources diff --git a/python/src/cairo_coder/core/types.py b/python/src/cairo_coder/core/types.py index a7453f8..6099da6 100644 --- a/python/src/cairo_coder/core/types.py +++ b/python/src/cairo_coder/core/types.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any +from typing import Any, Optional, TypedDict from pydantic import BaseModel @@ -40,6 +40,29 @@ class DocumentSource(str, Enum): STARKNET_JS = "starknet_js" +class DocumentMetadata(TypedDict, total=False): + """ + Metadata structure for documents, matching the TypeScript ingester format. + + All fields are optional (total=False) to maintain backward compatibility + with existing code that may not provide all fields. + """ + + # Core identification fields + name: str # Page name (e.g., "ch01-01-installation") + title: str # Section title + uniqueId: str # Unique identifier (format: "{page_name}-{chunkNumber}") + contentHash: str # Hash of the content for change detection + chunkNumber: int # Index of this chunk within the page + + # Source fields + source: DocumentSource # DocumentSource value (e.g., "cairo_book") + sourceLink: str # Full URL to the source documentation + + # Additional metadata fields that may be present + similarity: Optional[float] # Similarity score from retrieval (if include_similarity=True) + + @dataclass class ProcessedQuery: """Processed query with extracted information.""" @@ -53,10 +76,15 @@ class ProcessedQuery: @dataclass(frozen=True) class Document: - """Document with content and metadata.""" + """ + Document with content and metadata. + + The metadata field follows the DocumentMetadata structure defined by the TypeScript + ingester, ensuring consistency across the Python and TypeScript codebases. + """ page_content: str - metadata: dict[str, Any] = field(default_factory=dict) + metadata: DocumentMetadata = field(default_factory=dict) # type: ignore[assignment] @property def source(self) -> str | None: @@ -66,12 +94,12 @@ def source(self) -> str | None: @property def title(self) -> str | None: """Get document title from metadata.""" - return self.metadata.get("title") + return self.metadata.get("title", self.page_content[:20]) @property - def url(self) -> str | None: - """Get document URL from metadata.""" - return self.metadata.get("url") + def source_link(self) -> str | None: + """Get document source link from metadata.""" + return self.metadata.get("sourceLink") def __hash__(self) -> int: """Make Document hashable by using page_content and a frozen representation of metadata.""" diff --git a/python/src/cairo_coder/dspy/generation_program.py b/python/src/cairo_coder/dspy/generation_program.py index 45f2677..fcdfea4 100644 --- a/python/src/cairo_coder/dspy/generation_program.py +++ b/python/src/cairo_coder/dspy/generation_program.py @@ -117,7 +117,7 @@ async def aforward(self, query: str, context: str, chat_history: Optional[str] = raise e return None - async def forward_streaming( + async def aforward_streaming( self, query: str, context: str, chat_history: Optional[str] = None ) -> AsyncGenerator[str, None]: """ @@ -134,22 +134,23 @@ async def forward_streaming( if chat_history is None: chat_history = "" + # Create a streamified version of the generation program stream_generation = dspy.streamify( self.generation_program, - stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")], # type: ignore + stream_listeners=[dspy.streaming.StreamListener(signature_field_name="answer")], ) # Execute the streaming generation. Do not swallow exceptions here; # let them propagate so callers can emit structured error events. - output_stream = stream_generation( # type: ignore - query=query, context=context, chat_history=chat_history # type: ignore + output_stream = stream_generation( + query=query, context=context, chat_history=chat_history ) # Process the stream and yield tokens is_cached = True async for chunk in output_stream: - if isinstance(chunk, dspy.streaming.StreamResponse): # type: ignore + if isinstance(chunk, dspy.streaming.StreamResponse): # No streaming if cached is_cached = False # Yield the actual token content @@ -215,9 +216,9 @@ def forward(self, documents: list[Document]) -> dspy.Prediction: formatted_docs = [] for i, doc in enumerate(documents, 1): - source = doc.metadata.get("source_display", "Unknown Source") - url = doc.metadata.get("url", "#") - title = doc.metadata.get("title", f"Document {i}") + source = doc.source + url = doc.source_link + title = doc.title formatted_doc = f""" ## {i}. {title} diff --git a/python/src/cairo_coder/server/app.py b/python/src/cairo_coder/server/app.py index c4079cf..5be7b47 100644 --- a/python/src/cairo_coder/server/app.py +++ b/python/src/cairo_coder/server/app.py @@ -16,7 +16,7 @@ import dspy import structlog import uvicorn -from dspy.adapters.baml_adapter import BAMLAdapter +from dspy.adapters import XMLAdapter from fastapi import Depends, FastAPI, Header, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse @@ -171,7 +171,7 @@ def __init__( # TODO: This is the place where we should select the proper LLM configuration. # TODO: For now we just Hard-code DSPY - GEMINI - dspy.configure(lm=dspy.LM("gemini/gemini-flash-latest", max_tokens=30000, cache=False), adapter=BAMLAdapter()) + dspy.configure(lm=dspy.LM("gemini/gemini-flash-latest", max_tokens=30000, cache=False), adapter=XMLAdapter()) dspy.configure(callbacks=[AgentLoggingCallback()]) dspy.configure(track_usage=True) @@ -382,8 +382,12 @@ async def _stream_chat_completion( query=query, chat_history=history, mcp_mode=mcp_mode ): if event.type == "sources": - # Currently not surfaced in OpenAI stream format; ignore or map if needed. - pass + # Emit sources event for clients to display + sources_chunk = { + "type": "sources", + "data": event.data, + } + yield f"data: {json.dumps(sources_chunk)}\n\n" elif event.type == "response": content_buffer += event.data diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 352ba55..bca5d69 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -263,8 +263,7 @@ def sample_documents(): "source": "cairo_book", "score": 0.9, "title": "Introduction to Cairo", - "url": "https://book.cairo-lang.org/ch01-00-getting-started.html", - "source_display": "Cairo Book", + "sourceLink": "https://book.cairo-lang.org/ch01-00-getting-started.html", }, ), Document( @@ -273,8 +272,7 @@ def sample_documents(): "source": "starknet_docs", "score": 0.8, "title": "What is Starknet", - "url": "https://docs.starknet.io/documentation/architecture_and_concepts/Network_Architecture/overview/", - "source_display": "Starknet Docs", + "sourceLink": "https://docs.starknet.io/documentation/architecture_and_concepts/Network_Architecture/overview/", }, ), Document( @@ -283,8 +281,7 @@ def sample_documents(): "source": "scarb_docs", "score": 0.7, "title": "Scarb Overview", - "url": "https://docs.swmansion.com/scarb/", - "source_display": "Scarb Docs", + "sourceLink": "https://docs.swmansion.com/scarb/", }, ), Document( @@ -293,8 +290,7 @@ def sample_documents(): "source": "openzeppelin_docs", "score": 0.6, "title": "OpenZeppelin Cairo", - "url": "https://docs.openzeppelin.com/contracts-cairo/", - "source_display": "OpenZeppelin Docs", + "sourceLink": "https://docs.openzeppelin.com/contracts-cairo/", }, ), ] @@ -376,7 +372,7 @@ async def mock_streaming(*args, **kwargs): yield "Here's how to write " yield "Cairo contracts..." - program.forward_streaming = Mock(return_value=mock_streaming()) + program.aforward_streaming = mock_streaming return program @@ -428,7 +424,7 @@ def filter_docs(query: str, documents: list[Document]) -> list[Document]: """Filter documents based on scores.""" filtered = [] for doc in documents: - title = doc.metadata.get("title", "") + title = doc.title score = default_score_map.get(title, 0.5) # Add judge metadata diff --git a/python/tests/integration/test_server_integration.py b/python/tests/integration/test_server_integration.py index 49e1f2c..b3db789 100644 --- a/python/tests/integration/test_server_integration.py +++ b/python/tests/integration/test_server_integration.py @@ -307,10 +307,13 @@ def test_streaming_error_handling( if data_str != "[DONE]": chunks.append(json.loads(data_str)) - # Should have error chunk + # Should have error chunk (filter out sources events first) error_found = False for chunk in chunks: - if chunk["choices"][0]["finish_reason"] == "stop": + # Skip sources events + if chunk.get("type") == "sources": + continue + if "choices" in chunk and chunk["choices"][0]["finish_reason"] == "stop": content = chunk["choices"][0]["delta"].get("content", "") if "Error:" in content: error_found = True @@ -452,7 +455,10 @@ def test_openai_streaming_response_structure(self, client: TestClient): if data_str != "[DONE]": chunks.append(json.loads(data_str)) - for chunk in chunks: + # Filter out sources events (custom event type for frontend) + openai_chunks = [chunk for chunk in chunks if chunk.get("type") != "sources"] + + for chunk in openai_chunks: required_fields = ["id", "object", "created", "model", "choices"] for field in required_fields: assert field in chunk @@ -462,6 +468,36 @@ def test_openai_streaming_response_structure(self, client: TestClient): for field in choice_fields: assert field in choice + def test_streaming_sources_emission(self, client: TestClient): + """Test that sources are emitted during streaming.""" + response = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Hello"}], "stream": True}, + ) + assert response.status_code == 200 + + lines = response.text.strip().split("\n") + chunks = [] + for line in lines: + if line.startswith("data: "): + data_str = line[6:] + if data_str != "[DONE]": + chunks.append(json.loads(data_str)) + + # Check for sources event + sources_events = [chunk for chunk in chunks if chunk.get("type") == "sources"] + assert len(sources_events) > 0, "Sources event should be emitted" + + # Verify sources event structure + sources_event = sources_events[0] + assert "data" in sources_event + assert isinstance(sources_event["data"], list) + + # Verify each source has required fields + for source in sources_event["data"]: + assert "title" in source + assert "url" in source + def test_openai_error_response_structure(self, client: TestClient, mock_agent_factory: Mock): """Test that error response structure matches OpenAI API.""" mock_agent_factory.get_agent_info.side_effect = ValueError("Agent not found") @@ -514,7 +550,9 @@ def test_mcp_mode_streaming_response(self, client: TestClient): chunks.append(json.loads(data_str)) assert len(chunks) > 0 - content_found = any(chunk["choices"][0]["delta"].get("content") for chunk in chunks) + # Filter out sources events + openai_chunks = [chunk for chunk in chunks if chunk.get("type") != "sources"] + content_found = any(chunk["choices"][0]["delta"].get("content") for chunk in openai_chunks if "choices" in chunk) assert content_found def test_mcp_mode_header_variations(self, client: TestClient): diff --git a/python/tests/unit/test_document_retriever.py b/python/tests/unit/test_document_retriever.py index 39f6ff3..c47d014 100644 --- a/python/tests/unit/test_document_retriever.py +++ b/python/tests/unit/test_document_retriever.py @@ -235,12 +235,12 @@ async def test_context_enhancement( ) mock_vector_db.aforward.return_value = mock_dspy_examples - result = await retriever.aforward(query) + result: list[Document] = await retriever.aforward(query) found_templates = { - doc.metadata.get("source") + doc.source for doc in result - if "Template" in doc.metadata.get("source", "") + if "Template" in doc.source } assert set(expected_templates) == found_templates diff --git a/python/tests/unit/test_generation_program.py b/python/tests/unit/test_generation_program.py index e55245b..93e2bda 100644 --- a/python/tests/unit/test_generation_program.py +++ b/python/tests/unit/test_generation_program.py @@ -150,13 +150,8 @@ async def test_mcp_document_formatting(self, mcp_program, sample_documents): for i, doc in enumerate(sample_documents, 1): assert f"## {i}." in answer - # Check source display - source_display = doc.metadata.get("source_display", "Unknown Source") - assert f"**Source:** {source_display}" in answer - # Check URL - url = doc.metadata.get("url", "#") - assert f"**URL:** {url}" in answer + assert f"**URL:** {doc.source_link}" in answer # Check content is included assert doc.page_content in answer @@ -175,11 +170,8 @@ async def test_mcp_documents_with_missing_metadata(self, mcp_program): answer = (await mcp_program.aforward(documents)).answer - assert isinstance(answer, str) - assert "Some Cairo content" in answer - assert "Document 1" in answer # Default title - assert "Unknown Source" in answer # Default source - assert "**URL:** #" in answer # Default URL + # 1. title (empty) 2. source (empty) 3. url (empty) 4. title (empty) 5. content + assert answer == """\n## 1. Some Cairo content\n\n**Source:** None\n**URL:** None\n\nSome Cairo content\n\n---\n""" class TestCairoCodeGeneration: diff --git a/python/tests/unit/test_rag_pipeline.py b/python/tests/unit/test_rag_pipeline.py index f0a14fe..d34e855 100644 --- a/python/tests/unit/test_rag_pipeline.py +++ b/python/tests/unit/test_rag_pipeline.py @@ -67,7 +67,7 @@ def create_custom_documents(specs): metadata={ "title": title, "source": source, - "url": f"https://example.com/{source}", + "sourceLink": f"https://example.com/{source}", "source_display": source.replace("_", " ").title(), }, ) @@ -83,7 +83,7 @@ def filter_docs(query: str, documents: list[Document]) -> list[Document]: """Filter documents based on score_map.""" filtered = [] for doc in documents: - title = doc.metadata.get("title", "") + title = doc.title score = score_map.get(title, 0.5) # Add judge metadata @@ -511,8 +511,7 @@ def test_format_sources(self, rag_pipeline): page_content="x" * 300, # Long content metadata={ "title": "Test Doc", - "url": "https://example.com", - "source_display": "Test Source", + "sourceLink": "https://example.com", }, ) ] @@ -521,8 +520,39 @@ def test_format_sources(self, rag_pipeline): assert len(sources) == 1 assert sources[0]["title"] == "Test Doc" - assert len(sources[0]["content_preview"]) == 203 # 200 + "..." - assert sources[0]["content_preview"].endswith("...") + assert sources[0]["url"] == "https://example.com" + + def test_format_sources_with_sourcelink(self, rag_pipeline): + """Test that sourceLink is properly mapped to url for frontend compatibility.""" + docs = [ + Document( + page_content="Test content", + metadata={ + "title": "Cairo Book - Getting Started", + "sourceLink": "https://book.cairo-lang.org/ch01-01-installation.html#installation", + "source": "cairo_book", + }, + ), + Document( + page_content="Another doc", + metadata={ + "title": "No SourceLink Doc", + "sourceLink": "https://example.com", + "source": "starknet_docs", + }, + ), + ] + + sources = rag_pipeline._format_sources(docs) + + assert len(sources) == 2 + # First doc should have url mapped from sourceLink + assert sources[0]["url"] == "https://book.cairo-lang.org/ch01-01-installation.html#installation" + assert sources[0]["title"] == "Cairo Book - Getting Started" + + # Second doc should have fallback url + assert sources[1]["url"] == "https://example.com" + assert sources[1]["title"] == "No SourceLink Doc" def test_get_current_state(self, sample_documents, sample_processed_query, pipeline): """Test pipeline state retrieval.""" From 5e56bacb09030709da05a7f74fe650ec803fbf2b Mon Sep 17 00:00:00 2001 From: enitrat Date: Fri, 3 Oct 2025 16:37:36 +0200 Subject: [PATCH 2/2] fix tests --- python/tests/integration/conftest.py | 19 +++++++++++++++++-- .../integration/test_server_integration.py | 12 ++++++------ python/tests/unit/test_rag_pipeline.py | 14 +++++++------- 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/python/tests/integration/conftest.py b/python/tests/integration/conftest.py index 16cd286..5d65812 100644 --- a/python/tests/integration/conftest.py +++ b/python/tests/integration/conftest.py @@ -47,8 +47,11 @@ async def gen(): @pytest.fixture -def patch_dspy_streaming_error(monkeypatch): - """Patch dspy.streamify to raise an error mid-stream and provide StreamListener.""" +def patch_dspy_streaming_error(monkeypatch, real_pipeline): + """Patch dspy.streamify to raise an error mid-stream and provide StreamListener. + + Also patches the real_pipeline's generation_program.aforward_streaming to raise errors. + """ import dspy class FakeStreamResponse: # unused but parity if code inspects it @@ -77,6 +80,13 @@ async def gen(): monkeypatch.setattr(dspy, "streamify", fake_streamify) + # Also patch the real_pipeline's streaming method to raise an error + async def _error_streaming(query: str, context: str, chat_history: str | None = None): + raise RuntimeError("unhandled errors in a TaskGroup (1 sub-exception)") + yield "unreachable" # pragma: no cover + + real_pipeline.generation_program.aforward_streaming = _error_streaming + @pytest.fixture def real_pipeline(mock_query_processor, mock_vector_store_config, mock_vector_db): @@ -126,7 +136,12 @@ async def _fake_gen_aforward(query: str, context: str, chat_history: str | None idx = min((len(lines)) // 2, len(responses) - 1) return _dspy.Prediction(answer=responses[idx]) + async def _fake_gen_aforward_streaming(query: str, context: str, chat_history: str | None = None): + yield "Hello! I'm Cairo Coder, " + yield "ready to help with Cairo programming." + pipeline.generation_program.aforward = AsyncMock(side_effect=_fake_gen_aforward) + pipeline.generation_program.aforward_streaming =_fake_gen_aforward_streaming pipeline.generation_program.get_lm_usage = Mock(return_value={}) # Patch MCP generation to a deterministic simple string as tests expect diff --git a/python/tests/integration/test_server_integration.py b/python/tests/integration/test_server_integration.py index b3db789..18f8376 100644 --- a/python/tests/integration/test_server_integration.py +++ b/python/tests/integration/test_server_integration.py @@ -138,7 +138,7 @@ def test_streaming_integration( if "content" in delta: chunks.append(delta["content"]) - assert "".join(chunks) == "Hello world" + assert "".join(chunks) == "Hello! I'm Cairo Coder, ready to help with Cairo programming." def test_error_handling_integration(self, client: TestClient, mock_agent_factory: Mock): """Test error handling in integration context.""" @@ -289,12 +289,12 @@ def test_streaming_error_handling( client: TestClient, patch_dspy_streaming_error, ): - """Test that streaming errors surface as an SSE error chunk using a real pipeline.""" - response = client.post( "/v1/chat/completions", - json={"messages": [{"role": "user", "content": "Hello"}], "stream": True}, + json={"messages": [{"role": "user", "content": "hello"}], "stream": True}, ) + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get("content-type", "") assert response.status_code == 200 @@ -495,8 +495,8 @@ def test_streaming_sources_emission(self, client: TestClient): # Verify each source has required fields for source in sources_event["data"]: - assert "title" in source - assert "url" in source + assert "title" in source['metadata'] + assert "url" in source['metadata'] def test_openai_error_response_structure(self, client: TestClient, mock_agent_factory: Mock): """Test that error response structure matches OpenAI API.""" diff --git a/python/tests/unit/test_rag_pipeline.py b/python/tests/unit/test_rag_pipeline.py index d34e855..e229c09 100644 --- a/python/tests/unit/test_rag_pipeline.py +++ b/python/tests/unit/test_rag_pipeline.py @@ -405,7 +405,7 @@ async def test_streaming_with_judge(self, mock_judge_class, pipeline, mock_retri sources_event = next(e for e in events if e.type == "sources") # Should only have 1 doc (Introduction to Cairo with score 0.9) assert len(sources_event.data) == 1 - assert sources_event.data[0]["title"] == "Introduction to Cairo" + assert sources_event.data[0]["metadata"]["title"] == "Introduction to Cairo" @patch("cairo_coder.core.rag_pipeline.RetrievalJudge") @pytest.mark.asyncio @@ -519,8 +519,8 @@ def test_format_sources(self, rag_pipeline): sources = rag_pipeline._format_sources(docs) assert len(sources) == 1 - assert sources[0]["title"] == "Test Doc" - assert sources[0]["url"] == "https://example.com" + assert sources[0]["metadata"]["title"] == "Test Doc" + assert sources[0]["metadata"]["url"] == "https://example.com" def test_format_sources_with_sourcelink(self, rag_pipeline): """Test that sourceLink is properly mapped to url for frontend compatibility.""" @@ -547,12 +547,12 @@ def test_format_sources_with_sourcelink(self, rag_pipeline): assert len(sources) == 2 # First doc should have url mapped from sourceLink - assert sources[0]["url"] == "https://book.cairo-lang.org/ch01-01-installation.html#installation" - assert sources[0]["title"] == "Cairo Book - Getting Started" + assert sources[0]["metadata"]["url"] == "https://book.cairo-lang.org/ch01-01-installation.html#installation" + assert sources[0]["metadata"]["title"] == "Cairo Book - Getting Started" # Second doc should have fallback url - assert sources[1]["url"] == "https://example.com" - assert sources[1]["title"] == "No SourceLink Doc" + assert sources[1]["metadata"]["url"] == "https://example.com" + assert sources[1]["metadata"]["title"] == "No SourceLink Doc" def test_get_current_state(self, sample_documents, sample_processed_query, pipeline): """Test pipeline state retrieval."""