From 7d1bfb9456931288f52c3f03b101ef5b7691bdf0 Mon Sep 17 00:00:00 2001 From: Ebba Alva Date: Sat, 20 Sep 2025 21:42:47 +0000 Subject: [PATCH 1/3] Implement web search and answering endpoint with citation support; add ranking and synthesis services for enhanced query responses --- api/endpoints/web_answering.py | 319 +++++++++++++++++++++++++++++++++ services/ranking_service.py | 190 ++++++++++++++++++++ services/synthesis_service.py | 224 +++++++++++++++++++++++ tests/test_tavily_search.py | 47 +++++ 4 files changed, 780 insertions(+) create mode 100644 api/endpoints/web_answering.py create mode 100644 services/ranking_service.py create mode 100644 services/synthesis_service.py create mode 100644 tests/test_tavily_search.py diff --git a/api/endpoints/web_answering.py b/api/endpoints/web_answering.py new file mode 100644 index 0000000..2496501 --- /dev/null +++ b/api/endpoints/web_answering.py @@ -0,0 +1,319 @@ +""" +Web search and answering endpoint for Perplexity-style answers with citations. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Any + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.security import APIKeyHeader +import os +import secrets +from pydantic import BaseModel, Field, validator + +from config.settings import settings +from services.web_search_service import WebSearchService +from services.web_fetch_service import WebFetchService +from services.ranking_service import RankingService +from services.synthesis_service import SynthesisService + +logger = logging.getLogger(__name__) + +# Reuse internal API key authentication +INTERNAL_API_KEY = os.environ.get("INTERNAL_API_KEY") +api_key_header = APIKeyHeader(name="X-Internal-API-Key", auto_error=False) + +# Create router +router = APIRouter(tags=["websearch"]) + +# Request model +class WebSearchAnswerRequest(BaseModel): + query: str + top_k_results: int = 8 + max_context_chars: int = 20000 + region: str = "auto" + language: str = "en" + style_profile_id: Optional[str] = None + answer_tokens: int = 800 + include_snippets: bool = True + timeout_seconds: int = 25 + + # Validators + @validator("top_k_results") + def validate_top_k_results(cls, v): + return max(3, min(15, v)) # Clamp between 3 and 15 + + @validator("max_context_chars") + def validate_max_context_chars(cls, v): + return max(1000, min(50000, v)) # Clamp between 1000 and 50000 + + @validator("answer_tokens") + def validate_answer_tokens(cls, v): + return max(100, min(2000, v)) # Clamp between 100 and 2000 + + @validator("timeout_seconds") + def validate_timeout_seconds(cls, v): + return max(5, min(60, v)) # Clamp between 5 and 60 seconds + +# Citation model +class Citation(BaseModel): + id: int + url: str + title: Optional[str] = None + site_name: Optional[str] = None + published_at: Optional[str] = None + snippet: Optional[str] = None + score: float + +# Timings model +class Timings(BaseModel): + search: int = 0 + fetch: int = 0 + rank: int = 0 + generate: int = 0 + total: int = 0 + +# Metadata model +class Meta(BaseModel): + engine: str + region: str + language: str + style_profile_id: Optional[str] = None + +# Response model +class WebSearchAnswerResponse(BaseModel): + query: str + answer_markdown: str + citations: List[Citation] + used_sources_count: int + timings_ms: Timings + meta: Meta + +# Dependency for internal authentication +def verify_internal_api_key(api_key: str = Depends(api_key_header)): + # Do not log or expose the secret + if not api_key or not INTERNAL_API_KEY or not secrets.compare_digest(api_key, INTERNAL_API_KEY): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing internal API key." + ) + +@router.post("/websearch/answer", response_model=WebSearchAnswerResponse, status_code=200) +async def web_search_answer( + request: WebSearchAnswerRequest, + req: Request, + api_key: str = Depends(verify_internal_api_key) +): + """ + Perform web search and generate an answer with citations. + """ + # Generate request ID + request_id = req.headers.get("X-Request-Id", str(uuid.uuid4())) + logger_ctx = {"request_id": request_id} + + # Log request + logger.info( + f"Web search answer request: query='{request.query}', " + f"top_k={request.top_k_results}, timeout={request.timeout_seconds}s", + extra=logger_ctx + ) + + # Initialize timings + timings = { + "search": 0, + "fetch": 0, + "rank": 0, + "generate": 0, + "total": 0 + } + + # Track overall execution time + start_time = time.time() + + try: + # Initialize services + # Initialize services + try: + search_service = WebSearchService() + fetch_service = WebFetchService() + ranking_service = RankingService() + synthesis_service = SynthesisService() + except Exception as e: + logger.error(f"Failed to initialize services: {str(e)}", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Service initialization error: {str(e)}" + ) # 1. Search phase + search_start = time.time() + try: + search_results = await asyncio.wait_for( + search_service.search( + query=request.query, + k=request.top_k_results, + region=request.region, + language=request.language, + timeout_seconds=min(request.timeout_seconds * 0.4, 10) # Allocate 40% of timeout + ), + timeout=request.timeout_seconds * 0.4 + ) + except asyncio.TimeoutError: + logger.warning(f"Search timed out after {request.timeout_seconds * 0.4}s", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Search phase timed out." + ) + + timings["search"] = int((time.time() - search_start) * 1000) + + if not search_results: + logger.warning("No search results found", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No search results found for the query." + ) + + logger.info(f"Search completed with {len(search_results)} results", extra=logger_ctx) + + # 2. Fetch phase + fetch_start = time.time() + try: + fetched_docs = await asyncio.wait_for( + fetch_service.fetch_search_results( + search_results=search_results, + timeout_seconds=min(request.timeout_seconds * 0.3, 8), # Allocate 30% of timeout + preserve_snippets=request.include_snippets + ), + timeout=request.timeout_seconds * 0.3 + ) + except asyncio.TimeoutError: + logger.warning(f"Fetch timed out after {request.timeout_seconds * 0.3}s", extra=logger_ctx) + # Continue with whatever we got + fetched_docs = [] + + timings["fetch"] = int((time.time() - fetch_start) * 1000) + + if not fetched_docs and not request.include_snippets: + logger.warning("No documents fetched successfully", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Failed to fetch content from search results." + ) + + logger.info(f"Fetch completed with {len(fetched_docs)} documents", extra=logger_ctx) + + # 3. Ranking phase + rank_start = time.time() + try: + ranked_evidence = await asyncio.wait_for( + ranking_service.rank_documents( + query=request.query, + docs=fetched_docs, + max_context_chars=request.max_context_chars + ), + timeout=request.timeout_seconds * 0.15 # Allocate 15% of timeout + ) + except asyncio.TimeoutError: + logger.warning(f"Ranking timed out after {request.timeout_seconds * 0.15}s", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Ranking phase timed out." + ) + + timings["rank"] = int((time.time() - rank_start) * 1000) + + if not ranked_evidence: + logger.warning("No evidence ranked for the query", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No relevant evidence found for the query." + ) + + logger.info(f"Ranking completed with {len(ranked_evidence)} evidence passages", extra=logger_ctx) + + # 4. Synthesis phase + generate_start = time.time() + try: + synthesis_result = await asyncio.wait_for( + synthesis_service.generate_answer( + query=request.query, + evidence_list=ranked_evidence, + answer_tokens=request.answer_tokens, + style_profile_id=request.style_profile_id + ), + timeout=request.timeout_seconds * 0.15 # Allocate remaining time + ) + except asyncio.TimeoutError: + logger.warning(f"Synthesis timed out after {request.timeout_seconds * 0.15}s", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Synthesis phase timed out." + ) + + timings["generate"] = int((time.time() - generate_start) * 1000) + + # Calculate total time + timings["total"] = int((time.time() - start_time) * 1000) + + # Check if any citations were used + if not synthesis_result.used_citation_ids: + logger.warning("No citations used in the answer", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No evidence-based answer could be produced within constraints." + ) + + # Create citations list, including only those actually used in the answer + citations = [] + for evidence in ranked_evidence: + if evidence.id in synthesis_result.used_citation_ids: + citations.append(Citation( + id=evidence.id, + url=evidence.url, + title=evidence.title, + site_name=evidence.site_name, + published_at=evidence.published_at, + snippet=evidence.passage[:200] + "..." if len(evidence.passage) > 200 else evidence.passage, + score=evidence.score + )) + + # Sort citations by ID for consistency + citations.sort(key=lambda c: c.id) + + # Build response + response = WebSearchAnswerResponse( + query=request.query, + answer_markdown=synthesis_result.answer_markdown, + citations=citations, + used_sources_count=len(citations), + timings_ms=Timings(**timings), + meta=Meta( + engine=settings.web_search_engine, + region=request.region, + language=request.language, + style_profile_id=request.style_profile_id + ) + ) + + logger.info( + f"Answer generated successfully in {timings['total']}ms with {len(citations)} citations", + extra=logger_ctx + ) + + return response + + except HTTPException: + # Re-raise HTTP exceptions + raise + except Exception as e: + # Log error and return 500 + logger.exception(f"Error processing web search answer: {str(e)}", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error processing request: {str(e)}" + ) \ No newline at end of file diff --git a/services/ranking_service.py b/services/ranking_service.py new file mode 100644 index 0000000..4a4bf7f --- /dev/null +++ b/services/ranking_service.py @@ -0,0 +1,190 @@ +""" +Service for ranking text passages by relevance to a query. +Uses embedding-based similarity scoring. +""" + +from __future__ import annotations + +import logging +import numpy as np +import time +from dataclasses import dataclass +from typing import List, Optional + +from services.chunking_service import chunk_text +from services.embedding_service import embed_texts +from services.web_fetch_service import FetchedDoc + +logger = logging.getLogger(__name__) + +@dataclass +class RankedEvidence: + """Represents a ranked evidence passage with metadata.""" + id: int # 1-based ID used for citations + url: str + title: Optional[str] = None + site_name: Optional[str] = None + passage: str = "" + score: float = 0.0 + published_at: Optional[str] = None # ISO format date string + +class RankingService: + """ + Service for ranking passages by relevance to a query. + Uses embedding-based similarity scoring. + """ + + def __init__(self, ideal_passage_length: int = 1000, overlap: int = 200): + """ + Initialize the ranking service. + + Args: + ideal_passage_length: Target length of passages in characters + overlap: Overlap between passages in characters + """ + self.ideal_passage_length = ideal_passage_length + self.overlap = overlap + + def _split_into_passages(self, doc: FetchedDoc) -> List[tuple[str, str, Optional[str], Optional[str], Optional[str]]]: + """ + Split a document into passages with metadata. + + Args: + doc: The document to split + + Returns: + List of tuples: (passage, url, title, site_name, published_at) + """ + if not doc.text: + return [] + + # Use chunking service to split the text + passages = chunk_text( + doc.text, + max_length=self.ideal_passage_length, + overlap=self.overlap + ) + + # Return passages with metadata + return [(p, doc.url, doc.title, doc.site_name, doc.published_at) for p in passages] + + def _compute_similarity(self, query_embedding: List[float], passage_embeddings: List[List[float]]) -> List[float]: + """ + Compute cosine similarity between query and passages. + + Args: + query_embedding: Query embedding vector + passage_embeddings: List of passage embedding vectors + + Returns: + List of similarity scores + """ + # Convert to numpy arrays for efficient computation + query_vec = np.array(query_embedding) + passage_vecs = np.array(passage_embeddings) + + # Normalize vectors + query_vec = query_vec / np.linalg.norm(query_vec) + passage_vecs = passage_vecs / np.linalg.norm(passage_vecs, axis=1, keepdims=True) + + # Compute cosine similarity + similarities = np.dot(passage_vecs, query_vec) + + return similarities.tolist() + + async def rank_documents(self, query: str, docs: List[FetchedDoc], + max_context_chars: int = 20000) -> List[RankedEvidence]: + """ + Split documents into passages, rank them by relevance to query, + and return top passages within context budget. + + Args: + query: The search query + docs: List of fetched documents + max_context_chars: Maximum total character budget for evidence + + Returns: + List of RankedEvidence objects + """ + if not query or not docs: + return [] + + # Split all documents into passages + all_passages = [] + for doc in docs: + passages = self._split_into_passages(doc) + all_passages.extend(passages) + + if not all_passages: + logger.warning("No passages extracted from documents") + return [] + + # Unpack passages and metadata + passages = [p[0] for p in all_passages] + urls = [p[1] for p in all_passages] + titles = [p[2] for p in all_passages] + site_names = [p[3] for p in all_passages] + published_dates = [p[4] for p in all_passages] + + # Get embeddings for query and passages + start_time = time.time() + try: + # Embed query + query_embedding = embed_texts([query])[0] + + # Embed passages + passage_embeddings = embed_texts(passages) + + # Compute similarity scores + similarity_scores = self._compute_similarity(query_embedding, passage_embeddings) + except Exception as e: + logger.error(f"Error during embedding or similarity computation: {str(e)}") + # Fallback: assign decreasing scores based on original order + similarity_scores = [1.0 - (i / len(passages)) for i in range(len(passages))] + + # Create scored passages + scored_passages = [] + for i, (passage, url, title, site_name, published_at, score) in enumerate( + zip(passages, urls, titles, site_names, published_dates, similarity_scores) + ): + # Create RankedEvidence with 1-based ID + evidence = RankedEvidence( + id=i + 1, # 1-based ID for citations + url=url, + title=title, + site_name=site_name, + passage=passage, + score=score, + published_at=published_at + ) + scored_passages.append(evidence) + + # Sort by score (highest first) + scored_passages.sort(key=lambda p: p.score, reverse=True) + + # Select top passages within context budget + selected_passages = [] + current_budget = 0 + + for passage in scored_passages: + passage_length = len(passage.passage) + if current_budget + passage_length <= max_context_chars: + selected_passages.append(passage) + current_budget += passage_length + else: + # If we can't fit the full passage, check if we can fit a truncated version + remaining_budget = max_context_chars - current_budget + if remaining_budget >= 200: # Only truncate if we can fit a meaningful chunk + truncated_passage = passage.passage[:remaining_budget - 3] + "..." + passage.passage = truncated_passage + selected_passages.append(passage) + break + + # Re-assign IDs to ensure they're sequential + for i, passage in enumerate(selected_passages): + passage.id = i + 1 + + duration_ms = int((time.time() - start_time) * 1000) + logger.info(f"Ranked {len(all_passages)} passages in {duration_ms}ms, selected {len(selected_passages)} within budget") + + return selected_passages \ No newline at end of file diff --git a/services/synthesis_service.py b/services/synthesis_service.py new file mode 100644 index 0000000..e5b8e43 --- /dev/null +++ b/services/synthesis_service.py @@ -0,0 +1,224 @@ +""" +Service for synthesizing answers from ranked evidence passages. +Generates Markdown answers with citations and references. +""" + +from __future__ import annotations + +import logging +import os +import re +import time +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Tuple, Any + +from openai import OpenAI +from services.ranking_service import RankedEvidence + +logger = logging.getLogger(__name__) + +@dataclass +class SynthesisResult: + """Result of synthesis with answer and metadata.""" + answer_markdown: str + used_citation_ids: Set[int] + prompt_tokens: int = 0 + completion_tokens: int = 0 + generation_ms: int = 0 + +class SynthesisService: + """ + Service for synthesizing answers from ranked evidence passages. + Generates answers with citation markers and reference lists. + """ + + def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4o"): + """ + Initialize the synthesis service. + + Args: + api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) + model: LLM model to use + """ + self.api_key = api_key or os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OpenAI API key is required but not provided") + + self.model = model + self.client = OpenAI(api_key=self.api_key) + + def _build_system_prompt(self) -> str: + """Build the system prompt for the synthesis.""" + return ( + "You are a factual technical writer. Your job is to answer questions based ONLY on the provided evidence. " + "Be accurate, clear, and concise. Never include information not supported by the evidence." + ) + + def _build_evidence_block(self, evidence_list: List[RankedEvidence]) -> str: + """ + Build a numbered evidence block with metadata. + + Args: + evidence_list: List of ranked evidence + + Returns: + Formatted evidence block + """ + if not evidence_list: + return "NO EVIDENCE PROVIDED." + + evidence_block = "# EVIDENCE\n\n" + + for evidence in evidence_list: + # Format title + title = evidence.title or "Untitled Document" + + # Format site info + site_info = f" ({evidence.site_name})" if evidence.site_name else "" + + # Format evidence block + evidence_block += f"[{evidence.id}] {title}{site_info}\n" + evidence_block += f"URL: {evidence.url}\n" + if evidence.published_at: + evidence_block += f"Published: {evidence.published_at}\n" + evidence_block += f"Excerpt: {evidence.passage}\n\n" + + return evidence_block + + def _build_user_prompt(self, query: str, evidence_list: List[RankedEvidence], + answer_tokens: int, style_profile_id: Optional[str]) -> str: + """ + Build the user prompt with query, evidence, and instructions. + + Args: + query: The query to answer + evidence_list: List of ranked evidence + answer_tokens: Target token count for the answer + style_profile_id: Optional style profile ID for tone/voice + + Returns: + Complete user prompt + """ + # Build evidence block + evidence_block = self._build_evidence_block(evidence_list) + + # Build instructions + instructions = ( + f"Question: {query}\n\n" + "Instructions:\n" + "1. Answer the question using ONLY the provided evidence.\n" + "2. Add citation markers [^i] after each sentence, where i is the evidence number.\n" + "3. If a claim lacks evidence, either mark it (citation needed) or omit it entirely.\n" + "4. End with a References section listing each source you actually cited:\n" + " [^i]: Title — URL (site), date\n" + f"5. Be concise yet thorough. Target length: {answer_tokens} tokens.\n" + ) + + # Add style profile if provided + if style_profile_id: + if style_profile_id == "academic": + instructions += "\nStyle: Academic. Formal, precise language with scholarly tone. Use technical terminology where appropriate.\n" + elif style_profile_id == "simple": + instructions += "\nStyle: Simple. Clear, straightforward language accessible to non-experts. Avoid jargon.\n" + elif style_profile_id == "journalist": + instructions += "\nStyle: Journalistic. Informative, engaging, with clear explanations. Prioritize key facts.\n" + elif style_profile_id == "technical": + instructions += "\nStyle: Technical. Detailed, precise, assuming domain knowledge. Include technical specifics.\n" + else: + instructions += f"\nStyle Profile: {style_profile_id}\n" + + # Combine everything + return f"{instructions}\n\n{evidence_block}" + + def _extract_citation_ids(self, text: str) -> Set[int]: + """ + Extract citation IDs from text with [^i] markers. + + Args: + text: Text with citation markers + + Returns: + Set of citation IDs + """ + # Find all [^i] citations + citation_markers = re.findall(r'\[\^(\d+)\]', text) + + # Convert to integers and return as a set + return {int(id) for id in citation_markers if id.isdigit()} + + async def generate_answer(self, query: str, evidence_list: List[RankedEvidence], + answer_tokens: int = 800, + style_profile_id: Optional[str] = None) -> SynthesisResult: + """ + Generate an answer from evidence with citation markers. + + Args: + query: The query to answer + evidence_list: List of ranked evidence + answer_tokens: Target token count for the answer + style_profile_id: Optional style profile ID + + Returns: + SynthesisResult with answer and metadata + """ + if not query or not evidence_list: + return SynthesisResult( + answer_markdown="No evidence available to answer this question.", + used_citation_ids=set() + ) + + # Build prompts + system_prompt = self._build_system_prompt() + user_prompt = self._build_user_prompt(query, evidence_list, answer_tokens, style_profile_id) + + start_time = time.time() + try: + # Call OpenAI API + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + temperature=0.2, + max_tokens=answer_tokens, + ) + + # Extract answer + answer = response.choices[0].message.content + + # Calculate token usage + prompt_tokens = response.usage.prompt_tokens + completion_tokens = response.usage.completion_tokens + + # Extract citation IDs + used_citation_ids = self._extract_citation_ids(answer) + + # Calculate generation time + generation_ms = int((time.time() - start_time) * 1000) + + logger.info(f"Generated answer in {generation_ms}ms, {completion_tokens} tokens, {len(used_citation_ids)} citations") + + return SynthesisResult( + answer_markdown=answer, + used_citation_ids=used_citation_ids, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + generation_ms=generation_ms + ) + + except Exception as e: + logger.error(f"Error generating answer: {str(e)}") + + # Return a fallback result + fallback_answer = ( + f"I apologize, but I encountered an error while generating an answer to your question: \"{query}\"\n\n" + f"Error: {str(e)}\n\n" + f"Please try again with a different query or contact support if the problem persists." + ) + + return SynthesisResult( + answer_markdown=fallback_answer, + used_citation_ids=set(), + generation_ms=int((time.time() - start_time) * 1000) + ) \ No newline at end of file diff --git a/tests/test_tavily_search.py b/tests/test_tavily_search.py new file mode 100644 index 0000000..0a6f31a --- /dev/null +++ b/tests/test_tavily_search.py @@ -0,0 +1,47 @@ +import sys +import asyncio +import os +from pathlib import Path + +# Import the WebSearchService +from services.web_search_service import WebSearchService + +async def test_tavily_search(): + try: + # Initialize the search service with Tavily (not dummy) + search_service = WebSearchService(provider_name="tavily") + + # Perform a search + query = "WHat is the result of the asia cup match between ban vs sri today?" + print(f"Searching for: {query}") + + results = await search_service.search( + query=query, + k=5, + region="auto", + language="en", + timeout_seconds=15 + ) + + print(f"\nFound {len(results)} results:") + for idx, result in enumerate(results, 1): + print(f"\n--- Result {idx} ---") + print(f"Title: {result.title}") + print(f"URL: {result.url}") + print(f"Site: {result.site_name}") + if result.snippet: + snippet = result.snippet[:100] + "..." if len(result.snippet) > 100 else result.snippet + print(f"Snippet: {snippet}") + print(f"Score: {result.score}") + + assert len(results) > 0, "No search results returned" + return True + except Exception as e: + print(f"Error: {str(e)}") + import traceback + traceback.print_exc() + return False + +if __name__ == "__main__": + success = asyncio.run(test_tavily_search()) + print(f"\nTest {'succeeded' if success else 'failed'}") \ No newline at end of file From 424114306335b34e93e7283e81d5bf4a8f06e798 Mon Sep 17 00:00:00 2001 From: Ebba Alva Date: Mon, 22 Sep 2025 16:44:39 +0000 Subject: [PATCH 2/3] Refactor web answering endpoint and ranking service for improved error handling and performance; update synthesis service for better logging and timeout management; enhance test for Tavily search service. --- api/endpoints/web_answering.py | 269 ++++++++++++++++----------------- services/ranking_service.py | 16 +- services/synthesis_service.py | 37 ++--- tests/test_tavily_search.py | 61 +++----- 4 files changed, 179 insertions(+), 204 deletions(-) diff --git a/api/endpoints/web_answering.py b/api/endpoints/web_answering.py index 2496501..60cfe53 100644 --- a/api/endpoints/web_answering.py +++ b/api/endpoints/web_answering.py @@ -8,20 +8,20 @@ import logging import time import uuid -from dataclasses import dataclass -from typing import Dict, List, Optional, Set, Any +from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.security import APIKeyHeader import os import secrets -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel from config.settings import settings from services.web_search_service import WebSearchService from services.web_fetch_service import WebFetchService from services.ranking_service import RankingService from services.synthesis_service import SynthesisService +from pydantic import field_validator logger = logging.getLogger(__name__) @@ -43,21 +43,25 @@ class WebSearchAnswerRequest(BaseModel): answer_tokens: int = 800 include_snippets: bool = True timeout_seconds: int = 25 - + # Validators - @validator("top_k_results") + @field_validator("top_k_results") + @classmethod def validate_top_k_results(cls, v): return max(3, min(15, v)) # Clamp between 3 and 15 - - @validator("max_context_chars") + + @field_validator("max_context_chars") + @classmethod def validate_max_context_chars(cls, v): return max(1000, min(50000, v)) # Clamp between 1000 and 50000 - - @validator("answer_tokens") + + @field_validator("answer_tokens") + @classmethod def validate_answer_tokens(cls, v): return max(100, min(2000, v)) # Clamp between 100 and 2000 - - @validator("timeout_seconds") + + @field_validator("timeout_seconds") + @classmethod def validate_timeout_seconds(cls, v): return max(5, min(60, v)) # Clamp between 5 and 60 seconds @@ -110,6 +114,7 @@ async def web_search_answer( req: Request, api_key: str = Depends(verify_internal_api_key) ): + start_time = time.time() """ Perform web search and generate an answer with citations. """ @@ -132,12 +137,6 @@ async def web_search_answer( "generate": 0, "total": 0 } - - # Track overall execution time - start_time = time.time() - - try: - # Initialize services # Initialize services try: search_service = WebSearchService() @@ -145,128 +144,120 @@ async def web_search_answer( ranking_service = RankingService() synthesis_service = SynthesisService() except Exception as e: - logger.error(f"Failed to initialize services: {str(e)}", extra=logger_ctx) + logger.error("Failed to initialize services: %s", e, extra=logger_ctx) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Service initialization error: {str(e)}" - ) # 1. Search phase - search_start = time.time() - try: - search_results = await asyncio.wait_for( - search_service.search( - query=request.query, - k=request.top_k_results, - region=request.region, - language=request.language, - timeout_seconds=min(request.timeout_seconds * 0.4, 10) # Allocate 40% of timeout - ), - timeout=request.timeout_seconds * 0.4 - ) - except asyncio.TimeoutError: - logger.warning(f"Search timed out after {request.timeout_seconds * 0.4}s", extra=logger_ctx) - raise HTTPException( - status_code=status.HTTP_504_GATEWAY_TIMEOUT, - detail="Search phase timed out." - ) - - timings["search"] = int((time.time() - search_start) * 1000) - - if not search_results: - logger.warning("No search results found", extra=logger_ctx) - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="No search results found for the query." - ) - - logger.info(f"Search completed with {len(search_results)} results", extra=logger_ctx) - - # 2. Fetch phase - fetch_start = time.time() - try: - fetched_docs = await asyncio.wait_for( - fetch_service.fetch_search_results( - search_results=search_results, - timeout_seconds=min(request.timeout_seconds * 0.3, 8), # Allocate 30% of timeout - preserve_snippets=request.include_snippets - ), - timeout=request.timeout_seconds * 0.3 - ) - except asyncio.TimeoutError: - logger.warning(f"Fetch timed out after {request.timeout_seconds * 0.3}s", extra=logger_ctx) - # Continue with whatever we got - fetched_docs = [] - - timings["fetch"] = int((time.time() - fetch_start) * 1000) - - if not fetched_docs and not request.include_snippets: - logger.warning("No documents fetched successfully", extra=logger_ctx) - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="Failed to fetch content from search results." - ) - - logger.info(f"Fetch completed with {len(fetched_docs)} documents", extra=logger_ctx) - - # 3. Ranking phase - rank_start = time.time() - try: - ranked_evidence = await asyncio.wait_for( - ranking_service.rank_documents( - query=request.query, - docs=fetched_docs, - max_context_chars=request.max_context_chars - ), - timeout=request.timeout_seconds * 0.15 # Allocate 15% of timeout - ) - except asyncio.TimeoutError: - logger.warning(f"Ranking timed out after {request.timeout_seconds * 0.15}s", extra=logger_ctx) - raise HTTPException( - status_code=status.HTTP_504_GATEWAY_TIMEOUT, - detail="Ranking phase timed out." - ) - - timings["rank"] = int((time.time() - rank_start) * 1000) - - if not ranked_evidence: - logger.warning("No evidence ranked for the query", extra=logger_ctx) - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="No relevant evidence found for the query." - ) - - logger.info(f"Ranking completed with {len(ranked_evidence)} evidence passages", extra=logger_ctx) - - # 4. Synthesis phase - generate_start = time.time() - try: - synthesis_result = await asyncio.wait_for( - synthesis_service.generate_answer( - query=request.query, - evidence_list=ranked_evidence, - answer_tokens=request.answer_tokens, - style_profile_id=request.style_profile_id - ), - timeout=request.timeout_seconds * 0.15 # Allocate remaining time - ) - except asyncio.TimeoutError: - logger.warning(f"Synthesis timed out after {request.timeout_seconds * 0.15}s", extra=logger_ctx) - raise HTTPException( - status_code=status.HTTP_504_GATEWAY_TIMEOUT, - detail="Synthesis phase timed out." - ) - - timings["generate"] = int((time.time() - generate_start) * 1000) - - # Calculate total time - timings["total"] = int((time.time() - start_time) * 1000) - - # Check if any citations were used - if not synthesis_result.used_citation_ids: - logger.warning("No citations used in the answer", extra=logger_ctx) - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="No evidence-based answer could be produced within constraints." - ) + detail="Service initialization error." + ) + + # 1. Search phase + search_start = time.time() + try: + search_results = await asyncio.wait_for( + search_service.search( + query=request.query, + k=request.top_k_results, + region=request.region, + language=request.language, + timeout_seconds=min(request.timeout_seconds * 0.4, 10) # Allocate 40% of timeout + ), + timeout=request.timeout_seconds * 0.4 + ) + except asyncio.TimeoutError: + logger.warning(f"Search timed out after {request.timeout_seconds * 0.4}s", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Search phase timed out." + ) + timings["search"] = int((time.time() - search_start) * 1000) + if not search_results: + logger.warning("No search results found", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No search results found for the query." + ) + logger.info(f"Search completed with {len(search_results)} results", extra=logger_ctx) + + # 2. Fetch phase + fetch_start = time.time() + try: + fetched_docs = await asyncio.wait_for( + fetch_service.fetch_search_results( + search_results=search_results, + timeout_seconds=min(request.timeout_seconds * 0.3, 8), # Allocate 30% of timeout + preserve_snippets=request.include_snippets + ), + timeout=request.timeout_seconds * 0.3 + ) + except asyncio.TimeoutError: + logger.warning(f"Fetch timed out after {request.timeout_seconds * 0.3}s", extra=logger_ctx) + # Continue with whatever we got + fetched_docs = [] + timings["fetch"] = int((time.time() - fetch_start) * 1000) + if not fetched_docs and not request.include_snippets: + logger.warning("No documents fetched successfully", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Failed to fetch content from search results." + ) + logger.info(f"Fetch completed with {len(fetched_docs)} documents", extra=logger_ctx) + + # 3. Ranking phase + rank_start = time.time() + try: + ranked_evidence = await asyncio.wait_for( + ranking_service.rank_documents( + query=request.query, + docs=fetched_docs, + max_context_chars=request.max_context_chars + ), + timeout=request.timeout_seconds * 0.15 # Allocate 15% of timeout + ) + except asyncio.TimeoutError: + logger.warning(f"Ranking timed out after {request.timeout_seconds * 0.15}s", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Ranking phase timed out." + ) + timings["rank"] = int((time.time() - rank_start) * 1000) + if not ranked_evidence: + logger.warning("No evidence ranked for the query", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No relevant evidence found for the query." + ) + logger.info(f"Ranking completed with {len(ranked_evidence)} evidence passages", extra=logger_ctx) + + # 4. Synthesis phase + generate_start = time.time() + try: + synthesis_result = await asyncio.wait_for( + synthesis_service.generate_answer( + query=request.query, + evidence_list=ranked_evidence, + answer_tokens=request.answer_tokens, + style_profile_id=request.style_profile_id + ), + timeout=request.timeout_seconds * 0.15 # Allocate remaining time + ) + except asyncio.TimeoutError: + logger.warning(f"Synthesis timed out after {request.timeout_seconds * 0.15}s", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Synthesis phase timed out." + ) + timings["generate"] = int((time.time() - generate_start) * 1000) + + # Calculate total time + timings["total"] = int((time.time() - start_time) * 1000) + + # Check if any citations were used + if not synthesis_result.used_citation_ids: + logger.warning("No citations used in the answer", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No evidence-based answer could be produced within constraints." + ) # Create citations list, including only those actually used in the answer citations = [] @@ -312,8 +303,8 @@ async def web_search_answer( raise except Exception as e: # Log error and return 500 - logger.exception(f"Error processing web search answer: {str(e)}", extra=logger_ctx) + logger.exception(f"Error processing web search answer", extra=logger_ctx) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error processing request: {str(e)}" + detail="Internal server error" ) \ No newline at end of file diff --git a/services/ranking_service.py b/services/ranking_service.py index 4a4bf7f..1cc313b 100644 --- a/services/ranking_service.py +++ b/services/ranking_service.py @@ -83,13 +83,19 @@ def _compute_similarity(self, query_embedding: List[float], passage_embeddings: query_vec = np.array(query_embedding) passage_vecs = np.array(passage_embeddings) - # Normalize vectors - query_vec = query_vec / np.linalg.norm(query_vec) - passage_vecs = passage_vecs / np.linalg.norm(passage_vecs, axis=1, keepdims=True) - + # Normalize vectors with zero-norm guards + q_norm = np.linalg.norm(query_vec) + if q_norm == 0: + return [0.0] * len(passage_embeddings) + query_vec = query_vec / q_norm + p_norms = np.linalg.norm(passage_vecs, axis=1, keepdims=True) + # Avoid divide-by-zero for degenerate embeddings + p_norms[p_norms == 0] = 1.0 + passage_vecs = passage_vecs / p_norms + # Compute cosine similarity similarities = np.dot(passage_vecs, query_vec) - + return similarities.tolist() async def rank_documents(self, query: str, docs: List[FetchedDoc], diff --git a/services/synthesis_service.py b/services/synthesis_service.py index e5b8e43..c251365 100644 --- a/services/synthesis_service.py +++ b/services/synthesis_service.py @@ -14,6 +14,7 @@ from openai import OpenAI from services.ranking_service import RankedEvidence +from openai import AsyncOpenAI logger = logging.getLogger(__name__) @@ -45,8 +46,8 @@ def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4o"): raise ValueError("OpenAI API key is required but not provided") self.model = model - self.client = OpenAI(api_key=self.api_key) - + # Set a sane default timeout; endpoint also wraps with asyncio.wait_for + self.client = AsyncOpenAI(api_key=self.api_key, timeout=30.0) def _build_system_prompt(self) -> str: """Build the system prompt for the synthesis.""" return ( @@ -188,17 +189,20 @@ async def generate_answer(self, query: str, evidence_list: List[RankedEvidence], answer = response.choices[0].message.content # Calculate token usage - prompt_tokens = response.usage.prompt_tokens - completion_tokens = response.usage.completion_tokens - + prompt_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(response.usage, "completion_tokens", 0) or 0 + # Extract citation IDs used_citation_ids = self._extract_citation_ids(answer) - + # Calculate generation time generation_ms = int((time.time() - start_time) * 1000) - - logger.info(f"Generated answer in {generation_ms}ms, {completion_tokens} tokens, {len(used_citation_ids)} citations") - + + logger.info( + "Generated answer in %dms, %d tokens, %d citations", + generation_ms, completion_tokens, len(used_citation_ids) + ) + return SynthesisResult( answer_markdown=answer, used_citation_ids=used_citation_ids, @@ -206,17 +210,16 @@ async def generate_answer(self, query: str, evidence_list: List[RankedEvidence], completion_tokens=completion_tokens, generation_ms=generation_ms ) - - except Exception as e: - logger.error(f"Error generating answer: {str(e)}") - + + except Exception: + logger.exception("Error generating answer") + # Return a fallback result fallback_answer = ( - f"I apologize, but I encountered an error while generating an answer to your question: \"{query}\"\n\n" - f"Error: {str(e)}\n\n" - f"Please try again with a different query or contact support if the problem persists." + "I encountered an internal error while generating the answer. " + "Please try again or contact support if the problem persists." ) - + return SynthesisResult( answer_markdown=fallback_answer, used_citation_ids=set(), diff --git a/tests/test_tavily_search.py b/tests/test_tavily_search.py index 0a6f31a..4ffbdd1 100644 --- a/tests/test_tavily_search.py +++ b/tests/test_tavily_search.py @@ -1,47 +1,22 @@ -import sys -import asyncio import os -from pathlib import Path - -# Import the WebSearchService +import asyncio +import pytest from services.web_search_service import WebSearchService -async def test_tavily_search(): - try: - # Initialize the search service with Tavily (not dummy) - search_service = WebSearchService(provider_name="tavily") - - # Perform a search - query = "WHat is the result of the asia cup match between ban vs sri today?" - print(f"Searching for: {query}") - - results = await search_service.search( - query=query, - k=5, - region="auto", - language="en", - timeout_seconds=15 - ) - - print(f"\nFound {len(results)} results:") - for idx, result in enumerate(results, 1): - print(f"\n--- Result {idx} ---") - print(f"Title: {result.title}") - print(f"URL: {result.url}") - print(f"Site: {result.site_name}") - if result.snippet: - snippet = result.snippet[:100] + "..." if len(result.snippet) > 100 else result.snippet - print(f"Snippet: {snippet}") - print(f"Score: {result.score}") - - assert len(results) > 0, "No search results returned" - return True - except Exception as e: - print(f"Error: {str(e)}") - import traceback - traceback.print_exc() - return False +@pytest.mark.asyncio +@pytest.mark.integration +async def test_tavily_search_returns_results(): + if not os.getenv("TAVILY_API_KEY"): + pytest.skip("TAVILY_API_KEY not set") + search_service = WebSearchService(provider_name="tavily") + results = await search_service.search( + query="OpenAI API documentation", + k=3, + region="auto", + language="en", + timeout_seconds=15, + ) + assert isinstance(results, list) + assert len(results) > 0 -if __name__ == "__main__": - success = asyncio.run(test_tavily_search()) - print(f"\nTest {'succeeded' if success else 'failed'}") \ No newline at end of file +# No __main__ runner in test files \ No newline at end of file From fad8a54844ec95d0a48083bdc634d24bc125290b Mon Sep 17 00:00:00 2001 From: Ebba Alva Date: Mon, 22 Sep 2025 23:02:06 +0600 Subject: [PATCH 3/3] Refactor citation handling and response building in web search answer endpoint for improved clarity and logging --- api/endpoints/web_answering.py | 87 +++++++++++++++------------------- 1 file changed, 38 insertions(+), 49 deletions(-) diff --git a/api/endpoints/web_answering.py b/api/endpoints/web_answering.py index 60cfe53..3894730 100644 --- a/api/endpoints/web_answering.py +++ b/api/endpoints/web_answering.py @@ -258,53 +258,42 @@ async def web_search_answer( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="No evidence-based answer could be produced within constraints." ) - - # Create citations list, including only those actually used in the answer - citations = [] - for evidence in ranked_evidence: - if evidence.id in synthesis_result.used_citation_ids: - citations.append(Citation( - id=evidence.id, - url=evidence.url, - title=evidence.title, - site_name=evidence.site_name, - published_at=evidence.published_at, - snippet=evidence.passage[:200] + "..." if len(evidence.passage) > 200 else evidence.passage, - score=evidence.score - )) - - # Sort citations by ID for consistency - citations.sort(key=lambda c: c.id) - - # Build response - response = WebSearchAnswerResponse( - query=request.query, - answer_markdown=synthesis_result.answer_markdown, - citations=citations, - used_sources_count=len(citations), - timings_ms=Timings(**timings), - meta=Meta( - engine=settings.web_search_engine, - region=request.region, - language=request.language, - style_profile_id=request.style_profile_id - ) - ) - - logger.info( - f"Answer generated successfully in {timings['total']}ms with {len(citations)} citations", - extra=logger_ctx + + # Create citations list, including only those actually used in the answer + citations = [] + for evidence in ranked_evidence: + if evidence.id in synthesis_result.used_citation_ids: + citations.append(Citation( + id=evidence.id, + url=evidence.url, + title=evidence.title, + site_name=evidence.site_name, + published_at=evidence.published_at, + snippet=evidence.passage[:200] + "..." if len(evidence.passage) > 200 else evidence.passage, + score=evidence.score + )) + + # Sort citations by ID for consistency + citations.sort(key=lambda c: c.id) + + # Build response + response = WebSearchAnswerResponse( + query=request.query, + answer_markdown=synthesis_result.answer_markdown, + citations=citations, + used_sources_count=len(citations), + timings_ms=Timings(**timings), + meta=Meta( + engine=settings.web_search_engine, + region=request.region, + language=request.language, + style_profile_id=request.style_profile_id ) - - return response - - except HTTPException: - # Re-raise HTTP exceptions - raise - except Exception as e: - # Log error and return 500 - logger.exception(f"Error processing web search answer", extra=logger_ctx) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Internal server error" - ) \ No newline at end of file + ) + + logger.info( + f"Answer generated successfully in {timings['total']}ms with {len(citations)} citations", + extra=logger_ctx + ) + + return response \ No newline at end of file