Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 301 additions & 0 deletions services/web_fetch_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
"""
Service for fetching and extracting content from web pages.
Concurrently fetches pages with semaphore-based rate limiting.
"""

from __future__ import annotations

import asyncio
import ipaddress
import logging
import re
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Set
from urllib.parse import urlparse

import httpx
from config.settings import settings

logger = logging.getLogger(__name__)

# Import trafilatura if available, otherwise provide a fallback
try:
import trafilatura
TRAFILATURA_AVAILABLE = True
except ImportError:
TRAFILATURA_AVAILABLE = False
logger.warning("trafilatura package not found, using simple HTML extraction fallback")

@dataclass
class FetchedDoc:
"""Represents a fetched web document."""
url: str
title: Optional[str] = None
site_name: Optional[str] = None
text: str = ""
published_at: Optional[str] = None # ISO format date string
fetch_ms: int = 0 # Time taken to fetch in milliseconds

class WebFetchService:
"""
Service for fetching and extracting content from web pages.
Uses asyncio for concurrent fetching with rate limiting via semaphore.
"""

def __init__(self, max_concurrency: Optional[int] = None):
"""
Initialize the web fetch service.

Args:
max_concurrency: Maximum number of concurrent requests
(defaults to settings.max_fetch_concurrency)
"""
self.max_concurrency = max_concurrency or settings.max_fetch_concurrency
self.semaphore = asyncio.Semaphore(self.max_concurrency)

Comment on lines +46 to +56
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Block SSRF via open redirects; disable env proxies; reuse a single AsyncClient with limits.

Currently follow_redirects=True allows 3xx hops to localhost/IMDS/etc, bypassing your pre-check. Also, creating a new client per request is wasteful and inherits proxy env by default. Reuse one AsyncClient, set trust_env=False, and manually validate each redirect hop.

Apply this diff:

@@
-from urllib.parse import urlparse
+from urllib.parse import urlparse, urljoin
@@ class WebFetchService:
-    def __init__(self, max_concurrency: Optional[int] = None):
+    def __init__(self, max_concurrency: Optional[int] = None):
@@
-        self.max_concurrency = max_concurrency or settings.max_fetch_concurrency
-        self.semaphore = asyncio.Semaphore(self.max_concurrency)
+        raw = max_concurrency or getattr(settings, "max_fetch_concurrency", 10)
+        self.max_concurrency = max(1, raw)
+        self.semaphore = asyncio.Semaphore(self.max_concurrency)
@@
-        # Common HTML headers for the fetch requests
+        # Common HTML headers for the fetch requests
         self.headers = {
             "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
             "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
             "Accept-Language": "en-US,en;q=0.5"
             # Let httpx handle Accept-Encoding and compression automatically
         }
+
+        # Reuse a single client; disable env proxies; bound pool size
+        self._limits = httpx.Limits(
+            max_connections=self.max_concurrency,
+            max_keepalive_connections=self.max_concurrency,
+        )
+        self._client = httpx.AsyncClient(
+            headers=self.headers,
+            http2=True,
+            limits=self._limits,
+            follow_redirects=False,  # we validate/handle redirects manually
+            trust_env=False,         # ignore HTTP(S)_PROXY/NO_PROXY
+        )
+        self._max_redirects = getattr(settings, "fetch_max_redirects", 5)
+        self._max_content_bytes = getattr(settings, "fetch_max_content_bytes", 2_000_000)
+
+    async def aclose(self) -> None:
+        await self._client.aclose()
@@
-            try:
-                async with httpx.AsyncClient(timeout=timeout_seconds, follow_redirects=True) as client:
-                    response = await client.get(url, headers=self.headers)
-                    response.raise_for_status()
-                    
-                    html = response.text
-                    
-                    # Extract title if not already provided
-                    title = self._extract_title(html)
-                    
-                    # Extract text content
-                    text = self._extract_text_from_html(html)
-                    
-                    # Update the fetched document
-                    fetched_doc.title = title
-                    fetched_doc.text = text
-                    
-                    # Calculate fetch time
-                    fetch_ms = int((time.time() - start_time) * 1000)
-                    fetched_doc.fetch_ms = fetch_ms
-                    
-                    logger.info(f"Fetched {url} in {fetch_ms}ms, extracted {len(text)} chars")
-                    return fetched_doc
+            try:
+                current_url = url
+                html = ""
+                # Manually follow and validate up to N redirects
+                for _ in range(self._max_redirects + 1):
+                    if not self._is_url_allowed(current_url):
+                        raise httpx.HTTPStatusError("Unsafe URL after redirect validation",
+                                                    request=None, response=None)
+                    # Stream to cap payload size
+                    async with self._client.stream("GET", current_url, timeout=timeout_seconds) as response:
+                        if 300 <= response.status_code < 400 and "location" in response.headers:
+                            next_url = urljoin(current_url, response.headers["location"])
+                            current_url = next_url
+                            continue
+                        response.raise_for_status()
+                        ctype = response.headers.get("content-type", "")
+                        if "html" not in ctype and "xml" not in ctype:
+                            logger.debug("Skipping non-HTML content: %s (%s)", current_url, ctype)
+                            break
+                        buf = bytearray()
+                        async for chunk in response.aiter_bytes():
+                            buf.extend(chunk)
+                            if len(buf) > self._max_content_bytes:
+                                logger.warning("Aborting %s: response exceeded %d bytes", current_url, self._max_content_bytes)
+                                break
+                    if buf:
+                        html = buf.decode("utf-8", errors="replace")
+                    break
+
+                # Extract title/text if we have HTML
+                if html:
+                    title = self._extract_title(html)
+                    text = self._extract_text_from_html(html)
+                    fetched_doc.title = title
+                    fetched_doc.text = text
+
+                # Calculate fetch time
+                fetch_ms = int((time.time() - start_time) * 1000)
+                fetched_doc.fetch_ms = fetch_ms
+                logger.debug("Fetched %s in %dms, extracted %d chars", current_url, fetch_ms, len(fetched_doc.text))
+                return fetched_doc

Also applies to: 15-16, 161-213

🤖 Prompt for AI Agents
In services/web_fetch_service.py around lines 46-56 (also applies to lines 15-16
and 161-213): the service currently creates a new HTTP client per request,
inherits proxy environment variables, and uses follow_redirects=True which
allows open-redirect SSRF to internal addresses; fix by instantiating a single
httpx.AsyncClient in __init__ with appropriate limits (httpx.Limits or
equivalent) and trust_env=False, set follow_redirects=False, and reuse that
client for all requests; implement manual redirect handling when a 3xx response
is returned: read the Location header, resolve and validate each redirect hop
against the same SSRF/internal-host checks used for the initial URL (reject
local/IMDS/private IPs), enforce a maximum redirect count, and only follow
validated redirects using the single client; ensure the client is properly
closed on service shutdown.

# Common HTML headers for the fetch requests
self.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.5"
# Let httpx handle Accept-Encoding and compression automatically
}

def _extract_site_name(self, url: str) -> str:
"""Extract site name from URL."""
try:
parsed = urlparse(url)
domain = parsed.netloc
# Remove www. prefix if present
if domain.startswith("www."):
domain = domain[4:]
return domain
except Exception:
return ""

def _is_url_allowed(self, url: str) -> bool:
"""
Check if a URL is safe to fetch (SSRF protection).

Args:
url: The URL to validate

Returns:
True if URL is safe to fetch, False otherwise
"""
try:
p = urlparse(url)
if p.scheme not in ("http", "https"):
return False
if not p.hostname or p.username or p.password:
return False
try:
ip = ipaddress.ip_address(p.hostname)
if not ip.is_global:
return False
except ValueError:
# Hostname; DNS resolution checks can be added later if needed.
pass
return True
except Exception:
return False

Comment on lines +77 to +103
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

SSRF guard is a good start; consider tightening hostname cases.

Add explicit blocks for localhost-style hostnames and common link-local names; document future DNS resolution plans to catch rebinds.

Example tweak:

@@ def _is_url_allowed(self, url: str) -> bool:
-            if not p.hostname or p.username or p.password:
+            if not p.hostname or p.username or p.password:
                 return False
+            host = p.hostname.lower()
+            if host in {"localhost", "localhost.localdomain"} or host.endswith(".local"):
+                return False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _is_url_allowed(self, url: str) -> bool:
"""
Check if a URL is safe to fetch (SSRF protection).
Args:
url: The URL to validate
Returns:
True if URL is safe to fetch, False otherwise
"""
try:
p = urlparse(url)
if p.scheme not in ("http", "https"):
return False
if not p.hostname or p.username or p.password:
return False
try:
ip = ipaddress.ip_address(p.hostname)
if not ip.is_global:
return False
except ValueError:
# Hostname; DNS resolution checks can be added later if needed.
pass
return True
except Exception:
return False
def _is_url_allowed(self, url: str) -> bool:
"""
Check if a URL is safe to fetch (SSRF protection).
Args:
url: The URL to validate
Returns:
True if URL is safe to fetch, False otherwise
"""
try:
p = urlparse(url)
if p.scheme not in ("http", "https"):
return False
if not p.hostname or p.username or p.password:
return False
host = p.hostname.lower()
if host in {"localhost", "localhost.localdomain"} or host.endswith(".local"):
return False
try:
ip = ipaddress.ip_address(p.hostname)
if not ip.is_global:
return False
except ValueError:
# Hostname; DNS resolution checks can be added later if needed.
pass
return True
except Exception:
return False
🧰 Tools
🪛 Ruff (0.13.1)

100-100: Consider moving this statement to an else block

(TRY300)


101-101: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In services/web_fetch_service.py around lines 77 to 103, the SSRF guard
currently allows hostnames that could still resolve to
loopback/link-local/private addresses; update the function to explicitly reject
common localhost-style hostnames and numeric edge cases by normalizing
p.hostname (lowercase, strip surrounding brackets for IPv6) and returning False
for literal names like "localhost", "ip6-localhost", "0.0.0.0" and for IPv4/IPv6
addresses in loopback, link-local (169.254.0.0/16 and fe80::/10), multicast, and
private ranges (10/8, 172.16/12, 192.168/16) using ipaddress checks
(is_loopback, is_link_local, is_private, is_multicast) after parsing the
hostname to an ip object; if hostname is non-numeric keep the existing behavior
but add a clear TODO comment that DNS resolution and rebind protection will be
implemented later.

def _sanitize_html_fallback(self, html: str) -> str:
"""
Simple HTML sanitizer fallback when trafilatura is not available.
Removes HTML tags, scripts, styles, and excessive whitespace.
"""
# Remove scripts and styles
html = re.sub(r'<script.*?</script>', ' ', html, flags=re.DOTALL)
html = re.sub(r'<style.*?</style>', ' ', html, flags=re.DOTALL)

# Remove all tags
html = re.sub(r'<[^>]+>', ' ', html)

# Replace entities
html = re.sub(r'&nbsp;', ' ', html)
html = re.sub(r'&amp;', '&', html)
html = re.sub(r'&lt;', '<', html)
html = re.sub(r'&gt;', '>', html)
html = re.sub(r'&quot;', '"', html)
html = re.sub(r'&#\d+;', ' ', html)

# Normalize whitespace
html = re.sub(r'\s+', ' ', html)

return html.strip()

def _extract_title(self, html: str) -> Optional[str]:
"""Extract title from HTML."""
title_match = re.search(r'<title[^>]*>(.*?)</title>', html, re.IGNORECASE | re.DOTALL)
if title_match:
title = title_match.group(1).strip()
# Clean up title
title = re.sub(r'\s+', ' ', title)
return title
return None

def _extract_text_from_html(self, html: str) -> str:
"""
Extract readable text from HTML using trafilatura if available,
otherwise fallback to simple regex-based extraction.
"""
if not html:
return ""

if TRAFILATURA_AVAILABLE:
try:
text = trafilatura.extract(html, include_comments=False, include_tables=False,
favor_precision=True, include_formatting=False)
if text:
return text
# If trafilatura returns None, fall back to simple extraction
logger.warning("trafilatura extraction failed, falling back to simple extraction")
except Exception as e:
logger.warning(f"trafilatura extraction error: {e}, falling back to simple extraction")

# Fallback: simple HTML tag removal
return self._sanitize_html_fallback(html)

async def _fetch_url(self, url: str, timeout_seconds: int = 10) -> FetchedDoc:
"""
Fetch a single URL and extract its content.

Args:
url: The URL to fetch
timeout_seconds: Timeout in seconds

Returns:
FetchedDoc object with extracted content
"""
async with self.semaphore:
start_time = time.time()
site_name = self._extract_site_name(url)

# Initialize with empty/default values
fetched_doc = FetchedDoc(
url=url,
site_name=site_name,
fetch_ms=0
)

# SSRF guard
if not self._is_url_allowed(url):
fetched_doc.fetch_ms = int((time.time() - start_time) * 1000)
logger.warning("Blocked potentially unsafe URL: %s", url)
return fetched_doc

try:
async with httpx.AsyncClient(timeout=timeout_seconds, follow_redirects=True) as client:
response = await client.get(url, headers=self.headers)
response.raise_for_status()

html = response.text

# Extract title if not already provided
title = self._extract_title(html)

# Extract text content
text = self._extract_text_from_html(html)

# Update the fetched document
fetched_doc.title = title
fetched_doc.text = text

# Calculate fetch time
fetch_ms = int((time.time() - start_time) * 1000)
fetched_doc.fetch_ms = fetch_ms

logger.info(f"Fetched {url} in {fetch_ms}ms, extracted {len(text)} chars")
return fetched_doc

except httpx.HTTPStatusError as e:
status = e.response.status_code
logger.warning(f"HTTP error {status} fetching {url}: {str(e)}")
except httpx.RequestError as e:
logger.warning(f"Request error fetching {url}: {str(e)}")
except Exception as e:
logger.warning(f"Error fetching {url}: {str(e)}")

# If we got here, there was an error
fetch_ms = int((time.time() - start_time) * 1000)
fetched_doc.fetch_ms = fetch_ms
logger.warning(f"Failed to fetch {url} after {fetch_ms}ms")
return fetched_doc

async def fetch_urls(self, urls: List[str], timeout_seconds: int = 10) -> List[FetchedDoc]:
"""
Fetch multiple URLs concurrently.

Args:
urls: List of URLs to fetch
timeout_seconds: Timeout in seconds per request

Returns:
List of FetchedDoc objects
"""
if not urls:
return []

# Create fetch tasks for all URLs
tasks = [self._fetch_url(url, timeout_seconds) for url in urls]

# Execute all tasks concurrently
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
logger.error(f"Error in fetch_urls: {str(e)}")
return []

# Filter out exceptions and empty results
fetched_docs = []
for result in results:
if isinstance(result, Exception):
logger.warning(f"Exception during fetch: {str(result)}")
elif isinstance(result, FetchedDoc) and result.text:
fetched_docs.append(result)

logger.info(f"Fetched {len(fetched_docs)}/{len(urls)} URLs successfully")
return fetched_docs

async def fetch_search_results(self, search_results: List, timeout_seconds: int = 10,
preserve_snippets: bool = True) -> List[FetchedDoc]:
"""
Fetch content for search results.

Args:
search_results: List of SearchResult objects
timeout_seconds: Timeout in seconds per request
preserve_snippets: Whether to use snippets as fallback when fetch fails

Returns:
List of FetchedDoc objects
"""
# Extract URLs from search results
urls = [result.url for result in search_results if result.url]

# Create a mapping of URL to search result for later use
url_to_result = {result.url: result for result in search_results if result.url}

# Fetch all URLs
fetched_docs = await self.fetch_urls(urls, timeout_seconds)

# If preserve_snippets is True, create FetchedDocs for failed fetches using snippets
if preserve_snippets:
fetched_urls = {doc.url for doc in fetched_docs}
for url in urls:
if url not in fetched_urls and url in url_to_result:
result = url_to_result[url]
if result.snippet:
logger.info(f"Using snippet as fallback for {url}")
fetched_docs.append(FetchedDoc(
url=url,
title=result.title,
site_name=result.site_name or self._extract_site_name(url),
text=result.snippet,
published_at=result.published_at,
fetch_ms=0 # Indicate that this wasn't actually fetched
))

return fetched_docs