Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions pathwaysutils/collect_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for collecting JAX profiles for Pathways on Cloud.

This is a replacement for the `collect_profile` script in JAX that works with
Pathways on Cloud.
"""

import argparse
import logging

from pathwaysutils import profiling

_logger = logging.getLogger(__name__)


_DESCRIPTION = """
To profile running JAX programs, you first need to start the profiler server
in the program of interest. You can do this via
`jax.profiler.start_server(<port>)`. Once the program is running and the
profiler server has started, you can run `collect_profile` to trace the execution
for a provided duration. The trace file will be dumped into a GCS bucket
(determined by `--log_dir`).
"""
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument(
"--log_dir",
required=True,
help="GCS path to store log files.",
type=str,
)
parser.add_argument("port", help="Port to collect trace", type=int)
parser.add_argument(
"duration_ms", help="Duration to collect trace in milliseconds", type=int
)
parser.add_argument(
"--host",
default="127.0.0.1",
help=(
"Host to collect trace. This host IP/DNS address should be accessible"
" from where this API is being called. Defaults to 127.0.0.1"
),
type=str,
)


def main(args):
if profiling.collect_profile(
args.port, args.duration_ms, args.host, args.log_dir
):
_logger.info("Dumped profiling information in: %s", args.log_dir)
else:
_logger.error("Failed to collect profiling information.")

if __name__ == "__main__":
main(parser.parse_args())
69 changes: 59 additions & 10 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,28 @@

import dataclasses
import logging
import os
import pathlib
import tempfile
import threading
import time
import urllib.parse

import fastapi
import jax
from jax import numpy as jnp
from pathwaysutils import plugin_executable
import requests
import uvicorn

logger = logging.getLogger(__name__)

_logger = logging.getLogger(__name__)


class _ProfileState:
executable: plugin_executable.PluginExecutable | None = None
lock: threading.Lock

def __init__(self):
self.executable = None
self.lock = threading.Lock()
Expand Down Expand Up @@ -88,7 +97,7 @@ def stop_trace():
_original_stop_trace()


_profiler_thread = None
_profiler_thread: threading.Thread | None = None


def start_server(port: int):
Expand All @@ -102,7 +111,7 @@ def start_server(port: int):
port : The port to start the server on.
"""
def server_loop(port: int):
logger.debug("Starting JAX profiler server on port %s", port)
_logger.debug("Starting JAX profiler server on port %s", port)
app = fastapi.FastAPI()

@dataclasses.dataclass
Expand All @@ -112,14 +121,14 @@ class ProfilingConfig:

@app.post("/profiling")
async def profiling(pc: ProfilingConfig): # pylint: disable=unused-variable
logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
logger.debug("Writing profiling data to %s", pc.repository_path)
_logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
_logger.debug("Writing profiling data to %s", pc.repository_path)
jax.profiler.start_trace(pc.repository_path)
time.sleep(pc.duration_ms / 1e3)
jax.profiler.stop_trace()
return {"response": "profiling completed"}

uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
uvicorn.run(app, host="0.0.0.0", port=port, log_level="debug")

global _profiler_thread
if _profiler_thread is not None:
Expand All @@ -138,6 +147,44 @@ def stop_server():
raise ValueError("No active profiler server.")


def collect_profile(
port: int,
duration_ms: int,
host: str,
log_dir: str,
) -> bool:
"""Collects a JAX profile and saves it to the specified directory.

Args:
port: The port on which the JAX profiler server is running.
duration_ms: The duration in milliseconds for which to collect the profile.
host: The host on which the JAX profiler server is running.
log_dir: The GCS path to save the profile data.

Returns:
True if the profile was collected successfully, False otherwise.

Raises:
ValueError: If the log_dir is not a GCS path.
"""
if not log_dir.startswith("gs://"):
raise ValueError("log_dir must be a GCS path.")

json = {
"duration_ms": duration_ms,
"repository_path": log_dir,
}
address = urllib.parse.urljoin(f"http://{host}:{port}", "profiling")
try:
response = requests.post(address, json=json)
response.raise_for_status()
except requests.exceptions.RequestException as e:
_logger.error("Failed to collect profiling data: %s", e)
return False

return True


def monkey_patch_jax():
"""Monkey patches JAX with Pathways versions of functions.

Expand All @@ -158,25 +205,27 @@ def start_trace_patch(
create_perfetto_link: bool = False, # pylint: disable=unused-argument
create_perfetto_trace: bool = False, # pylint: disable=unused-argument
) -> None:
logger.debug("jax.profile.start_trace patched with pathways' start_trace")
_logger.debug("jax.profile.start_trace patched with pathways' start_trace")
return start_trace(log_dir)

jax.profiler.start_trace = start_trace_patch

def stop_trace_patch() -> None:
logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
_logger.debug("jax.profile.stop_trace patched with pathways' stop_trace")
return stop_trace()

jax.profiler.stop_trace = stop_trace_patch

def start_server_patch(port: int):
logger.debug("jax.profile.start_server patched with pathways' start_server")
_logger.debug(
"jax.profile.start_server patched with pathways' start_server"
)
return start_server(port)

jax.profiler.start_server = start_server_patch

def stop_server_patch():
logger.debug("jax.profile.stop_server patched with pathways' stop_server")
_logger.debug("jax.profile.stop_server patched with pathways' stop_server")
return stop_server()

jax.profiler.stop_server = stop_server_patch
147 changes: 147 additions & 0 deletions pathwaysutils/test/profiling_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

from pathwaysutils import profiling
import requests

from absl.testing import absltest
from absl.testing import parameterized


class ProfilingTest(parameterized.TestCase):
"""Tests for Pathways on Cloud profiling."""

def setUp(self):
super().setUp()
self.mock_post = self.enter_context(
mock.patch.object(requests, "post", autospec=True)
)

@parameterized.parameters(8000, 1234)
def test_collect_profile_port(self, port):
profiling.collect_profile(
port=port,
duration_ms=1000,
host="127.0.0.1",
log_dir="gs://test_bucket/test_dir",
)

self.mock_post.assert_called_once_with(
f"http://127.0.0.1:{port}/profiling",
json={
"duration_ms": 1000,
"repository_path": "gs://test_bucket/test_dir",
},
)

@parameterized.parameters(1000, 1234)
def test_collect_profile_duration_ms(self, duration_ms):
profiling.collect_profile(
port=8000,
duration_ms=duration_ms,
host="127.0.0.1",
log_dir="gs://test_bucket/test_dir",
)

self.mock_post.assert_called_once_with(
"http://127.0.0.1:8000/profiling",
json={
"duration_ms": duration_ms,
"repository_path": "gs://test_bucket/test_dir",
},
)

@parameterized.parameters("127.0.0.1", "localhost", "192.168.1.1")
def test_collect_profile_host(self, host):
profiling.collect_profile(
port=8000,
duration_ms=1000,
host=host,
log_dir="gs://test_bucket/test_dir",
)

self.mock_post.assert_called_once_with(
f"http://{host}:8000/profiling",
json={
"duration_ms": 1000,
"repository_path": "gs://test_bucket/test_dir",
},
)

@parameterized.parameters(
"gs://test_bucket/test_log_dir",
"gs://test_bucket2",
"gs://test_bucket3/test/log/dir",
)
def test_collect_profile_log_dir(self, log_dir):
profiling.collect_profile(
port=8000, duration_ms=1000, host="127.0.0.1", log_dir=log_dir
)

self.mock_post.assert_called_once_with(
"http://127.0.0.1:8000/profiling",
json={
"duration_ms": 1000,
"repository_path": log_dir,
},
)

@parameterized.parameters("/logs/test_log_dir", "relative_path/my_log_dir")
def test_collect_profile_log_dir_error(self, log_dir):
with self.assertRaises(ValueError):
profiling.collect_profile(
port=8000, duration_ms=1000, host="127.0.0.1", log_dir=log_dir
)

@parameterized.parameters(
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.TooManyRedirects,
requests.exceptions.RequestException,
requests.exceptions.HTTPError,
)
def test_collect_profile_request_error(self, exception_type):
self.mock_post.side_effect = exception_type

result = profiling.collect_profile(
port=8000,
duration_ms=1000,
host="127.0.0.1",
log_dir="gs://test_bucket/test_dir",
)

self.assertFalse(result)
self.mock_post.assert_called_once()

def test_collect_profile_success(self):
mock_response = mock.Mock()
mock_response.raise_for_status.return_value = None
self.mock_post.return_value = mock_response

result = profiling.collect_profile(
port=8000,
duration_ms=1000,
host="127.0.0.1",
log_dir="gs://test_bucket/test_dir",
)

self.assertTrue(result)
self.mock_post.assert_called_once()
mock_response.raise_for_status.assert_called_once()


if __name__ == "__main__":
absltest.main()