From 56c90a2ea4796fda906a604126f213c872f0248e Mon Sep 17 00:00:00 2001 From: Max Enfayz Date: Sat, 20 Sep 2025 21:18:26 +0000 Subject: [PATCH 1/5] Update environment configuration and enhance API routing - Added web search configuration variables to .env.example - Updated .gitignore to include a new line for clarity - Refactored import statement in indexing_router.py for consistency - Included web answering router in main.py for improved routing structure - Updated requirements.txt to include trafilatura and ensure httpx uses HTTP/2 --- .env.example | 8 +++++++- .gitignore | 1 + api/indexing_router.py | 2 +- api/main.py | 2 ++ requirements.txt | 6 ++++-- 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/.env.example b/.env.example index 6b1ab04..5d0055f 100644 --- a/.env.example +++ b/.env.example @@ -4,4 +4,10 @@ VECTOR_DB_ENV= POSTGRES_URI= SECRET_KEY= # Comma-separated list of allowed CORS origins, e.g. http://localhost:3000,https://yourdomain.com -CORS_ALLOW_ORIGINS= \ No newline at end of file +CORS_ALLOW_ORIGINS= + +# Web search configuration +WEB_SEARCH_ENGINE=tavily +TAVILY_API_KEY=your_tavily_api_key_here +MAX_FETCH_CONCURRENCY=4 +DEFAULT_TOP_K_RESULTS=8 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 583d9e7..c936975 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ authormaton/ experimentalCode/.env .env + # Ignore Python cache __pycache__/ diff --git a/api/indexing_router.py b/api/indexing_router.py index e2861c8..58b7e08 100644 --- a/api/indexing_router.py +++ b/api/indexing_router.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, status, Request from pydantic import BaseModel from config.settings import settings -from services.vector_db_service import VectorDBService +from services.vector_db_service import VectorDBClient as VectorDBService from services.embedding_service import embed_texts_batched from services.chunking_service import chunk_text from services.parsing_service import extract_text_from_pdf, extract_text_from_docx diff --git a/api/main.py b/api/main.py index af0c939..25f60a1 100644 --- a/api/main.py +++ b/api/main.py @@ -46,8 +46,10 @@ def read_root(): # Register routers from api.endpoints.upload import router as upload_router from api.endpoints.internal import router as internal_router +from api.endpoints.web_answering import router as web_answering_router app.include_router(upload_router, prefix="/upload") app.include_router(internal_router) +app.include_router(web_answering_router, prefix="/internal", tags=["websearch"]) app.include_router(indexing_router) @app.get("/health") diff --git a/requirements.txt b/requirements.txt index 4e168c8..06d7131 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ python-multipart>=0.0.6,<1.0.0 fastapi>=0.110.0,<1.0.0 uvicorn[standard]>=0.29.0,<1.0.0 pytest>=8.2.0,<9.0.0 -httpx>=0.27.0,<1.0.0 +httpx[http2]>=0.27.0,<1.0.0 python-dotenv>=1.0.0,<2.0.0 PyPDF2>=3.0.0,<4.0.0 requests>=2.31.0,<3.0.0 @@ -13,4 +13,6 @@ pinecone-client>=3.0.0,<4.0.0 weaviate-client>=4.4.0,<5.0.0 transformers>=4.40.0,<5.0.0 torch>=2.2.0,<3.0.0 -pydantic>=2.6.0,<3.0.0 \ No newline at end of file +pydantic>=2.6.0,<3.0.0 +trafilatura>=1.6.0,<2.0.0 +numpy>=1.26.0,<2.0.0 \ No newline at end of file From 6a0e924393721a8df347d68cbb65776a43a62385 Mon Sep 17 00:00:00 2001 From: Payton Zuniga Date: Sat, 20 Sep 2025 21:25:11 +0000 Subject: [PATCH 2/5] Add web search service with Tavily and dummy providers; enhance settings for search configuration --- config/settings.py | 10 +- services/web_search_service.py | 298 +++++++++++++++++++++++++++++++++ 2 files changed, 307 insertions(+), 1 deletion(-) create mode 100644 services/web_search_service.py diff --git a/config/settings.py b/config/settings.py index d5cea49..6ccc3ad 100644 --- a/config/settings.py +++ b/config/settings.py @@ -5,7 +5,8 @@ """ import os from pydantic_settings import BaseSettings -from pydantic import SecretStr, ValidationError +from pydantic import SecretStr, ValidationError, Field +from typing import Optional import sys try: from dotenv import load_dotenv @@ -24,6 +25,13 @@ class Settings(BaseSettings): embedding_dimension: int = 3072 embed_batch_size: int = 128 max_upload_mb: int = 25 + + # Web search settings + web_search_engine: str = os.environ.get("WEB_SEARCH_ENGINE", "dummy") # Default to dummy provider if not specified + tavily_api_key: Optional[SecretStr] = None + bing_api_key: Optional[SecretStr] = None + max_fetch_concurrency: int = 4 + default_top_k_results: int = 8 try: settings = Settings() diff --git a/services/web_search_service.py b/services/web_search_service.py new file mode 100644 index 0000000..4c0ac51 --- /dev/null +++ b/services/web_search_service.py @@ -0,0 +1,298 @@ +""" +Service for performing web searches via different providers. +Currently supports Tavily, with provider-agnostic wrapper. +""" + +from __future__ import annotations + +import asyncio +import logging +import random +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Protocol, Type, ClassVar, Mapping + +import httpx +from config.settings import settings + +logger = logging.getLogger(__name__) + +@dataclass +class SearchResult: + """Represents a search result from any provider.""" + url: str + title: Optional[str] = None + site_name: Optional[str] = None + snippet: Optional[str] = None + published_at: Optional[str] = None # ISO format date string + score: Optional[float] = None + provider_meta: Dict = field(default_factory=dict) + +class SearchProvider(ABC): + """Abstract base class for search providers.""" + + @abstractmethod + async def search(self, query: str, k: int, region: str, language: str, timeout_seconds: int) -> List[SearchResult]: + """ + Perform a search with the given parameters. + + Args: + query: The search query + k: Number of results to return + region: Region code (e.g., "us", "eu", "auto") + language: Language code (e.g., "en", "fr", "de") + timeout_seconds: Timeout in seconds + + Returns: + List of SearchResult objects + """ + pass + +class DummySearchProvider(SearchProvider): + """Dummy search provider for testing or when no API keys are available.""" + + async def search(self, query: str, k: int, region: str, language: str, timeout_seconds: int) -> List[SearchResult]: + """ + Return dummy search results for testing. + + Args: + query: The search query + k: Number of results to return + region: Region code (ignored) + language: Language code (ignored) + timeout_seconds: Timeout in seconds (ignored) + + Returns: + List of dummy SearchResult objects + """ + # Create k dummy results + results = [] + + for i in range(min(k, 5)): # Cap at 5 results + result = SearchResult( + url=f"https://example.com/result-{i+1}", + title=f"Dummy Result {i+1} for '{query}'", + site_name="Example.com", + snippet=f"This is a dummy search result #{i+1} for the query: '{query}'. " + f"This is placeholder text and does not contain real information.", + published_at="2025-09-01T12:00:00Z", + score=1.0 - (i * 0.1), + provider_meta={"provider": "dummy", "result_id": i+1} + ) + results.append(result) + + # Simulate network delay + await asyncio.sleep(0.5) + + logger.warning(f"Using DummySearchProvider for query: {query}. Configure a real search provider for production.") + return results + +class TavilySearchProvider(SearchProvider): + """Tavily search provider implementation.""" + + def __init__(self, api_key: Optional[str] = None): + """ + Initialize with Tavily API key. + + Args: + api_key: Tavily API key (defaults to settings.tavily_api_key) + """ + self.api_key = api_key or settings.tavily_api_key.get_secret_value() if settings.tavily_api_key else None + if not self.api_key: + raise ValueError("Tavily API key is required but not provided in settings or constructor") + + # Tavily API configuration + self.api_url = "https://api.tavily.com/search" + + async def search(self, query: str, k: int, region: str, language: str, timeout_seconds: int) -> List[SearchResult]: + """ + Perform a search using Tavily API. + + Args: + query: The search query + k: Number of results to return (max 15) + region: Region hint (Tavily handles this internally) + language: Language code + timeout_seconds: Timeout in seconds + + Returns: + List of SearchResult objects + """ + if not query: + raise ValueError("Query cannot be empty") + + # Clamp k between 3 and 15 + k = max(3, min(15, k)) + + params = { + "query": query, + "search_depth": "advanced", + "include_domains": [], + "exclude_domains": [], + "max_results": k, + "include_answer": False, + "include_raw_content": False, # We'll fetch the content separately + } + + # Add language if not auto + if language and language.lower() != "auto": + params["language"] = language + + # Create HTTP client with timeout and retry logic + async with httpx.AsyncClient(timeout=timeout_seconds) as client: + max_retries = 3 + + # Set headers with API key + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + + for attempt in range(max_retries): + try: + response = await client.post(self.api_url, json=params, headers=headers) + response.raise_for_status() + break + except httpx.HTTPStatusError as e: + if e.response.status_code == 429: # Rate limit + if attempt < max_retries - 1: + wait_time = (2 ** attempt) + random.random() + logger.warning(f"Rate limited by Tavily. Retrying in {wait_time:.2f}s") + await asyncio.sleep(wait_time) + else: + logger.error(f"Tavily search failed after {max_retries} attempts: rate limited") + raise + else: + logger.error(f"Tavily search failed with status {e.response.status_code}: {e.response.text}") + raise + except httpx.RequestError as e: + logger.error(f"Tavily request failed: {str(e)}") + if attempt < max_retries - 1: + wait_time = (2 ** attempt) + random.random() + logger.warning(f"Request error. Retrying in {wait_time:.2f}s") + await asyncio.sleep(wait_time) + else: + raise + + data = response.json() + results = data.get("results", []) + + search_results = [] + for idx, result in enumerate(results): + search_result = SearchResult( + url=result.get("url", ""), + title=result.get("title"), + site_name=result.get("domain"), + snippet=result.get("description"), + published_at=result.get("published_date"), + # Use position as score if not provided + score=result.get("relevance_score", 1.0 - (idx / (len(results) or 1))), + provider_meta={"tavily_id": result.get("id")} if "id" in result else {} + ) + search_results.append(search_result) + + return search_results + +class WebSearchService: + """ + Service for performing web searches across different providers. + Uses the provider specified in settings.web_search_engine. + """ + + # Registry of available providers + _providers: ClassVar[Dict[str, Type[SearchProvider]]] = { + "tavily": TavilySearchProvider, + "dummy": DummySearchProvider, + } + + # LRU cache for search results (module-level, simple implementation) + _cache: ClassVar[Dict[str, tuple[float, List[SearchResult]]]] = {} + _cache_ttl: ClassVar[int] = 300 # 5 minutes in seconds + _cache_max_size: ClassVar[int] = 100 + + def __init__(self, provider_name: Optional[str] = None): + """ + Initialize the web search service. + + Args: + provider_name: Name of the provider to use (defaults to settings.web_search_engine) + """ + self.provider_name = provider_name or settings.web_search_engine + + # If no provider name is set, use the dummy provider + if not self.provider_name or self.provider_name == "none": + logger.warning("No search provider specified. Using DummySearchProvider for development.") + self.provider_name = "dummy" + self.provider = DummySearchProvider() + return + + if self.provider_name not in self._providers: + raise ValueError(f"Unsupported search provider: {self.provider_name}") + + try: + self.provider = self._providers[self.provider_name]() + logger.info(f"Initialized WebSearchService with provider: {self.provider_name}") + except Exception as e: + logger.error(f"Failed to initialize {self.provider_name} search provider: {str(e)}") + logger.warning("Falling back to DummySearchProvider due to initialization failure.") + self.provider_name = "dummy" + self.provider = DummySearchProvider() + + async def search(self, query: str, k: int = None, region: str = "auto", + language: str = "en", timeout_seconds: int = 15, + use_cache: bool = True) -> List[SearchResult]: + """ + Perform a web search using the configured provider. + + Args: + query: The search query + k: Number of results to return (defaults to settings.default_top_k_results) + region: Region code or "auto" + language: Language code (default "en") + timeout_seconds: Timeout in seconds + use_cache: Whether to use the cache + + Returns: + List of SearchResult objects + """ + if not query: + return [] + + # Use default from settings if not specified + k = k or settings.default_top_k_results + + # Check cache first if enabled + cache_key = f"{self.provider_name}:{query}:{k}:{region}:{language}" + if use_cache and cache_key in self._cache: + timestamp, results = self._cache[cache_key] + if time.time() - timestamp <= self._cache_ttl: + logger.debug(f"Cache hit for query: {query}") + return results + + # Cache miss or disabled, perform actual search + try: + start_time = time.time() + results = await self.provider.search( + query=query, + k=k, + region=region, + language=language, + timeout_seconds=timeout_seconds + ) + duration_ms = (time.time() - start_time) * 1000 + logger.info(f"Search completed in {duration_ms:.0f}ms. Query: {query}, Provider: {self.provider_name}, Results: {len(results)}") + + # Cache results if caching is enabled + if use_cache: + # If cache is full, remove the oldest entry + if len(self._cache) >= self._cache_max_size: + oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k][0]) + self._cache.pop(oldest_key) + + self._cache[cache_key] = (time.time(), results) + + return results + except Exception as e: + logger.error(f"Search failed: {str(e)}") + raise \ No newline at end of file From e01739f4c1f546e3c6b898837e64aaa3ecaaa338 Mon Sep 17 00:00:00 2001 From: Payton Zuniga Date: Sat, 20 Sep 2025 21:35:15 +0000 Subject: [PATCH 3/5] Implement WebFetchService for concurrent web content fetching with error handling and HTML extraction --- services/web_fetch_service.py | 274 ++++++++++++++++++++++++++++++++++ 1 file changed, 274 insertions(+) create mode 100644 services/web_fetch_service.py diff --git a/services/web_fetch_service.py b/services/web_fetch_service.py new file mode 100644 index 0000000..6e5f15c --- /dev/null +++ b/services/web_fetch_service.py @@ -0,0 +1,274 @@ +""" +Service for fetching and extracting content from web pages. +Concurrently fetches pages with semaphore-based rate limiting. +""" + +from __future__ import annotations + +import asyncio +import logging +import re +import time +from dataclasses import dataclass +from typing import Dict, List, Optional, Set +from urllib.parse import urlparse + +import httpx +from config.settings import settings + +logger = logging.getLogger(__name__) + +# Import trafilatura if available, otherwise provide a fallback +try: + import trafilatura + TRAFILATURA_AVAILABLE = True +except ImportError: + TRAFILATURA_AVAILABLE = False + logger.warning("trafilatura package not found, using simple HTML extraction fallback") + +@dataclass +class FetchedDoc: + """Represents a fetched web document.""" + url: str + title: Optional[str] = None + site_name: Optional[str] = None + text: str = "" + published_at: Optional[str] = None # ISO format date string + fetch_ms: int = 0 # Time taken to fetch in milliseconds + +class WebFetchService: + """ + Service for fetching and extracting content from web pages. + Uses asyncio for concurrent fetching with rate limiting via semaphore. + """ + + def __init__(self, max_concurrency: Optional[int] = None): + """ + Initialize the web fetch service. + + Args: + max_concurrency: Maximum number of concurrent requests + (defaults to settings.max_fetch_concurrency) + """ + self.max_concurrency = max_concurrency or settings.max_fetch_concurrency + self.semaphore = asyncio.Semaphore(self.max_concurrency) + + # Common HTML headers for the fetch requests + self.headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", + "Accept-Language": "en-US,en;q=0.5", + "Accept-Encoding": "gzip, deflate, br", + "Connection": "keep-alive", + "Upgrade-Insecure-Requests": "1", + "Sec-Fetch-Dest": "document", + "Sec-Fetch-Mode": "navigate", + "Sec-Fetch-Site": "none", + "Sec-Fetch-User": "?1", + "DNT": "1", + } + + def _extract_site_name(self, url: str) -> str: + """Extract site name from URL.""" + try: + parsed = urlparse(url) + domain = parsed.netloc + # Remove www. prefix if present + if domain.startswith("www."): + domain = domain[4:] + return domain + except Exception: + return "" + + def _sanitize_html_fallback(self, html: str) -> str: + """ + Simple HTML sanitizer fallback when trafilatura is not available. + Removes HTML tags, scripts, styles, and excessive whitespace. + """ + # Remove scripts and styles + html = re.sub(r'', ' ', html, flags=re.DOTALL) + html = re.sub(r'', ' ', html, flags=re.DOTALL) + + # Remove all tags + html = re.sub(r'<[^>]+>', ' ', html) + + # Replace entities + html = re.sub(r' ', ' ', html) + html = re.sub(r'&', '&', html) + html = re.sub(r'<', '<', html) + html = re.sub(r'>', '>', html) + html = re.sub(r'"', '"', html) + html = re.sub(r'&#\d+;', ' ', html) + + # Normalize whitespace + html = re.sub(r'\s+', ' ', html) + + return html.strip() + + def _extract_title(self, html: str) -> Optional[str]: + """Extract title from HTML.""" + title_match = re.search(r']*>(.*?)', html, re.IGNORECASE | re.DOTALL) + if title_match: + title = title_match.group(1).strip() + # Clean up title + title = re.sub(r'\s+', ' ', title) + return title + return None + + def _extract_text_from_html(self, html: str) -> str: + """ + Extract readable text from HTML using trafilatura if available, + otherwise fallback to simple regex-based extraction. + """ + if not html: + return "" + + if TRAFILATURA_AVAILABLE: + try: + text = trafilatura.extract(html, include_comments=False, include_tables=False, + favor_precision=True, include_formatting=False) + if text: + return text + # If trafilatura returns None, fall back to simple extraction + logger.warning("trafilatura extraction failed, falling back to simple extraction") + except Exception as e: + logger.warning(f"trafilatura extraction error: {e}, falling back to simple extraction") + + # Fallback: simple HTML tag removal + return self._sanitize_html_fallback(html) + + async def _fetch_url(self, url: str, timeout_seconds: int = 10) -> FetchedDoc: + """ + Fetch a single URL and extract its content. + + Args: + url: The URL to fetch + timeout_seconds: Timeout in seconds + + Returns: + FetchedDoc object with extracted content + """ + async with self.semaphore: + start_time = time.time() + site_name = self._extract_site_name(url) + + # Initialize with empty/default values + fetched_doc = FetchedDoc( + url=url, + site_name=site_name, + fetch_ms=0 + ) + + try: + async with httpx.AsyncClient(timeout=timeout_seconds, follow_redirects=True) as client: + response = await client.get(url, headers=self.headers) + response.raise_for_status() + + html = response.text + + # Extract title if not already provided + title = self._extract_title(html) + + # Extract text content + text = self._extract_text_from_html(html) + + # Update the fetched document + fetched_doc.title = title + fetched_doc.text = text + + # Calculate fetch time + fetch_ms = int((time.time() - start_time) * 1000) + fetched_doc.fetch_ms = fetch_ms + + logger.info(f"Fetched {url} in {fetch_ms}ms, extracted {len(text)} chars") + return fetched_doc + + except httpx.HTTPStatusError as e: + status = e.response.status_code + logger.warning(f"HTTP error {status} fetching {url}: {str(e)}") + except httpx.RequestError as e: + logger.warning(f"Request error fetching {url}: {str(e)}") + except Exception as e: + logger.warning(f"Error fetching {url}: {str(e)}") + + # If we got here, there was an error + fetch_ms = int((time.time() - start_time) * 1000) + fetched_doc.fetch_ms = fetch_ms + logger.warning(f"Failed to fetch {url} after {fetch_ms}ms") + return fetched_doc + + async def fetch_urls(self, urls: List[str], timeout_seconds: int = 10) -> List[FetchedDoc]: + """ + Fetch multiple URLs concurrently. + + Args: + urls: List of URLs to fetch + timeout_seconds: Timeout in seconds per request + + Returns: + List of FetchedDoc objects + """ + if not urls: + return [] + + # Create fetch tasks for all URLs + tasks = [self._fetch_url(url, timeout_seconds) for url in urls] + + # Execute all tasks concurrently + try: + results = await asyncio.gather(*tasks, return_exceptions=True) + except Exception as e: + logger.error(f"Error in fetch_urls: {str(e)}") + return [] + + # Filter out exceptions and empty results + fetched_docs = [] + for result in results: + if isinstance(result, Exception): + logger.warning(f"Exception during fetch: {str(result)}") + elif isinstance(result, FetchedDoc) and result.text: + fetched_docs.append(result) + + logger.info(f"Fetched {len(fetched_docs)}/{len(urls)} URLs successfully") + return fetched_docs + + async def fetch_search_results(self, search_results: List, timeout_seconds: int = 10, + preserve_snippets: bool = True) -> List[FetchedDoc]: + """ + Fetch content for search results. + + Args: + search_results: List of SearchResult objects + timeout_seconds: Timeout in seconds per request + preserve_snippets: Whether to use snippets as fallback when fetch fails + + Returns: + List of FetchedDoc objects + """ + # Extract URLs from search results + urls = [result.url for result in search_results if result.url] + + # Create a mapping of URL to search result for later use + url_to_result = {result.url: result for result in search_results if result.url} + + # Fetch all URLs + fetched_docs = await self.fetch_urls(urls, timeout_seconds) + + # If preserve_snippets is True, create FetchedDocs for failed fetches using snippets + if preserve_snippets: + fetched_urls = {doc.url for doc in fetched_docs} + for url in urls: + if url not in fetched_urls and url in url_to_result: + result = url_to_result[url] + if result.snippet: + logger.info(f"Using snippet as fallback for {url}") + fetched_docs.append(FetchedDoc( + url=url, + title=result.title, + site_name=result.site_name or self._extract_site_name(url), + text=result.snippet, + published_at=result.published_at, + fetch_ms=0 # Indicate that this wasn't actually fetched + )) + + return fetched_docs \ No newline at end of file From b8664ff1d193a752bb5eb021143511e5c2664a75 Mon Sep 17 00:00:00 2001 From: Tahbit Fehran Date: Sun, 21 Sep 2025 06:55:52 -0600 Subject: [PATCH 4/5] Refactor HTTP headers in WebFetchService to simplify configuration and allow httpx to manage encoding automatically --- services/web_fetch_service.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/services/web_fetch_service.py b/services/web_fetch_service.py index 6e5f15c..c2515fa 100644 --- a/services/web_fetch_service.py +++ b/services/web_fetch_service.py @@ -57,15 +57,8 @@ def __init__(self, max_concurrency: Optional[int] = None): self.headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", - "Accept-Language": "en-US,en;q=0.5", - "Accept-Encoding": "gzip, deflate, br", - "Connection": "keep-alive", - "Upgrade-Insecure-Requests": "1", - "Sec-Fetch-Dest": "document", - "Sec-Fetch-Mode": "navigate", - "Sec-Fetch-Site": "none", - "Sec-Fetch-User": "?1", - "DNT": "1", + "Accept-Language": "en-US,en;q=0.5" + # Let httpx handle Accept-Encoding and compression automatically } def _extract_site_name(self, url: str) -> str: From b45c8da4dc065e98ff7176bd6b335986bbeb8923 Mon Sep 17 00:00:00 2001 From: Tahbit Fehran Date: Sun, 21 Sep 2025 07:01:06 -0600 Subject: [PATCH 5/5] Add SSRF protection to WebFetchService by validating URLs before fetching --- services/web_fetch_service.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/services/web_fetch_service.py b/services/web_fetch_service.py index c2515fa..9604b4d 100644 --- a/services/web_fetch_service.py +++ b/services/web_fetch_service.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio +import ipaddress import logging import re import time @@ -73,6 +74,33 @@ def _extract_site_name(self, url: str) -> str: except Exception: return "" + def _is_url_allowed(self, url: str) -> bool: + """ + Check if a URL is safe to fetch (SSRF protection). + + Args: + url: The URL to validate + + Returns: + True if URL is safe to fetch, False otherwise + """ + try: + p = urlparse(url) + if p.scheme not in ("http", "https"): + return False + if not p.hostname or p.username or p.password: + return False + try: + ip = ipaddress.ip_address(p.hostname) + if not ip.is_global: + return False + except ValueError: + # Hostname; DNS resolution checks can be added later if needed. + pass + return True + except Exception: + return False + def _sanitize_html_fallback(self, html: str) -> str: """ Simple HTML sanitizer fallback when trafilatura is not available. @@ -152,6 +180,12 @@ async def _fetch_url(self, url: str, timeout_seconds: int = 10) -> FetchedDoc: fetch_ms=0 ) + # SSRF guard + if not self._is_url_allowed(url): + fetched_doc.fetch_ms = int((time.time() - start_time) * 1000) + logger.warning("Blocked potentially unsafe URL: %s", url) + return fetched_doc + try: async with httpx.AsyncClient(timeout=timeout_seconds, follow_redirects=True) as client: response = await client.get(url, headers=self.headers)