Skip to content

Commit

Permalink
feat: Add ASGI root path parameter to Phoenix server (#3186)
Browse files Browse the repository at this point in the history
* Add ASGI root path parameter to Phoenix server

Allow for a root path parameter so that Phoenix can be hosted behind a proxy with a path prefix

* Name root path parameter more specifically

* Fix welcome message

* Enable client code to handle endpoints with path prefix

* Fix wrong endpoint in evaluation module

* Guarantee training slashes in host base URLs

* Remove host root path CLI parameter

CLI parameters for settings that are naturally adjusted via environment variables are discouraged for Phoenix

* Apply pre-commit fixes

* Fix and adjust unit tests
  • Loading branch information
sbachstein committed May 17, 2024
1 parent 9f229b6 commit e27cc5d
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 32 deletions.
7 changes: 7 additions & 0 deletions src/phoenix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ENV_PHOENIX_PORT = "PHOENIX_PORT"
ENV_PHOENIX_GRPC_PORT = "PHOENIX_GRPC_PORT"
ENV_PHOENIX_HOST = "PHOENIX_HOST"
ENV_PHOENIX_HOST_ROOT_PATH = "PHOENIX_HOST_ROOT_PATH"
ENV_NOTEBOOK_ENV = "PHOENIX_NOTEBOOK_ENV"
ENV_PHOENIX_COLLECTOR_ENDPOINT = "PHOENIX_COLLECTOR_ENDPOINT"
"""
Expand Down Expand Up @@ -98,6 +99,8 @@ def get_working_dir() -> Path:
"""The host the server will run on after launch_app is called."""
PORT = 6006
"""The port the server will run on after launch_app is called."""
HOST_ROOT_PATH = ""
"""The ASGI root path of the server, i.e. the root path where the web application is mounted"""
GRPC_PORT = 4317
"""The port the gRPC server will run on after launch_app is called.
The default network port for OTLP/gRPC is 4317.
Expand Down Expand Up @@ -183,6 +186,10 @@ def get_env_host() -> str:
return os.getenv(ENV_PHOENIX_HOST) or HOST


def get_env_host_root_path() -> str:
return os.getenv(ENV_PHOENIX_HOST_ROOT_PATH) or HOST_ROOT_PATH


def get_env_collector_endpoint() -> Optional[str]:
return os.getenv(ENV_PHOENIX_COLLECTOR_ENDPOINT)

Expand Down
10 changes: 6 additions & 4 deletions src/phoenix/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_env_enable_prometheus,
get_env_grpc_port,
get_env_host,
get_env_host_root_path,
get_env_port,
get_pids_path,
get_working_dir,
Expand Down Expand Up @@ -67,10 +68,10 @@
| https://docs.arize.com/phoenix
|
| 🚀 Phoenix Server 🚀
| Phoenix UI: http://{host}:{port}
| Phoenix UI: http://{host}:{port}/{root_path}
| Log traces:
| - gRPC: http://{host}:{grpc_port}
| - HTTP: http://{host}:{port}/v1/traces
| - HTTP: http://{host}:{port}/{root_path}/v1/traces
| Storage: {storage}
"""

Expand Down Expand Up @@ -197,7 +198,7 @@ def _get_pid_file() -> Path:
host = None

port = args.port or get_env_port()

host_root_path = get_env_host_root_path()
model = create_model_from_datasets(
primary_dataset,
reference_dataset,
Expand Down Expand Up @@ -251,7 +252,7 @@ def _get_pid_file() -> Path:
initial_spans=fixture_spans,
initial_evaluations=fixture_evals,
)
server = Server(config=Config(app, host=host, port=port)) # type: ignore
server = Server(config=Config(app, host=host, port=port, root_path=host_root_path)) # type: ignore
Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start()

# Print information about the server
Expand All @@ -260,6 +261,7 @@ def _get_pid_file() -> Path:
"version": phoenix_version,
"host": display_host,
"port": port,
"root_path": host_root_path.strip("/"),
"grpc_port": get_env_grpc_port(),
"storage": get_printable_db_url(db_connection_str),
}
Expand Down
16 changes: 8 additions & 8 deletions src/phoenix/session/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def __init__(
host = get_env_host()
if host == "0.0.0.0":
host = "127.0.0.1"
self._base_url = (
endpoint or get_env_collector_endpoint() or f"http://{host}:{get_env_port()}"
)
base_url = endpoint or get_env_collector_endpoint() or f"http://{host}:{get_env_port()}"
self._base_url = base_url if base_url.endswith("/") else base_url + "/"

self._session = Session()
weakref.finalize(self, self._session.close)
if warn_if_server_not_running:
Expand Down Expand Up @@ -99,7 +99,7 @@ def query_spans(
)
end_time = end_time or stop_time
response = self._session.post(
url=urljoin(self._base_url, "/v1/spans"),
url=urljoin(self._base_url, "v1/spans"),
params={"project-name": project_name},
json={
"queries": [q.to_dict() for q in queries],
Expand Down Expand Up @@ -146,7 +146,7 @@ def get_evaluations(
"""
project_name = project_name or get_env_project_name()
response = self._session.get(
urljoin(self._base_url, "/v1/evaluations"),
urljoin(self._base_url, "v1/evaluations"),
params={"project-name": project_name},
)
if response.status_code == 404:
Expand All @@ -167,7 +167,7 @@ def get_evaluations(

def _warn_if_phoenix_is_not_running(self) -> None:
try:
self._session.get(urljoin(self._base_url, "/arize_phoenix_version")).raise_for_status()
self._session.get(urljoin(self._base_url, "arize_phoenix_version")).raise_for_status()
except Exception:
logger.warning(
f"Arize Phoenix is not running on {self._base_url}. Launch Phoenix "
Expand Down Expand Up @@ -198,7 +198,7 @@ def log_evaluations(self, *evals: Evaluations, **kwargs: Any) -> None:
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
self._session.post(
urljoin(self._base_url, "/v1/evaluations"),
urljoin(self._base_url, "v1/evaluations"),
data=cast(bytes, sink.getvalue().to_pybytes()),
headers=headers,
).raise_for_status()
Expand Down Expand Up @@ -241,7 +241,7 @@ def log_traces(self, trace_dataset: TraceDataset, project_name: Optional[str] =
serialized = otlp_span.SerializeToString()
data = gzip.compress(serialized)
self._session.post(
urljoin(self._base_url, "/v1/traces"),
urljoin(self._base_url, "v1/traces"),
data=data,
headers={
"content-type": "application/x-protobuf",
Expand Down
6 changes: 1 addition & 5 deletions src/phoenix/session/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Union,
cast,
)
from urllib.parse import urljoin

import pandas as pd
from google.protobuf.wrappers_pb2 import DoubleValue, StringValue
Expand Down Expand Up @@ -147,8 +146,5 @@ def log_evaluations(
if host == "0.0.0.0":
host = "127.0.0.1"
port = port or get_env_port()
endpoint = endpoint or urljoin(
get_env_collector_endpoint() or f"http://{host}:{port}",
"/v1/traces",
)
endpoint = endpoint or get_env_collector_endpoint() or f"http://{host}:{port}"
Client(endpoint=endpoint).log_evaluations(*evals)
21 changes: 11 additions & 10 deletions src/phoenix/trace/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def __init__(self) -> None:
host = get_env_host()
if host == "0.0.0.0":
host = "127.0.0.1"
endpoint = urljoin(
get_env_collector_endpoint() or f"http://{host}:{get_env_port()}",
"/v1/traces",
)
_warn_if_phoenix_is_not_running(endpoint)
base_url = get_env_collector_endpoint() or f"http://{host}:{get_env_port()}"
base_url = base_url if base_url.endswith("/") else base_url + "/"
_warn_if_phoenix_is_not_running(base_url)

endpoint = urljoin(base_url, "v1/traces")
super().__init__(endpoint)


Expand Down Expand Up @@ -68,11 +68,12 @@ def __init__(
"""
self._host = host or get_env_host()
self._port = port or get_env_port()
self._base_url = (
base_url = (
endpoint
or get_env_collector_endpoint()
or f"http://{'127.0.0.1' if self._host == '0.0.0.0' else self._host}:{self._port}"
)
self._base_url = base_url if base_url.endswith("/") else base_url + "/"
_warn_if_phoenix_is_not_running(self._base_url)
self._session = Session()
weakref.finalize(self, self._session.close)
Expand Down Expand Up @@ -117,16 +118,16 @@ def _send(self, message: Message) -> None:

def _url(self, message: Message) -> str:
if isinstance(message, pb.Evaluation):
return urljoin(self._base_url, "/v1/evaluations")
return urljoin(self._base_url, "v1/evaluations")
logger.exception(f"unrecognized message type: {type(message)}")
assert_never(message)


def _warn_if_phoenix_is_not_running(endpoint: str) -> None:
def _warn_if_phoenix_is_not_running(base_url: str) -> None:
try:
requests.get(urljoin(endpoint, "/arize_phoenix_version")).raise_for_status()
requests.get(urljoin(base_url, "arize_phoenix_version")).raise_for_status()
except Exception:
logger.warning(
f"Arize Phoenix is not running on {endpoint}. Launch Phoenix "
f"Arize Phoenix is not running on {base_url}. Launch Phoenix "
f"with `import phoenix as px; px.launch_app()`"
)
23 changes: 23 additions & 0 deletions tests/session/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,29 @@
from phoenix.trace.trace_dataset import TraceDataset


def test_base_path(monkeypatch: pytest.MonkeyPatch):
# Reset environment variables
monkeypatch.delenv("PHOENIX_HOST", False)
monkeypatch.delenv("PHOENIX_PORT", False)
monkeypatch.delenv("PHOENIX_COLLECTOR_ENDPOINT", False)

# Test that host and port environment variables are interpreted correctly
monkeypatch.setenv("PHOENIX_HOST", "my-host")
monkeypatch.setenv("PHOENIX_PORT", "1234")
client = Client()
assert client._base_url == "http://my-host:1234/"

# Test that a collector endpoint environment variables takes precedence
monkeypatch.setenv("PHOENIX_COLLECTOR_ENDPOINT", "http://my-collector-endpoint/with/prefix")
client = Client()
assert client._base_url == "http://my-collector-endpoint/with/prefix/"

# Test a given endpoint takes precedence over environment variables
endpoint = "https://other-collector-endpoint/with/other/prefix"
client = Client(endpoint=endpoint)
assert client._base_url == "https://other-collector-endpoint/with/other/prefix/"


@responses.activate
def test_get_spans_dataframe(client: Client, endpoint: str, dataframe: pd.DataFrame):
url = urljoin(endpoint, "v1/spans")
Expand Down
15 changes: 10 additions & 5 deletions tests/trace/test_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,24 @@ def test_exporter(monkeypatch: pytest.MonkeyPatch):
# Test that it defaults to local
monkeypatch.delenv("PHOENIX_COLLECTOR_ENDPOINT", False)
exporter = HttpExporter()
assert exporter._base_url == f"http://127.0.0.1:{PORT}"
assert exporter._base_url == f"http://127.0.0.1:{PORT}/"

# Test that you can configure host and port
host, port = "abcd", 1234
exporter = HttpExporter(host=host, port=port)
assert exporter._base_url == f"http://{host}:{port}"
assert exporter._base_url == f"http://{host}:{port}/"

# Test that you can configure an endpoint
endpoint = "https://my-phoenix-server.com/"
endpoint = "https://my-phoenix-server.com"
exporter = HttpExporter(endpoint=endpoint)
assert exporter._base_url == endpoint
assert exporter._base_url == "https://my-phoenix-server.com/"

# Test that it supports environment variables
monkeypatch.setenv("PHOENIX_COLLECTOR_ENDPOINT", endpoint)
exporter = HttpExporter()
assert exporter._base_url == endpoint
assert exporter._base_url == "https://my-phoenix-server.com/"

# Test that an endpoint with a root path prefix is interpreted correctly
endpoint = "https://my-phoenix-server.com/my/root/path"
exporter = HttpExporter(endpoint=endpoint)
assert exporter._base_url == "https://my-phoenix-server.com/my/root/path/"

0 comments on commit e27cc5d

Please sign in to comment.