Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?

- **`2024/8/1`**: Orbax is the new default checkpointer for Stable Diffusion 1.X, 2.x. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
- **`2024/10/22`**: LoRA support for Hyper SDXL.
- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
- **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported.

# Overview
Expand All @@ -32,6 +32,7 @@ MaxDiffusion supports
* Stable Diffusion 2.1 (training and inference)
* Stable Diffusion XL (training and inference).
* Stable Diffusion Lightning (inference).
* Hyper-SD XL LoRA loading (inference).
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
* Dreambooth training support for Stable Diffusion 1.x,2.x.

Expand All @@ -43,6 +44,7 @@ MaxDiffusion supports
* [Training](#training)
* [Dreambooth](#dreambooth)
* [Inference](#inference)
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
* [SDXL Lightning](#sdxl-lightning)
* [ControlNet](#controlnet)
* [Comparison To Alternatives](#comparison-to-alternatives)
Expand Down Expand Up @@ -129,6 +131,14 @@ To generate images, run the following command:
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
```

## Hyper SDXL LoRA

Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD)

```bash
python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=2 do_classifier_free_guidance=False prompt="a photograph of a cat wearing a hat riding a skateboard in a park." per_device_batch_size=1 pretrained_model_name_or_path="Lykon/AAM_XL_AnimeMix" from_pt=True revision=main diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}'
```

## SDXL Lightning

Single and Multi host inference is supported with sharding annotations:
Expand Down
Binary file removed generated_image.png
Binary file not shown.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ git+https://github.com/mlperf/logging.git
opencv-python-headless==4.10.0.84
orbax-checkpoint>=0.5.20
tokenizers==0.20.0
huggingface_hub==0.24.7
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated line?


huggingface_hub==0.24.7
27 changes: 22 additions & 5 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,10 @@ timestep_bias: {

# Override parameters from checkpoints's scheduler.
diffusion_scheduler_config: {
_class_name: '',
# values are v_prediction or leave empty to use scheduler's default.
prediction_type: '',
_class_name: 'FlaxEulerDiscreteScheduler',
prediction_type: 'epsilon',
rescale_zero_terminal_snr: False,
timestep_spacing: ''
timestep_spacing: 'trailing'
}

# Output directory
Expand Down Expand Up @@ -197,7 +196,7 @@ profiler_steps: 10
prompt: "A magical castle in the middle of a forest, artistic drawing"
negative_prompt: "purple, red"
do_classifier_free_guidance: True
guidance_scale: 9
guidance_scale: 9.0
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 20
Expand All @@ -209,6 +208,24 @@ lightning_repo: ""
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
lightning_ckpt: ""

# LoRA parameters
# Values are lists to support multiple LoRA loading during inference in the future.
lora_config: {
lora_model_name_or_path: [],
weight_name: [],
adapter_name: [],
scale: [],
from_pt: []
}
# Ex with values:
# lora_config : {
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
# adapter_name: ["hyper-sdxl"],
# scale: [0.7],
# from_pt: [True]
# }

enable_mllog: False

#controlnet
Expand Down
22 changes: 20 additions & 2 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ text_encoder_learning_rate: 4.25e-6
diffusion_scheduler_config: {
_class_name: 'DDIMScheduler',
# values are v_prediction or leave empty to use scheduler's default.
prediction_type: '',
prediction_type: 'epsilon',
rescale_zero_terminal_snr: False,
timestep_spacing: 'trailing'
}
Expand Down Expand Up @@ -156,7 +156,7 @@ profiler_steps: 5
prompt: "portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal, elegant, sharp focus, soft lighting, vibrant colors"
negative_prompt: "purple, red"
do_classifier_free_guidance: False
guidance_scale: 2
guidance_scale: 2.0
guidance_rescale: -1
num_inference_steps: 4

Expand All @@ -165,4 +165,22 @@ lightning_from_pt: True
lightning_repo: "ByteDance/SDXL-Lightning"
lightning_ckpt: "sdxl_lightning_4step_unet.safetensors"

# LoRA parameters
# Values are lists to support multiple LoRA loading during inference in the future.
lora_config: {
lora_model_name_or_path: [],
weight_name: [],
adapter_name: [],
scale: [],
from_pt: []
}
# Ex with values:
# lora_config : {
# lora_model_name_or_path: ["ByteDance/Hyper-SD"],
# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"],
# adapter_name: ["hyper-sdxl"],
# scale: [0.7],
# from_pt: [True]
# }

enable_mllog: False
55 changes: 28 additions & 27 deletions src/maxdiffusion/generate_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import flax.linen as nn
from flax.linen import partitioning as nn_partitioning

from maxdiffusion import (
FlaxEulerDiscreteScheduler,
)


from maxdiffusion import pyconfig, max_utils
from maxdiffusion.image_processor import VaeImageProcessor
from maxdiffusion.maxdiffusion_utils import (get_add_time_ids, rescale_noise_cfg, load_sdxllightning_unet)
from maxdiffusion.maxdiffusion_utils import (
get_add_time_ids,
rescale_noise_cfg,
load_sdxllightning_unet,
maybe_load_lora,
create_scheduler,
)

from maxdiffusion.trainers.sdxl_trainer import (StableDiffusionXLTrainer)

Expand Down Expand Up @@ -82,7 +84,6 @@ def apply_classifier_free_guidance(noise_pred, guidance_scale):
lambda _: noise_pred,
operand=None,
)

latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

return latents, scheduler_state, state
Expand Down Expand Up @@ -217,6 +218,8 @@ def run(config):
checkpoint_loader = GenerateSDXL(config)
pipeline, params = checkpoint_loader.load_checkpoint()

noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config)

weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng)
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False
Expand All @@ -228,20 +231,24 @@ def run(config):
if unet_params:
params["unet"] = unet_params

# maybe load lora and create interceptor
params, lora_interceptor = maybe_load_lora(config, pipeline, params)

if config.lightning_repo:
pipeline, params = load_sdxllightning_unet(config, pipeline, params)

# Don't restore the train state to save memory, just restore params
# Don't restore the full train state, instead, just restore params
# and create an inference state.
unet_state, unet_state_shardings = max_utils.setup_initial_state(
model=pipeline.unet,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=params.get("unet", None),
training=False,
)
with nn.intercept_methods(lora_interceptor):
unet_state, unet_state_shardings = max_utils.setup_initial_state(
model=pipeline.unet,
tx=None,
config=config,
mesh=checkpoint_loader.mesh,
weights_init_fn=weights_init_fn,
model_params=params.get("unet", None),
training=False,
)

vae_state, vae_state_shardings = checkpoint_loader.create_vae_state(
pipeline, params, checkpoint_item_name="vae_state", is_training=False
Expand All @@ -267,14 +274,6 @@ def run(config):
states["text_encoder_state"] = text_encoder_state
states["text_encoder_2_state"] = text_encoder_2_state

noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained(
config.pretrained_model_name_or_path,
revision=config.revision,
subfolder="scheduler",
dtype=jnp.float32,
timestep_spacing="trailing",
)

pipeline.scheduler = noise_scheduler
params["scheduler"] = noise_scheduler_state

Expand All @@ -293,10 +292,12 @@ def run(config):
)

s = time.time()
p_run_inference(states).block_until_ready()
with nn.intercept_methods(lora_interceptor):
p_run_inference(states).block_until_ready()
print("compile time: ", (time.time() - s))
s = time.time()
images = p_run_inference(states).block_until_ready()
with nn.intercept_methods(lora_interceptor):
images = p_run_inference(states).block_until_ready()
print("inference time: ", (time.time() - s))
images = jax.experimental.multihost_utils.process_allgather(images)
numpy_images = np.array(images)
Expand Down
15 changes: 15 additions & 0 deletions src/maxdiffusion/loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .lora_pipeline import StableDiffusionLoraLoaderMixin
106 changes: 106 additions & 0 deletions src/maxdiffusion/loaders/lora_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..models.modeling_utils import load_state_dict
from ..utils import _get_model_file

import safetensors


class LoRABaseMixin:
"""Utility class for handing LoRAs"""

_lora_lodable_modules = []
num_fused_loras = 0

def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")

@classmethod
def _fetch_state_dict(
cls,
pretrained_model_name_or_path_or_dict,
weight_name,
use_safetensors,
local_files_only,
cache_dir,
force_download,
resume_download,
proxies,
use_auth_token,
revision,
subfolder,
user_agent,
allow_pickle,
):
from .lora_pipeline import LORA_WEIGHT_NAME_SAFE

model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
Comment thread
anfals marked this conversation as resolved.
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (weight_name is not None and weight_name.endswith(".safetensors")):
try:
# Here we're relaxing the loading check to enable more Inference API
# friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`.
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
pass

if model_file is None:
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict

return state_dict
Loading