diff --git a/README.md b/README.md index 9e85c6c85..f9be0f3fc 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? +- **`2025/02/12`**: Flux LoRA for inference. - **`2025/02/08`**: Flux schnell & dev inference. - **`2024/12/12`**: Load multiple LoRAs for inference. - **`2024/10/22`**: LoRA support for Hyper SDXL. @@ -47,7 +48,8 @@ MaxDiffusion supports * [Training](#training) * [Dreambooth](#dreambooth) * [Inference](#inference) - * [Flux](#flux) + * [Flux](#flux) + * [Flux LoRA](#flux-lora) * [Hyper-SD XL LoRA](#hyper-sdxl-lora) * [Load Multiple LoRA](#load-multiple-lora) * [SDXL Lightning](#sdxl-lightning) @@ -169,6 +171,24 @@ To generate images, run the following command: python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False ``` + ## Flux LoRA + + Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know. + + Tested with [Amateur Photography](https://civitai.com/models/652699/amateur-photography-flux-dev) and [XLabs-AI](https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main) LoRA collection. + + First download the LoRA file to a local directory, for example, `/home/jfacevedo/anime_lora.safetensors`. Then run as follows: + + ```bash + python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}' + ``` + + Loading multiple LoRAs is supported as follows: + + ```bash + python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors", "/home/jfacevedo/amateurphoto-v6-forcu.safetensors"], "weight_name" : ["anime_lora.safetensors","amateurphoto-v6-forcu.safetensors"], "adapter_name" : ["anime","realistic"], "scale": [0.6, 0.6], "from_pt": ["true","true"]}' + ``` + ## Hyper SDXL LoRA Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 6026dd992..1c221ee0b 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -16,6 +16,7 @@ from typing import Callable, List, Union, Sequence from absl import app +from contextlib import ExitStack import functools import math import time @@ -24,6 +25,7 @@ import jax from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P import jax.numpy as jnp +import flax.linen as nn from chex import Array from einops import rearrange from flax.linen import partitioning as nn_partitioning @@ -39,6 +41,28 @@ get_precision, setup_initial_state, ) +from maxdiffusion.loaders.flux_lora_pipeline import FluxLoraLoaderMixin + + +def maybe_load_flux_lora(config, lora_loader, params): + def _noop_interceptor(next_fn, args, kwargs, context): + return next_fn(*args, **kwargs) + + lora_config = config.lora_config + interceptors = [_noop_interceptor] + if len(lora_config["lora_model_name_or_path"]) > 0: + interceptors = [] + for i in range(len(lora_config["lora_model_name_or_path"])): + params, rank, network_alphas = lora_loader.load_lora_weights( + config, + lora_config["lora_model_name_or_path"][i], + weight_name=lora_config["weight_name"][i], + params=params, + adapter_name=lora_config["adapter_name"][i], + ) + interceptor = lora_loader.make_lora_interceptor(params, rank, network_alphas, lora_config["adapter_name"][i]) + interceptors.append(interceptor) + return params, interceptors def unpack(x: Array, height: int, width: int) -> Array: @@ -97,7 +121,6 @@ def prepare_latent_image_ids(height, width): latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels) - return latent_image_ids.astype(jnp.bfloat16) @@ -127,7 +150,6 @@ def run_inference( vec=vec, guidance_vec=guidance_vec, ) - vae_decode_p = functools.partial(vae_decode, vae=vae, state=vae_state, config=config) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): @@ -373,21 +395,29 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep # loads pretrained weights transformer_params = load_flow_model(config.flux_name, transformer_eval_params, "cpu") + params = {} + params["transformer"] = transformer_params + # maybe load lora and create interceptor + lora_loader = FluxLoraLoaderMixin() + params, lora_interceptors = maybe_load_flux_lora(config, lora_loader, params) + transformer_params = params["transformer"] # create transformer state weights_init_fn = functools.partial( transformer.init_weights, rngs=rng, max_sequence_length=config.max_sequence_length, eval_only=False ) - transformer_state, transformer_state_shardings = setup_initial_state( - model=transformer, - tx=None, - config=config, - mesh=mesh, - weights_init_fn=weights_init_fn, - model_params=None, - training=False, - ) - transformer_state = transformer_state.replace(params=transformer_params) - transformer_state = jax.device_put(transformer_state, transformer_state_shardings) + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + transformer_state, transformer_state_shardings = setup_initial_state( + model=transformer, + tx=None, + config=config, + mesh=mesh, + weights_init_fn=weights_init_fn, + model_params=None, + training=False, + ) + transformer_state = transformer_state.replace(params=transformer_params) + transformer_state = jax.device_put(transformer_state, transformer_state_shardings) get_memory_allocations() states = {} @@ -432,17 +462,23 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep out_shardings=None, ) t0 = time.perf_counter() - p_run_inference(states).block_until_ready() + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + p_run_inference(states).block_until_ready() t1 = time.perf_counter() max_logging.log(f"Compile time: {t1 - t0:.1f}s.") t0 = time.perf_counter() - imgs = p_run_inference(states).block_until_ready() + with ExitStack() as stack, jax.profiler.trace("/home/jfacevedo/trace/"): + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + imgs = p_run_inference(states).block_until_ready() t1 = time.perf_counter() max_logging.log(f"Inference time: {t1 - t0:.1f}s.") t0 = time.perf_counter() - imgs = p_run_inference(states).block_until_ready() + with ExitStack() as stack: + _ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors] + imgs = p_run_inference(states).block_until_ready() imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) t1 = time.perf_counter() max_logging.log(f"Inference time: {t1 - t0:.1f}s.") @@ -453,6 +489,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep for i, image in enumerate(imgs): Image.fromarray(image).save(f"flux_{i}.png") + return imgs + def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) diff --git a/src/maxdiffusion/loaders/__init__.py b/src/maxdiffusion/loaders/__init__.py index 34133182e..2c9e973d1 100644 --- a/src/maxdiffusion/loaders/__init__.py +++ b/src/maxdiffusion/loaders/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .lora_pipeline import StableDiffusionLoraLoaderMixin +from .flux_lora_pipeline import FluxLoraLoaderMixin diff --git a/src/maxdiffusion/loaders/flux_lora_pipeline.py b/src/maxdiffusion/loaders/flux_lora_pipeline.py new file mode 100644 index 000000000..5f449ee9a --- /dev/null +++ b/src/maxdiffusion/loaders/flux_lora_pipeline.py @@ -0,0 +1,144 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, Dict +from .lora_base import LoRABaseMixin +from ..models.lora import LoRALinearLayer, BaseLoRALayer +import jax.numpy as jnp +from flax.traverse_util import flatten_dict +from ..models.modeling_flax_pytorch_utils import convert_flux_lora_pytorch_state_dict_to_flax +from huggingface_hub.utils import validate_hf_hub_args + + +class FluxLoraLoaderMixin(LoRABaseMixin): + + _lora_lodable_modules = ["transformer", "text_encoder"] + + def load_lora_weights( + self, + config, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]], + params, + adapter_name=None, + **kwargs, + ): + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + params, rank, network_alphas = self.load_lora( + config, + state_dict, + params=params, + adapter_name=adapter_name, + ) + + return params, rank, network_alphas + + def rename_for_interceptor(params_keys, network_alphas, adapter_name): + new_params_keys = [] + new_network_alphas = {} + lora_name = f"lora-{adapter_name}" + for layer_lora in params_keys: + if lora_name in layer_lora: + new_layer_lora = layer_lora[: layer_lora.index(lora_name)] + if new_layer_lora not in new_params_keys: + new_params_keys.append(new_layer_lora) + network_alpha = network_alphas.get(layer_lora, None) + new_network_alphas[new_layer_lora] = network_alpha + return new_params_keys, new_network_alphas + + @classmethod + def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name): + network_alphas_for_interceptor = {} + + transformer_keys = flatten_dict(params["transformer"]).keys() + lora_keys, transformer_alphas = cls.rename_for_interceptor(transformer_keys, network_alphas, adapter_name) + network_alphas_for_interceptor.update(transformer_alphas) + + def _intercept(next_fn, args, kwargs, context): + mod = context.module + while mod is not None: + if isinstance(mod, BaseLoRALayer): + return next_fn(*args, **kwargs) + mod = mod.parent + h = next_fn(*args, **kwargs) + if context.method_name == "__call__": + module_path = context.module.path + if module_path in lora_keys: + lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor, adapter_name) + return lora_layer(h, *args, **kwargs) + return h + + return _intercept + + @classmethod + def _get_lora_layer(cls, module_path, module, rank, network_alphas, adapter_name): + network_alpha = network_alphas.get(module_path, None) + lora_module = LoRALinearLayer( + out_features=module.features, + rank=rank, + network_alpha=network_alpha, + dtype=module.dtype, + weights_dtype=module.param_dtype, + precision=module.precision, + name=f"lora-{adapter_name}", + ) + return lora_module + + @classmethod + @validate_hf_hub_args + def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs): + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + resume_download = kwargs.pop("resume_download", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + return state_dict + + @classmethod + def load_lora(cls, config, state_dict, params, adapter_name=None): + params, rank, network_alphas = convert_flux_lora_pytorch_state_dict_to_flax(config, state_dict, params, adapter_name) + return params, rank, network_alphas diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index 812a3ff4a..7feb20ca5 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -134,7 +134,7 @@ def rename_for_interceptor(params_keys, network_alphas, adapter_name): @classmethod def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name): - # Only unet interceptor supported for now. + network_alphas_for_interceptor = {} unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys() diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index bff07988a..5035e36e4 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -144,7 +144,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None): hidden_states = self.linear2(attn_mlp) hidden_states = gate * hidden_states hidden_states = residual + hidden_states - if hidden_states.dtype == jnp.float16 or hidden_states.dtype == jnp.bfloat16: + if hidden_states.dtype == jnp.float16: hidden_states = jnp.clip(hidden_states, -65504, 65504) return hidden_states, temb, image_rotary_emb @@ -294,7 +294,7 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb= context_ff_output = self.txt_mlp(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output - if encoder_hidden_states.dtype == jnp.float16 or encoder_hidden_states.dtype == jnp.bfloat16: + if encoder_hidden_states.dtype == jnp.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) return hidden_states, encoder_hidden_states, temb, image_rotary_emb diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 86095b8ed..9552c69f1 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -223,6 +223,61 @@ def create_flax_params_from_pytorch_state( return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas +def convert_flux_lora_pytorch_state_dict_to_flax(config, pt_state_dict, params, adapter_name): + pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} + transformer_params = flatten_dict(unfreeze(params["transformer"])) + network_alphas = {} + rank = None + for pt_key, tensor in pt_state_dict.items(): + renamed_pt_key = rename_key(pt_key) + renamed_pt_key = renamed_pt_key.replace("lora_unet_", "") + renamed_pt_key = renamed_pt_key.replace("lora_down", f"lora-{adapter_name}.down") + renamed_pt_key = renamed_pt_key.replace("lora_up", f"lora-{adapter_name}.up") + + if "double_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("double_blocks.", "double_blocks_") + renamed_pt_key = renamed_pt_key.replace("processor.proj_lora1.down", f"attn.i_proj.lora-{adapter_name}.down") + renamed_pt_key = renamed_pt_key.replace("processor.proj_lora1.up", f"attn.i_proj.lora-{adapter_name}.up") + renamed_pt_key = renamed_pt_key.replace("processor.proj_lora2.down", f"attn.e_proj.lora-{adapter_name}.down") + renamed_pt_key = renamed_pt_key.replace("processor.proj_lora2.up", f"attn.e_proj.lora-{adapter_name}.up") + renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.down", f"attn.i_qkv.lora-{adapter_name}.down") + renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora1.up", f"attn.i_qkv.lora-{adapter_name}.up") + renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.down", f"attn.e_qkv.lora-{adapter_name}.down") + renamed_pt_key = renamed_pt_key.replace("processor.qkv_lora2.up", f"attn.e_qkv.lora-{adapter_name}.up") + + renamed_pt_key = renamed_pt_key.replace("_img_attn_proj", ".attn.i_proj") + renamed_pt_key = renamed_pt_key.replace("_img_attn_qkv", ".attn.i_qkv") + renamed_pt_key = renamed_pt_key.replace("_img_mlp_0", ".img_mlp.layers_0") + renamed_pt_key = renamed_pt_key.replace("_img_mlp_2", ".img_mlp.layers_2") + renamed_pt_key = renamed_pt_key.replace("_img_mod_lin", ".img_norm1.lin") + renamed_pt_key = renamed_pt_key.replace("_txt_attn_proj", ".attn.e_proj") + renamed_pt_key = renamed_pt_key.replace("_txt_attn_qkv", ".attn.e_qkv") + renamed_pt_key = renamed_pt_key.replace("_txt_mlp_0", ".txt_mlp.layers_0") + renamed_pt_key = renamed_pt_key.replace("_txt_mlp_2", ".txt_mlp.layers_2") + renamed_pt_key = renamed_pt_key.replace("_txt_mod_lin", ".txt_norm1.lin") + elif "single_blocks" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("_linear1", ".linear1") + renamed_pt_key = renamed_pt_key.replace("_linear2", ".linear2") + renamed_pt_key = renamed_pt_key.replace("_modulation_lin", ".norm.lin") + + renamed_pt_key = renamed_pt_key.replace("weight", "kernel") + + pt_tuple_key = tuple(renamed_pt_key.split(".")) + if "alpha" in pt_tuple_key: + pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "down", "kernel") + network_alphas[tuple([*pt_tuple_key])] = tensor.item() # noqa: C409 + pt_tuple_key = pt_tuple_key[:-1] + (f"lora-{adapter_name}", "up", "kernel") + network_alphas[tuple([*pt_tuple_key])] = tensor.item() # noqa: C409 + else: + if pt_tuple_key[-2] == "up": + rank = tensor.shape[1] + transformer_params[tuple([*pt_tuple_key])] = jnp.asarray(tensor.T, dtype=config.weights_dtype) # noqa: C409 + + params["transformer"] = unflatten_dict(transformer_params) + + return params, rank, network_alphas + + def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas, adapter_name): # Step 1: Convert pytorch tensor to numpy # sometimes we load weights in bf16 and numpy doesn't support it diff --git a/src/maxdiffusion/tests/generate_flux_smoke_test.py b/src/maxdiffusion/tests/generate_flux_smoke_test.py new file mode 100644 index 000000000..b8ee06f9c --- /dev/null +++ b/src/maxdiffusion/tests/generate_flux_smoke_test.py @@ -0,0 +1,87 @@ +import os +import unittest +import pytest + +import numpy as np + +from .. import pyconfig +from absl.testing import absltest +from maxdiffusion.generate_flux import run as generate_flux +from PIL import Image +from skimage.metrics import structural_similarity as ssim +from google.cloud import storage + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + +JAX_CACHE_DIR = "gs://maxdiffusion-github-runner-test-assets/cache_dir" + + +def download_blob(gcs_file, local_file): + gcs_dir_arr = gcs_file.replace("gs://", "").split("/") + storage_client = storage.Client() + bucket = storage_client.get_bucket(gcs_dir_arr[0]) + blob = bucket.blob("/".join(gcs_dir_arr[1:])) + blob.download_to_filename(local_file) + + +class GenerateFlux(unittest.TestCase): + """Smoke test.""" + + def setUp(self): + GenerateFlux.dummy_data = {} + + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + def test_flux_dev(self): + img_url = os.path.join(THIS_DIR, "images", "test_flux_dev.png") + base_image = np.array(Image.open(img_url)).astype(np.uint8) + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_flux_dev.yml"), + "run_name=flux_test", + "output_dir=/tmp/", + "jax_cache_dir=/tmp/cache_dir", + 'prompt="A cute corgi lives in a house made out of sushi, anime"', + ], + unittest=True, + ) + + images = generate_flux(pyconfig.config) + test_image = np.array(images[0]).astype(np.uint8) + ssim_compare = ssim(base_image, test_image, multichannel=True, channel_axis=-1, data_range=255) + assert base_image.shape == test_image.shape + assert ssim_compare >= 0.80 + + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") + def test_flux_dev_lora(self): + img_url = os.path.join(THIS_DIR, "images", "test_flux_dev_lora.png") + base_image = np.array(Image.open(img_url)).astype(np.uint8) + + gcs_lora_path = "gs://maxdiffusion-github-runner-test-assets/flux/lora/anime_lora.safetensors" + local_path = "/tmp/anime_lora.safetensors" + download_blob(gcs_lora_path, local_path) + + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_flux_dev.yml"), + "run_name=flux_test", + "output_dir=/tmp/", + "jax_cache_dir=/tmp/cache_dir", + 'prompt="A cute corgi lives in a house made out of sushi, anime"', + 'lora_config={"lora_model_name_or_path" : ["/tmp/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}', + ], + unittest=True, + ) + + images = generate_flux(pyconfig.config) + test_image = np.array(images[1]).astype(np.uint8) + ssim_compare = ssim(base_image, test_image, multichannel=True, channel_axis=-1, data_range=255) + assert base_image.shape == test_image.shape + assert ssim_compare >= 0.80 + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/images/test_flux_dev.png b/src/maxdiffusion/tests/images/test_flux_dev.png new file mode 100644 index 000000000..26d5af076 Binary files /dev/null and b/src/maxdiffusion/tests/images/test_flux_dev.png differ diff --git a/src/maxdiffusion/tests/images/test_flux_dev_lora.png b/src/maxdiffusion/tests/images/test_flux_dev_lora.png new file mode 100644 index 000000000..68e96c1cb Binary files /dev/null and b/src/maxdiffusion/tests/images/test_flux_dev_lora.png differ