Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 77 additions & 5 deletions end_to_end/tpu/eval_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
limitations under the License.
"""

"""
Example to run
python end_to_end/tpu/eval_assert.py avg_tflops metrics.txt 100
python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100
python end_to_end/tpu/eval_assert.py avg_step_time metrics.txt 0.5 100
"""



# pylint: skip-file
"""Reads and asserts over target values"""
from absl import app
Expand All @@ -34,26 +43,89 @@ def get_last_n_data(metrics_file, target, n=10):
return last_n_data


def test_final_loss(metrics_file, target_loss):
def test_final_loss(metrics_file, target_loss, num_samples_str="10"):
target_loss = float(target_loss)
num_samples = int(num_samples_str)
with open(metrics_file, "r", encoding="utf8") as _:
use_last_n_data = 10
last_n_data = get_last_n_data(metrics_file, "learning/loss", use_last_n_data)
last_n_data = get_last_n_data(metrics_file, "learning/loss",num_samples)
avg_last_n_data = sum(last_n_data) / len(last_n_data)
print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}")
print(f"Target loss is {target_loss}")
assert avg_last_n_data < target_loss
print("Final loss test passed.")


def test_avg_step_time(metrics_file, max_avg_step_time_str, num_samples_str="10"):
"""Tests if the average of the last N step times is below a maximum threshold."""
max_avg_step_time = float(max_avg_step_time_str)
num_samples = int(num_samples_str)
metric_key = "perf/step_time_seconds"
last_n_step_times = get_last_n_data(metrics_file, metric_key, num_samples)

if not last_n_step_times:
raise ValueError(f"Metric '{metric_key}' not found or no data points in {metrics_file}.")

avg_last_n_step_time = sum(last_n_step_times) / len(last_n_step_times)

print(f"Found {len(last_n_step_times)} data points for '{metric_key}'.")
print(f"Mean of last {len(last_n_step_times)} step times is {avg_last_n_step_time:.4f} s")

assert (
avg_last_n_step_time < max_avg_step_time
), f"Average step time {avg_last_n_step_time:.4f}s is not less than target {max_avg_step_time}s."
print("Average step time test passed.")


def test_avg_tflops(metrics_file, min_avg_tflops_str, num_samples_str="10"):
"""Tests if the average of the last N TFLOPs/sec values is above a minimum threshold."""
min_avg_tflops = float(min_avg_tflops_str)
num_samples = int(num_samples_str)
metric_key = "perf/per_device_tflops_per_sec"

last_n_tflops = get_last_n_data(metrics_file, metric_key, num_samples)

if not last_n_tflops:
raise ValueError(f"Metric '{metric_key}' not found or no data points in {metrics_file}.")

avg_last_n_tflops = sum(last_n_tflops) / len(last_n_tflops)

print(f"Found {len(last_n_tflops)} data points for '{metric_key}'.")
print(f"Mean of last {len(last_n_tflops)} steps TFLOPs/sec is {avg_last_n_tflops:.2f}")

assert (
avg_last_n_tflops > min_avg_tflops
), f"Average TFLOPs/sec {avg_last_n_tflops:.2f} is not greater than target {min_avg_tflops}."
print("Average TFLOPs/sec test passed.")


def main(argv: Sequence[str]) -> None:
if len(argv) < 2:
print("Usage: python script.py <test_scenario> [test_vars...]")
print("Available scenarios: final_loss, avg_step_time, avg_tflops")
raise ValueError("Test scenario not specified.")

_, test_scenario, *test_vars = argv

if test_scenario == "final_loss":
test_final_loss(*test_vars)
if len(test_vars) < 2:
raise ValueError("Usage: final_loss <metrics_file> <target_loss> [num_samples]")
metrics_file, target_loss, *num_samples_opt = test_vars
num_samples = num_samples_opt[0] if num_samples_opt else "10"
test_final_loss(metrics_file, target_loss, num_samples)
elif test_scenario == "avg_step_time":
if len(test_vars) < 2:
raise ValueError("Usage: avg_step_time <metrics_file> <max_avg_step_time> [num_samples]")
metrics_file, max_avg_step_time, *num_samples_opt = test_vars
num_samples = num_samples_opt[0] if num_samples_opt else "10"
test_avg_step_time(metrics_file, max_avg_step_time, num_samples)
elif test_scenario == "avg_tflops":
if len(test_vars) < 2:
raise ValueError("Usage: avg_tflops <metrics_file> <min_avg_tflops> [num_samples]")
metrics_file, min_avg_tflops, *num_samples_opt = test_vars
num_samples = num_samples_opt[0] if num_samples_opt else "10"
test_avg_tflops(metrics_file, min_avg_tflops, num_samples)
else:
raise ValueError(f"Unrecognized test_scenario {test_scenario}")
raise ValueError(f"Unrecognized test_scenario '{test_scenario}'. Available: final_loss, avg_step_time, avg_tflops")


if __name__ == "__main__":
Expand Down
Empty file modified end_to_end/tpu/test_sdxl_training_loss.sh
100644 → 100755
Empty file.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ Pillow
pylint
pyink
pytest==8.2.2
tensorflow==2.17.0
tensorflow>=2.17.0
tensorflow-datasets>=4.9.6
ruff>=0.1.5,<=0.2
git+https://github.com/mlperf/logging.git
opencv-python-headless==4.10.0.84
orbax-checkpoint==0.10.3
tokenizers==0.21.0
huggingface_hub==0.24.7
huggingface_hub==0.30.2
transformers==4.48.1
einops==0.8.0
sentencepiece
Expand Down
2 changes: 1 addition & 1 deletion requirements_with_jax_stable_stack.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ftfy
git+https://github.com/mlperf/logging.git
google-cloud-storage==2.17.0
grain-nightly==0.0.10
huggingface_hub==0.24.7
huggingface_hub==0.30.2
jax>=0.4.30
jaxlib>=0.4.30
Jinja2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub==0.24.7",
"huggingface-hub==0.30.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
write_metrics: True
gcs_metrics: True

timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
write_timing_metrics: True

# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
log_period: 10000000000 # Flushes Tensorboard
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
write_metrics: True
gcs_metrics: True

timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
write_timing_metrics: True

# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
log_period: 10000000000 # Flushes Tensorboard
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
write_metrics: True
gcs_metrics: False

timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
write_timing_metrics: True

# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
log_period: 10000000000 # Flushes Tensorboard
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ run_name: ''
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
write_metrics: True

timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
write_timing_metrics: True

gcs_metrics: False
# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
write_metrics: True
gcs_metrics: False

timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
write_timing_metrics: True

# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
log_period: 100
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ metrics_file: "" # for testing, local file that stores scalar metrics. If empty,
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
write_metrics: True
gcs_metrics: False

timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
write_timing_metrics: True

# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
log_period: 100
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ run_name: ''
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
write_metrics: True

timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
write_timing_metrics: True

gcs_metrics: False
# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ run_name: ''
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
gcs_metrics: True

timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
write_timing_metrics: True

# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
log_period: 100
Expand Down
9 changes: 8 additions & 1 deletion src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
# pylint: disable=bare-except, consider-using-generator
""" Common Max Utils needed by multiple modules"""
import functools
from functools import reduce
from functools import partial, reduce
from contextlib import nullcontext
from typing import Dict, Callable
import json
import yaml
import os
Expand Down Expand Up @@ -564,6 +565,12 @@ def calculate_model_tflops(module: module_lib.Module, rngs: Union[PRNGKey, RNGSe
return total_flops


def get_train_step_partial_with_signature(train_step: Callable, pipeline: object, params: Dict, config: object) -> Callable:
partial_train = partial(train_step, pipeline=pipeline, params=params, config=config)
partial_train.__name__ = "train_step"
return partial_train


def calculate_num_params_from_pytree(params):
"""Calculates number of parameters from a pytree"""
params_sizes = jax.tree_util.tree_map(jax.numpy.size, params)
Expand Down
79 changes: 65 additions & 14 deletions src/maxdiffusion/trainers/base_stable_diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,25 @@
"""

from abc import abstractmethod
import time
from typing import Any, Callable
import jax
from maxdiffusion import (max_utils, maxdiffusion_utils, max_logging)

from maxdiffusion.checkpointing.base_stable_diffusion_checkpointer import (BaseStableDiffusionCheckpointer)

# Define a filename for logging


def _log_to_file(message: str, log_file: str = ""):
"""Appends a message to the global log file with a timestamp."""
timestamp = time.strftime("%Y-%m-%d %H:%M:%S %Z", time.localtime())
full_message = f"[{timestamp}] {message}\n"
if log_file:
with open(log_file, "a") as f:
f.write(full_message)
max_logging.log(full_message.strip())


class BaseStableDiffusionTrainer(BaseStableDiffusionCheckpointer):

Expand Down Expand Up @@ -67,6 +81,29 @@ def get_data_shardings(self):
def create_scheduler(self, pipeline, params):
pass

def _time_and_log_call(
self, func_obj: Callable[..., Any], *func_args: Any, description: str = "", **func_kwargs: Any
) -> Any:
"""
Times a function call, logs its duration, and returns its result.
"""
if not description:
if hasattr(func_obj, "__name__"):
description = func_obj.__name__
elif hasattr(func_obj, "__call__") and hasattr(type(func_obj), "__name__"):
description = type(func_obj).__name__
log_file = ""

if self.config.write_timing_metrics and self.config.timing_metrics_file:
log_file = self.config.get.timing_metrics_file
_log_to_file(f"Starting: {description}...", log_file=log_file)
start_time = time.perf_counter() # Use perf_counter for more precise duration measurement
result = func_obj(*func_args, **func_kwargs)
end_time = time.perf_counter()
duration = end_time - start_time
_log_to_file(f"Finished: {description} - Duration: {duration:.4f} seconds", log_file=log_file)
return result

def calculate_tflops(self, pipeline, params):
per_device_tflops = maxdiffusion_utils.calculate_unet_tflops(
self.config, pipeline, (self.config.per_device_batch_size * jax.local_device_count()), self.rng, train=True
Expand All @@ -75,22 +112,28 @@ def calculate_tflops(self, pipeline, params):
return per_device_tflops

def start_training(self):

# Hook
self.pre_training_steps()
# Load checkpoint - will load or create states
pipeline, params = self.load_checkpoint()
pipeline, params = self._time_and_log_call(self.load_checkpoint)
# create train states
train_states = {}
state_shardings = {}
vae_state, vae_state_mesh_shardings = self.create_vae_state(
pipeline=pipeline, params=params, checkpoint_item_name="vae_state", is_training=False
vae_state, vae_state_mesh_shardings = self._time_and_log_call(
self.create_vae_state,
# Arguments for create_vae_state
pipeline=pipeline,
params=params,
checkpoint_item_name="vae_state",
is_training=False,
)

train_states["vae_state"] = vae_state
state_shardings["vae_state_shardings"] = vae_state_mesh_shardings

text_encoder_state, text_encoder_state_mesh_shardings = self.create_text_encoder_state(
text_encoder_state, text_encoder_state_mesh_shardings = self._time_and_log_call(
self.create_text_encoder_state,
# Arguments for create_text_encoder_state
pipeline=pipeline,
params=params,
checkpoint_item_name="text_encoder_state",
Expand All @@ -99,8 +142,13 @@ def start_training(self):
train_states["text_encoder_state"] = text_encoder_state
state_shardings["text_encoder_state_shardings"] = text_encoder_state_mesh_shardings
if hasattr(pipeline, "text_encoder_2"):
text_encoder_2_state, text_encoder_2_state_mesh_shardings = self.create_text_encoder_2_state(
pipeline, params, "text_encoder_2_state", is_training=self.config.train_text_encoder
text_encoder_2_state, text_encoder_2_state_mesh_shardings = self._time_and_log_call(
self.create_text_encoder_2_state,
# Arguments for create_text_encoder_2_state
pipeline=pipeline,
params=params,
checkpoint_item_name="text_encoder_2_state",
is_training=self.config.train_text_encoder,
)
train_states["text_encoder_2_state"] = text_encoder_2_state
state_shardings["text_encoder_2_state_shardings"] = text_encoder_2_state_mesh_shardings
Expand All @@ -115,11 +163,12 @@ def start_training(self):
self.per_device_tflops = per_device_tflops

# Load dataset
data_iterator = self.load_dataset(pipeline, params, train_states)
data_iterator = self._time_and_log_call(self.load_dataset, pipeline, params, train_states)
if self.config.dataset_type == "grain":
data_iterator = self.restore_data_iterator_state(data_iterator)
data_iterator = self._time_and_log_call(self.restore_data_iterator_state, data_iterator=data_iterator)

unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self.create_unet_state(
unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self._time_and_log_call(
self.create_unet_state,
# ambiguous here, but if self.params.get("unet") doesn't exist
# Then its 1 of 2 scenarios:
# 1. unet state will be loaded directly from orbax
Expand All @@ -134,11 +183,13 @@ def start_training(self):

data_shardings = self.get_data_shardings()
# Compile train_step
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)
p_train_step = self._time_and_log_call(
self.compile_train_step, pipeline, params, train_states, state_shardings, data_shardings
)
# Start training
train_states = self.training_loop(
p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler
train_states = self._time_and_log_call(
self.training_loop, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler
)
# 6. save final checkpoint
# Hook
self.post_training_steps(pipeline, params, train_states)
self._time_and_log_call(self.post_training_steps, pipeline, params, train_states)
Loading
Loading