Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions axlearn/cloud/gcp/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@ class Config(measurement.Recorder.Config):
upload_interval: Required[int] = REQUIRED
rolling_window_size: Sequence[int] = []
jax_backend: Optional[str] = None
# Disabled by default because of performance degradation. This doesn't disable goodput
# recording.
# TODO (apolloreno): once the performance degradation is fixed, will change default to True
enable_monitoring: bool = False
# Enable or disable monitoring. Recording is always enabled.
enable_monitoring: bool = True
Comment thread
dipannita08 marked this conversation as resolved.
Comment thread
dipannita08 marked this conversation as resolved.

@classmethod
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
Expand All @@ -77,7 +75,7 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
- rolling_window_size: Comma-separated list of integers representing rolling window
sizes in seconds.
- jax_backend: The type of jax backend.
- enable_monitoring: Boolean to enable/disable goodput monitoring (default: false).
- enable_monitoring: Boolean to enable/disable goodput monitoring (default: true).
"""
cfg: measurement.Recorder.Config = cls.default_config()
parsed_flags = parse_kv_flags(fv.recorder_spec, delimiter="=")
Expand Down
24 changes: 12 additions & 12 deletions axlearn/cloud/gcp/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_from_flags(
recorder_spec,
expected_rolling_window_size,
expected_jax_backend,
expected_enable_monitoring=False,
expected_enable_monitoring=True,
):
"""Tests that flags are correctly parsed into the config."""
mock_fv = mock.MagicMock(spec=flags.FlagValues)
Expand Down Expand Up @@ -403,35 +403,35 @@ def test_maybe_monitor_all(
mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_not_called()

@mock.patch("jax.process_index", return_value=0)
def test_enable_monitoring_disabled_by_default(self, _):
"""Tests that monitoring is disabled by default (enable_monitoring=False)."""
def test_enable_monitoring_enabled_by_default(self, _):
"""Tests that monitoring is enabled by default (enable_monitoring=True)."""
cfg = GoodputRecorder.default_config().set(
name="test-disabled",
name="test-default-enabled",
upload_dir="/test",
upload_interval=30,
# enable_monitoring defaults to False
# Enable_monitoring defaults to True.
rolling_window_size=[10, 20],
)
recorder = GoodputRecorder(cfg)

# Verify the flag defaults to False
self.assertFalse(recorder.config.enable_monitoring)
# Verify the flag defaults to True
self.assertTrue(recorder.config.enable_monitoring)

with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls:
# Test that cumulative goodput monitoring is skipped
# Test that cumulative goodput monitoring is active by default.
with recorder._maybe_monitor_goodput():
pass
mock_monitor_cls.assert_not_called()
mock_monitor_cls.assert_called()

# Test that rolling window monitoring is skipped
# Test that rolling window monitoring is active by default.
with recorder._maybe_monitor_rolling_window_goodput():
pass
mock_monitor_cls.assert_not_called()
mock_monitor_cls.assert_called()

# Test that maybe_monitor_all is skipped
with recorder.maybe_monitor_all():
pass
mock_monitor_cls.assert_not_called()
mock_monitor_cls.assert_called()

@mock.patch("jax.process_index", return_value=0)
def test_enable_monitoring_explicitly_disabled(self, _):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ gcp = [
"google-cloud-compute==1.19.2", # Needed for region discovery for CloudBuild API access.
"google-cloud-core==2.3.3",
"google-cloud-build==3.24.1",
"ml-goodput-measurement==0.0.14",
"ml-goodput-measurement==0.0.15",
"pika==1.3.2", # used by event queue
"pyOpenSSL>=22.1.0", # compat with cryptography version.
"tpu-info==0.2.0", # For TPU monitoring from libtpu. https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info
Expand Down