Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,5 @@ This script will automatically format your code with `pyink` and help you identi

The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.

## Profiling
To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md).
34 changes: 34 additions & 0 deletions docs/profiling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# ML Diagnostics and Profiling

MaxDiffusion supports automated profiling and performance tracking via [Google Cloud ML Diagnostics](https://docs.cloud.google.com/tpu/docs/ml-diagnostics/sdk).

## 1. Manual Installation
To keep the core MaxDiffusion repository lightweight and ensure it runs without dependencies for users who don't need profiling, the ML Diagnostics packages are **not** installed by default.

To use this feature, you must manually install the required package in your environment:
```bash
pip install google-cloud-mldiagnostics
```

## 2. Configuration Settings
To enable ML Diagnostics for your training or generation jobs, you need to update your configuration. You can either add these directly to your .yml config file or pass them as command-line arguments:

```yaml
# ML Diagnostics settings
enable_ml_diagnostics: True
profiler_gcs_path: "gs://<your-bucket-name>/profiler/ml_diagnostics"
enable_ondemand_xprof: True
```

## 3. GCS Bucket Permissions (Troubleshooting)
The GCS bucket you provide in `profiler_gcs_path` **must** have the correct IAM permissions to allow the Hypercompute Cluster service account to write data.

If permissions are not configured correctly, your job will fail with an error similar to this:
> `message: 'service-32478767326@gcp-sa-hypercomputecluster.iam.gserviceaccount.com does not have storage.buckets.get access to the GCS bucket <your-bucket>: permission denied'`

**Fix:** Ensure you grant the required Storage roles (e.g., `Storage Object Admin`) to the service account mentioned in your error message for your specific GCS bucket.

## 4. Viewing Your Runs
Once your job is running with diagnostics enabled, you can monitor the profiles, execution times, and metrics in the Cluster Director console here:

🔗 **https://pantheon.corp.google.com/cluster-director/diagnostics**
5 changes: 5 additions & 0 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,8 @@ quantization: ''
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
7 changes: 6 additions & 1 deletion src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,9 @@ quantization: ''
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
use_qwix_quantization: False
use_qwix_quantization: False

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
5 changes: 5 additions & 0 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,8 @@ quantization: ''
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,7 @@ quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,7 @@ quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
7 changes: 6 additions & 1 deletion src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,9 @@ quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

save_final_checkpoint: False
save_final_checkpoint: False

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
7 changes: 6 additions & 1 deletion src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -395,4 +395,9 @@ eval_data_dir: ""
enable_generate_video_for_eval: False # This will increase the used TPU memory.
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).

enable_ssim: False
enable_ssim: False

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
5 changes: 5 additions & 0 deletions src/maxdiffusion/configs/base_wan_1_3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,8 @@ enable_generate_video_for_eval: False # This will increase the used TPU memory.
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).

enable_ssim: False

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
7 changes: 6 additions & 1 deletion src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -367,4 +367,9 @@ eval_data_dir: ""
enable_generate_video_for_eval: False # This will increase the used TPU memory.
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).

enable_ssim: False
enable_ssim: False

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
7 changes: 6 additions & 1 deletion src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -360,4 +360,9 @@ enable_ssim: False
# i2v specific parameters
# I2V Input Image
# URL or local path to the conditioning image
image_url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
image_url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
7 changes: 6 additions & 1 deletion src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -367,4 +367,9 @@ enable_ssim: False
# i2v specific parameters
# I2V Input Image
# URL or local path to the conditioning image
image_url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
image_url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
5 changes: 5 additions & 0 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,8 @@ quantization: ''
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
5 changes: 5 additions & 0 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,8 @@ quantization: ''
quantization_local_shard_count: -1
use_qwix_quantization: False
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
7 changes: 6 additions & 1 deletion src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1
enable_profiler: False

# ML Diagnostics settings
enable_ml_diagnostics: True
profiler_gcs_path: "gs://mehdy/profiler/ml_diagnostics"
enable_ondemand_xprof: True

replicate_vae: False

allow_split_physical_axes: False
Expand Down Expand Up @@ -134,4 +139,4 @@ upsampler_temporal_patch_size: 1
upsampler_adain_factor: 0.0
upsampler_tone_map_compression_ratio: 0.0
upsampler_rational_spatial_scale: 2.0
upsampler_output_type: "pil"
upsampler_output_type: "pil"
7 changes: 6 additions & 1 deletion src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,9 @@ compile_topology_num_slices: -1
quantization_local_shard_count: -1
use_qwix_quantization: False
jit_initializers: True
enable_single_replica_ckpt_restoring: False
enable_single_replica_ckpt_restoring: False

# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from maxdiffusion import pyconfig
from maxdiffusion import pyconfig, max_utils
from maxdiffusion.utils import load_image
from maxdiffusion import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel

Expand Down Expand Up @@ -76,6 +76,7 @@ def run(config):

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
max_utils.ensure_machinelearning_job_runs(pyconfig.config)
run(pyconfig.config)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from flax.training.common_utils import shard
from maxdiffusion.utils import load_image
from PIL import Image
from maxdiffusion import pyconfig
from maxdiffusion import pyconfig, max_utils
from maxdiffusion import FlaxStableDiffusionXLControlNetPipeline, FlaxControlNetModel
import cv2

Expand Down Expand Up @@ -91,6 +91,7 @@ def run(config):

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
max_utils.ensure_machinelearning_job_runs(pyconfig.config)
run(pyconfig.config)


Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax
from absl import app
from maxdiffusion import (
max_utils,
max_logging,
pyconfig,
)
Expand All @@ -38,6 +39,7 @@ def train(config):
def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
config = pyconfig.config
max_utils.ensure_machinelearning_job_runs(config)
validate_train_config(config)
max_logging.log(f"Found {jax.device_count()} devices.")
train(config)
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def run(config):

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
max_utils.ensure_machinelearning_job_runs(pyconfig.config)
run(pyconfig.config)


Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/generate_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from flax.linen import partitioning as nn_partitioning
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)

from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging, max_utils
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from maxdiffusion.train_utils import transformer_engine_context
from maxdiffusion.max_utils import (
Expand Down Expand Up @@ -489,6 +489,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
max_utils.ensure_machinelearning_job_runs(pyconfig.config)
run(pyconfig.config)


Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/generate_flux_multi_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from flax.linen import partitioning as nn_partitioning
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)

from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging, max_utils
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from maxdiffusion.max_utils import (
device_put_replicated,
Expand Down Expand Up @@ -571,6 +571,7 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
max_utils.ensure_machinelearning_job_runs(pyconfig.config)
run(pyconfig.config)


Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/generate_flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def run(config):

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
max_utils.ensure_machinelearning_job_runs(pyconfig.config)
run(pyconfig.config)


Expand Down
8 changes: 4 additions & 4 deletions src/maxdiffusion/generate_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")

s0 = time.perf_counter()
if getattr(config, "enable_profiler", False):
max_utils.activate_profiler(config)
call_pipeline(config, pipeline, prompt, negative_prompt)
max_utils.deactivate_profiler(config)
if max_utils.profiler_enabled(config):
with max_utils.Profiler(config):
call_pipeline(config, pipeline, prompt, negative_prompt)
generation_time_with_profiler = time.perf_counter() - s0
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
if writer and jax.process_index() == 0:
Expand All @@ -245,6 +244,7 @@ def main(argv: Sequence[str]) -> None:
flax.config.update("flax_always_shard_variable", False)
except LookupError:
pass
max_utils.ensure_machinelearning_job_runs(pyconfig.config)
run(pyconfig.config, commit_hash=commit_hash)


Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/generate_ltx_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem
import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor
from maxdiffusion import pyconfig, max_logging
from maxdiffusion import pyconfig, max_logging, max_utils
from maxdiffusion.train_utils import transformer_engine_context
import torchvision.transforms.functional as TVF
import imageio
Expand Down Expand Up @@ -264,6 +264,7 @@ def run(config):

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
max_utils.ensure_machinelearning_job_runs(pyconfig.config)
run(pyconfig.config)


Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/generate_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def run(config):

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
max_utils.ensure_machinelearning_job_runs(pyconfig.config)
run(pyconfig.config)


Expand Down
8 changes: 4 additions & 4 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
)

s0 = time.perf_counter()
if config.enable_profiler:
max_utils.activate_profiler(config)
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
max_utils.deactivate_profiler(config)
if max_utils.profiler_enabled(config):
with max_utils.Profiler(config):
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
generation_time_with_profiler = time.perf_counter() - s0
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
if writer and jax.process_index() == 0:
Expand All @@ -322,6 +321,7 @@ def main(argv: Sequence[str]) -> None:
flax.config.update("flax_always_shard_variable", False)
except LookupError:
pass
max_utils.ensure_machinelearning_job_runs(pyconfig.config)
run(pyconfig.config, commit_hash=commit_hash)


Expand Down
Loading
Loading