diff --git a/api/endpoints/web_answering.py b/api/endpoints/web_answering.py new file mode 100644 index 0000000..3894730 --- /dev/null +++ b/api/endpoints/web_answering.py @@ -0,0 +1,299 @@ +""" +Web search and answering endpoint for Perplexity-style answers with citations. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.security import APIKeyHeader +import os +import secrets +from pydantic import BaseModel + +from config.settings import settings +from services.web_search_service import WebSearchService +from services.web_fetch_service import WebFetchService +from services.ranking_service import RankingService +from services.synthesis_service import SynthesisService +from pydantic import field_validator + +logger = logging.getLogger(__name__) + +# Reuse internal API key authentication +INTERNAL_API_KEY = os.environ.get("INTERNAL_API_KEY") +api_key_header = APIKeyHeader(name="X-Internal-API-Key", auto_error=False) + +# Create router +router = APIRouter(tags=["websearch"]) + +# Request model +class WebSearchAnswerRequest(BaseModel): + query: str + top_k_results: int = 8 + max_context_chars: int = 20000 + region: str = "auto" + language: str = "en" + style_profile_id: Optional[str] = None + answer_tokens: int = 800 + include_snippets: bool = True + timeout_seconds: int = 25 + + # Validators + @field_validator("top_k_results") + @classmethod + def validate_top_k_results(cls, v): + return max(3, min(15, v)) # Clamp between 3 and 15 + + @field_validator("max_context_chars") + @classmethod + def validate_max_context_chars(cls, v): + return max(1000, min(50000, v)) # Clamp between 1000 and 50000 + + @field_validator("answer_tokens") + @classmethod + def validate_answer_tokens(cls, v): + return max(100, min(2000, v)) # Clamp between 100 and 2000 + + @field_validator("timeout_seconds") + @classmethod + def validate_timeout_seconds(cls, v): + return max(5, min(60, v)) # Clamp between 5 and 60 seconds + +# Citation model +class Citation(BaseModel): + id: int + url: str + title: Optional[str] = None + site_name: Optional[str] = None + published_at: Optional[str] = None + snippet: Optional[str] = None + score: float + +# Timings model +class Timings(BaseModel): + search: int = 0 + fetch: int = 0 + rank: int = 0 + generate: int = 0 + total: int = 0 + +# Metadata model +class Meta(BaseModel): + engine: str + region: str + language: str + style_profile_id: Optional[str] = None + +# Response model +class WebSearchAnswerResponse(BaseModel): + query: str + answer_markdown: str + citations: List[Citation] + used_sources_count: int + timings_ms: Timings + meta: Meta + +# Dependency for internal authentication +def verify_internal_api_key(api_key: str = Depends(api_key_header)): + # Do not log or expose the secret + if not api_key or not INTERNAL_API_KEY or not secrets.compare_digest(api_key, INTERNAL_API_KEY): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing internal API key." + ) + +@router.post("/websearch/answer", response_model=WebSearchAnswerResponse, status_code=200) +async def web_search_answer( + request: WebSearchAnswerRequest, + req: Request, + api_key: str = Depends(verify_internal_api_key) +): + start_time = time.time() + """ + Perform web search and generate an answer with citations. + """ + # Generate request ID + request_id = req.headers.get("X-Request-Id", str(uuid.uuid4())) + logger_ctx = {"request_id": request_id} + + # Log request + logger.info( + f"Web search answer request: query='{request.query}', " + f"top_k={request.top_k_results}, timeout={request.timeout_seconds}s", + extra=logger_ctx + ) + + # Initialize timings + timings = { + "search": 0, + "fetch": 0, + "rank": 0, + "generate": 0, + "total": 0 + } + # Initialize services + try: + search_service = WebSearchService() + fetch_service = WebFetchService() + ranking_service = RankingService() + synthesis_service = SynthesisService() + except Exception as e: + logger.error("Failed to initialize services: %s", e, extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Service initialization error." + ) + + # 1. Search phase + search_start = time.time() + try: + search_results = await asyncio.wait_for( + search_service.search( + query=request.query, + k=request.top_k_results, + region=request.region, + language=request.language, + timeout_seconds=min(request.timeout_seconds * 0.4, 10) # Allocate 40% of timeout + ), + timeout=request.timeout_seconds * 0.4 + ) + except asyncio.TimeoutError: + logger.warning(f"Search timed out after {request.timeout_seconds * 0.4}s", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Search phase timed out." + ) + timings["search"] = int((time.time() - search_start) * 1000) + if not search_results: + logger.warning("No search results found", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No search results found for the query." + ) + logger.info(f"Search completed with {len(search_results)} results", extra=logger_ctx) + + # 2. Fetch phase + fetch_start = time.time() + try: + fetched_docs = await asyncio.wait_for( + fetch_service.fetch_search_results( + search_results=search_results, + timeout_seconds=min(request.timeout_seconds * 0.3, 8), # Allocate 30% of timeout + preserve_snippets=request.include_snippets + ), + timeout=request.timeout_seconds * 0.3 + ) + except asyncio.TimeoutError: + logger.warning(f"Fetch timed out after {request.timeout_seconds * 0.3}s", extra=logger_ctx) + # Continue with whatever we got + fetched_docs = [] + timings["fetch"] = int((time.time() - fetch_start) * 1000) + if not fetched_docs and not request.include_snippets: + logger.warning("No documents fetched successfully", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Failed to fetch content from search results." + ) + logger.info(f"Fetch completed with {len(fetched_docs)} documents", extra=logger_ctx) + + # 3. Ranking phase + rank_start = time.time() + try: + ranked_evidence = await asyncio.wait_for( + ranking_service.rank_documents( + query=request.query, + docs=fetched_docs, + max_context_chars=request.max_context_chars + ), + timeout=request.timeout_seconds * 0.15 # Allocate 15% of timeout + ) + except asyncio.TimeoutError: + logger.warning(f"Ranking timed out after {request.timeout_seconds * 0.15}s", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Ranking phase timed out." + ) + timings["rank"] = int((time.time() - rank_start) * 1000) + if not ranked_evidence: + logger.warning("No evidence ranked for the query", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No relevant evidence found for the query." + ) + logger.info(f"Ranking completed with {len(ranked_evidence)} evidence passages", extra=logger_ctx) + + # 4. Synthesis phase + generate_start = time.time() + try: + synthesis_result = await asyncio.wait_for( + synthesis_service.generate_answer( + query=request.query, + evidence_list=ranked_evidence, + answer_tokens=request.answer_tokens, + style_profile_id=request.style_profile_id + ), + timeout=request.timeout_seconds * 0.15 # Allocate remaining time + ) + except asyncio.TimeoutError: + logger.warning(f"Synthesis timed out after {request.timeout_seconds * 0.15}s", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail="Synthesis phase timed out." + ) + timings["generate"] = int((time.time() - generate_start) * 1000) + + # Calculate total time + timings["total"] = int((time.time() - start_time) * 1000) + + # Check if any citations were used + if not synthesis_result.used_citation_ids: + logger.warning("No citations used in the answer", extra=logger_ctx) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No evidence-based answer could be produced within constraints." + ) + + # Create citations list, including only those actually used in the answer + citations = [] + for evidence in ranked_evidence: + if evidence.id in synthesis_result.used_citation_ids: + citations.append(Citation( + id=evidence.id, + url=evidence.url, + title=evidence.title, + site_name=evidence.site_name, + published_at=evidence.published_at, + snippet=evidence.passage[:200] + "..." if len(evidence.passage) > 200 else evidence.passage, + score=evidence.score + )) + + # Sort citations by ID for consistency + citations.sort(key=lambda c: c.id) + + # Build response + response = WebSearchAnswerResponse( + query=request.query, + answer_markdown=synthesis_result.answer_markdown, + citations=citations, + used_sources_count=len(citations), + timings_ms=Timings(**timings), + meta=Meta( + engine=settings.web_search_engine, + region=request.region, + language=request.language, + style_profile_id=request.style_profile_id + ) + ) + + logger.info( + f"Answer generated successfully in {timings['total']}ms with {len(citations)} citations", + extra=logger_ctx + ) + + return response \ No newline at end of file diff --git a/services/ranking_service.py b/services/ranking_service.py new file mode 100644 index 0000000..1cc313b --- /dev/null +++ b/services/ranking_service.py @@ -0,0 +1,196 @@ +""" +Service for ranking text passages by relevance to a query. +Uses embedding-based similarity scoring. +""" + +from __future__ import annotations + +import logging +import numpy as np +import time +from dataclasses import dataclass +from typing import List, Optional + +from services.chunking_service import chunk_text +from services.embedding_service import embed_texts +from services.web_fetch_service import FetchedDoc + +logger = logging.getLogger(__name__) + +@dataclass +class RankedEvidence: + """Represents a ranked evidence passage with metadata.""" + id: int # 1-based ID used for citations + url: str + title: Optional[str] = None + site_name: Optional[str] = None + passage: str = "" + score: float = 0.0 + published_at: Optional[str] = None # ISO format date string + +class RankingService: + """ + Service for ranking passages by relevance to a query. + Uses embedding-based similarity scoring. + """ + + def __init__(self, ideal_passage_length: int = 1000, overlap: int = 200): + """ + Initialize the ranking service. + + Args: + ideal_passage_length: Target length of passages in characters + overlap: Overlap between passages in characters + """ + self.ideal_passage_length = ideal_passage_length + self.overlap = overlap + + def _split_into_passages(self, doc: FetchedDoc) -> List[tuple[str, str, Optional[str], Optional[str], Optional[str]]]: + """ + Split a document into passages with metadata. + + Args: + doc: The document to split + + Returns: + List of tuples: (passage, url, title, site_name, published_at) + """ + if not doc.text: + return [] + + # Use chunking service to split the text + passages = chunk_text( + doc.text, + max_length=self.ideal_passage_length, + overlap=self.overlap + ) + + # Return passages with metadata + return [(p, doc.url, doc.title, doc.site_name, doc.published_at) for p in passages] + + def _compute_similarity(self, query_embedding: List[float], passage_embeddings: List[List[float]]) -> List[float]: + """ + Compute cosine similarity between query and passages. + + Args: + query_embedding: Query embedding vector + passage_embeddings: List of passage embedding vectors + + Returns: + List of similarity scores + """ + # Convert to numpy arrays for efficient computation + query_vec = np.array(query_embedding) + passage_vecs = np.array(passage_embeddings) + + # Normalize vectors with zero-norm guards + q_norm = np.linalg.norm(query_vec) + if q_norm == 0: + return [0.0] * len(passage_embeddings) + query_vec = query_vec / q_norm + p_norms = np.linalg.norm(passage_vecs, axis=1, keepdims=True) + # Avoid divide-by-zero for degenerate embeddings + p_norms[p_norms == 0] = 1.0 + passage_vecs = passage_vecs / p_norms + + # Compute cosine similarity + similarities = np.dot(passage_vecs, query_vec) + + return similarities.tolist() + + async def rank_documents(self, query: str, docs: List[FetchedDoc], + max_context_chars: int = 20000) -> List[RankedEvidence]: + """ + Split documents into passages, rank them by relevance to query, + and return top passages within context budget. + + Args: + query: The search query + docs: List of fetched documents + max_context_chars: Maximum total character budget for evidence + + Returns: + List of RankedEvidence objects + """ + if not query or not docs: + return [] + + # Split all documents into passages + all_passages = [] + for doc in docs: + passages = self._split_into_passages(doc) + all_passages.extend(passages) + + if not all_passages: + logger.warning("No passages extracted from documents") + return [] + + # Unpack passages and metadata + passages = [p[0] for p in all_passages] + urls = [p[1] for p in all_passages] + titles = [p[2] for p in all_passages] + site_names = [p[3] for p in all_passages] + published_dates = [p[4] for p in all_passages] + + # Get embeddings for query and passages + start_time = time.time() + try: + # Embed query + query_embedding = embed_texts([query])[0] + + # Embed passages + passage_embeddings = embed_texts(passages) + + # Compute similarity scores + similarity_scores = self._compute_similarity(query_embedding, passage_embeddings) + except Exception as e: + logger.error(f"Error during embedding or similarity computation: {str(e)}") + # Fallback: assign decreasing scores based on original order + similarity_scores = [1.0 - (i / len(passages)) for i in range(len(passages))] + + # Create scored passages + scored_passages = [] + for i, (passage, url, title, site_name, published_at, score) in enumerate( + zip(passages, urls, titles, site_names, published_dates, similarity_scores) + ): + # Create RankedEvidence with 1-based ID + evidence = RankedEvidence( + id=i + 1, # 1-based ID for citations + url=url, + title=title, + site_name=site_name, + passage=passage, + score=score, + published_at=published_at + ) + scored_passages.append(evidence) + + # Sort by score (highest first) + scored_passages.sort(key=lambda p: p.score, reverse=True) + + # Select top passages within context budget + selected_passages = [] + current_budget = 0 + + for passage in scored_passages: + passage_length = len(passage.passage) + if current_budget + passage_length <= max_context_chars: + selected_passages.append(passage) + current_budget += passage_length + else: + # If we can't fit the full passage, check if we can fit a truncated version + remaining_budget = max_context_chars - current_budget + if remaining_budget >= 200: # Only truncate if we can fit a meaningful chunk + truncated_passage = passage.passage[:remaining_budget - 3] + "..." + passage.passage = truncated_passage + selected_passages.append(passage) + break + + # Re-assign IDs to ensure they're sequential + for i, passage in enumerate(selected_passages): + passage.id = i + 1 + + duration_ms = int((time.time() - start_time) * 1000) + logger.info(f"Ranked {len(all_passages)} passages in {duration_ms}ms, selected {len(selected_passages)} within budget") + + return selected_passages \ No newline at end of file diff --git a/services/synthesis_service.py b/services/synthesis_service.py new file mode 100644 index 0000000..c251365 --- /dev/null +++ b/services/synthesis_service.py @@ -0,0 +1,227 @@ +""" +Service for synthesizing answers from ranked evidence passages. +Generates Markdown answers with citations and references. +""" + +from __future__ import annotations + +import logging +import os +import re +import time +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Tuple, Any + +from openai import OpenAI +from services.ranking_service import RankedEvidence +from openai import AsyncOpenAI + +logger = logging.getLogger(__name__) + +@dataclass +class SynthesisResult: + """Result of synthesis with answer and metadata.""" + answer_markdown: str + used_citation_ids: Set[int] + prompt_tokens: int = 0 + completion_tokens: int = 0 + generation_ms: int = 0 + +class SynthesisService: + """ + Service for synthesizing answers from ranked evidence passages. + Generates answers with citation markers and reference lists. + """ + + def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4o"): + """ + Initialize the synthesis service. + + Args: + api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) + model: LLM model to use + """ + self.api_key = api_key or os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OpenAI API key is required but not provided") + + self.model = model + # Set a sane default timeout; endpoint also wraps with asyncio.wait_for + self.client = AsyncOpenAI(api_key=self.api_key, timeout=30.0) + def _build_system_prompt(self) -> str: + """Build the system prompt for the synthesis.""" + return ( + "You are a factual technical writer. Your job is to answer questions based ONLY on the provided evidence. " + "Be accurate, clear, and concise. Never include information not supported by the evidence." + ) + + def _build_evidence_block(self, evidence_list: List[RankedEvidence]) -> str: + """ + Build a numbered evidence block with metadata. + + Args: + evidence_list: List of ranked evidence + + Returns: + Formatted evidence block + """ + if not evidence_list: + return "NO EVIDENCE PROVIDED." + + evidence_block = "# EVIDENCE\n\n" + + for evidence in evidence_list: + # Format title + title = evidence.title or "Untitled Document" + + # Format site info + site_info = f" ({evidence.site_name})" if evidence.site_name else "" + + # Format evidence block + evidence_block += f"[{evidence.id}] {title}{site_info}\n" + evidence_block += f"URL: {evidence.url}\n" + if evidence.published_at: + evidence_block += f"Published: {evidence.published_at}\n" + evidence_block += f"Excerpt: {evidence.passage}\n\n" + + return evidence_block + + def _build_user_prompt(self, query: str, evidence_list: List[RankedEvidence], + answer_tokens: int, style_profile_id: Optional[str]) -> str: + """ + Build the user prompt with query, evidence, and instructions. + + Args: + query: The query to answer + evidence_list: List of ranked evidence + answer_tokens: Target token count for the answer + style_profile_id: Optional style profile ID for tone/voice + + Returns: + Complete user prompt + """ + # Build evidence block + evidence_block = self._build_evidence_block(evidence_list) + + # Build instructions + instructions = ( + f"Question: {query}\n\n" + "Instructions:\n" + "1. Answer the question using ONLY the provided evidence.\n" + "2. Add citation markers [^i] after each sentence, where i is the evidence number.\n" + "3. If a claim lacks evidence, either mark it (citation needed) or omit it entirely.\n" + "4. End with a References section listing each source you actually cited:\n" + " [^i]: Title — URL (site), date\n" + f"5. Be concise yet thorough. Target length: {answer_tokens} tokens.\n" + ) + + # Add style profile if provided + if style_profile_id: + if style_profile_id == "academic": + instructions += "\nStyle: Academic. Formal, precise language with scholarly tone. Use technical terminology where appropriate.\n" + elif style_profile_id == "simple": + instructions += "\nStyle: Simple. Clear, straightforward language accessible to non-experts. Avoid jargon.\n" + elif style_profile_id == "journalist": + instructions += "\nStyle: Journalistic. Informative, engaging, with clear explanations. Prioritize key facts.\n" + elif style_profile_id == "technical": + instructions += "\nStyle: Technical. Detailed, precise, assuming domain knowledge. Include technical specifics.\n" + else: + instructions += f"\nStyle Profile: {style_profile_id}\n" + + # Combine everything + return f"{instructions}\n\n{evidence_block}" + + def _extract_citation_ids(self, text: str) -> Set[int]: + """ + Extract citation IDs from text with [^i] markers. + + Args: + text: Text with citation markers + + Returns: + Set of citation IDs + """ + # Find all [^i] citations + citation_markers = re.findall(r'\[\^(\d+)\]', text) + + # Convert to integers and return as a set + return {int(id) for id in citation_markers if id.isdigit()} + + async def generate_answer(self, query: str, evidence_list: List[RankedEvidence], + answer_tokens: int = 800, + style_profile_id: Optional[str] = None) -> SynthesisResult: + """ + Generate an answer from evidence with citation markers. + + Args: + query: The query to answer + evidence_list: List of ranked evidence + answer_tokens: Target token count for the answer + style_profile_id: Optional style profile ID + + Returns: + SynthesisResult with answer and metadata + """ + if not query or not evidence_list: + return SynthesisResult( + answer_markdown="No evidence available to answer this question.", + used_citation_ids=set() + ) + + # Build prompts + system_prompt = self._build_system_prompt() + user_prompt = self._build_user_prompt(query, evidence_list, answer_tokens, style_profile_id) + + start_time = time.time() + try: + # Call OpenAI API + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + temperature=0.2, + max_tokens=answer_tokens, + ) + + # Extract answer + answer = response.choices[0].message.content + + # Calculate token usage + prompt_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(response.usage, "completion_tokens", 0) or 0 + + # Extract citation IDs + used_citation_ids = self._extract_citation_ids(answer) + + # Calculate generation time + generation_ms = int((time.time() - start_time) * 1000) + + logger.info( + "Generated answer in %dms, %d tokens, %d citations", + generation_ms, completion_tokens, len(used_citation_ids) + ) + + return SynthesisResult( + answer_markdown=answer, + used_citation_ids=used_citation_ids, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + generation_ms=generation_ms + ) + + except Exception: + logger.exception("Error generating answer") + + # Return a fallback result + fallback_answer = ( + "I encountered an internal error while generating the answer. " + "Please try again or contact support if the problem persists." + ) + + return SynthesisResult( + answer_markdown=fallback_answer, + used_citation_ids=set(), + generation_ms=int((time.time() - start_time) * 1000) + ) \ No newline at end of file diff --git a/tests/test_tavily_search.py b/tests/test_tavily_search.py new file mode 100644 index 0000000..4ffbdd1 --- /dev/null +++ b/tests/test_tavily_search.py @@ -0,0 +1,22 @@ +import os +import asyncio +import pytest +from services.web_search_service import WebSearchService + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_tavily_search_returns_results(): + if not os.getenv("TAVILY_API_KEY"): + pytest.skip("TAVILY_API_KEY not set") + search_service = WebSearchService(provider_name="tavily") + results = await search_service.search( + query="OpenAI API documentation", + k=3, + region="auto", + language="en", + timeout_seconds=15, + ) + assert isinstance(results, list) + assert len(results) > 0 + +# No __main__ runner in test files \ No newline at end of file