Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d5ac715
add support for flux vae. ~ wip
jfacevedo-google Jan 14, 2025
394ebd1
test for flux vae both encoding and decoding.
jfacevedo-google Jan 14, 2025
025642b
add clip text encoder test
jfacevedo-google Jan 15, 2025
a2b7f82
remove transformers inside maxdiffusion, add transformers dependency.…
jfacevedo-google Jan 22, 2025
2b83d5c
add double block to flux
jfacevedo-google Jan 22, 2025
37d9f00
forward pass for single double block.
jfacevedo-google Jan 22, 2025
8785d00
trying to use scan.
jfacevedo-google Jan 23, 2025
cb91d5e
add single stream block
jfacevedo-google Jan 24, 2025
bb71982
finish transformer
jfacevedo-google Jan 29, 2025
3eb5729
convert pt weights to flax and load transformer state.
jfacevedo-google Jan 30, 2025
956341e
apply fsdp sharding, do one forward pass in the transformer.
jfacevedo-google Jan 30, 2025
4b64f5d
wip - generate fn
jfacevedo-google Jan 30, 2025
860e76e
working loop, bad generation
jfacevedo-google Jan 30, 2025
93a3bb6
e2e, encoder offloading.
jfacevedo-google Jan 30, 2025
601f40c
add missing conversions of pt to jax weights.
jfacevedo-google Jan 31, 2025
d16c020
support both dev and schnell loading. Images still incorrect.
jfacevedo-google Feb 1, 2025
4a12b39
flux schnell working
jfacevedo-google Feb 3, 2025
9871c7d
removed unused code.
jfacevedo-google Feb 3, 2025
a75a125
support dev
jfacevedo-google Feb 3, 2025
05b6fc8
add sentencepiece requirement
jfacevedo-google Feb 4, 2025
df25e47
fix repeated double and single blocks.
jfacevedo-google Feb 4, 2025
587bc6a
optimized flash block sizes for trillium.
jfacevedo-google Feb 4, 2025
8905362
Merge branch 'main' into flux_impl
jfacevedo-google Feb 4, 2025
b87443f
clean up code and lint
jfacevedo-google Feb 4, 2025
37df8b9
fix sdxl generate smoke tests.
jfacevedo-google Feb 5, 2025
e56825f
fix rest of unit tests.
jfacevedo-google Feb 5, 2025
064a3a7
update readme and some dependencies.
entrpn Feb 5, 2025
fa1c23b
remove unused dependencies.
entrpn Feb 5, 2025
b4d0502
initial lora implementation for flux
jfacevedo-google Feb 6, 2025
9e07358
adding another format lora support.
jfacevedo-google Feb 12, 2025
4c68d53
Merge branch 'main' into flux_lora
jfacevedo-google Feb 12, 2025
1f2e65c
Support other format loras. update readme. Run code_style.
jfacevedo-google Feb 13, 2025
24ee4cc
ruff
jfacevedo-google Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2025/02/12`**: Flux LoRA for inference.
- **`2025/02/08`**: Flux schnell & dev inference.
- **`2024/12/12`**: Load multiple LoRAs for inference.
- **`2024/10/22`**: LoRA support for Hyper SDXL.
Expand Down Expand Up @@ -47,7 +48,8 @@ MaxDiffusion supports
* [Training](#training)
* [Dreambooth](#dreambooth)
* [Inference](#inference)
* [Flux](#flux)
* [Flux](#flux)
* [Flux LoRA](#flux-lora)
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
* [Load Multiple LoRA](#load-multiple-lora)
* [SDXL Lightning](#sdxl-lightning)
Expand Down Expand Up @@ -169,6 +171,24 @@ To generate images, run the following command:
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
```

## Flux LoRA

Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know.

Tested with [Amateur Photography](https://civitai.com/models/652699/amateur-photography-flux-dev) and [XLabs-AI](https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main) LoRA collection.

First download the LoRA file to a local directory, for example, `/home/jfacevedo/anime_lora.safetensors`. Then run as follows:

```bash
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}'
```

Loading multiple LoRAs is supported as follows:

```bash
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors", "/home/jfacevedo/amateurphoto-v6-forcu.safetensors"], "weight_name" : ["anime_lora.safetensors","amateurphoto-v6-forcu.safetensors"], "adapter_name" : ["anime","realistic"], "scale": [0.6, 0.6], "from_pt": ["true","true"]}'
```

## Hyper SDXL LoRA

Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD)
Expand Down
70 changes: 54 additions & 16 deletions src/maxdiffusion/generate_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Callable, List, Union, Sequence
from absl import app
from contextlib import ExitStack
import functools
import math
import time
Expand All @@ -24,6 +25,7 @@
import jax
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
import jax.numpy as jnp
import flax.linen as nn
from chex import Array
from einops import rearrange
from flax.linen import partitioning as nn_partitioning
Expand All @@ -39,6 +41,28 @@
get_precision,
setup_initial_state,
)
from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin


def maybe_load_flux_lora(config, lora_loader, params):
def _noop_interceptor(next_fn, args, kwargs, context):
return next_fn(*args, **kwargs)

lora_config = config.lora_config
interceptors = [_noop_interceptor]
if len(lora_config["lora_model_name_or_path"]) > 0:
interceptors = []
for i in range(len(lora_config["lora_model_name_or_path"])):
params, rank, network_alphas = lora_loader.load_lora_weights(
config,
lora_config["lora_model_name_or_path"][i],
weight_name=lora_config["weight_name"][i],
params=params,
adapter_name=lora_config["adapter_name"][i],
)
interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i])
interceptors.append(interceptor)
return params, interceptors
Comment on lines +47 to +65
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Why not use maybe_load_lora() from maxdiffusion_utils?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe_use_lora() in maxdiffusion_utils is specific to sdxl and won't work with flux. I should rename that method to maybe_load_sdxl_lora. I will create an issue to track this and add it on a different commit. Thanks for the review.



def unpack(x: Array, height: int, width: int) -> Array:
Expand Down Expand Up @@ -97,7 +121,6 @@ def prepare_latent_image_ids(height, width):
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels)

return latent_image_ids.astype(jnp.bfloat16)


Expand Down Expand Up @@ -127,7 +150,6 @@ def run_inference(
vec=vec,
guidance_vec=guidance_vec,
)

vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config)

with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
Expand Down Expand Up @@ -373,21 +395,29 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep

# loads pretrained weights
transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu")
params = {}
params["transformer"] = transformer_params
# maybe load lora and create interceptor
lora_loader = FluxLoraLoaderMixin()
params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params)
transformer_params = params["transformer"]
# create transformer state
weights_init_fn = functools.partial(
transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False
)
transformer_state, transformer_state_shardings = setup_initial_state(
model=transformer,
tx=None,
config=config,
mesh=mesh,
weights_init_fn=weights_init_fn,
model_params=None,
training=False,
)
transformer_state = transformer_state.replace(params=transformer_params)
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
transformer_state, transformer_state_shardings = setup_initial_state(
model=transformer,
tx=None,
config=config,
mesh=mesh,
weights_init_fn=weights_init_fn,
model_params=None,
training=False,
)
transformer_state = transformer_state.replace(params=transformer_params)
transformer_state = jax.device_put(transformer_state, transformer_state_shardings)
get_memory_allocations()

states = {}
Expand Down Expand Up @@ -432,17 +462,23 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
out_shardings=None,
)
t0 = time.perf_counter()
p_run_inference(states).block_until_ready()
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
p_run_inference(states).block_until_ready()
t1 = time.perf_counter()
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")

t0 = time.perf_counter()
imgs = p_run_inference(states).block_until_ready()
with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"):
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
imgs = p_run_inference(states).block_until_ready()
t1 = time.perf_counter()
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")

t0 = time.perf_counter()
imgs = p_run_inference(states).block_until_ready()
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
imgs = p_run_inference(states).block_until_ready()
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
t1 = time.perf_counter()
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
Expand All @@ -453,6 +489,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
for i, image in enumerate(imgs):
Image.fromarray(image).save(f"flux_{i}.png")

return imgs


def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from .lora_pipeline import StableDiffusionLoraLoaderMixin
from .flux_lora_pipeline import FluxLoraLoaderMixin
144 changes: 144 additions & 0 deletions src/maxdiffusion/loaders/flux_lora_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union, Dict
from .lora_base import LoRABaseMixin
from ..models.lora import LoRALinearLayer, BaseLoRALayer
import jax.numpy as jnp
from flax.traverse_util import flatten_dict
from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax
from huggingface_hub.utils import validate_hf_hub_args


class FluxLoraLoaderMixin(LoRABaseMixin):

_lora_lodable_modules = ["transformer", "text_encoder"]

def load_lora_weights(
self,
config,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]],
params,
adapter_name=None,
**kwargs,
):
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

params, rank, network_alphas = self.load_lora(
config,
state_dict,
params=params,
adapter_name=adapter_name,
)

return params, rank, network_alphas

def rename_for_interceptor(params_keys, network_alphas, adapter_name):
new_params_keys = []
new_network_alphas = {}
lora_name = f"lora-{adapter_name}"
for layer_lora in params_keys:
if lora_name in layer_lora:
new_layer_lora = layer_lora[: layer_lora.index(lora_name)]
if new_layer_lora not in new_params_keys:
new_params_keys.append(new_layer_lora)
network_alpha = network_alphas.get(layer_lora, None)
new_network_alphas[new_layer_lora] = network_alpha
return new_params_keys, new_network_alphas

@classmethod
def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
network_alphas_for_interceptor = {}

transformer_keys = flatten_dict(params["transformer"]).keys()
lora_keys, transformer_alphas = cls.rename_for_interceptor(transformer_keys, network_alphas, adapter_name)
network_alphas_for_interceptor.update(transformer_alphas)

def _intercept(next_fn, args, kwargs, context):
mod = context.module
while mod is not None:
if isinstance(mod, BaseLoRALayer):
return next_fn(*args, **kwargs)
mod = mod.parent
h = next_fn(*args, **kwargs)
if context.method_name == "__call__":
module_path = context.module.path
if module_path in lora_keys:
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor, adapter_name)
return lora_layer(h, *args, **kwargs)
return h

return _intercept

@classmethod
def _get_lora_layer(cls, module_path, module, rank, network_alphas, adapter_name):
network_alpha = network_alphas.get(module_path, None)
lora_module = LoRALinearLayer(
out_features=module.features,
rank=rank,
network_alpha=network_alpha,
dtype=module.dtype,
weights_dtype=module.param_dtype,
precision=module.precision,
name=f"lora-{adapter_name}",
)
return lora_module

@classmethod
@validate_hf_hub_args
def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs):

cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
resume_download = kwargs.pop("resume_download", False)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}

state_dict = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)

return state_dict

@classmethod
def load_lora(cls, config, state_dict, params, adapter_name=None):
params, rank, network_alphas = convert_flux_lora_pytorch_state_dict_to_flax(config, state_dict, params, adapter_name)
return params, rank, network_alphas
2 changes: 1 addition & 1 deletion src/maxdiffusion/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def rename_for_interceptor(params_keys, network_alphas, adapter_name):

@classmethod
def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
# Only unet interceptor supported for now.

network_alphas_for_interceptor = {}

unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None):
hidden_states = self.linear2(attn_mlp)
hidden_states = gate * hidden_states
hidden_states = residual + hidden_states
if hidden_states.dtype == jnp.float16 or hidden_states.dtype == jnp.bfloat16:
if hidden_states.dtype == jnp.float16:
hidden_states = jnp.clip(hidden_states, -65504, 65504)

return hidden_states, temb, image_rotary_emb
Expand Down Expand Up @@ -294,7 +294,7 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=

context_ff_output = self.txt_mlp(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
if encoder_hidden_states.dtype == jnp.float16 or encoder_hidden_states.dtype == jnp.bfloat16:
if encoder_hidden_states.dtype == jnp.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return hidden_states, encoder_hidden_states, temb, image_rotary_emb

Expand Down
Loading