Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tests/__init__.py
Empty file.
239 changes: 239 additions & 0 deletions tests/test_ai_base_url_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
"""Tests for the SSRF guard around the user-supplied AI provider baseUrl.

Regression tests for issue #393. Without the guard, any caller can point
``/api/ai/test`` or ``/api/ai/chat`` at arbitrary internal hosts (cloud
metadata services, RFC1918 networks, link-local addresses, ...) and have
the UltraRAG backend fetch them.
"""

from __future__ import annotations

import socket

import pytest

from ui.backend._ai_base_url import validate_ai_base_url


# --- helpers ---------------------------------------------------------------


def _patch_resolver(monkeypatch, mapping):
"""Force ``socket.getaddrinfo`` to return the given hostname → IPs map.

Each value is a list of address strings (v4 and/or v6). Anything not in
``mapping`` raises ``socket.gaierror`` so we never hit real DNS.
"""

def fake_getaddrinfo(host, *_args, **_kwargs):
if host not in mapping:
raise socket.gaierror(8, f"nodename nor servname provided ({host})")
out = []
for ip in mapping[host]:
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
sockaddr = (ip, 0, 0, 0) if family == socket.AF_INET6 else (ip, 0)
out.append((family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", sockaddr))
return out

monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo)


@pytest.fixture(autouse=True)
def _clean_env(monkeypatch):
"""Each test runs with a clean policy by default; tests opt into env knobs."""
monkeypatch.delenv("ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE", raising=False)
monkeypatch.delenv("ULTRARAG_AI_BASE_URL_ALLOWLIST", raising=False)


# --- empty / shape errors --------------------------------------------------


@pytest.mark.parametrize("value", ["", None, " "])
def test_empty_base_url_rejected(value):
err = validate_ai_base_url(value)
assert err is not None
assert any(token in err.lower() for token in ("required", "scheme", "hostname"))


@pytest.mark.parametrize(
"value",
[
"file:///etc/passwd",
"ftp://example.com/x",
"data:text/plain,hi",
"gopher://example.com/",
"javascript:alert(1)",
"//api.openai.com/v1", # missing scheme
],
)
def test_non_http_scheme_rejected(value):
err = validate_ai_base_url(value)
assert err is not None
assert "scheme" in err.lower() or "hostname" in err.lower()


def test_missing_hostname_rejected():
# urlparse can produce empty hostname for things like "http:///foo"
err = validate_ai_base_url("http:///models")
assert err is not None


# --- IP literals (the core SSRF scenarios) ---------------------------------


@pytest.mark.parametrize(
"url",
[
# Issue #393's exact attack: AWS / GCP / Azure IMDS.
"http://169.254.169.254/latest",
"https://169.254.169.254/computeMetadata/v1/",
# Other link-local IPv4.
"http://169.254.5.5/",
],
)
def test_imds_and_link_local_v4_rejected(url):
err = validate_ai_base_url(url)
assert err is not None
assert "disallowed" in err.lower()


def test_link_local_v6_literal_rejected():
err = validate_ai_base_url("http://[fe80::1]/")
assert err is not None
assert "disallowed" in err.lower()


@pytest.mark.parametrize(
"url",
[
"http://224.0.0.1/", # multicast
"http://0.0.0.0/", # unspecified
"http://[::]/", # unspecified v6
"http://[ff02::1]/", # v6 multicast
],
)
def test_multicast_and_unspecified_rejected(url):
err = validate_ai_base_url(url)
assert err is not None


def test_loopback_allowed_by_default(monkeypatch):
"""Self-hosted RAG users run Ollama / vLLM / LM Studio on localhost."""
_patch_resolver(monkeypatch, {})
assert validate_ai_base_url("http://127.0.0.1:11434/v1") is None
assert validate_ai_base_url("http://[::1]:11434/v1") is None


def test_private_v4_allowed_by_default(monkeypatch):
_patch_resolver(monkeypatch, {})
assert validate_ai_base_url("http://10.0.0.5/v1") is None
assert validate_ai_base_url("http://192.168.1.10/v1") is None
assert validate_ai_base_url("http://172.16.0.1/v1") is None


# --- DNS resolution (anti-rebinding) ---------------------------------------


def test_public_hostname_allowed(monkeypatch):
_patch_resolver(monkeypatch, {"api.openai.com": ["104.18.6.192"]})
assert validate_ai_base_url("https://api.openai.com/v1") is None


def test_hostname_resolving_to_imds_rejected(monkeypatch):
"""DNS rebinding: attacker controls a public hostname, points it at IMDS."""
_patch_resolver(monkeypatch, {"evil.example.com": ["169.254.169.254"]})
err = validate_ai_base_url("https://evil.example.com/v1")
assert err is not None
assert "disallowed" in err.lower()


def test_hostname_with_one_safe_one_unsafe_record_rejected(monkeypatch):
"""If ANY resolved address is unsafe, reject — defends against
DNS rebinding where the response rotates between a public IP and
an internal one."""
_patch_resolver(
monkeypatch,
{"mixed.example.com": ["1.1.1.1", "169.254.169.254"]},
)
err = validate_ai_base_url("https://mixed.example.com/")
assert err is not None


def test_hostname_with_aaaa_link_local_rejected(monkeypatch):
_patch_resolver(monkeypatch, {"v6evil.example.com": ["fe80::1"]})
err = validate_ai_base_url("https://v6evil.example.com/")
assert err is not None


def test_unresolvable_hostname_rejected(monkeypatch):
_patch_resolver(monkeypatch, {})
err = validate_ai_base_url("https://nope.nonexistent.invalid/v1")
assert err is not None
assert "could not be resolved" in err.lower()


# --- env-var policy knobs --------------------------------------------------


def test_block_private_mode_rejects_loopback(monkeypatch):
monkeypatch.setenv("ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE", "1")
_patch_resolver(monkeypatch, {})
err = validate_ai_base_url("http://127.0.0.1:11434/v1")
assert err is not None


def test_block_private_mode_rejects_rfc1918_via_dns(monkeypatch):
monkeypatch.setenv("ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE", "1")
_patch_resolver(monkeypatch, {"intranet.local": ["10.0.0.5"]})
err = validate_ai_base_url("http://intranet.local/v1")
assert err is not None


def test_block_private_mode_still_allows_public(monkeypatch):
monkeypatch.setenv("ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE", "1")
_patch_resolver(monkeypatch, {"api.openai.com": ["104.18.6.192"]})
assert validate_ai_base_url("https://api.openai.com/v1") is None


def test_allowlist_mode_accepts_listed_hostname(monkeypatch):
monkeypatch.setenv(
"ULTRARAG_AI_BASE_URL_ALLOWLIST",
"api.openai.com, api.anthropic.com",
)
_patch_resolver(monkeypatch, {"api.openai.com": ["104.18.6.192"]})
assert validate_ai_base_url("https://api.openai.com/v1") is None


def test_allowlist_mode_rejects_unlisted_hostname(monkeypatch):
monkeypatch.setenv(
"ULTRARAG_AI_BASE_URL_ALLOWLIST",
"api.openai.com",
)
err = validate_ai_base_url("https://api.anthropic.com/v1")
assert err is not None
assert "ULTRARAG_AI_BASE_URL_ALLOWLIST" in err


def test_allowlist_mode_runs_before_dns(monkeypatch):
"""Allowlist short-circuits — no DNS lookup for unlisted hosts.

Important so that a hostile caller can't trigger DNS exfiltration
or pin the worker thread on a slow resolver in strict-mode deployments.
"""
called = {"n": 0}

def boom(*args, **kwargs):
called["n"] += 1
raise AssertionError("DNS resolver must not be called")

monkeypatch.setenv("ULTRARAG_AI_BASE_URL_ALLOWLIST", "api.openai.com")
monkeypatch.setattr(socket, "getaddrinfo", boom)
err = validate_ai_base_url("https://attacker.example.com/")
assert err is not None
assert called["n"] == 0


def test_allowlist_is_case_insensitive(monkeypatch):
monkeypatch.setenv("ULTRARAG_AI_BASE_URL_ALLOWLIST", "api.openai.com")
_patch_resolver(monkeypatch, {"api.openai.com": ["104.18.6.192"]})
assert validate_ai_base_url("https://API.OpenAI.com/v1") is None
159 changes: 159 additions & 0 deletions ui/backend/_ai_base_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""SSRF guards for the user-supplied AI provider ``baseUrl``.

The ``/api/ai/test`` and ``/api/ai/chat`` endpoints accept a ``baseUrl`` from
the request body and use it directly to construct outbound HTTP requests.
Without validation, a caller can point UltraRAG at:

* loopback / private addresses on the host running UltraRAG,
* cloud instance metadata services (AWS / GCP / Azure ``169.254.169.254``,
link-local ``fe80::/10``),
* arbitrary internal hosts unreachable from the public internet.

This module exposes :func:`validate_ai_base_url`, used by both endpoints.

Default policy (matches self-hosted RAG expectations):

* Only ``http://`` and ``https://`` schemes are accepted.
* Link-local, multicast, reserved and unspecified addresses are always
rejected (closing the IMDS attack from issue #393).
* Loopback and RFC1918 private addresses are allowed by default so that
Ollama / vLLM / LM Studio at ``localhost`` keep working.
* All A/AAAA records resolved for the hostname are checked, defending
against DNS rebinding where one record points at a public IP and a
second points at IMDS.

Two opt-in env vars tighten the policy for production deployments:

* ``ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE=1`` — also reject loopback /
private / shared / site-local addresses.
* ``ULTRARAG_AI_BASE_URL_ALLOWLIST=api.openai.com,api.anthropic.com,...``
— only hostnames in this CSV are accepted; everything else fails fast
before DNS resolution.
"""

from __future__ import annotations

import ipaddress
import os
import socket
from typing import Optional, Sequence
from urllib.parse import urlparse

ALLOWED_SCHEMES = ("http", "https")


def _is_unsafe_address(ip: ipaddress._BaseAddress, *, block_private: bool) -> bool:
"""Decide whether ``ip`` should be blocked from outbound AI requests.

Order matters:

1. Always reject link-local, multicast and the unspecified address —
these are never legitimate AI provider destinations and link-local
in particular is the IMDS attack from issue #393.
2. In ``block_private`` (strict) mode, also reject loopback and
RFC1918 private — operators set this when the host has no
legitimate sibling AI service.
3. Otherwise, *allow* loopback and private explicitly so that
Ollama / vLLM / LM Studio at ``localhost`` keep working. We do
this before the ``is_reserved`` check because in Python 3.12
``IPv6Address('::1').is_reserved`` is True (``::1`` sits inside
the reserved ``0::/8`` block) — without this short-circuit,
legitimate IPv6 loopback would be rejected.
4. Reject any other reserved address (e.g. IPv4 240.0.0.0/4).
"""
if ip.is_link_local or ip.is_multicast or ip.is_unspecified:
return True
if block_private and (ip.is_loopback or ip.is_private):
return True
if ip.is_loopback or ip.is_private:
return False
if ip.is_reserved:
return True
return False


def _resolve_host(host: str) -> Sequence[ipaddress._BaseAddress]:
"""Return all A/AAAA addresses for ``host``.

Raises ``socket.gaierror`` on resolution failure so the caller can map
it to a user-facing error.
"""
infos = socket.getaddrinfo(host, None, proto=socket.IPPROTO_TCP)
addrs: list[ipaddress._BaseAddress] = []
for info in infos:
sockaddr = info[4]
# IPv4: (ip, port) ; IPv6: (ip, port, flowinfo, scopeid)
ip_str = sockaddr[0]
try:
addrs.append(ipaddress.ip_address(ip_str))
except ValueError:
continue
return addrs


def _read_allowlist() -> Optional[set[str]]:
raw = os.environ.get("ULTRARAG_AI_BASE_URL_ALLOWLIST", "")
items = {h.strip().lower() for h in raw.split(",") if h.strip()}
return items or None


def _read_block_private() -> bool:
return os.environ.get("ULTRARAG_AI_BASE_URL_BLOCK_PRIVATE", "").lower() in (
"1",
"true",
"yes",
"on",
)


def validate_ai_base_url(base_url: str) -> Optional[str]:
"""Return ``None`` when ``base_url`` is safe to fetch, else a reason string.

The returned string is suitable for surfacing in an API ``error`` field.
"""
if not base_url or not isinstance(base_url, str):
return "baseUrl is required"

parsed = urlparse(base_url.strip())
scheme = (parsed.scheme or "").lower()
if scheme not in ALLOWED_SCHEMES:
return (
f"baseUrl scheme must be http or https (got {parsed.scheme or 'empty'!r})"
)

host = parsed.hostname
if not host:
return "baseUrl is missing a hostname"

allowlist = _read_allowlist()
if allowlist is not None and host.lower() not in allowlist:
return f"baseUrl host is not in ULTRARAG_AI_BASE_URL_ALLOWLIST: {host}"

block_private = _read_block_private()

# If the host is already an IP literal, validate it directly without
# touching DNS — bracketed IPv6 hosts come back from urlparse without
# the brackets, so ip_address() is happy.
try:
literal = ipaddress.ip_address(host)
except ValueError:
literal = None

if literal is not None:
if _is_unsafe_address(literal, block_private=block_private):
return f"baseUrl host {host} resolves to a disallowed address ({literal})"
return None

try:
addrs = _resolve_host(host)
except socket.gaierror as exc:
return f"baseUrl host {host} could not be resolved: {exc}"

if not addrs:
return f"baseUrl host {host} did not resolve to any address"

for ip in addrs:
if _is_unsafe_address(ip, block_private=block_private):
return f"baseUrl host {host} resolves to a disallowed address ({ip})"

return None
Loading