Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions src/maxtext/common/gcloud_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,29 @@ def _import():

# ------------------------- TensorBoardX --------------------------


class StubSummaryWriter:
"""Stubbed TensorBoardX SummaryWriter replacement."""

def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
del args, kwargs

def add_text(self, *args, **kwargs):
pass

def add_scalar(self, *args, **kwargs):
pass

def add_histogram(self, *args, **kwargs):
pass

def flush(self):
pass

def close(self):
pass


try:
if not is_decoupled(): # Only attempt real import when not decoupled
from tensorboardX import writer # type: ignore # pylint: disable=import-outside-toplevel,unused-import
Expand All @@ -619,30 +642,10 @@ def _import():
except Exception: # pragma: no cover - provide stub fallback # pylint: disable=broad-exception-caught
_TENSORBOARDX_AVAILABLE = False

class _StubSummaryWriter:
"""Stubbed TensorBoardX SummaryWriter replacement."""

def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
del args, kwargs

def add_text(self, *args, **kwargs):
pass

def add_scalar(self, *args, **kwargs):
pass

def add_histogram(self, *args, **kwargs):
pass

def flush(self):
pass

def close(self):
pass

class writer: # pylint: disable=too-few-public-methods
SummaryWriter = _StubSummaryWriter
SummaryWriter = StubSummaryWriter


__all__.append("writer")
__all__.append("_TENSORBOARDX_AVAILABLE")
__all__.append("StubSummaryWriter")
4 changes: 3 additions & 1 deletion src/maxtext/common/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class MetricLogger:
"""

def __init__(self, config, learning_rate_schedule):
self.writer = max_utils.initialize_summary_writer(config.tensorboard_dir, config.run_name)
self.writer = max_utils.initialize_summary_writer(config.tensorboard_dir, config.run_name, config.enable_tensorboard)
self.config = config
self.metadata = {}
self.running_gcs_metrics = [] if config.gcs_metrics else None
Expand Down Expand Up @@ -295,6 +295,8 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step):

def write_setup_info_to_tensorboard(self, params):
"""Writes setup information like train config params, num model params, and XLA flags to TensorBoard."""
if not self.config.enable_tensorboard:
return
num_model_parameters = max_utils.calculate_num_params_from_pytree(params)
self.metadata[MetadataKey.PER_DEVICE_TFLOPS], _, _ = maxtext_utils.calculate_tflops_training_per_device(self.config)
self.metadata[MetadataKey.PER_DEVICE_TOKENS] = maxtext_utils.calculate_tokens_training_per_device(self.config)
Expand Down
8 changes: 6 additions & 2 deletions src/maxtext/utils/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

from maxtext.utils import elastic_utils
from maxtext.common.gcloud_stub import is_decoupled
from maxtext.common.gcloud_stub import writer, _TENSORBOARDX_AVAILABLE
from maxtext.common.gcloud_stub import writer, _TENSORBOARDX_AVAILABLE, StubSummaryWriter
from maxtext.utils import max_logging
from maxtext.common.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN

Expand Down Expand Up @@ -182,7 +182,7 @@ def summarize_size_from_pytree(params):
return num_params, num_bytes, num_bytes / num_params


def initialize_summary_writer(tensorboard_dir, run_name):
def initialize_summary_writer(tensorboard_dir, run_name, enable_tensorboard=True):
"""Return a tensorboardX SummaryWriter or a no-op stub.

In decoupled mode (no Google Cloud), this prefers a repo-local
Expand All @@ -191,6 +191,10 @@ def initialize_summary_writer(tensorboard_dir, run_name):
if jax.process_index() != 0:
return None

if not enable_tensorboard:
max_logging.log("TensorBoard disabled; using no-op SummaryWriter.")
return StubSummaryWriter()

if not _TENSORBOARDX_AVAILABLE:
max_logging.log("tensorboardX not available; using no-op SummaryWriter.")
return writer.SummaryWriter()
Expand Down
Loading