Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
run: docker system prune --all --force
- name: build maxdiffusion jax stable stack gpu image
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest DEVICE=gpu
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_gpu MODE=stable PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_gpu DEVICE=gpu
- name: build maxdiffusion jax nightly image
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ diffusion_scheduler_config: {

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

base_output_directory: ""

Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ diffusion_scheduler_config: {

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ diffusion_scheduler_config: {

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ base_output_directory: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ base_output_directory: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ base_output_directory: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_wan_t2v.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ base_output_directory: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ base_output_directory: ""

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'

skip_jax_distributed_system: False
# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']

Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ diffusion_scheduler_config: {

# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'

skip_jax_distributed_system: False
# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
base_output_directory: ""
Expand Down
3 changes: 3 additions & 0 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,9 @@ def initialize_jax_for_gpu():


def maybe_initialize_jax_distributed_system(raw_keys):
if raw_keys["skip_jax_distributed_system"]:
max_logging.log("Skipping jax distributed system due to skip_jax_distributed_system=True flag.")
return
if is_gpu_backend(raw_keys):
max_logging.log("Attempting to initialize the jax distributed system for GPU backend...")
initialize_jax_for_gpu()
Expand Down
Loading