Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d5ac715
add support for flux vae. ~ wip
jfacevedo-google Jan 14, 2025
394ebd1
test for flux vae both encoding and decoding.
jfacevedo-google Jan 14, 2025
025642b
add clip text encoder test
jfacevedo-google Jan 15, 2025
a2b7f82
remove transformers inside maxdiffusion, add transformers dependency.…
jfacevedo-google Jan 22, 2025
2b83d5c
add double block to flux
jfacevedo-google Jan 22, 2025
37d9f00
forward pass for single double block.
jfacevedo-google Jan 22, 2025
8785d00
trying to use scan.
jfacevedo-google Jan 23, 2025
cb91d5e
add single stream block
jfacevedo-google Jan 24, 2025
bb71982
finish transformer
jfacevedo-google Jan 29, 2025
3eb5729
convert pt weights to flax and load transformer state.
jfacevedo-google Jan 30, 2025
956341e
apply fsdp sharding, do one forward pass in the transformer.
jfacevedo-google Jan 30, 2025
4b64f5d
wip - generate fn
jfacevedo-google Jan 30, 2025
860e76e
working loop, bad generation
jfacevedo-google Jan 30, 2025
93a3bb6
e2e, encoder offloading.
jfacevedo-google Jan 30, 2025
601f40c
add missing conversions of pt to jax weights.
jfacevedo-google Jan 31, 2025
d16c020
support both dev and schnell loading. Images still incorrect.
jfacevedo-google Feb 1, 2025
4a12b39
flux schnell working
jfacevedo-google Feb 3, 2025
9871c7d
removed unused code.
jfacevedo-google Feb 3, 2025
a75a125
support dev
jfacevedo-google Feb 3, 2025
05b6fc8
add sentencepiece requirement
jfacevedo-google Feb 4, 2025
df25e47
fix repeated double and single blocks.
jfacevedo-google Feb 4, 2025
587bc6a
optimized flash block sizes for trillium.
jfacevedo-google Feb 4, 2025
8905362
Merge branch 'main' into flux_impl
jfacevedo-google Feb 4, 2025
b87443f
clean up code and lint
jfacevedo-google Feb 4, 2025
37df8b9
fix sdxl generate smoke tests.
jfacevedo-google Feb 5, 2025
e56825f
fix rest of unit tests.
jfacevedo-google Feb 5, 2025
064a3a7
update readme and some dependencies.
entrpn Feb 5, 2025
fa1c23b
remove unused dependencies.
entrpn Feb 5, 2025
a774fb1
fix typo in readme.
entrpn Feb 5, 2025
2b56b7a
fixes issue with bad generations.
jfacevedo-google Feb 12, 2025
50315c7
linting.
jfacevedo-google Feb 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2025/02/08`**: Flux schnell & dev inference.
- **`2024/12/12`**: Load multiple LoRAs for inference.
- **`2024/10/22`**: LoRA support for Hyper SDXL.
- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
Expand Down Expand Up @@ -46,6 +47,7 @@ MaxDiffusion supports
* [Training](#training)
* [Dreambooth](#dreambooth)
* [Inference](#inference)
* [Flux](#flux)
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
* [Load Multiple LoRA](#load-multiple-lora)
* [SDXL Lightning](#sdxl-lightning)
Expand Down Expand Up @@ -133,6 +135,39 @@ To generate images, run the following command:
```bash
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
```
## Flux

First make sure you have permissions to access the Flux repos in Huggingface.

Expected results on 1024 x 1024 images with flash attention and bfloat16:

| Model | Accelerator | Sharding Strategy | Batch Size | Steps | time (secs) |
| --- | --- | --- | --- | --- | --- |
| Flux-dev | v4-8 | DDP | 4 | 28 | 23 |
| Flux-schnell | v4-8 | DDP | 4 | 4 | 2.2 |
| Flux-dev | v6e-4 | DDP | 4 | 28 | 5.5 |
| Flux-schnell | v6e-4 | DDP | 4 | 4 | 0.8 |
| Flux-schnell | v6e-4 | FSDP | 4 | 4 | 1.2 |

Schnell:

```bash
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1
```

Dev:

```bash
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1
```

If you are using a TPU v6e (Trillium), you can use optimized flash block sizes for faster inference. Uncomment Flux-dev [config](src/maxdiffusion/configs/base_flux_dev.yml#60) and Flux-schnell [config](src/maxdiffusion/configs/base_flux_schnell.yml#68)

To keep text encoders, vae and transformer on HBM memory at all times, the following command shards the model across devices.

```bash
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
```

## Hyper SDXL LoRA

Expand Down
13 changes: 8 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ absl-py
datasets
flax>=0.10.2
optax>=0.2.3
torch>=2.3.1
torchvision>=0.18.1
torch==2.5.1
torchvision==0.20.1
ftfy
tensorboard==2.17.0
tensorboard>=2.17.0
tensorboardx==2.6.2.2
tensorboard-plugin-profile==2.15.2
Jinja2
Expand All @@ -25,5 +25,8 @@ ruff>=0.1.5,<=0.2
git+https://github.com/mlperf/logging.git
opencv-python-headless==4.10.0.84
orbax-checkpoint==0.10.2
tokenizers==0.20.0
huggingface_hub==0.24.7
tokenizers==0.21.0
huggingface_hub==0.24.7
transformers==4.48.1
einops==0.8.0
sentencepiece
2 changes: 2 additions & 0 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
Expand Down Expand Up @@ -451,6 +452,7 @@
from .models.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
max_logging,
)

from maxdiffusion.transformers import (CLIPTokenizer, FlaxCLIPTextModel, CLIPTextConfig, FlaxCLIPTextModelWithProjection)
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, CLIPTextConfig, FlaxCLIPTextModelWithProjection)

from maxdiffusion.checkpointing.checkpointing_utils import (
create_orbax_checkpoint_manager,
Expand Down Expand Up @@ -88,11 +88,14 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training)
config=self.config,
mesh=self.mesh,
weights_init_fn=weights_init_fn,
model_params=None if self.config.train_new_unet else params.get("unet", None),
model_params=None,
checkpoint_manager=self.checkpoint_manager,
checkpoint_item=checkpoint_item_name,
training=is_training,
)
if not self.config.train_new_unet:
unet_state = unet_state.replace(params=params.get("unet", None))
unet_state = jax.device_put(unet_state, state_mesh_shardings)
return unet_state, state_mesh_shardings, learning_rate_scheduler

def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):
Expand Down Expand Up @@ -150,17 +153,20 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is
input_shape=(self.total_train_batch_size, pipeline.tokenizer.model_max_length),
)

return max_utils.setup_initial_state(
state, state_mesh_shardings = max_utils.setup_initial_state(
model=pipeline.text_encoder_2,
tx=tx,
config=self.config,
mesh=self.mesh,
weights_init_fn=weights_init_fn,
model_params=params.get("text_encoder_2", None),
model_params=None,
checkpoint_manager=self.checkpoint_manager,
checkpoint_item=checkpoint_item_name,
training=is_training,
)
state = state.replace(params=params.get("text_encoder_2", None))
state = jax.device_put(state, state_mesh_shardings)
return state, state_mesh_shardings

def restore_data_iterator_state(self, data_iterator):
if (
Expand Down Expand Up @@ -302,15 +308,16 @@ def load_checkpoint(self, step=None, scheduler_class=None):
tokenizer_path = os.path.join(tokenizer_path, "tokenizer")
tokenizer_path = max_utils.download_blobs(tokenizer_path, "/tmp")
tokenizer = CLIPTokenizer.from_pretrained(
tokenizer_path, subfolder="tokenizer", dtype=self.config.activations_dtype, weights_dtype=self.config.weights_dtype
tokenizer_path,
subfolder="tokenizer",
dtype=self.config.activations_dtype,
)

te_pretrained_config = CLIPTextConfig(**model_configs[0]["text_encoder_config"])
text_encoder = FlaxCLIPTextModel(
te_pretrained_config,
seed=self.config.seed,
dtype=self.config.activations_dtype,
weights_dtype=self.config.weights_dtype,
_do_init=False,
)

Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

BATCH = "activation_batch"
LENGTH = "activation_length"
EMBED = "activation_embed"
HEAD = "activation_heads"
D_KV = "activation_kv"
KEEP_1 = "activation_keep_1"
Expand Down
Loading