diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_ai_base_url_validation.py b/tests/test_ai_base_url_validation.py new file mode 100644 index 00000000..63161747 --- /dev/null +++ b/tests/test_ai_base_url_validation.py @@ -0,0 +1,239 @@ +"""Tests for the SSRF guard around the user-supplied AI provider baseUrl. + +Regression tests for issue #393. Without the guard, any caller can point +``/api/ai/test`` or ``/api/ai/chat`` at arbitrary internal hosts (cloud +metadata services, RFC1918 networks, link-local addresses, ...) and have +the UltraRAG backend fetch them. +""" + +from __future__ import annotations + +import socket + +import pytest + +from ui.backend._ai_base_url import validate_ai_base_url + + +# --- helpers --------------------------------------------------------------- + + +def _patch_resolver(monkeypatch, mapping): + """Force ``socket.getaddrinfo`` to return the given hostname → IPs map. + + Each value is a list of address strings (v4 and/or v6). Anything not in + ``mapping`` raises ``socket.gaierror`` so we never hit real DNS. + """ + + def fake_getaddrinfo(host, *_args, **_kwargs): + if host not in mapping: + raise socket.gaierror(8, f"nodename nor servname provided ({host})") + out = [] + for ip in mapping[host]: + family = socket.AF_INET6 if ":" in ip else socket.AF_INET + sockaddr = (ip, 0, 0, 0) if family == socket.AF_INET6 else (ip, 0) + out.append((family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", sockaddr)) + return out + + monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo) + + +@pytest.fixture(autouse=True) +def _clean_env(monkeypatch): + """Each test runs with a clean policy by default; tests opt into env knobs.""" + monkeypatch.delenv("ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE", raising=False) + monkeypatch.delenv("ULTRARAG_AI_BASE_URL_ALLOWLIST", raising=False) + + +# --- empty / shape errors -------------------------------------------------- + + +@pytest.mark.parametrize("value", ["", None, " "]) +def test_empty_base_url_rejected(value): + err = validate_ai_base_url(value) + assert err is not None + assert any(token in err.lower() for token in ("required", "scheme", "hostname")) + + +@pytest.mark.parametrize( + "value", + [ + "file:///etc/passwd", + "ftp://example.com/x", + "data:text/plain,hi", + "gopher://example.com/", + "javascript:alert(1)", + "//api.openai.com/v1", # missing scheme + ], +) +def test_non_http_scheme_rejected(value): + err = validate_ai_base_url(value) + assert err is not None + assert "scheme" in err.lower() or "hostname" in err.lower() + + +def test_missing_hostname_rejected(): + # urlparse can produce empty hostname for things like "http:///foo" + err = validate_ai_base_url("http:///models") + assert err is not None + + +# --- IP literals (the core SSRF scenarios) --------------------------------- + + +@pytest.mark.parametrize( + "url", + [ + # Issue #393's exact attack: AWS / GCP / Azure IMDS. + "http://169.254.169.254/latest", + "https://169.254.169.254/computeMetadata/v1/", + # Other link-local IPv4. + "http://169.254.5.5/", + ], +) +def test_imds_and_link_local_v4_rejected(url): + err = validate_ai_base_url(url) + assert err is not None + assert "disallowed" in err.lower() + + +def test_link_local_v6_literal_rejected(): + err = validate_ai_base_url("http://[fe80::1]/") + assert err is not None + assert "disallowed" in err.lower() + + +@pytest.mark.parametrize( + "url", + [ + "http://224.0.0.1/", # multicast + "http://0.0.0.0/", # unspecified + "http://[::]/", # unspecified v6 + "http://[ff02::1]/", # v6 multicast + ], +) +def test_multicast_and_unspecified_rejected(url): + err = validate_ai_base_url(url) + assert err is not None + + +def test_loopback_allowed_by_default(monkeypatch): + """Self-hosted RAG users run Ollama / vLLM / LM Studio on localhost.""" + _patch_resolver(monkeypatch, {}) + assert validate_ai_base_url("http://127.0.0.1:11434/v1") is None + assert validate_ai_base_url("http://[::1]:11434/v1") is None + + +def test_private_v4_allowed_by_default(monkeypatch): + _patch_resolver(monkeypatch, {}) + assert validate_ai_base_url("http://10.0.0.5/v1") is None + assert validate_ai_base_url("http://192.168.1.10/v1") is None + assert validate_ai_base_url("http://172.16.0.1/v1") is None + + +# --- DNS resolution (anti-rebinding) --------------------------------------- + + +def test_public_hostname_allowed(monkeypatch): + _patch_resolver(monkeypatch, {"api.openai.com": ["104.18.6.192"]}) + assert validate_ai_base_url("https://api.openai.com/v1") is None + + +def test_hostname_resolving_to_imds_rejected(monkeypatch): + """DNS rebinding: attacker controls a public hostname, points it at IMDS.""" + _patch_resolver(monkeypatch, {"evil.example.com": ["169.254.169.254"]}) + err = validate_ai_base_url("https://evil.example.com/v1") + assert err is not None + assert "disallowed" in err.lower() + + +def test_hostname_with_one_safe_one_unsafe_record_rejected(monkeypatch): + """If ANY resolved address is unsafe, reject — defends against + DNS rebinding where the response rotates between a public IP and + an internal one.""" + _patch_resolver( + monkeypatch, + {"mixed.example.com": ["1.1.1.1", "169.254.169.254"]}, + ) + err = validate_ai_base_url("https://mixed.example.com/") + assert err is not None + + +def test_hostname_with_aaaa_link_local_rejected(monkeypatch): + _patch_resolver(monkeypatch, {"v6evil.example.com": ["fe80::1"]}) + err = validate_ai_base_url("https://v6evil.example.com/") + assert err is not None + + +def test_unresolvable_hostname_rejected(monkeypatch): + _patch_resolver(monkeypatch, {}) + err = validate_ai_base_url("https://nope.nonexistent.invalid/v1") + assert err is not None + assert "could not be resolved" in err.lower() + + +# --- env-var policy knobs -------------------------------------------------- + + +def test_block_private_mode_rejects_loopback(monkeypatch): + monkeypatch.setenv("ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE", "1") + _patch_resolver(monkeypatch, {}) + err = validate_ai_base_url("http://127.0.0.1:11434/v1") + assert err is not None + + +def test_block_private_mode_rejects_rfc1918_via_dns(monkeypatch): + monkeypatch.setenv("ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE", "1") + _patch_resolver(monkeypatch, {"intranet.local": ["10.0.0.5"]}) + err = validate_ai_base_url("http://intranet.local/v1") + assert err is not None + + +def test_block_private_mode_still_allows_public(monkeypatch): + monkeypatch.setenv("ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE", "1") + _patch_resolver(monkeypatch, {"api.openai.com": ["104.18.6.192"]}) + assert validate_ai_base_url("https://api.openai.com/v1") is None + + +def test_allowlist_mode_accepts_listed_hostname(monkeypatch): + monkeypatch.setenv( + "ULTRARAG_AI_BASE_URL_ALLOWLIST", + "api.openai.com, api.anthropic.com", + ) + _patch_resolver(monkeypatch, {"api.openai.com": ["104.18.6.192"]}) + assert validate_ai_base_url("https://api.openai.com/v1") is None + + +def test_allowlist_mode_rejects_unlisted_hostname(monkeypatch): + monkeypatch.setenv( + "ULTRARAG_AI_BASE_URL_ALLOWLIST", + "api.openai.com", + ) + err = validate_ai_base_url("https://api.anthropic.com/v1") + assert err is not None + assert "ULTRARAG_AI_BASE_URL_ALLOWLIST" in err + + +def test_allowlist_mode_runs_before_dns(monkeypatch): + """Allowlist short-circuits — no DNS lookup for unlisted hosts. + + Important so that a hostile caller can't trigger DNS exfiltration + or pin the worker thread on a slow resolver in strict-mode deployments. + """ + called = {"n": 0} + + def boom(*args, **kwargs): + called["n"] += 1 + raise AssertionError("DNS resolver must not be called") + + monkeypatch.setenv("ULTRARAG_AI_BASE_URL_ALLOWLIST", "api.openai.com") + monkeypatch.setattr(socket, "getaddrinfo", boom) + err = validate_ai_base_url("https://attacker.example.com/") + assert err is not None + assert called["n"] == 0 + + +def test_allowlist_is_case_insensitive(monkeypatch): + monkeypatch.setenv("ULTRARAG_AI_BASE_URL_ALLOWLIST", "api.openai.com") + _patch_resolver(monkeypatch, {"api.openai.com": ["104.18.6.192"]}) + assert validate_ai_base_url("https://API.OpenAI.com/v1") is None diff --git a/ui/backend/_ai_base_url.py b/ui/backend/_ai_base_url.py new file mode 100644 index 00000000..dea3e64b --- /dev/null +++ b/ui/backend/_ai_base_url.py @@ -0,0 +1,159 @@ +"""SSRF guards for the user-supplied AI provider ``baseUrl``. + +The ``/api/ai/test`` and ``/api/ai/chat`` endpoints accept a ``baseUrl`` from +the request body and use it directly to construct outbound HTTP requests. +Without validation, a caller can point UltraRAG at: + +* loopback / private addresses on the host running UltraRAG, +* cloud instance metadata services (AWS / GCP / Azure ``169.254.169.254``, + link-local ``fe80::/10``), +* arbitrary internal hosts unreachable from the public internet. + +This module exposes :func:`validate_ai_base_url`, used by both endpoints. + +Default policy (matches self-hosted RAG expectations): + +* Only ``http://`` and ``https://`` schemes are accepted. +* Link-local, multicast, reserved and unspecified addresses are always + rejected (closing the IMDS attack from issue #393). +* Loopback and RFC1918 private addresses are allowed by default so that + Ollama / vLLM / LM Studio at ``localhost`` keep working. +* All A/AAAA records resolved for the hostname are checked, defending + against DNS rebinding where one record points at a public IP and a + second points at IMDS. + +Two opt-in env vars tighten the policy for production deployments: + +* ``ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE=1`` — also reject loopback / + private / shared / site-local addresses. +* ``ULTRARAG_AI_BASE_URL_ALLOWLIST=api.openai.com,api.anthropic.com,...`` + — only hostnames in this CSV are accepted; everything else fails fast + before DNS resolution. +""" + +from __future__ import annotations + +import ipaddress +import os +import socket +from typing import Optional, Sequence +from urllib.parse import urlparse + +ALLOWED_SCHEMES = ("http", "https") + + +def _is_unsafe_address(ip: ipaddress._BaseAddress, *, block_private: bool) -> bool: + """Decide whether ``ip`` should be blocked from outbound AI requests. + + Order matters: + + 1. Always reject link-local, multicast and the unspecified address — + these are never legitimate AI provider destinations and link-local + in particular is the IMDS attack from issue #393. + 2. In ``block_private`` (strict) mode, also reject loopback and + RFC1918 private — operators set this when the host has no + legitimate sibling AI service. + 3. Otherwise, *allow* loopback and private explicitly so that + Ollama / vLLM / LM Studio at ``localhost`` keep working. We do + this before the ``is_reserved`` check because in Python 3.12 + ``IPv6Address('::1').is_reserved`` is True (``::1`` sits inside + the reserved ``0::/8`` block) — without this short-circuit, + legitimate IPv6 loopback would be rejected. + 4. Reject any other reserved address (e.g. IPv4 240.0.0.0/4). + """ + if ip.is_link_local or ip.is_multicast or ip.is_unspecified: + return True + if block_private and (ip.is_loopback or ip.is_private): + return True + if ip.is_loopback or ip.is_private: + return False + if ip.is_reserved: + return True + return False + + +def _resolve_host(host: str) -> Sequence[ipaddress._BaseAddress]: + """Return all A/AAAA addresses for ``host``. + + Raises ``socket.gaierror`` on resolution failure so the caller can map + it to a user-facing error. + """ + infos = socket.getaddrinfo(host, None, proto=socket.IPPROTO_TCP) + addrs: list[ipaddress._BaseAddress] = [] + for info in infos: + sockaddr = info[4] + # IPv4: (ip, port) ; IPv6: (ip, port, flowinfo, scopeid) + ip_str = sockaddr[0] + try: + addrs.append(ipaddress.ip_address(ip_str)) + except ValueError: + continue + return addrs + + +def _read_allowlist() -> Optional[set[str]]: + raw = os.environ.get("ULTRARAG_AI_BASE_URL_ALLOWLIST", "") + items = {h.strip().lower() for h in raw.split(",") if h.strip()} + return items or None + + +def _read_block_private() -> bool: + return os.environ.get("ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE", "").lower() in ( + "1", + "true", + "yes", + "on", + ) + + +def validate_ai_base_url(base_url: str) -> Optional[str]: + """Return ``None`` when ``base_url`` is safe to fetch, else a reason string. + + The returned string is suitable for surfacing in an API ``error`` field. + """ + if not base_url or not isinstance(base_url, str): + return "baseUrl is required" + + parsed = urlparse(base_url.strip()) + scheme = (parsed.scheme or "").lower() + if scheme not in ALLOWED_SCHEMES: + return ( + f"baseUrl scheme must be http or https (got {parsed.scheme or 'empty'!r})" + ) + + host = parsed.hostname + if not host: + return "baseUrl is missing a hostname" + + allowlist = _read_allowlist() + if allowlist is not None and host.lower() not in allowlist: + return f"baseUrl host is not in ULTRARAG_AI_BASE_URL_ALLOWLIST: {host}" + + block_private = _read_block_private() + + # If the host is already an IP literal, validate it directly without + # touching DNS — bracketed IPv6 hosts come back from urlparse without + # the brackets, so ip_address() is happy. + try: + literal = ipaddress.ip_address(host) + except ValueError: + literal = None + + if literal is not None: + if _is_unsafe_address(literal, block_private=block_private): + return f"baseUrl host {host} resolves to a disallowed address ({literal})" + return None + + try: + addrs = _resolve_host(host) + except socket.gaierror as exc: + return f"baseUrl host {host} could not be resolved: {exc}" + + if not addrs: + return f"baseUrl host {host} did not resolve to any address" + + for ip in addrs: + if _is_unsafe_address(ip, block_private=block_private): + return f"baseUrl host {host} resolves to a disallowed address ({ip})" + + return None diff --git a/ui/backend/app.py b/ui/backend/app.py index 9051b007..afbf1260 100644 --- a/ui/backend/app.py +++ b/ui/backend/app.py @@ -29,6 +29,7 @@ from . import chat_store as chat_store_backend from . import kb_visibility_store as kb_visibility_backend from . import pipeline_manager as pm +from ._ai_base_url import validate_ai_base_url from .storage_paths import ( UI_MEMORY_ROOT_DIR, UI_STORAGE_ROOT, @@ -2363,6 +2364,10 @@ def test_ai_connection(): if not api_key: return jsonify({"success": False, "error": "API key is required"}) + url_error = validate_ai_base_url(base_url) + if url_error: + return jsonify({"success": False, "error": url_error}) + try: if provider == "openai" or provider == "custom": # OpenAI-compatible API @@ -2497,6 +2502,10 @@ def decode_sse_line(raw_line: Any) -> str: if not api_key: return jsonify({"error": "API key is required"}) + url_error = validate_ai_base_url(base_url) + if url_error: + return jsonify({"error": url_error}) + # Build system prompt with context system_prompt = build_ai_system_prompt(context)