From a935c3968ceeb32db1e13e4f7e5a3a6bbb9d8bac Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Mon, 4 May 2026 19:00:45 +0000 Subject: [PATCH] Respect enable_tensorboard=False --- src/maxtext/common/gcloud_stub.py | 47 +++++++++++++++-------------- src/maxtext/common/metric_logger.py | 4 ++- src/maxtext/utils/max_utils.py | 8 +++-- 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/src/maxtext/common/gcloud_stub.py b/src/maxtext/common/gcloud_stub.py index a102629ca5..2ecc96bac7 100644 --- a/src/maxtext/common/gcloud_stub.py +++ b/src/maxtext/common/gcloud_stub.py @@ -609,6 +609,29 @@ def _import(): # ------------------------- TensorBoardX -------------------------- + +class StubSummaryWriter: + """Stubbed TensorBoardX SummaryWriter replacement.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + del args, kwargs + + def add_text(self, *args, **kwargs): + pass + + def add_scalar(self, *args, **kwargs): + pass + + def add_histogram(self, *args, **kwargs): + pass + + def flush(self): + pass + + def close(self): + pass + + try: if not is_decoupled(): # Only attempt real import when not decoupled from tensorboardX import writer # type: ignore # pylint: disable=import-outside-toplevel,unused-import @@ -619,30 +642,10 @@ def _import(): except Exception: # pragma: no cover - provide stub fallback # pylint: disable=broad-exception-caught _TENSORBOARDX_AVAILABLE = False - class _StubSummaryWriter: - """Stubbed TensorBoardX SummaryWriter replacement.""" - - def __init__(self, *args, **kwargs): # pylint: disable=unused-argument - del args, kwargs - - def add_text(self, *args, **kwargs): - pass - - def add_scalar(self, *args, **kwargs): - pass - - def add_histogram(self, *args, **kwargs): - pass - - def flush(self): - pass - - def close(self): - pass - class writer: # pylint: disable=too-few-public-methods - SummaryWriter = _StubSummaryWriter + SummaryWriter = StubSummaryWriter __all__.append("writer") __all__.append("_TENSORBOARDX_AVAILABLE") +__all__.append("StubSummaryWriter") diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 9c119881ab..dc90becbb8 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -92,7 +92,7 @@ class MetricLogger: """ def __init__(self, config, learning_rate_schedule): - self.writer = max_utils.initialize_summary_writer(config.tensorboard_dir, config.run_name) + self.writer = max_utils.initialize_summary_writer(config.tensorboard_dir, config.run_name, config.enable_tensorboard) self.config = config self.metadata = {} self.running_gcs_metrics = [] if config.gcs_metrics else None @@ -295,6 +295,8 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step): def write_setup_info_to_tensorboard(self, params): """Writes setup information like train config params, num model params, and XLA flags to TensorBoard.""" + if not self.config.enable_tensorboard: + return num_model_parameters = max_utils.calculate_num_params_from_pytree(params) self.metadata[MetadataKey.PER_DEVICE_TFLOPS], _, _ = maxtext_utils.calculate_tflops_training_per_device(self.config) self.metadata[MetadataKey.PER_DEVICE_TOKENS] = maxtext_utils.calculate_tokens_training_per_device(self.config) diff --git a/src/maxtext/utils/max_utils.py b/src/maxtext/utils/max_utils.py index f7536bef0b..daaf2d8904 100644 --- a/src/maxtext/utils/max_utils.py +++ b/src/maxtext/utils/max_utils.py @@ -44,7 +44,7 @@ from maxtext.utils import elastic_utils from maxtext.common.gcloud_stub import is_decoupled -from maxtext.common.gcloud_stub import writer, _TENSORBOARDX_AVAILABLE +from maxtext.common.gcloud_stub import writer, _TENSORBOARDX_AVAILABLE, StubSummaryWriter from maxtext.utils import max_logging from maxtext.common.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN @@ -182,7 +182,7 @@ def summarize_size_from_pytree(params): return num_params, num_bytes, num_bytes / num_params -def initialize_summary_writer(tensorboard_dir, run_name): +def initialize_summary_writer(tensorboard_dir, run_name, enable_tensorboard=True): """Return a tensorboardX SummaryWriter or a no-op stub. In decoupled mode (no Google Cloud), this prefers a repo-local @@ -191,6 +191,10 @@ def initialize_summary_writer(tensorboard_dir, run_name): if jax.process_index() != 0: return None + if not enable_tensorboard: + max_logging.log("TensorBoard disabled; using no-op SummaryWriter.") + return StubSummaryWriter() + if not _TENSORBOARDX_AVAILABLE: max_logging.log("tensorboardX not available; using no-op SummaryWriter.") return writer.SummaryWriter()