From d706152add2a4097afcac0b8e65c0fe00f299729 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Sun, 6 Oct 2024 22:59:07 +0000 Subject: [PATCH 01/12] lora wip - skelleton --- src/maxdiffusion/loaders/__init__.py | 0 src/maxdiffusion/loaders/lora_base.py | 17 ++++ src/maxdiffusion/loaders/lora_loader.py | 33 ++++++++ src/maxdiffusion/models/__init__.py | 1 + src/maxdiffusion/models/lora.py | 84 +++++++++++++++++++ .../models/modeling_flax_pytorch_utils.py | 46 ++++++++++ src/maxdiffusion/models/modeling_utils.py | 5 ++ 7 files changed, 186 insertions(+) create mode 100644 src/maxdiffusion/loaders/__init__.py create mode 100644 src/maxdiffusion/loaders/lora_base.py create mode 100644 src/maxdiffusion/loaders/lora_loader.py create mode 100644 src/maxdiffusion/models/lora.py diff --git a/src/maxdiffusion/loaders/__init__.py b/src/maxdiffusion/loaders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/maxdiffusion/loaders/lora_base.py b/src/maxdiffusion/loaders/lora_base.py new file mode 100644 index 000000000..cd0cea969 --- /dev/null +++ b/src/maxdiffusion/loaders/lora_base.py @@ -0,0 +1,17 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class LoRABase: + def load_lora_weights(self, **kwargs): + raise NotImplementedError("`load_lora_weights()` is not implemented.") \ No newline at end of file diff --git a/src/maxdiffusion/loaders/lora_loader.py b/src/maxdiffusion/loaders/lora_loader.py new file mode 100644 index 000000000..cc8c3e7ef --- /dev/null +++ b/src/maxdiffusion/loaders/lora_loader.py @@ -0,0 +1,33 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from typing import Union, Any, Tuple, Dict + +class LoRABaseLoader: + lora_modules = [] + + def load_lora_weights(self, **kwargs): + raise NotImplementedError("`load_lora_weights()` is not implemented.") + +class StableDiffusionLoRALoader(LoRABaseLoader): + def load_lora_weights( + self, + pretrained_model_name_or_path: Union[str, os.PathLike], + from_pt=True, + **kwargs): + assert ( + from_pt == True + ), "Only Pytorch LoRA is supported right now." diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index e459be8b8..ec09a5eb1 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -30,6 +30,7 @@ from .controlnet_flax import FlaxControlNetModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL + from .lora import * else: import sys diff --git a/src/maxdiffusion/models/lora.py b/src/maxdiffusion/models/lora.py new file mode 100644 index 000000000..d33f9c676 --- /dev/null +++ b/src/maxdiffusion/models/lora.py @@ -0,0 +1,84 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import os + +from typing import Union, Any, Tuple, Dict +import jax +import jax.numpy as jnp +import flax.linen as nn + +from .modeling_utils import load_state_dict +from .modeling_flax_utils import FlaxModelMixin + +class LoRAConfig(FlaxModelMixin): + loras: dict = {} + + def load(self, + pretrained_model_name_or_path: Union[str, os.PathLike], + from_pt=True): + assert ( + from_pt == True + ), "Only Pytorch LoRA is supported right now." + + pt_state_dict = load_state_dict(pretrained_model_name_or_path) + breakpoint() + +class BaseLoRALayer(): + """ + Base LoRA layer class for all LoRA layer implementation + """ + pass + +class LoRALinearLayer(nn.Module, BaseLoRALayer): + """ + Implements LoRA linear layer + """ + in_features: int + out_features: int + rank: int = 0 + network_alpha: float = None + mesh: jax.sharding.Mesh = None + dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, hidden_states, scale): + if self.rank > min(self.in_features, self.out_features): + raise ValueError(f"LoRA rank {self.rank} mulst be less or equl to {min(self.in_features, self.out_features)}") + + down_hidden_states = nn.Dense( + features=self.rank, + use_bias=False, + kernel_init=nn.initializers.normal(stddev=1.0/self.rank), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + name="down" + )(hidden_states) + up_hidden_states = nn.Dense( + features=self.out_features, + use_bias=False, + kernel_init=nn.initializers.zeros_init(), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + name="up" + )(down_hidden_states) + if self.network_alpha: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states * scale + diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 54da2fbc8..d335bd7f8 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -19,6 +19,7 @@ import jax.numpy as jnp from flax.linen import Partitioned from flax.traverse_util import flatten_dict, unflatten_dict +from flax.core.frozen_dict import unfreeze, freeze from jax.random import PRNGKey from ..utils import logging @@ -54,6 +55,11 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic ("to_k", "key"), ("to_v", "value"), ("to_q", "query"), + ("to_k_lora", "to_k_lora"), + ("to_k_lora", "to_k_lora"), + ("to_q_lora", "to_q_lora"), + ("to_v_lora", "to_v_lora"), + ("to_out_lora", "to_out_lora") ): if pt_tuple_key[-2] == rename_from: weight_name = pt_tuple_key[-1] @@ -107,6 +113,46 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic return pt_tuple_key, pt_tensor +def create_flax_params_from_pytorch_state(pt_state_dict, flax_state_dict, is_lora=False): + rank = None + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + renamed_pt_key = rename_key(pt_key) + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, flax_state_dict) + + if is_lora: + if "lora.up" in renamed_pt_key: + rank = pt_tensor.shape[1] + + flax_key_list = list(flax_key) + flax_key_list.remove("processor") + flax_key_list.remove("unet") + flax_key = tuple(flax_key_list) + + + if flax_key in flax_state_dict: + if flax_tensor.shape != flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + + return flax_state_dict, rank + +def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, unet_params): + # Step 1: Convert pytorch tensor to numpy + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + unet_params = flatten_dict(unfreeze(unet_params)) + flax_state_dict, rank = create_flax_params_from_pytorch_state(pt_state_dict, unet_params,is_lora=True) + + return freeze(unflatten_dict(flax_state_dict)), rank def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): # Step 1: Convert pytorch tensor to numpy diff --git a/src/maxdiffusion/models/modeling_utils.py b/src/maxdiffusion/models/modeling_utils.py index f7771a02f..de42a0e71 100644 --- a/src/maxdiffusion/models/modeling_utils.py +++ b/src/maxdiffusion/models/modeling_utils.py @@ -99,6 +99,11 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: first_tuple = next(gen) return first_tuple[1].dtype +# def load_lora_state_dict(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): +# """ +# Load LoRA +# """ + def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): """ From f02a3725c9d4a2e3ee10cca6f2a4e7523062a809 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 7 Oct 2024 17:54:18 +0000 Subject: [PATCH 02/12] add lora layers to sd models. --- src/maxdiffusion/loaders/__init__.py | 13 + src/maxdiffusion/loaders/lora_base.py | 91 ++- .../loaders/lora_conversion_utils.py | 621 ++++++++++++++++++ src/maxdiffusion/loaders/lora_loader.py | 33 - src/maxdiffusion/loaders/lora_pipeline.py | 210 ++++++ src/maxdiffusion/models/attention_flax.py | 98 ++- src/maxdiffusion/models/lora.py | 30 +- .../models/unet_2d_blocks_flax.py | 44 +- .../models/unet_2d_condition_flax.py | 32 +- 9 files changed, 1099 insertions(+), 73 deletions(-) create mode 100644 src/maxdiffusion/loaders/lora_conversion_utils.py delete mode 100644 src/maxdiffusion/loaders/lora_loader.py create mode 100644 src/maxdiffusion/loaders/lora_pipeline.py diff --git a/src/maxdiffusion/loaders/__init__.py b/src/maxdiffusion/loaders/__init__.py index e69de29bb..03f797967 100644 --- a/src/maxdiffusion/loaders/__init__.py +++ b/src/maxdiffusion/loaders/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/src/maxdiffusion/loaders/lora_base.py b/src/maxdiffusion/loaders/lora_base.py index cd0cea969..69aa6f6eb 100644 --- a/src/maxdiffusion/loaders/lora_base.py +++ b/src/maxdiffusion/loaders/lora_base.py @@ -12,6 +12,93 @@ # See the License for the specific language governing permissions and # limitations under the License. -class LoRABase: +from ..models.modeling_utils import load_state_dict +from ..utils import _get_model_file + +import safetensors + +class LoRABaseMixin: + """Utility class for handing LoRAs""" + + _lora_lodable_modules = [] + num_fused_loras = 0 + def load_lora_weights(self, **kwargs): - raise NotImplementedError("`load_lora_weights()` is not implemented.") \ No newline at end of file + raise NotImplementedError("`load_lora_weights()` is not implemented.") + + @classmethod + def _fetch_state_dict( + cls, + pretrained_model_name_or_path_or_dict, + weight_name, + use_safetensors, + local_files_only, + cache_dir, + force_download, + proxies, + token, + revision, + subfolder, + user_agent, + allow_pickle, + ): + from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE + + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + # Here we're relaxing the loading check to enable more Inference API + # friendliness where sometimes, it's not at all possible to automatically + # determine `weight_name`. + if weight_name is None: + weight_name = cls._best_guess_weight_name( + pretrained_model_name_or_path_or_dict, + file_extension=".safetensors", + local_files_only=local_files_only, + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except (IOError, safetensors.SafetensorError) as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + model_file = None + pass + + if model_file is None: + if weight_name is None: + weight_name = cls._best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + return state_dict \ No newline at end of file diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py new file mode 100644 index 000000000..b47e03562 --- /dev/null +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -0,0 +1,621 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import torch + +from ..utils import is_peft_version, logging + +logger = logging.get_logger(__name__) + +def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5): + # 1. get all state_dict_keys + all_keys = list(state_dict.keys()) + sgm_patterns = ["input_blocks", "middle_block", "output_blocks"] + + # 2. check if needs remapping, if not return original dict + is_in_sgm_format = False + for key in all_keys: + if any(p in key for p in sgm_patterns): + is_in_sgm_format = True + break + + if not is_in_sgm_format: + return state_dict + + # 3. Else remap from SGM patterns + new_state_dict = {} + inner_block_map = ["resnets", "attentions", "upsamplers"] + + # Retrieves # of down, mid and up blocks + input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() + + for layer in all_keys: + if "text" in layer: + new_state_dict[layer] = state_dict.pop(layer) + else: + layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) + if sgm_patterns[0] in layer: + input_block_ids.add(layer_id) + elif sgm_patterns[1] in layer: + middle_block_ids.add(layer_id) + elif sgm_patterns[2] in layer: + output_block_ids.add(layer_id) + else: + raise ValueError(f"Checkpoint not supported because layer {layer} not supported.") + + input_blocks = { + layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] + for layer_id in input_block_ids + } + middle_blocks = { + layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] + for layer_id in middle_block_ids + } + output_blocks = { + layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key] + for layer_id in output_block_ids + } + + # Rename keys accordingly + for i in input_block_ids: + block_id = (i - 1) // (unet_config.layers_per_block + 1) + layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1) + + for key in input_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" + inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in middle_block_ids: + key_part = None + if i == 0: + key_part = [inner_block_map[0], "0"] + elif i == 1: + key_part = [inner_block_map[1], "0"] + elif i == 2: + key_part = [inner_block_map[0], "1"] + else: + raise ValueError(f"Invalid middle block id {i}.") + + for key in middle_blocks[i]: + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in output_block_ids: + block_id = i // (unet_config.layers_per_block + 1) + layer_in_block_id = i % (unet_config.layers_per_block + 1) + + for key in output_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] + inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + if len(state_dict) > 0: + raise ValueError("At this point all state dict entries have to be converted.") + + return new_state_dict + + +def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): + """ + Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict. + + Args: + state_dict (`dict`): The state dict to convert. + unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet". + text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to + "text_encoder". + + Returns: + `tuple`: A tuple containing the converted state dict and a dictionary of alphas. + """ + unet_state_dict = {} + te_state_dict = {} + te2_state_dict = {} + network_alphas = {} + + # Check for DoRA-enabled LoRAs. + dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict) + dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) + dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) + if dora_present_in_unet or dora_present_in_te or dora_present_in_te2: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + + # Iterate over all LoRA weights. + all_lora_keys = list(state_dict.keys()) + for key in all_lora_keys: + if not key.endswith("lora_down.weight"): + continue + + # Extract LoRA name. + lora_name = key.split(".")[0] + + # Find corresponding up weight and alpha. + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" + + # Handle U-Net LoRAs. + if lora_name.startswith("lora_unet_"): + diffusers_name = _convert_unet_lora_key(key) + + # Store down and up weights. + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # Store DoRA scale if present. + if dora_present_in_unet: + dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." + unet_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + + # Handle text encoder LoRAs. + elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): + diffusers_name = _convert_text_encoder_lora_key(key, lora_name) + + # Store down and up weights for te or te2. + if lora_name.startswith(("lora_te_", "lora_te1_")): + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + else: + te2_state_dict[diffusers_name] = state_dict.pop(key) + te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # Store DoRA scale if present. + if dora_present_in_te or dora_present_in_te2: + dora_scale_key_to_replace_te = ( + "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." + ) + if lora_name.startswith(("lora_te_", "lora_te1_")): + te_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + elif lora_name.startswith("lora_te2_"): + te2_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + + # Store alpha if present. + if lora_name_alpha in state_dict: + alpha = state_dict.pop(lora_name_alpha).item() + network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha)) + + # Check if any keys remain. + if len(state_dict) > 0: + raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") + + logger.info("Non-diffusers checkpoint detected.") + + # Construct final state dict. + unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} + te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()} + te2_state_dict = ( + {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()} + if len(te2_state_dict) > 0 + else None + ) + if te2_state_dict is not None: + te_state_dict.update(te2_state_dict) + + new_state_dict = {**unet_state_dict, **te_state_dict} + return new_state_dict, network_alphas + + +def _convert_unet_lora_key(key): + """ + Converts a U-Net LoRA key to a Diffusers compatible key. + """ + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + + # Replace common U-Net naming patterns. + diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("middle.block", "mid_block") + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") + diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") + + # SDXL specific conversions. + if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: + pattern = r"\.\d+(?=\D*$)" + diffusers_name = re.sub(pattern, "", diffusers_name, count=1) + if ".in." in diffusers_name: + diffusers_name = diffusers_name.replace("in.layers.2", "conv1") + if ".out." in diffusers_name: + diffusers_name = diffusers_name.replace("out.layers.3", "conv2") + if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: + diffusers_name = diffusers_name.replace("op", "conv") + if "skip" in diffusers_name: + diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") + + # LyCORIS specific conversions. + if "time.emb.proj" in diffusers_name: + diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") + if "conv.shortcut" in diffusers_name: + diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") + + # General conversions. + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + elif "ff" in diffusers_name: + pass + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + pass + else: + pass + + return diffusers_name + + +def _convert_text_encoder_lora_key(key, lora_name): + """ + Converts a text encoder LoRA key to a Diffusers compatible key. + """ + if lora_name.startswith(("lora_te_", "lora_te1_")): + key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" + else: + key_to_replace = "lora_te2_" + + diffusers_name = key.replace(key_to_replace, "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("text.projection", "text_projection") + + if "self_attn" in diffusers_name or "text_projection" in diffusers_name: + pass + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + return diffusers_name + + +def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): + """ + Gets the correct alpha name for the Diffusers model. + """ + if lora_name_alpha.startswith("lora_unet_"): + prefix = "unet." + elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): + prefix = "text_encoder." + else: + prefix = "text_encoder_2." + new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" + return {new_name: alpha} + + +# The utilities under `_convert_kohya_flux_lora_to_diffusers()` +# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py +# All credits go to `kohya-ss`. +def _convert_kohya_flux_lora_to_diffusers(state_dict): + def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] + + # scale weight by alpha and dim + alpha = sds_sd.pop(sds_key + ".alpha") + scale = alpha / sd_lora_rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all( + up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0 + ) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] + + def _convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(19): + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_out.0", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_0", + f"transformer.transformer_blocks.{i}.ff.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_2", + f"transformer.transformer_blocks.{i}.ff.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mod_lin", + f"transformer.transformer_blocks.{i}.norm1.linear", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_add_out", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_0", + f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_2", + f"transformer.transformer_blocks.{i}.ff_context.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mod_lin", + f"transformer.transformer_blocks.{i}.norm1_context.linear", + ) + + for i in range(38): + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear2", + f"transformer.single_transformer_blocks.{i}.proj_out", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_modulation_lin", + f"transformer.single_transformer_blocks.{i}.norm.linear", + ) + + if len(sds_sd) > 0: + logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + + return ait_sd + + return _convert_sd_scripts_to_ai_toolkit(state_dict) + + +# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6 +# Some utilities were reused from +# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py +def _convert_xlabs_flux_lora_to_diffusers(old_state_dict): + new_state_dict = {} + orig_keys = list(old_state_dict.keys()) + + def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + down_weight = sds_sd.pop(sds_key) + up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight")) + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 + + for old_key in orig_keys: + # Handle double_blocks + if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")): + block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1) + new_key = f"transformer.transformer_blocks.{block_num}" + + if "processor.proj_lora1" in old_key: + new_key += ".attn.to_out.0" + elif "processor.proj_lora2" in old_key: + new_key += ".attn.to_add_out" + # Handle text latents. + elif "processor.qkv_lora2" in old_key and "up" not in old_key: + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [ + f"transformer.transformer_blocks.{block_num}.attn.add_q_proj", + f"transformer.transformer_blocks.{block_num}.attn.add_k_proj", + f"transformer.transformer_blocks.{block_num}.attn.add_v_proj", + ], + ) + # continue + # Handle image latents. + elif "processor.qkv_lora1" in old_key and "up" not in old_key: + handle_qkv( + old_state_dict, + new_state_dict, + old_key, + [ + f"transformer.transformer_blocks.{block_num}.attn.to_q", + f"transformer.transformer_blocks.{block_num}.attn.to_k", + f"transformer.transformer_blocks.{block_num}.attn.to_v", + ], + ) + # continue + + if "down" in old_key: + new_key += ".lora_A.weight" + elif "up" in old_key: + new_key += ".lora_B.weight" + + # Handle single_blocks + elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"): + block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) + new_key = f"transformer.single_transformer_blocks.{block_num}" + + if "proj_lora1" in old_key or "proj_lora2" in old_key: + new_key += ".proj_out" + elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: + new_key += ".norm.linear" + + if "down" in old_key: + new_key += ".lora_A.weight" + elif "up" in old_key: + new_key += ".lora_B.weight" + + else: + # Handle other potential key patterns here + new_key = old_key + + # Since we already handle qkv above. + if "qkv" not in old_key: + new_state_dict[new_key] = old_state_dict.pop(old_key) + + if len(old_state_dict) > 0: + raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") + + return new_state_dict diff --git a/src/maxdiffusion/loaders/lora_loader.py b/src/maxdiffusion/loaders/lora_loader.py deleted file mode 100644 index cc8c3e7ef..000000000 --- a/src/maxdiffusion/loaders/lora_loader.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -from typing import Union, Any, Tuple, Dict - -class LoRABaseLoader: - lora_modules = [] - - def load_lora_weights(self, **kwargs): - raise NotImplementedError("`load_lora_weights()` is not implemented.") - -class StableDiffusionLoRALoader(LoRABaseLoader): - def load_lora_weights( - self, - pretrained_model_name_or_path: Union[str, os.PathLike], - from_pt=True, - **kwargs): - assert ( - from_pt == True - ), "Only Pytorch LoRA is supported right now." diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py new file mode 100644 index 000000000..8cbec4359 --- /dev/null +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -0,0 +1,210 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, Dict +import jax.numpy as jnp +from .lora_base import LoRABaseMixin +from .lora_conversion_utils import ( + _convert_non_diffusers_lora_to_diffusers, + _maybe_map_sgm_blocks_to_diffusers, +) +from huggingface_hub.utils import validate_hf_hub_args + +TEXT_ENCODER_NAME = "text_encoder" +UNET_NAME = "unet" +TRANSFORMER_NAME = "transformer" + +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" + +class StableDiffusionLoraLoaderMixin(LoRABaseMixin): + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]], adapter_name=None, **kwargs): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is + loaded into `self.unet`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state + dict is loaded into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_unet( + state_dict, + network_alphas=network_alphas, + unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, + adapter_name=adapter_name, + _pipeline=self, + ) + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path: str, + **kwargs + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + """ + # Load the main state dict first which has the LoRA layers for either of + # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + network_alphas = None + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) + + return state_dict, network_alphas + + @classmethod + def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `unet`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + keys = list(state_dict.keys()) + only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) + if not only_text_encoder: + # Load the layers corresponding to Unet. + unet_params, rank = convert_lora_pytorch_state_dict_to_flax(state_dict, unet_params) + unet_config["lora_rank"] = rank + unet_model = FlaxUNet2DConditionModel.from_config(unet_config) \ No newline at end of file diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 1aa800c21..7ac1b00dc 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -23,6 +23,7 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from .. import common_types, max_logging +from .lora import LoRALinearLayer Array = common_types.Array Mesh = common_types.Mesh @@ -372,6 +373,8 @@ class FlaxAttention(nn.Module): value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) out_axis_names: AxisNames = (BATCH, LENGTH, HEAD) precision: jax.lax.Precision = None + lora_rank: int = 0 + lora_network_alpha: float = None def setup(self): @@ -435,7 +438,49 @@ def setup(self): ) self.dropout_layer = nn.Dropout(rate=self.dropout) - def __call__(self, hidden_states, context=None, deterministic=True): + if self.lora_rank > 0: + self.to_q_lora = LoRALinearLayer( + in_features=inner_dim, + out_features=inner_dim, + rank=self.lora_rank, + network_alpha=self.lora_network_alpha, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + ) + self.to_k_lora = LoRALinearLayer( + in_features=inner_dim, + out_features=inner_dim, + rank=self.lora_rank, + network_alpha=self.lora_network_alpha, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + ) + self.to_v_lora = LoRALinearLayer( + in_features=inner_dim, + out_features=inner_dim, + rank=self.lora_rank, + network_alpha=self.lora_network_alpha, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + ) + self.to_out_lora = LoRALinearLayer( + in_features=inner_dim, + out_features=inner_dim, + rank=self.lora_rank, + network_alpha=self.lora_network_alpha, + mesh=self.mesh, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision + ) + + def __call__(self, hidden_states, context=None, deterministic=True, cross_attention_kwargs=None): context = hidden_states if context is None else context query_proj = self.query(hidden_states) key_proj = self.key(context) @@ -445,13 +490,18 @@ def __call__(self, hidden_states, context=None, deterministic=True): key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) + if self.lora_rank > 0: + lora_scale = cross_attention_kwargs.get("scale", 0.0) + query_proj = query_proj + self.to_q_lora(hidden_states, lora_scale) + key_proj = key_proj + self.to_k_lora(context, lora_scale) + value_proj = value_proj + self.to_v_lora(context, lora_scale) + hidden_states = self.attention_op.apply_attention(query_proj, key_proj, value_proj) - hidden_states = self.proj_attn(hidden_states) + hidden_states = self.proj_attn(hidden_states) + 0 if self.lora_rank <=0 else self.to_out_lora(hidden_states, lora_scale) hidden_states = nn.with_logical_constraint(hidden_states, (BATCH, LENGTH, HEAD)) return self.dropout_layer(hidden_states, deterministic=deterministic) - class FlaxBasicTransformerBlock(nn.Module): r""" A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: @@ -500,6 +550,8 @@ class FlaxBasicTransformerBlock(nn.Module): flash_block_sizes: BlockSizes = None mesh: jax.sharding.Mesh = None precision: jax.lax.Precision = None + lora_rank: int = 0 + lora_network_alpha: float = None def setup(self): # self attention (or cross_attention if only_cross_attention is True) @@ -517,6 +569,8 @@ def setup(self): dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, + lora_rank=self.lora_rank, + lora_network_alpha=self.lora_network_alpha ) # cross attention self.attn2 = FlaxAttention( @@ -533,6 +587,8 @@ def setup(self): dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, + lora_rank=self.lora_rank, + lora_network_alpha=self.lora_network_alpha ) self.ff = FlaxFeedForward( dim=self.dim, dropout=self.dropout, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision @@ -542,18 +598,33 @@ def setup(self): self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype) self.dropout_layer = nn.Dropout(rate=self.dropout) - def __call__(self, hidden_states, context, deterministic=True): + def __call__(self, hidden_states, context, deterministic=True, cross_attention_kwargs=None): # self attention residual = hidden_states if self.only_cross_attention: - hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) + hidden_states = self.attn1( + self.norm1(hidden_states), + context, + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs + ) else: - hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) + hidden_states = self.attn1( + self.norm1(hidden_states), + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs + ) + hidden_states = hidden_states + residual # cross attention residual = hidden_states - hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) + hidden_states = self.attn2( + self.norm2(hidden_states), + context, + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs + ) hidden_states = hidden_states + residual # feed forward @@ -618,6 +689,8 @@ class FlaxTransformer2DModel(nn.Module): norm_num_groups: int = 32 precision: jax.lax.Precision = None hidden_state_axis_names: AxisNames = (BATCH, LENGTH, D_KV) + lora_rank: int = 0 + lora_network_alpha: float = None def setup(self): self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype) @@ -663,6 +736,8 @@ def setup(self): flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, precision=self.precision, + lora_rank=self.lora_rank, + lora_network_alpha=self.lora_network_alpha ) for _ in range(self.depth) ] @@ -689,7 +764,7 @@ def setup(self): self.dropout_layer = nn.Dropout(rate=self.dropout) - def __call__(self, hidden_states, context, deterministic=True): + def __call__(self, hidden_states, context, deterministic=True, cross_attention_kwargs=None): batch, height, width, channels = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) @@ -701,7 +776,12 @@ def __call__(self, hidden_states, context, deterministic=True): hidden_states = hidden_states.reshape(batch, height * width, channels) for transformer_block in self.transformer_blocks: - hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) + hidden_states = transformer_block( + hidden_states, + context, + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs + ) if self.use_linear_projection: hidden_states = self.proj_out(hidden_states) diff --git a/src/maxdiffusion/models/lora.py b/src/maxdiffusion/models/lora.py index d33f9c676..90aa79cc9 100644 --- a/src/maxdiffusion/models/lora.py +++ b/src/maxdiffusion/models/lora.py @@ -15,27 +15,10 @@ """ import os -from typing import Union, Any, Tuple, Dict import jax import jax.numpy as jnp import flax.linen as nn -from .modeling_utils import load_state_dict -from .modeling_flax_utils import FlaxModelMixin - -class LoRAConfig(FlaxModelMixin): - loras: dict = {} - - def load(self, - pretrained_model_name_or_path: Union[str, os.PathLike], - from_pt=True): - assert ( - from_pt == True - ), "Only Pytorch LoRA is supported right now." - - pt_state_dict = load_state_dict(pretrained_model_name_or_path) - breakpoint() - class BaseLoRALayer(): """ Base LoRA layer class for all LoRA layer implementation @@ -52,17 +35,21 @@ class LoRALinearLayer(nn.Module, BaseLoRALayer): network_alpha: float = None mesh: jax.sharding.Mesh = None dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 precision: jax.lax.Precision = None @nn.compact def __call__(self, hidden_states, scale): if self.rank > min(self.in_features, self.out_features): - raise ValueError(f"LoRA rank {self.rank} mulst be less or equl to {min(self.in_features, self.out_features)}") + raise ValueError(f"LoRA rank {self.rank} must be less or equal to {min(self.in_features, self.out_features)}") down_hidden_states = nn.Dense( features=self.rank, use_bias=False, - kernel_init=nn.initializers.normal(stddev=1.0/self.rank), + kernel_init=nn.with_logical_partitioning( + nn.initializers.normal(stddev=1.0/self.rank), + ('embed', 'heads') + ), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -71,7 +58,10 @@ def __call__(self, hidden_states, scale): up_hidden_states = nn.Dense( features=self.out_features, use_bias=False, - kernel_init=nn.initializers.zeros_init(), + kernel_init=nn.with_logical_partitioning( + nn.initializers.zeros_init(), + ('embed', 'heads') + ), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, diff --git a/src/maxdiffusion/models/unet_2d_blocks_flax.py b/src/maxdiffusion/models/unet_2d_blocks_flax.py index 08a64497b..bef9eef2c 100644 --- a/src/maxdiffusion/models/unet_2d_blocks_flax.py +++ b/src/maxdiffusion/models/unet_2d_blocks_flax.py @@ -54,6 +54,10 @@ class FlaxCrossAttnDownBlock2D(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + lora_rank (`int`, *optional*, defaults to 0): + The dimension of the LoRA update matrices. + lora_network_alpha(`float`, *optional*, defaults to None) + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ in_channels: int @@ -75,6 +79,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): transformer_layers_per_block: int = 1 norm_num_groups: int = 32 precision: jax.lax.Precision = None + lora_rank: int = 0 + lora_network_alpha: float = None def setup(self): resnets = [] @@ -90,7 +96,7 @@ def setup(self): dtype=self.dtype, weights_dtype=self.weights_dtype, norm_num_groups=self.norm_num_groups, - precision=self.precision, + precision=self.precision ) resnets.append(res_block) @@ -111,6 +117,8 @@ def setup(self): weights_dtype=self.weights_dtype, norm_num_groups=self.norm_num_groups, precision=self.precision, + lora_rank=self.lora_rank, + lora_network_alpha=self.lora_network_alpha ) attentions.append(attn_block) @@ -120,12 +128,13 @@ def setup(self): if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype, weights_dtype=self.weights_dtype) - def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True, cross_attention_kwargs=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb, deterministic=deterministic) - hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs) output_states += (hidden_states,) if self.add_downsample: @@ -232,6 +241,10 @@ class FlaxCrossAttnUpBlock2D(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + lora_rank (`int`, *optional*, defaults to 0): + The dimension of the LoRA update matrices. + lora_network_alpha(`float`, *optional*, defaults to None) + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ in_channels: int @@ -254,6 +267,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): transformer_layers_per_block: int = 1 norm_num_groups: int = 32 precision: jax.lax.Precision = None + lora_rank: int = 0 + lora_network_alpha: float = None def setup(self): resnets = [] @@ -291,6 +306,8 @@ def setup(self): weights_dtype=self.weights_dtype, norm_num_groups=self.norm_num_groups, precision=self.precision, + lora_rank=self.lora_rank, + lora_network_alpha=self.lora_network_alpha ) attentions.append(attn_block) @@ -300,7 +317,7 @@ def setup(self): if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype, weights_dtype=self.weights_dtype) - def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): + def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True, cross_attention_kwargs=None): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -308,7 +325,7 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) - hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) @@ -414,6 +431,10 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + lora_rank (`int`, *optional*, defaults to 0): + The dimension of the LoRA update matrices. + lora_network_alpha(`float`, *optional*, defaults to None) + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ in_channels: int @@ -432,6 +453,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): transformer_layers_per_block: int = 1 norm_num_groups: int = 32 precision: jax.lax.Precision = None + lora_rank: int = 0 + lora_network_alpha: float = None def setup(self): # there is always at least one resnet @@ -466,6 +489,8 @@ def setup(self): weights_dtype=self.weights_dtype, norm_num_groups=self.norm_num_groups, precision=self.precision, + lora_rank=self.lora_rank, + lora_network_alpha=self.lora_network_alpha ) attentions.append(attn_block) @@ -483,10 +508,15 @@ def setup(self): self.resnets = resnets self.attentions = attentions - def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True, cross_attention_kwargs=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = attn( + hidden_states, + encoder_hidden_states, + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs + ) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) return hidden_states diff --git a/src/maxdiffusion/models/unet_2d_condition_flax.py b/src/maxdiffusion/models/unet_2d_condition_flax.py index a93af0ac0..41aa2b4ad 100644 --- a/src/maxdiffusion/models/unet_2d_condition_flax.py +++ b/src/maxdiffusion/models/unet_2d_condition_flax.py @@ -105,6 +105,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): Overrides default block sizes for flash attention. mesh (`jax.sharding.mesh`, *optional*, defaults to `None`): jax mesh is required if attention is set to flash. + lora_rank (`int`, *optional*, defaults to 0): + The dimension of the LoRA update matrices. + lora_network_alpha(`float`, *optional*, defaults to None) + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ sample_size: int = 32 @@ -142,6 +146,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): projection_class_embeddings_input_dim: Optional[int] = None norm_num_groups: int = 32 precision: jax.lax.Precision = None + lora_rank: Optional[int] = 0 + lora_network_alpha: Optional[float] = None def init_weights(self, rng: jax.Array, eval_only: bool = False) -> FrozenDict: # init input tensors @@ -280,6 +286,8 @@ def setup(self): dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, + lora_rank=self.lora_rank, + lora_network_alpha=self.lora_network_alpha ) else: down_block = FlaxDownBlock2D( @@ -312,6 +320,8 @@ def setup(self): dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, + lora_rank=self.lora_rank, + lora_network_alpha=self.lora_network_alpha ) # up @@ -349,6 +359,8 @@ def setup(self): dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, + lora_rank=self.lora_rank, + lora_network_alpha=self.lora_network_alpha ) else: up_block = FlaxUpBlock2D( @@ -392,6 +404,7 @@ def __call__( mid_block_additional_residual=None, return_dict: bool = True, train: bool = False, + cross_attention_kwargs: Optional[Union[Dict, FrozenDict]] = None, ) -> Union[FlaxUNet2DConditionOutput, Tuple]: r""" Args: @@ -410,6 +423,8 @@ def __call__( plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. + cross_attention_kwargs: (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to FlaxAttention. Returns: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: @@ -461,7 +476,13 @@ def __call__( down_block_res_samples = (sample,) for down_block in self.down_blocks: if isinstance(down_block, FlaxCrossAttnDownBlock2D): - sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + sample, res_samples = down_block( + sample, + t_emb, + encoder_hidden_states, + deterministic=not train, + cross_attention_kwargs=cross_attention_kwargs + ) else: sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples @@ -478,7 +499,13 @@ def __call__( down_block_res_samples = new_down_block_res_samples # 4. mid - sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + sample = self.mid_block( + sample, + t_emb, + encoder_hidden_states, + deterministic=not train, + cross_attention_kwargs=cross_attention_kwargs + ) if mid_block_additional_residual is not None: sample += mid_block_additional_residual @@ -494,6 +521,7 @@ def __call__( encoder_hidden_states=encoder_hidden_states, res_hidden_states_tuple=res_samples, deterministic=not train, + cross_attention_kwargs=cross_attention_kwargs ) else: sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) From b73222fe95f85af7589ac4353e494db505be3ff9 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 11 Oct 2024 00:48:44 +0000 Subject: [PATCH 03/12] prototype lora with flax interceptors. --- src/maxdiffusion/loaders/__init__.py | 4 +- src/maxdiffusion/loaders/lora_base.py | 33 ++++--- .../loaders/lora_conversion_utils.py | 19 ++-- src/maxdiffusion/loaders/lora_pipeline.py | 96 +++++++++++++++++-- src/maxdiffusion/models/lora.py | 81 +++++++++++++--- .../models/modeling_flax_pytorch_utils.py | 45 +++++---- .../pipeline_flax_stable_diffusion.py | 5 +- .../pipeline_flax_stable_diffusion_xl.py | 5 +- 8 files changed, 222 insertions(+), 66 deletions(-) diff --git a/src/maxdiffusion/loaders/__init__.py b/src/maxdiffusion/loaders/__init__.py index 03f797967..b4f994eb0 100644 --- a/src/maxdiffusion/loaders/__init__.py +++ b/src/maxdiffusion/loaders/__init__.py @@ -10,4 +10,6 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. + +from .lora_pipeline import StableDiffusionLoraLoaderMixin \ No newline at end of file diff --git a/src/maxdiffusion/loaders/lora_base.py b/src/maxdiffusion/loaders/lora_base.py index 69aa6f6eb..33cb885d3 100644 --- a/src/maxdiffusion/loaders/lora_base.py +++ b/src/maxdiffusion/loaders/lora_base.py @@ -35,12 +35,13 @@ def _fetch_state_dict( local_files_only, cache_dir, force_download, + resume_download, proxies, - token, + use_auth_token, revision, subfolder, user_agent, - allow_pickle, + allow_pickle ): from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE @@ -65,12 +66,13 @@ def _fetch_state_dict( weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, + resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, - token=token, + use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, - user_agent=user_agent, + user_agent=user_agent ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except (IOError, safetensors.SafetensorError) as e: @@ -86,17 +88,18 @@ def _fetch_state_dict( pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only ) model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent + ) state_dict = load_state_dict(model_file) else: state_dict = pretrained_model_name_or_path_or_dict diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py index b47e03562..6ed6350b8 100644 --- a/src/maxdiffusion/loaders/lora_conversion_utils.py +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -14,11 +14,9 @@ import re -import torch - -from ..utils import is_peft_version, logging +from .. import max_logging -logger = logging.get_logger(__name__) +import torch def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5): # 1. get all state_dict_keys @@ -146,10 +144,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) if dora_present_in_unet or dora_present_in_te or dora_present_in_te2: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) + raise ValueError( + "DoRA is not currently supported" + ) # Iterate over all LoRA weights. all_lora_keys = list(state_dict.keys()) @@ -214,7 +211,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ if len(state_dict) > 0: raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") - logger.info("Non-diffusers checkpoint detected.") + max_logging.log("Non-diffusers checkpoint detected.") # Construct final state dict. unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} @@ -395,7 +392,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): ) i += dims[j] if is_sparse: - logger.info(f"weight is sparse: {sds_key}") + max_logging.log(f"weight is sparse: {sds_key}") # make ai-toolkit weight ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] @@ -515,7 +512,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd): ) if len(sds_sd) > 0: - logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + max_logging.log(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") return ait_sd diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index 8cbec4359..9b4d132cc 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -14,11 +14,14 @@ from typing import Union, Dict import jax.numpy as jnp +from flax.core.frozen_dict import unfreeze from .lora_base import LoRABaseMixin +from ..models.lora import LoRALinearLayer, LoRAConv2DLayer from .lora_conversion_utils import ( _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, ) +from ..models.modeling_flax_pytorch_utils import convert_lora_pytorch_state_dict_to_flax from huggingface_hub.utils import validate_hf_hub_args TEXT_ENCODER_NAME = "text_encoder" @@ -29,7 +32,21 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" class StableDiffusionLoraLoaderMixin(LoRABaseMixin): - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]], adapter_name=None, **kwargs): + r""" + Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + """ + + _lora_lodable_modules = ["unet", "text_encoder"] + unet_name = UNET_NAME + text_encoder_name = TEXT_ENCODER_NAME + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]], + params, + adapter_name=None, + **kwargs): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. @@ -64,13 +81,67 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_unet( + unet_lora_params, rank = self.load_lora_into_unet( state_dict, network_alphas=network_alphas, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, + params=params, adapter_name=adapter_name, _pipeline=self, ) + return unfreeze(unet_lora_params), rank + + @classmethod + def _get_lora_layer(cls, module_path, module, rank): + # TODO - here we create either Linear or Conv layers + is_conv = any('conv' in str_ for str_ in module_path) + if is_conv: + lora_module = LoRAConv2DLayer( + out_features=module.features, + rank=rank, + kernel_size=module.kernel_size, + strides=module.strides, + padding=module.padding, + dtype=module.dtype, + weights_dtype=module.param_dtype, + precision=module.precision, + name="lora" + ) + else: + lora_module = LoRALinearLayer( + out_features=module.features, + rank=rank, + dtype=module.dtype, + weights_dtype=module.param_dtype, + precision=module.precision, + name="lora" + ) + return lora_module + + @classmethod + def make_lora_interceptor( + cls, + params_keys, + rank + ): + tmp = [] + for layer_lora in params_keys: + if 'lora' in layer_lora: + print(layer_lora) + new_layer_lora = layer_lora[:layer_lora.index('lora')] + if new_layer_lora not in tmp: + tmp.append(new_layer_lora) + params_keys = tmp + def _intercept(next_fn, args, kwargs, context): + h = next_fn(*args, **kwargs) + if context.method_name == '__call__': + module_path = context.module.path + if module_path in params_keys: + print(f"module_path: {module_path}") + lora_layer = cls._get_lora_layer(module_path, context.module, rank) + return lora_layer(h, *args, **kwargs) + return h + return _intercept @classmethod @validate_hf_hub_args @@ -131,12 +202,13 @@ def lora_state_dict( force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) + use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) + resume_download = kwargs.pop("resume_download", False) allow_pickle = False if use_safetensors is None: @@ -155,8 +227,9 @@ def lora_state_dict( local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, + resume_download=resume_download, proxies=proxies, - token=token, + use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, @@ -182,7 +255,15 @@ def lora_state_dict( return state_dict, network_alphas @classmethod - def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + def load_lora_into_unet( + cls, + state_dict, + network_alphas, + unet, + params, + adapter_name=None, + _pipeline=None + ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -205,6 +286,5 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) if not only_text_encoder: # Load the layers corresponding to Unet. - unet_params, rank = convert_lora_pytorch_state_dict_to_flax(state_dict, unet_params) - unet_config["lora_rank"] = rank - unet_model = FlaxUNet2DConditionModel.from_config(unet_config) \ No newline at end of file + unet_lora_params, rank = convert_lora_pytorch_state_dict_to_flax(state_dict, params["unet"]) + return unet_lora_params, rank \ No newline at end of file diff --git a/src/maxdiffusion/models/lora.py b/src/maxdiffusion/models/lora.py index 90aa79cc9..249065c59 100644 --- a/src/maxdiffusion/models/lora.py +++ b/src/maxdiffusion/models/lora.py @@ -15,6 +15,7 @@ """ import os +from typing import Union, Tuple, Optional import jax import jax.numpy as jnp import flax.linen as nn @@ -29,27 +30,27 @@ class LoRALinearLayer(nn.Module, BaseLoRALayer): """ Implements LoRA linear layer """ - in_features: int out_features: int rank: int = 0 network_alpha: float = None - mesh: jax.sharding.Mesh = None dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 precision: jax.lax.Precision = None + axis_names=('embed', 'heads') #default for qkv and to_out, set it otherwise. @nn.compact - def __call__(self, hidden_states, scale): - if self.rank > min(self.in_features, self.out_features): + def __call__(self, h, hidden_states): + if self.rank > self.out_features: raise ValueError(f"LoRA rank {self.rank} must be less or equal to {min(self.in_features, self.out_features)}") down_hidden_states = nn.Dense( features=self.rank, use_bias=False, - kernel_init=nn.with_logical_partitioning( - nn.initializers.normal(stddev=1.0/self.rank), - ('embed', 'heads') - ), + kernel_init=nn.initializers.kaiming_uniform(), + # kernel_init=nn.with_logical_partitioning( + # nn.initializers.kaiming_uniform(), + # self.axis_names + # ), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -58,10 +59,11 @@ def __call__(self, hidden_states, scale): up_hidden_states = nn.Dense( features=self.out_features, use_bias=False, - kernel_init=nn.with_logical_partitioning( - nn.initializers.zeros_init(), - ('embed', 'heads') - ), + kernel_init=nn.initializers.zeros_init(), + # kernel_init=nn.with_logical_partitioning( + # nn.initializers.zeros_init(), + # self.axis_names + # ), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -70,5 +72,58 @@ def __call__(self, hidden_states, scale): if self.network_alpha: up_hidden_states *= self.network_alpha / self.rank - return up_hidden_states * scale + return h + up_hidden_states +class LoRAConv2DLayer(nn.Module): + """ + Implements LoRA Conv layer + """ + out_features: int + rank: int = 4 + kernel_size: Union[int, Tuple[int, int]] = (1,1) + strides: Union[int, Tuple[int, int]] = (1, 1) + padding: Union[int, Tuple[int, int], str] = 0 + network_alpha: Optional[float] = None + dtype: jnp.dtype = jnp.float32 + weights_dtype: jnp.dtype = jnp.float32 + precision: jax.lax.Precision = None + + @nn.compact + def __call__(self, h, hidden_states): + down_hidden_states = nn.Conv( + self.rank, + kernel_size=self.kernel_size, + strides=self.strides, + padding=self.padding, + use_bias=False, + dtype=self.dtype, + param_dtype=self.weights_dtype, + kernel_init=nn.initializers.kaiming_uniform(), + # kernel_init=nn.with_logical_partitioning( + # nn.initializers.kaiming_uniform(),, + # ("keep_1", "keep_2", "conv_in", "conv_out") + # ), + precision=self.precision, + name="down" + )(hidden_states) + + up_hidden_states = nn.Conv( + self.out_features, + use_bias=False, + kernel_size=(1, 1), + strides=(1, 1), + dtype=self.dtype, + param_dtype=self.weights_dtype, + kernel_init=nn.initializers.zeros_init(), + # kernel_init=nn.with_logical_partitioning( + # nn.initializers.zeros_init(), + # ("keep_1", "keep_2", "conv_in", "conv_out") + # ), + precision=self.precision, + name="up" + )(down_hidden_states) + + if self.network_alpha: + up_hidden_states *= self.network_alpha / self.rank + + return h + up_hidden_states \ No newline at end of file diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index d335bd7f8..8b8861efe 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -55,11 +55,6 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic ("to_k", "key"), ("to_v", "value"), ("to_q", "query"), - ("to_k_lora", "to_k_lora"), - ("to_k_lora", "to_k_lora"), - ("to_q_lora", "to_q_lora"), - ("to_v_lora", "to_v_lora"), - ("to_out_lora", "to_out_lora") ): if pt_tuple_key[-2] == rename_from: weight_name = pt_tuple_key[-1] @@ -119,19 +114,38 @@ def create_flax_params_from_pytorch_state(pt_state_dict, flax_state_dict, is_lor for pt_key, pt_tensor in pt_state_dict.items(): renamed_pt_key = rename_key(pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) + flax_key_list = [*pt_tuple_key] + for rename_from, rename_to in ( + ("to_k_lora", ("to_k", "lora")), + ("to_q_lora", ("to_q", "lora")), + ("to_v_lora", ("to_v", "lora")), + ("to_out_lora", ("to_out_0", "lora")), + ("weight", "kernel") + ): + # for readability + tmp = [] + for s in flax_key_list: + if s == rename_from: + if type(rename_to) is tuple: + for s_in_tuple in rename_to: + tmp.append(s_in_tuple) + else: + tmp.append(rename_to) + else: + tmp.append(s) + flax_key_list = tmp + + flax_tensor = pt_tensor - # Correctly rename weight parameters - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, flax_state_dict) - if is_lora: if "lora.up" in renamed_pt_key: rank = pt_tensor.shape[1] - flax_key_list = list(flax_key) - flax_key_list.remove("processor") - flax_key_list.remove("unet") - flax_key = tuple(flax_key_list) - + if "processor" in flax_key_list: + flax_key_list.remove("processor") + if "unet" in flax_key_list: + flax_key_list.remove("unet") + flax_key = tuple(flax_key_list) if flax_key in flax_state_dict: if flax_tensor.shape != flax_state_dict[flax_key].shape: @@ -139,7 +153,6 @@ def create_flax_params_from_pytorch_state(pt_state_dict, flax_state_dict, is_lor f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " f"{flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." ) - # also add unexpected weight so that warning is thrown flax_state_dict[flax_key] = jnp.asarray(flax_tensor) @@ -147,11 +160,11 @@ def create_flax_params_from_pytorch_state(pt_state_dict, flax_state_dict, is_lor def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, unet_params): # Step 1: Convert pytorch tensor to numpy - pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + # sometimes we load weights in bf16 and numpy doesn't support it + pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} unet_params = flatten_dict(unfreeze(unet_params)) flax_state_dict, rank = create_flax_params_from_pytorch_state(pt_state_dict, unet_params,is_lora=True) - return freeze(unflatten_dict(flax_state_dict)), rank def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): diff --git a/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index c32872c4a..54411f790 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -30,6 +30,9 @@ FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) +from ...loaders import ( + StableDiffusionLoraLoaderMixin +) from ...utils import deprecate, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from .pipeline_output import FlaxStableDiffusionPipelineOutput @@ -73,7 +76,7 @@ """ -class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): +class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline, StableDiffusionLoraLoaderMixin): r""" Flax-based pipeline for text-to-image generation using Stable Diffusion. diff --git a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index ca410e939..362ccefb9 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -29,6 +29,9 @@ FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) +from ...loaders import ( + StableDiffusionLoraLoaderMixin +) from ..pipeline_flax_utils import FlaxDiffusionPipeline from .pipeline_output import FlaxStableDiffusionXLPipelineOutput @@ -39,7 +42,7 @@ DEBUG = False -class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): +class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline, StableDiffusionLoraLoaderMixin): def __init__( self, From d1614858ba7bcfeab43e5704aef864f507a14458 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 11 Oct 2024 09:26:31 +0000 Subject: [PATCH 04/12] clean up uneeded layers. --- src/maxdiffusion/configs/base_xl.yml | 2 +- src/maxdiffusion/loaders/lora_pipeline.py | 6 +- src/maxdiffusion/max_utils.py | 17 ++++- src/maxdiffusion/models/attention_flax.py | 68 ++----------------- src/maxdiffusion/models/lora.py | 6 ++ src/maxdiffusion/models/resnet_flax.py | 17 +++-- .../models/unet_2d_blocks_flax.py | 32 ++------- .../models/unet_2d_condition_flax.py | 16 +---- 8 files changed, 47 insertions(+), 117 deletions(-) diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index abce1aeec..cb957bf7b 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -188,7 +188,7 @@ profiler_steps: 10 prompt: "A magical castle in the middle of a forest, artistic drawing" negative_prompt: "purple, red" do_classifier_free_guidance: True -guidance_scale: 9 +guidance_scale: 9.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 20 diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index 9b4d132cc..e78ba0ea1 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -93,7 +93,6 @@ def load_lora_weights( @classmethod def _get_lora_layer(cls, module_path, module, rank): - # TODO - here we create either Linear or Conv layers is_conv = any('conv' in str_ for str_ in module_path) if is_conv: lora_module = LoRAConv2DLayer( @@ -102,6 +101,9 @@ def _get_lora_layer(cls, module_path, module, rank): kernel_size=module.kernel_size, strides=module.strides, padding=module.padding, + input_dilation=module.input_dilation, + kernel_dilation=module.kernel_dilation, + feature_group_count=module.feature_group_count, dtype=module.dtype, weights_dtype=module.param_dtype, precision=module.precision, @@ -127,7 +129,6 @@ def make_lora_interceptor( tmp = [] for layer_lora in params_keys: if 'lora' in layer_lora: - print(layer_lora) new_layer_lora = layer_lora[:layer_lora.index('lora')] if new_layer_lora not in tmp: tmp.append(new_layer_lora) @@ -137,7 +138,6 @@ def _intercept(next_fn, args, kwargs, context): if context.method_name == '__call__': module_path = context.module.path if module_path in params_keys: - print(f"module_path: {module_path}") lora_layer = cls._get_lora_layer(module_path, context.module, rank) return lora_layer(h, *args, **kwargs) return h diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 62d8a4eec..326e40c46 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -360,7 +360,15 @@ def get_abstract_state(model, tx, config, mesh, weights_init_fn, training=True): def setup_initial_state( - model, tx, config, mesh, weights_init_fn, model_params=None, checkpoint_manager=None, checkpoint_item=None, training=True + model, + tx, + config, + mesh, + weights_init_fn, + model_params=None, + checkpoint_manager=None, + checkpoint_item=None, + training=True ): """We initialize the model and optimizer state, and optionally load from a checkpoint as necessary. @@ -396,7 +404,6 @@ def setup_initial_state( state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") - init_train_state_partial = functools.partial( init_train_state, model=model, @@ -407,7 +414,11 @@ def setup_initial_state( eval_only=False, ) - state = jax.jit(init_train_state_partial, in_shardings=None, out_shardings=state_mesh_shardings)() + state = jax.jit( + init_train_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings + )() state = unbox_logicallypartioned_trainstate(state) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 7ac1b00dc..bbe6b8824 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -373,8 +373,6 @@ class FlaxAttention(nn.Module): value_axis_names: AxisNames = (BATCH, LENGTH, HEAD) out_axis_names: AxisNames = (BATCH, LENGTH, HEAD) precision: jax.lax.Precision = None - lora_rank: int = 0 - lora_network_alpha: float = None def setup(self): @@ -438,48 +436,6 @@ def setup(self): ) self.dropout_layer = nn.Dropout(rate=self.dropout) - if self.lora_rank > 0: - self.to_q_lora = LoRALinearLayer( - in_features=inner_dim, - out_features=inner_dim, - rank=self.lora_rank, - network_alpha=self.lora_network_alpha, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision - ) - self.to_k_lora = LoRALinearLayer( - in_features=inner_dim, - out_features=inner_dim, - rank=self.lora_rank, - network_alpha=self.lora_network_alpha, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision - ) - self.to_v_lora = LoRALinearLayer( - in_features=inner_dim, - out_features=inner_dim, - rank=self.lora_rank, - network_alpha=self.lora_network_alpha, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision - ) - self.to_out_lora = LoRALinearLayer( - in_features=inner_dim, - out_features=inner_dim, - rank=self.lora_rank, - network_alpha=self.lora_network_alpha, - mesh=self.mesh, - dtype=self.dtype, - weights_dtype=self.weights_dtype, - precision=self.precision - ) - def __call__(self, hidden_states, context=None, deterministic=True, cross_attention_kwargs=None): context = hidden_states if context is None else context query_proj = self.query(hidden_states) @@ -490,15 +446,9 @@ def __call__(self, hidden_states, context=None, deterministic=True, cross_attent key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names) value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names) - if self.lora_rank > 0: - lora_scale = cross_attention_kwargs.get("scale", 0.0) - query_proj = query_proj + self.to_q_lora(hidden_states, lora_scale) - key_proj = key_proj + self.to_k_lora(context, lora_scale) - value_proj = value_proj + self.to_v_lora(context, lora_scale) - hidden_states = self.attention_op.apply_attention(query_proj, key_proj, value_proj) - hidden_states = self.proj_attn(hidden_states) + 0 if self.lora_rank <=0 else self.to_out_lora(hidden_states, lora_scale) + hidden_states = self.proj_attn(hidden_states) hidden_states = nn.with_logical_constraint(hidden_states, (BATCH, LENGTH, HEAD)) return self.dropout_layer(hidden_states, deterministic=deterministic) @@ -550,8 +500,6 @@ class FlaxBasicTransformerBlock(nn.Module): flash_block_sizes: BlockSizes = None mesh: jax.sharding.Mesh = None precision: jax.lax.Precision = None - lora_rank: int = 0 - lora_network_alpha: float = None def setup(self): # self attention (or cross_attention if only_cross_attention is True) @@ -568,9 +516,7 @@ def setup(self): mesh=self.mesh, dtype=self.dtype, weights_dtype=self.weights_dtype, - precision=self.precision, - lora_rank=self.lora_rank, - lora_network_alpha=self.lora_network_alpha + precision=self.precision ) # cross attention self.attn2 = FlaxAttention( @@ -586,9 +532,7 @@ def setup(self): mesh=self.mesh, dtype=self.dtype, weights_dtype=self.weights_dtype, - precision=self.precision, - lora_rank=self.lora_rank, - lora_network_alpha=self.lora_network_alpha + precision=self.precision ) self.ff = FlaxFeedForward( dim=self.dim, dropout=self.dropout, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision @@ -689,8 +633,6 @@ class FlaxTransformer2DModel(nn.Module): norm_num_groups: int = 32 precision: jax.lax.Precision = None hidden_state_axis_names: AxisNames = (BATCH, LENGTH, D_KV) - lora_rank: int = 0 - lora_network_alpha: float = None def setup(self): self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype) @@ -735,9 +677,7 @@ def setup(self): flash_min_seq_length=self.flash_min_seq_length, flash_block_sizes=self.flash_block_sizes, mesh=self.mesh, - precision=self.precision, - lora_rank=self.lora_rank, - lora_network_alpha=self.lora_network_alpha + precision=self.precision ) for _ in range(self.depth) ] diff --git a/src/maxdiffusion/models/lora.py b/src/maxdiffusion/models/lora.py index 249065c59..30c55f3bf 100644 --- a/src/maxdiffusion/models/lora.py +++ b/src/maxdiffusion/models/lora.py @@ -83,6 +83,9 @@ class LoRAConv2DLayer(nn.Module): kernel_size: Union[int, Tuple[int, int]] = (1,1) strides: Union[int, Tuple[int, int]] = (1, 1) padding: Union[int, Tuple[int, int], str] = 0 + input_dilation: int = 1 + kernel_dilation: int = 1 + feature_group_count: int = 1 network_alpha: Optional[float] = None dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 @@ -95,6 +98,9 @@ def __call__(self, h, hidden_states): kernel_size=self.kernel_size, strides=self.strides, padding=self.padding, + input_dilation=self.input_dilation, + kernel_dilation=self.kernel_dilation, + feature_group_count=self.feature_group_count, use_bias=False, dtype=self.dtype, param_dtype=self.weights_dtype, diff --git a/src/maxdiffusion/models/resnet_flax.py b/src/maxdiffusion/models/resnet_flax.py index b9e9ac062..3ab9434e8 100644 --- a/src/maxdiffusion/models/resnet_flax.py +++ b/src/maxdiffusion/models/resnet_flax.py @@ -147,7 +147,13 @@ def setup(self): precision=self.precision, ) - self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision) + self.time_emb_proj = nn.Dense( + out_channels, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision + ) + self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), @@ -161,15 +167,18 @@ def setup(self): precision=self.precision, ) - def __call__(self, hidden_states, temb, deterministic=True): + def __call__(self, hidden_states, temb, deterministic=True, cross_attention_kwargs={}): + lora_scale = cross_attention_kwargs.get("scale", 0.0) residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = nn.with_logical_constraint(hidden_states, ("conv_batch", "height", "keep_2", "out_channels")) - - temb = self.time_emb_proj(nn.swish(temb)) + temb = nn.swish(temb) + temb = self.time_emb_proj(temb) + temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) + hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) diff --git a/src/maxdiffusion/models/unet_2d_blocks_flax.py b/src/maxdiffusion/models/unet_2d_blocks_flax.py index bef9eef2c..ead4091a7 100644 --- a/src/maxdiffusion/models/unet_2d_blocks_flax.py +++ b/src/maxdiffusion/models/unet_2d_blocks_flax.py @@ -54,10 +54,6 @@ class FlaxCrossAttnDownBlock2D(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` - lora_rank (`int`, *optional*, defaults to 0): - The dimension of the LoRA update matrices. - lora_network_alpha(`float`, *optional*, defaults to None) - Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ in_channels: int @@ -79,8 +75,6 @@ class FlaxCrossAttnDownBlock2D(nn.Module): transformer_layers_per_block: int = 1 norm_num_groups: int = 32 precision: jax.lax.Precision = None - lora_rank: int = 0 - lora_network_alpha: float = None def setup(self): resnets = [] @@ -116,9 +110,7 @@ def setup(self): dtype=self.dtype, weights_dtype=self.weights_dtype, norm_num_groups=self.norm_num_groups, - precision=self.precision, - lora_rank=self.lora_rank, - lora_network_alpha=self.lora_network_alpha + precision=self.precision ) attentions.append(attn_block) @@ -128,7 +120,7 @@ def setup(self): if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype, weights_dtype=self.weights_dtype) - def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True, cross_attention_kwargs=None): + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True, cross_attention_kwargs={}): output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -241,10 +233,6 @@ class FlaxCrossAttnUpBlock2D(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` - lora_rank (`int`, *optional*, defaults to 0): - The dimension of the LoRA update matrices. - lora_network_alpha(`float`, *optional*, defaults to None) - Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ in_channels: int @@ -267,8 +255,6 @@ class FlaxCrossAttnUpBlock2D(nn.Module): transformer_layers_per_block: int = 1 norm_num_groups: int = 32 precision: jax.lax.Precision = None - lora_rank: int = 0 - lora_network_alpha: float = None def setup(self): resnets = [] @@ -305,9 +291,7 @@ def setup(self): dtype=self.dtype, weights_dtype=self.weights_dtype, norm_num_groups=self.norm_num_groups, - precision=self.precision, - lora_rank=self.lora_rank, - lora_network_alpha=self.lora_network_alpha + precision=self.precision ) attentions.append(attn_block) @@ -431,10 +415,6 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): jax mesh is required if attention is set to flash. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` - lora_rank (`int`, *optional*, defaults to 0): - The dimension of the LoRA update matrices. - lora_network_alpha(`float`, *optional*, defaults to None) - Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ in_channels: int @@ -453,8 +433,6 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): transformer_layers_per_block: int = 1 norm_num_groups: int = 32 precision: jax.lax.Precision = None - lora_rank: int = 0 - lora_network_alpha: float = None def setup(self): # there is always at least one resnet @@ -488,9 +466,7 @@ def setup(self): dtype=self.dtype, weights_dtype=self.weights_dtype, norm_num_groups=self.norm_num_groups, - precision=self.precision, - lora_rank=self.lora_rank, - lora_network_alpha=self.lora_network_alpha + precision=self.precision ) attentions.append(attn_block) diff --git a/src/maxdiffusion/models/unet_2d_condition_flax.py b/src/maxdiffusion/models/unet_2d_condition_flax.py index 41aa2b4ad..a52696a06 100644 --- a/src/maxdiffusion/models/unet_2d_condition_flax.py +++ b/src/maxdiffusion/models/unet_2d_condition_flax.py @@ -105,10 +105,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): Overrides default block sizes for flash attention. mesh (`jax.sharding.mesh`, *optional*, defaults to `None`): jax mesh is required if attention is set to flash. - lora_rank (`int`, *optional*, defaults to 0): - The dimension of the LoRA update matrices. - lora_network_alpha(`float`, *optional*, defaults to None) - Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ sample_size: int = 32 @@ -146,8 +142,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): projection_class_embeddings_input_dim: Optional[int] = None norm_num_groups: int = 32 precision: jax.lax.Precision = None - lora_rank: Optional[int] = 0 - lora_network_alpha: Optional[float] = None def init_weights(self, rng: jax.Array, eval_only: bool = False) -> FrozenDict: # init input tensors @@ -285,9 +279,7 @@ def setup(self): mesh=self.mesh, dtype=self.dtype, weights_dtype=self.weights_dtype, - precision=self.precision, - lora_rank=self.lora_rank, - lora_network_alpha=self.lora_network_alpha + precision=self.precision ) else: down_block = FlaxDownBlock2D( @@ -320,8 +312,6 @@ def setup(self): dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision, - lora_rank=self.lora_rank, - lora_network_alpha=self.lora_network_alpha ) # up @@ -358,9 +348,7 @@ def setup(self): mesh=self.mesh, dtype=self.dtype, weights_dtype=self.weights_dtype, - precision=self.precision, - lora_rank=self.lora_rank, - lora_network_alpha=self.lora_network_alpha + precision=self.precision ) else: up_block = FlaxUpBlock2D( From 14ed2ee7df952661ca745281e320ed13539f997b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 15 Oct 2024 17:27:52 +0000 Subject: [PATCH 05/12] lora wip - todo - conv layer issue. --- src/maxdiffusion/loaders/lora_pipeline.py | 9 ++++- src/maxdiffusion/models/lora.py | 49 ++++++++++++++++++++++- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index e78ba0ea1..7daa2d192 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -16,7 +16,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import unfreeze from .lora_base import LoRABaseMixin -from ..models.lora import LoRALinearLayer, LoRAConv2DLayer +from ..models.lora import LoRALinearLayer, LoRAConv2DLayer, BaseLoRALayer from .lora_conversion_utils import ( _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, @@ -134,10 +134,17 @@ def make_lora_interceptor( tmp.append(new_layer_lora) params_keys = tmp def _intercept(next_fn, args, kwargs, context): + mod = context.module + while mod is not None: + if isinstance(mod, BaseLoRALayer): + return next_fn(*args, **kwargs) + mod = mod.parent h = next_fn(*args, **kwargs) if context.method_name == '__call__': module_path = context.module.path if module_path in params_keys: + print(h) + print(module_path) lora_layer = cls._get_lora_layer(module_path, context.module, rank) return lora_layer(h, *args, **kwargs) return h diff --git a/src/maxdiffusion/models/lora.py b/src/maxdiffusion/models/lora.py index 30c55f3bf..e5fa78f3f 100644 --- a/src/maxdiffusion/models/lora.py +++ b/src/maxdiffusion/models/lora.py @@ -74,7 +74,54 @@ def __call__(self, h, hidden_states): return h + up_hidden_states -class LoRAConv2DLayer(nn.Module): +# class LoRAConv2DLayer(nn.Module): +# """ +# Implements LoRA Conv layer +# """ +# out_features: int +# rank: int = 4 +# kernel_size: Union[int, Tuple[int, int]] = (1,1) +# strides: Union[int, Tuple[int, int]] = (1, 1) +# padding: Union[int, Tuple[int, int], str] = 0 +# input_dilation: int = 1 +# kernel_dilation: int = 1 +# feature_group_count: int = 1 +# network_alpha: Optional[float] = None +# dtype: jnp.dtype = jnp.float32 +# weights_dtype: jnp.dtype = jnp.float32 +# precision: jax.lax.Precision = None + +# @nn.compact +# def __call__(self, h, hidden_states): +# breakpoint() +# print("out_features: ", self.out_features) +# print("h.shape: ", h.shape) +# print("hidden_states.shape: ", hidden_states.shape) +# lora_a = self.param('down', nn.initializers.kaiming_uniform(), (hidden_states.shape[-1], self.rank)) +# lora_b = self.param('up', jax.nn.initializers.zeros, (self.rank, self.out_features)) + +# # Compute LoRA contribution +# lora_out = hidden_states @ lora_a @ lora_b +# if self.network_alpha: +# lora_out = lora_out * (self.lora / self.rank) +# print("lora_out: ", lora_out) +# return h + jax.lax.conv_general_dilated( +# lora_out, +# lora_out, +# window_strides=self.strides, +# padding=self.padding, +# dimension_numbers=('NHWC', 'HWIO', 'NHWC'), +# ) + # return h + nn.Conv( + # features=self.out_features, + # kernel_size=self.kernel_size, + # strides=self.strides, + # padding=self.padding, + # dtype=self.dtype, + # param_dtype=self.weights_dtype, + # )(lora_out) + +class LoRAConv2DLayer(nn.Module, BaseLoRALayer): """ Implements LoRA Conv layer """ From 7db33b46f2888efeb75e399f06ddacbb51ff9c23 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 16 Oct 2024 01:00:10 +0000 Subject: [PATCH 06/12] load lora params --- src/maxdiffusion/loaders/lora_pipeline.py | 1 - .../models/modeling_flax_pytorch_utils.py | 49 +++++++++++-------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index 7daa2d192..6cdffeaaa 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -76,7 +76,6 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 8b8861efe..7ad0d2eb2 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -114,28 +114,35 @@ def create_flax_params_from_pytorch_state(pt_state_dict, flax_state_dict, is_lor for pt_key, pt_tensor in pt_state_dict.items(): renamed_pt_key = rename_key(pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key_list = [*pt_tuple_key] - for rename_from, rename_to in ( - ("to_k_lora", ("to_k", "lora")), - ("to_q_lora", ("to_q", "lora")), - ("to_v_lora", ("to_v", "lora")), - ("to_out_lora", ("to_out_0", "lora")), - ("weight", "kernel") - ): - # for readability - tmp = [] - for s in flax_key_list: - if s == rename_from: - if type(rename_to) is tuple: - for s_in_tuple in rename_to: - tmp.append(s_in_tuple) + #breakpoint() + # conv + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: + pt_tensor = pt_tensor.transpose(2, 3, 1, 0) + pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + flax_key_list = [*pt_tuple_key] + flax_tensor = pt_tensor + else: + flax_key_list = [*pt_tuple_key] + for rename_from, rename_to in ( + ("to_k_lora", ("to_k", "lora")), + ("to_q_lora", ("to_q", "lora")), + ("to_v_lora", ("to_v", "lora")), + ("to_out_lora", ("to_out_0", "lora")), + ("weight", "kernel") + ): + tmp = [] + for s in flax_key_list: + if s == rename_from: + if type(rename_to) is tuple: + for s_in_tuple in rename_to: + tmp.append(s_in_tuple) + else: + tmp.append(rename_to) else: - tmp.append(rename_to) - else: - tmp.append(s) - flax_key_list = tmp + tmp.append(s) + flax_key_list = tmp - flax_tensor = pt_tensor + flax_tensor = pt_tensor.T if is_lora: if "lora.up" in renamed_pt_key: @@ -164,7 +171,7 @@ def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, unet_params): pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} unet_params = flatten_dict(unfreeze(unet_params)) - flax_state_dict, rank = create_flax_params_from_pytorch_state(pt_state_dict, unet_params,is_lora=True) + flax_state_dict, rank = create_flax_params_from_pytorch_state(pt_state_dict, unet_params, is_lora=True) return freeze(unflatten_dict(flax_state_dict)), rank def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): From db76c7cf811df9616ebc52f4702d2a3a190d0f8f Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 16 Oct 2024 23:23:44 +0000 Subject: [PATCH 07/12] attempt to inject lora layers for text encoders. --- src/maxdiffusion/loaders/lora_pipeline.py | 63 +++++++++++++++---- .../models/modeling_flax_pytorch_utils.py | 48 ++++++++------ 2 files changed, 81 insertions(+), 30 deletions(-) diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index 6cdffeaaa..18acf3c54 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -31,6 +31,52 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" +# class StableDiffusionXLLoraLoaderMixin(LoRABaseMixin): +# r""" +# Load LoRA layers into Stable Diffusion XL +# """ + +# _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"] +# unet_name = UNET_NAME +# text_encoder_name = TEXT_ENCODER_NAME + +# def load_lora_weights( +# self, +# pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]], +# adapter_name = None, +# **kwargs,): +# """ +# Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and +# `self.text_encoder`. + +# All kwargs are forwarded to `self.lora_state_dict`. + +# See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is +# loaded. + +# See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is +# loaded into `self.unet`. + +# See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state +# dict is loaded into `self.text_encoder`. + +# Parameters: +# pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): +# See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. +# kwargs (`dict`, *optional*): +# See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. +# adapter_name (`str`, *optional*): +# Adapter name to be used for referencing the loaded adapter model. If not specified, it will use +# `default_{i}` where i is the total number of adapters being loaded. +# """ + +# # if a dict is passed, copy it instead of modifying it inplace +# if isinstance(pretrained_model_name_or_path_or_dict, dict): +# pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + +# self. + + class StableDiffusionLoraLoaderMixin(LoRABaseMixin): r""" Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and @@ -80,15 +126,14 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - unet_lora_params, rank = self.load_lora_into_unet( + params, rank = self.load_lora_into_unet( state_dict, network_alphas=network_alphas, - unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, params=params, adapter_name=adapter_name, _pipeline=self, ) - return unfreeze(unet_lora_params), rank + return params, rank @classmethod def _get_lora_layer(cls, module_path, module, rank): @@ -142,8 +187,6 @@ def _intercept(next_fn, args, kwargs, context): if context.method_name == '__call__': module_path = context.module.path if module_path in params_keys: - print(h) - print(module_path) lora_layer = cls._get_lora_layer(module_path, context.module, rank) return lora_layer(h, *args, **kwargs) return h @@ -265,7 +308,6 @@ def load_lora_into_unet( cls, state_dict, network_alphas, - unet, params, adapter_name=None, _pipeline=None @@ -288,9 +330,6 @@ def load_lora_into_unet( Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ - keys = list(state_dict.keys()) - only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) - if not only_text_encoder: - # Load the layers corresponding to Unet. - unet_lora_params, rank = convert_lora_pytorch_state_dict_to_flax(state_dict, params["unet"]) - return unet_lora_params, rank \ No newline at end of file + # Load the layers corresponding to Unet. + params, rank = convert_lora_pytorch_state_dict_to_flax(state_dict, params) + return params, rank \ No newline at end of file diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 7ad0d2eb2..f3756972b 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -108,13 +108,18 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic return pt_tuple_key, pt_tensor -def create_flax_params_from_pytorch_state(pt_state_dict, flax_state_dict, is_lora=False): +def create_flax_params_from_pytorch_state( + pt_state_dict, + unet_state_dict, + text_encoder_state_dict, + text_encoder_2_state_dict, + is_lora=False + ): rank = None # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): renamed_pt_key = rename_key(pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) - #breakpoint() # conv if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: pt_tensor = pt_tensor.transpose(2, 3, 1, 0) @@ -147,32 +152,39 @@ def create_flax_params_from_pytorch_state(pt_state_dict, flax_state_dict, is_lor if is_lora: if "lora.up" in renamed_pt_key: rank = pt_tensor.shape[1] - if "processor" in flax_key_list: flax_key_list.remove("processor") if "unet" in flax_key_list: flax_key_list.remove("unet") - flax_key = tuple(flax_key_list) - - if flax_key in flax_state_dict: - if flax_tensor.shape != flax_state_dict[flax_key].shape: - raise ValueError( - f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " - f"{flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." - ) - # also add unexpected weight so that warning is thrown - flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + unet_state_dict[tuple(flax_key_list)] = jnp.asarray(flax_tensor) + if "text_encoder" in flax_key_list: + flax_key_list.remove("text_encoder") + text_encoder_state_dict[tuple(flax_key_list)] = jnp.asarray(flax_tensor) + if "text_encoder_2" in flax_key_list: + flax_key_list.remove("text_encoder_2") + text_encoder_2_state_dict[tuple(flax_key_list)] = jnp.asarray(flax_tensor) - return flax_state_dict, rank + return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank -def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, unet_params): +def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params): # Step 1: Convert pytorch tensor to numpy # sometimes we load weights in bf16 and numpy doesn't support it pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} - unet_params = flatten_dict(unfreeze(unet_params)) - flax_state_dict, rank = create_flax_params_from_pytorch_state(pt_state_dict, unet_params, is_lora=True) - return freeze(unflatten_dict(flax_state_dict)), rank + unet_params = flatten_dict(unfreeze(params["unet"])) + text_encoder_params = flatten_dict(unfreeze(params["text_encoder"])) + if "text_encoder_2" in params.keys(): + text_encoder_2_params = flatten_dict(unfreeze(params["text_encoder_2"])) + else: + text_encoder_2_params = None + unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank = create_flax_params_from_pytorch_state( + pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, is_lora=True) + params["unet"] = unflatten_dict(unet_state_dict) + params["text_encoder"] = unflatten_dict(text_encoder_state_dict) + if text_encoder_2_state_dict is not None: + params["text_encoder_2"] = unflatten_dict(text_encoder_2_state_dict) + + return params, rank def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): # Step 1: Convert pytorch tensor to numpy From df117146ad009ea3fe95b3a9f13decc7ea420916 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 21 Oct 2024 19:10:54 +0000 Subject: [PATCH 08/12] load network_alphas from lora ckpt. --- src/maxdiffusion/loaders/lora_pipeline.py | 84 ++++++------------- src/maxdiffusion/models/attention_flax.py | 1 - .../models/modeling_flax_pytorch_utils.py | 52 ++++++++++-- 3 files changed, 69 insertions(+), 68 deletions(-) diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index 18acf3c54..567be491a 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -31,52 +31,6 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" -# class StableDiffusionXLLoraLoaderMixin(LoRABaseMixin): -# r""" -# Load LoRA layers into Stable Diffusion XL -# """ - -# _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"] -# unet_name = UNET_NAME -# text_encoder_name = TEXT_ENCODER_NAME - -# def load_lora_weights( -# self, -# pretrained_model_name_or_path_or_dict: Union[str, Dict[str, jnp.ndarray]], -# adapter_name = None, -# **kwargs,): -# """ -# Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and -# `self.text_encoder`. - -# All kwargs are forwarded to `self.lora_state_dict`. - -# See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is -# loaded. - -# See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is -# loaded into `self.unet`. - -# See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state -# dict is loaded into `self.text_encoder`. - -# Parameters: -# pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): -# See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. -# kwargs (`dict`, *optional*): -# See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. -# adapter_name (`str`, *optional*): -# Adapter name to be used for referencing the loaded adapter model. If not specified, it will use -# `default_{i}` where i is the total number of adapters being loaded. -# """ - -# # if a dict is passed, copy it instead of modifying it inplace -# if isinstance(pretrained_model_name_or_path_or_dict, dict): -# pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - -# self. - - class StableDiffusionLoraLoaderMixin(LoRABaseMixin): r""" Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and @@ -126,22 +80,24 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - params, rank = self.load_lora_into_unet( + params, rank, network_alphas = self.load_lora_into_unet( state_dict, network_alphas=network_alphas, params=params, adapter_name=adapter_name, _pipeline=self, ) - return params, rank + return params, rank, network_alphas @classmethod - def _get_lora_layer(cls, module_path, module, rank): + def _get_lora_layer(cls, module_path, module, rank, network_alphas): is_conv = any('conv' in str_ for str_ in module_path) + network_alpha = network_alphas.get(module_path, None) if is_conv: lora_module = LoRAConv2DLayer( out_features=module.features, rank=rank, + network_alpha=network_alpha, kernel_size=module.kernel_size, strides=module.strides, padding=module.padding, @@ -157,6 +113,7 @@ def _get_lora_layer(cls, module_path, module, rank): lora_module = LoRALinearLayer( out_features=module.features, rank=rank, + network_alpha=network_alpha, dtype=module.dtype, weights_dtype=module.param_dtype, precision=module.precision, @@ -164,19 +121,26 @@ def _get_lora_layer(cls, module_path, module, rank): ) return lora_module + def rename_for_interceptor(params_keys, network_alphas): + new_params_keys = [] + for layer_lora in params_keys: + if 'lora' in layer_lora: + new_layer_lora = layer_lora[:layer_lora.index('lora')] + if new_layer_lora not in new_params_keys: + new_params_keys.append(new_layer_lora) + network_alpha = network_alphas[layer_lora] + del network_alphas[layer_lora] + network_alphas[new_layer_lora] = network_alpha + return new_params_keys, network_alphas + @classmethod def make_lora_interceptor( cls, params_keys, - rank + rank, + network_alphas ): - tmp = [] - for layer_lora in params_keys: - if 'lora' in layer_lora: - new_layer_lora = layer_lora[:layer_lora.index('lora')] - if new_layer_lora not in tmp: - tmp.append(new_layer_lora) - params_keys = tmp + params_keys, network_alphas = cls.rename_for_interceptor(params_keys, network_alphas) def _intercept(next_fn, args, kwargs, context): mod = context.module while mod is not None: @@ -187,7 +151,7 @@ def _intercept(next_fn, args, kwargs, context): if context.method_name == '__call__': module_path = context.module.path if module_path in params_keys: - lora_layer = cls._get_lora_layer(module_path, context.module, rank) + lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas) return lora_layer(h, *args, **kwargs) return h return _intercept @@ -331,5 +295,5 @@ def load_lora_into_unet( `default_{i}` where i is the total number of adapters being loaded. """ # Load the layers corresponding to Unet. - params, rank = convert_lora_pytorch_state_dict_to_flax(state_dict, params) - return params, rank \ No newline at end of file + params, rank, network_alphas = convert_lora_pytorch_state_dict_to_flax(state_dict, params, network_alphas) + return params, rank, network_alphas \ No newline at end of file diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index bbe6b8824..aed95cf6b 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -23,7 +23,6 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from .. import common_types, max_logging -from .lora import LoRALinearLayer Array = common_types.Array Mesh = common_types.Mesh diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index f3756972b..0e237341b 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -108,16 +108,38 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic return pt_tuple_key, pt_tensor +def get_network_alpha_value(pt_key, network_alphas): + network_alpha_value = -1 + network_alpha_key = tuple(pt_key.split(".")) + for item in network_alpha_key: + # alpha names for LoRA follow different convention for qkv values. + # Ex: + # conv layer - unet.down_blocks.0.downsamplers.0.conv.alpha + # to_k_lora - unet.down_blocks.1.attentions.0.transformer_blocks.1.attn1.processor.to_k_lora.down.weight.alpha + if "lora" == item: + network_alpha_key = network_alpha_key[:network_alpha_key.index(item)] + ("alpha",) + break + elif "lora" in item: + network_alpha_key = network_alpha_key + ("alpha",) + break + network_alpha_key = ".".join(network_alpha_key) + if network_alpha_key in network_alphas: + network_alpha_value = network_alphas[network_alpha_key] + return network_alpha_value + def create_flax_params_from_pytorch_state( pt_state_dict, unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, + network_alphas, is_lora=False ): rank = None + renamed_network_alphas = {} # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): + network_alpha_value = get_network_alpha_value(pt_key, network_alphas) renamed_pt_key = rename_key(pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) # conv @@ -150,23 +172,28 @@ def create_flax_params_from_pytorch_state( flax_tensor = pt_tensor.T if is_lora: - if "lora.up" in renamed_pt_key: - rank = pt_tensor.shape[1] + if "lora.up" in renamed_pt_key: + rank = pt_tensor.shape[1] if "processor" in flax_key_list: flax_key_list.remove("processor") if "unet" in flax_key_list: flax_key_list.remove("unet") unet_state_dict[tuple(flax_key_list)] = jnp.asarray(flax_tensor) + if "text_encoder" in flax_key_list: flax_key_list.remove("text_encoder") text_encoder_state_dict[tuple(flax_key_list)] = jnp.asarray(flax_tensor) + if "text_encoder_2" in flax_key_list: flax_key_list.remove("text_encoder_2") text_encoder_2_state_dict[tuple(flax_key_list)] = jnp.asarray(flax_tensor) + + if network_alpha_value >= 0: + renamed_network_alphas[tuple(flax_key_list)] = network_alpha_value - return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank + return unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, renamed_network_alphas -def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params): +def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alphas): # Step 1: Convert pytorch tensor to numpy # sometimes we load weights in bf16 and numpy doesn't support it pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()} @@ -177,14 +204,25 @@ def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params): text_encoder_2_params = flatten_dict(unfreeze(params["text_encoder_2"])) else: text_encoder_2_params = None - unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank = create_flax_params_from_pytorch_state( - pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, is_lora=True) + (unet_state_dict, + text_encoder_state_dict, + text_encoder_2_state_dict, + rank, + network_alphas + ) = create_flax_params_from_pytorch_state( + pt_state_dict, + unet_params, + text_encoder_params, + text_encoder_2_params, + network_alphas, + is_lora=True + ) params["unet"] = unflatten_dict(unet_state_dict) params["text_encoder"] = unflatten_dict(text_encoder_state_dict) if text_encoder_2_state_dict is not None: params["text_encoder_2"] = unflatten_dict(text_encoder_2_state_dict) - return params, rank + return params, rank, network_alphas def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): # Step 1: Convert pytorch tensor to numpy From 1aae85df3c1b037ce354203a0f1a486ad42a8458 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 21 Oct 2024 23:38:18 +0000 Subject: [PATCH 09/12] load lora from config. --- src/maxdiffusion/configs/base_xl.yml | 18 +++++ .../configs/base_xl_lightning.yml | 18 +++++ src/maxdiffusion/loaders/lora_pipeline.py | 15 ++-- src/maxdiffusion/models/lora.py | 70 ++----------------- 4 files changed, 49 insertions(+), 72 deletions(-) diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index af8003a16..188f6e8cc 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -206,6 +206,24 @@ lightning_repo: "" # Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. lightning_ckpt: "" +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + enable_mllog: False #controlnet diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index f8db2f6e5..2275b375d 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -162,4 +162,22 @@ lightning_from_pt: True lightning_repo: "ByteDance/SDXL-Lightning" lightning_ckpt: "sdxl_lightning_4step_unet.safetensors" +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + enable_mllog: False diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index 567be491a..3f22fa43b 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -13,8 +13,8 @@ # limitations under the License. from typing import Union, Dict +import flax import jax.numpy as jnp -from flax.core.frozen_dict import unfreeze from .lora_base import LoRABaseMixin from ..models.lora import LoRALinearLayer, LoRAConv2DLayer, BaseLoRALayer from .lora_conversion_utils import ( @@ -80,7 +80,7 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - params, rank, network_alphas = self.load_lora_into_unet( + params, rank, network_alphas = self.load_lora( state_dict, network_alphas=network_alphas, params=params, @@ -136,11 +136,14 @@ def rename_for_interceptor(params_keys, network_alphas): @classmethod def make_lora_interceptor( cls, - params_keys, + params, rank, network_alphas ): - params_keys, network_alphas = cls.rename_for_interceptor(params_keys, network_alphas) + # Only unet interceptor supported for now. + unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys() + unet_lora_keys, network_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas) + def _intercept(next_fn, args, kwargs, context): mod = context.module while mod is not None: @@ -150,7 +153,7 @@ def _intercept(next_fn, args, kwargs, context): h = next_fn(*args, **kwargs) if context.method_name == '__call__': module_path = context.module.path - if module_path in params_keys: + if module_path in unet_lora_keys: lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas) return lora_layer(h, *args, **kwargs) return h @@ -268,7 +271,7 @@ def lora_state_dict( return state_dict, network_alphas @classmethod - def load_lora_into_unet( + def load_lora( cls, state_dict, network_alphas, diff --git a/src/maxdiffusion/models/lora.py b/src/maxdiffusion/models/lora.py index e5fa78f3f..3c9c131fb 100644 --- a/src/maxdiffusion/models/lora.py +++ b/src/maxdiffusion/models/lora.py @@ -36,7 +36,7 @@ class LoRALinearLayer(nn.Module, BaseLoRALayer): dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 precision: jax.lax.Precision = None - axis_names=('embed', 'heads') #default for qkv and to_out, set it otherwise. + lora_scale: float = 1.0 @nn.compact def __call__(self, h, hidden_states): @@ -47,10 +47,6 @@ def __call__(self, h, hidden_states): features=self.rank, use_bias=False, kernel_init=nn.initializers.kaiming_uniform(), - # kernel_init=nn.with_logical_partitioning( - # nn.initializers.kaiming_uniform(), - # self.axis_names - # ), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -60,10 +56,6 @@ def __call__(self, h, hidden_states): features=self.out_features, use_bias=False, kernel_init=nn.initializers.zeros_init(), - # kernel_init=nn.with_logical_partitioning( - # nn.initializers.zeros_init(), - # self.axis_names - # ), dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision, @@ -72,54 +64,7 @@ def __call__(self, h, hidden_states): if self.network_alpha: up_hidden_states *= self.network_alpha / self.rank - return h + up_hidden_states - -# class LoRAConv2DLayer(nn.Module): -# """ -# Implements LoRA Conv layer -# """ -# out_features: int -# rank: int = 4 -# kernel_size: Union[int, Tuple[int, int]] = (1,1) -# strides: Union[int, Tuple[int, int]] = (1, 1) -# padding: Union[int, Tuple[int, int], str] = 0 -# input_dilation: int = 1 -# kernel_dilation: int = 1 -# feature_group_count: int = 1 -# network_alpha: Optional[float] = None -# dtype: jnp.dtype = jnp.float32 -# weights_dtype: jnp.dtype = jnp.float32 -# precision: jax.lax.Precision = None - -# @nn.compact -# def __call__(self, h, hidden_states): -# breakpoint() -# print("out_features: ", self.out_features) -# print("h.shape: ", h.shape) -# print("hidden_states.shape: ", hidden_states.shape) -# lora_a = self.param('down', nn.initializers.kaiming_uniform(), (hidden_states.shape[-1], self.rank)) -# lora_b = self.param('up', jax.nn.initializers.zeros, (self.rank, self.out_features)) - -# # Compute LoRA contribution -# lora_out = hidden_states @ lora_a @ lora_b -# if self.network_alpha: -# lora_out = lora_out * (self.lora / self.rank) -# print("lora_out: ", lora_out) -# return h + jax.lax.conv_general_dilated( -# lora_out, -# lora_out, -# window_strides=self.strides, -# padding=self.padding, -# dimension_numbers=('NHWC', 'HWIO', 'NHWC'), -# ) - # return h + nn.Conv( - # features=self.out_features, - # kernel_size=self.kernel_size, - # strides=self.strides, - # padding=self.padding, - # dtype=self.dtype, - # param_dtype=self.weights_dtype, - # )(lora_out) + return h + (up_hidden_states * self.lora_scale) class LoRAConv2DLayer(nn.Module, BaseLoRALayer): """ @@ -137,6 +82,7 @@ class LoRAConv2DLayer(nn.Module, BaseLoRALayer): dtype: jnp.dtype = jnp.float32 weights_dtype: jnp.dtype = jnp.float32 precision: jax.lax.Precision = None + lora_scale: float = 1.0 @nn.compact def __call__(self, h, hidden_states): @@ -152,10 +98,6 @@ def __call__(self, h, hidden_states): dtype=self.dtype, param_dtype=self.weights_dtype, kernel_init=nn.initializers.kaiming_uniform(), - # kernel_init=nn.with_logical_partitioning( - # nn.initializers.kaiming_uniform(),, - # ("keep_1", "keep_2", "conv_in", "conv_out") - # ), precision=self.precision, name="down" )(hidden_states) @@ -168,10 +110,6 @@ def __call__(self, h, hidden_states): dtype=self.dtype, param_dtype=self.weights_dtype, kernel_init=nn.initializers.zeros_init(), - # kernel_init=nn.with_logical_partitioning( - # nn.initializers.zeros_init(), - # ("keep_1", "keep_2", "conv_in", "conv_out") - # ), precision=self.precision, name="up" )(down_hidden_states) @@ -179,4 +117,4 @@ def __call__(self, h, hidden_states): if self.network_alpha: up_hidden_states *= self.network_alpha / self.rank - return h + up_hidden_states \ No newline at end of file + return h + (up_hidden_states * self.lora_scale) \ No newline at end of file From cf669c473f3c464f35068c01123d9c39cf96a9f7 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 22 Oct 2024 22:05:04 +0000 Subject: [PATCH 10/12] update generate code to support hypersdxl lora. --- README.md | 14 +++- src/maxdiffusion/configs/base_xl.yml | 7 +- .../configs/base_xl_lightning.yml | 4 +- src/maxdiffusion/generate_sdxl.py | 61 ++++++++++-------- src/maxdiffusion/maxdiffusion_utils.py | 19 +++++- src/maxdiffusion/models/modeling_utils.py | 6 -- .../scheduling_euler_discrete_flax.py | 1 - .../tests/images/test_lightning.png | Bin 1704573 -> 1343016 bytes 8 files changed, 68 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 84be41de4..e4f6a021e 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? - -- **`2024/8/1`**: Orbax is the new default checkpointer for Stable Diffusion 1.X, 2.x. You can still use `pipeline.save_pretrained` after training to save in diffusers format. +- **`2024/10/22`**: LoRA support for Hyper SDXL. +- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format. - **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported. # Overview @@ -32,6 +32,7 @@ MaxDiffusion supports * Stable Diffusion 2.1 (training and inference) * Stable Diffusion XL (training and inference). * Stable Diffusion Lightning (inference). +* Hyper-SD XL LoRA loading (inference). * ControlNet inference (Stable Diffusion 1.4 & SDXL). * Dreambooth training support for Stable Diffusion 1.x,2.x. @@ -43,6 +44,7 @@ MaxDiffusion supports * [Training](#training) * [Dreambooth](#dreambooth) * [Inference](#inference) + * [Hyper-SD XL LoRA](#hyper-sdxl-lora) * [SDXL Lightning](#sdxl-lightning) * [ControlNet](#controlnet) * [Comparison To Alternatives](#comparison-to-alternatives) @@ -129,6 +131,14 @@ To generate images, run the following command: python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run" ``` + ## Hyper SDXL LoRA + + Supports Hyper-SDXL models from [ByteDance](https://huggingface.co/ByteDance/Hyper-SD) + + ```bash + python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=2 do_classifier_free_guidance=False prompt="a photograph of a cat wearing a hat riding a skateboard in a park." per_device_batch_size=1 pretrained_model_name_or_path="Lykon/AAM_XL_AnimeMix" from_pt=True revision=main diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}' + ``` + ## SDXL Lightning Single and Multi host inference is supported with sharding annotations: diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 188f6e8cc..ca8fb8ee9 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -75,11 +75,10 @@ timestep_bias: { # Override parameters from checkpoints's scheduler. diffusion_scheduler_config: { - _class_name: '', - # values are v_prediction or leave empty to use scheduler's default. - prediction_type: '', + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', rescale_zero_terminal_snr: False, - timestep_spacing: '' + timestep_spacing: 'trailing' } # Output directory diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index 2275b375d..9ef228e9a 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -56,7 +56,7 @@ text_encoder_learning_rate: 4.25e-6 diffusion_scheduler_config: { _class_name: 'DDIMScheduler', # values are v_prediction or leave empty to use scheduler's default. - prediction_type: '', + prediction_type: 'epsilon', rescale_zero_terminal_snr: False, timestep_spacing: 'trailing' } @@ -153,7 +153,7 @@ profiler_steps: 5 prompt: "portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal, elegant, sharp focus, soft lighting, vibrant colors" negative_prompt: "purple, red" do_classifier_free_guidance: False -guidance_scale: 2 +guidance_scale: 2.0 guidance_rescale: -1 num_inference_steps: 4 diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index 271d02ed1..ac91973dd 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -23,16 +23,18 @@ import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P +import flax.linen as nn from flax.linen import partitioning as nn_partitioning -from maxdiffusion import ( - FlaxEulerDiscreteScheduler, -) - - from maxdiffusion import pyconfig, max_utils from maxdiffusion.image_processor import VaeImageProcessor -from maxdiffusion.maxdiffusion_utils import (get_add_time_ids, rescale_noise_cfg, load_sdxllightning_unet) +from maxdiffusion.maxdiffusion_utils import ( + get_add_time_ids, + rescale_noise_cfg, + load_sdxllightning_unet, + maybe_load_lora, + create_scheduler +) from maxdiffusion.trainers.sdxl_trainer import (StableDiffusionXLTrainer) @@ -82,8 +84,11 @@ def apply_classifier_free_guidance(noise_pred, guidance_scale): lambda _: noise_pred, operand=None, ) - - latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + latents, scheduler_state = pipeline.scheduler.step( + scheduler_state, + noise_pred, + t, + latents).to_tuple() return latents, scheduler_state, state @@ -217,6 +222,8 @@ def run(config): checkpoint_loader = GenerateSDXL(config) pipeline, params = checkpoint_loader.load_checkpoint() + noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config) + weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng) unboxed_abstract_state, _, _ = max_utils.get_abstract_state( pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False @@ -227,21 +234,25 @@ def run(config): ) if unet_params: params["unet"] = unet_params + + # maybe load lora and create interceptor + params, lora_interceptor = maybe_load_lora(config, pipeline, params) if config.lightning_repo: pipeline, params = load_sdxllightning_unet(config, pipeline, params) - # Don't restore the train state to save memory, just restore params + # Don't restore the full train state, instead, just restore params # and create an inference state. - unet_state, unet_state_shardings = max_utils.setup_initial_state( - model=pipeline.unet, - tx=None, - config=config, - mesh=checkpoint_loader.mesh, - weights_init_fn=weights_init_fn, - model_params=params.get("unet", None), - training=False, - ) + with nn.intercept_methods(lora_interceptor): + unet_state, unet_state_shardings = max_utils.setup_initial_state( + model=pipeline.unet, + tx=None, + config=config, + mesh=checkpoint_loader.mesh, + weights_init_fn=weights_init_fn, + model_params=params.get("unet", None), + training=False, + ) vae_state, vae_state_shardings = checkpoint_loader.create_vae_state( pipeline, params, checkpoint_item_name="vae_state", is_training=False @@ -267,14 +278,6 @@ def run(config): states["text_encoder_state"] = text_encoder_state states["text_encoder_2_state"] = text_encoder_2_state - noise_scheduler, noise_scheduler_state = FlaxEulerDiscreteScheduler.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - subfolder="scheduler", - dtype=jnp.float32, - timestep_spacing="trailing", - ) - pipeline.scheduler = noise_scheduler params["scheduler"] = noise_scheduler_state @@ -293,10 +296,12 @@ def run(config): ) s = time.time() - p_run_inference(states).block_until_ready() + with nn.intercept_methods(lora_interceptor): + p_run_inference(states).block_until_ready() print("compile time: ", (time.time() - s)) s = time.time() - images = p_run_inference(states).block_until_ready() + with nn.intercept_methods(lora_interceptor): + images = p_run_inference(states).block_until_ready() print("inference time: ", (time.time() - s)) images = jax.experimental.multihost_utils.process_allgather(images) numpy_images = np.array(images) diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 7a0337d4e..6813e7c1b 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -28,7 +28,6 @@ from .models.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax - def load_sdxllightning_unet(config, pipeline, params): """Load lightning""" if not config.lightning_from_pt: @@ -38,6 +37,24 @@ def load_sdxllightning_unet(config, pipeline, params): params["unet"] = flax_unet_dict return pipeline, params +def maybe_load_lora(config, pipeline, params): + + def _noop_interceptor(next_fn, args, kwargs, context): + return next_fn(*args, **kwargs) + lora_config = config.lora_config + interceptor = _noop_interceptor + if len(lora_config["lora_model_name_or_path"]) > 0: + # For now only first lora supported. In the future, they will be merged + # before being loaded. + params, rank, network_alphas = pipeline.load_lora_weights( + lora_config["lora_model_name_or_path"][0], + weight_name=lora_config["weight_name"][0], + params=params, + adapter_name=lora_config["adapter_name"][0] + ) + interceptor = pipeline.make_lora_interceptor(params, rank, network_alphas) + + return params, interceptor def vae_apply(images, sample_rng, vae, vae_params): """Apply vae encoder to images.""" diff --git a/src/maxdiffusion/models/modeling_utils.py b/src/maxdiffusion/models/modeling_utils.py index de42a0e71..6846ce417 100644 --- a/src/maxdiffusion/models/modeling_utils.py +++ b/src/maxdiffusion/models/modeling_utils.py @@ -99,12 +99,6 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: first_tuple = next(gen) return first_tuple[1].dtype -# def load_lora_state_dict(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): -# """ -# Load LoRA -# """ - - def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): """ Reads a checkpoint file, returning properly formatted errors if they arise. diff --git a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py index 5b739a209..bc9013cbf 100644 --- a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py @@ -222,7 +222,6 @@ def step( step_index = step_index[0] sigma = state.sigmas[step_index] - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output diff --git a/src/maxdiffusion/tests/images/test_lightning.png b/src/maxdiffusion/tests/images/test_lightning.png index 0f50b358e547fb0a6f8494a4a2c663f93c7a798f..36e844cc573bbb6e3f89cf1dfc77ec19cb4f2f09 100644 GIT binary patch literal 1343016 zcmV)DK*7I>P)8%97jY(qN>|97c9-o&g}pH!_vO&5>3zDy)=!x3uYyV2zOIu-UkLE z<`LO_jWWaC&D2x{FgOf_Z{PeDgxL->0wl}~06+tX@rri=!tCl3|B&=+c>PTSNWuvC zjY$|l5dIG1zBD6YMlwE)PyCO%BLM@azsBQ12x!Lml8L|T&%3%y{z@?Bi+rYy@l8)O zp9Db|VK4$PG6!j9{^m4X^JmxniF_8X|Frp8Fq#1}W<{GeK!Cib`C7*aU8oUY^1T^Mq6))$gd>5QM2OkZ7_9Y?Bl0B~nnQINQ2-a$mCi3>ZqRSI; z`Exx(Qc5C`fGHssa-*K>9Bp(~cSZmx0vn$TC4?Tq5O=I!08ogcqFm|tHyAXiHRM2?#paio(`w2h^N;-QLDU-S44ro0DWA}ipgPlp z#djxawUmi;+9)~Z^0jDWrmIP0PHgen`(JqF1acvLbb``VHdfk+V-Wk*c1=T?nSx1E zCI*m0ZXoY{^_BQrg0D?|&vSzfSM$mhxc<#1eL7Jf>p7uXT)Ga;LpZaMFcZ3Xq%jqk z#8gUO4C>59@)?K_C}KAWg_U{)A(Np7bM7-3)C{R{f~cEKa$}-=QlKbGvl8jKy016I z#~MVzRIMoyV5u{OFxAp~0EAG3_@{Yy7g?c?EEQ=HN9|)Qfu*pe!DRZ(#*kSY2SUvV z>^mUho2_OPSm?gN*Qzj3UG4ajOq`(&P?!OxP$a{efu>rEmi+F_?^@PU8;x)2&{_Uu zV?MzpLryB?izR8KVy+bp$sG3?ZyHPtmgEIOYH^YZY|<{AFiYLvrr%0Dm)yG2T3;ZO z@yen7n&cQmk7XY?@A{H@;+LPrU%Bpcef`O5Gxdd^Ty(>rEETUl(;JmsrjQeY5+Lo9`v4bqF}mlDW5va6xe zf`LM#4(XBzZB`yu5}ea>&87!k@6q&|vqrA^j6cbIkj~e~4aT$K@bPK}zHXL}m%8Z{ zorJ(M1FgXN)k?|LpT-~_<{EN-aH*-+>uA)7O3rlDCY_Hgou@W{@niXZ*xTj0txV8F% zTW!OjJS%N&mW-9!eh?3@ER9(xvgB@Q4AgUcxbp<2>0D9672RTHnOEdGNz^O4n?z5E zjaM}Pt8((opAx0ZWRu^Gf5veZcAs66v<1%f`l32BIRtB_x z!u*kVDP_|j&-@EH3{t@9DI^mlrPPVrC6uc1BcZ<(N#}%s!B3$+CvrA(-ZQ!~CrgxB zpuUpi1hlknCRU3xfjW07WAZZoZ4;Aq z7b2wYl_8hwQ>;2f5K=?DnrX=xUcc%4N&;f;w(%Pcx^urwB#k|T_bs@CH1ZVy8>mP1@|_!}6^ku2iK*v7JJ>U~QO*tEg9fJuTgJeYg~ibV6d zM4Ia(lb0+OePZ-!d96MgyqPyi5fl)8ayNuGFFaYkzXEqvh9rvFrRtBBE%K`G>X@*_ zouNfDmK10C?;sr1hxOf5KsLXLFZWMi^*c-^Q~q^O(|q!>t5?V zn{HLs`D-WjJPMyl!GKBSIqz^{vg_AmeI)v|WP%bNvxxfQD8{3%Z~oaY$X%Fq*!4F^ zly{{jHe5EdPtKWnj6CnjN#le?36dSAM@jb0MZCHHW216g|r`h zI6{gzK98N1^nWx8XoBIyYou`K4sq}z#Gc<}Tro7L(V&j?htf?yjK!AfzN-OZ5D}27 zCJ?nP^={a>7P>@I!p6Ef9FfVP?7Ahb(Xt6$g|OZ)x^jgl@C9^#p&;~SuGm27fPsQ^ za>H4pG&B-2mV7LZtu2kaB*si4upvH*Y(Mi0enDyt6l4i7IG*d|?SJ^jo4va2ufDE@ zpWjz)5U11HYy+lvh&N(6YA;~vdn8n2#b#8)l(5rEknkW=T8vQNF=&vJ4PNu-kZtfW z6Av(faC)6ZSi9h|C#lnrpg|}IHl$`~Q790C)<}ZBCgMCHG&6ComU<{mmA)Db7ln(v z@->cyi_wTd2UG!I5sy$)F#%j!ggnsnD}ykhs%3BsE)E)@=8ju}X1fIzbERSmK$=(* zt~1xX4De*wn1Qngi2`pA{Ob4A;t)Th2yiJ z*DJunl;AqU=~93g$|=o^98l5+gVH3uyk-^@HC_cI9lyH&GiY^2Au%LJ)CkcI{Zw2*pq4L||Z zxmw}xHDIshqU$CQ8eVHSw~YNVb$3!NI6;M~S?nSu3;~+y03W_?g0i?FvDInmh_8vF z2x9CrGPZ3MP4~;3tR=;`6NS0`10dz+Sk(BFnF)Y|2AcDv!jZKrm@@o)oYN>njhIJH zhGha=3Dqk)pMh@7p5jKn7Jt@Shha~81OOcGVeA2rSZIoa1MraMq{l+<#GTD5Dm&OS zSIJT~hdw6(2ifai_#=3|w|W*A)f=5;ocK&@d2W`u3+bYd#Votq*7tUXK$4VGg7 zT2jI%?%~1|sFP1K={BF5*hNz7DV=HS-VDJn498n)5zkhjxX?0lG~?szO$eifg!TAu zGoerHH2ycLRCrZiWW~+hh83Nl-AWWUHFl`(|XqwC~8)5FhbQRggV!yj8@as zha#c?O_4T(o;EX{>jDe+i? zGm(#8Va=~s3M*;KIRpkJMNLUVBX+7)-dl(P_pzFwHqc;7qYF3+OcQ^*x~F>(EO*E4 zaF|3IkOxXw)->?LSu}1zx+KdTe9FC>TjRBu^C3YnLn1S}nky=b{0F=9!&^-B4GD1Ux zGQun|Bge;Pnh>KTiPBe^EVc;sXy%eJY)IeK6gN!SCcXitofM&(I;9b!rmZD{OXH-0 z^z4@)zCbrk1WlbTe3l}f_aYDlxML6G>#J@OL{zz&0MW4Th$Nz5rBDPzj|aKn5fRmhOo4811WGWZjbUk1qMTU8yUwSAGDimpvXn33PA@WZkt<<-hrb@?Xy4;- z`CwBYMZQwia1c?!xP;?A5Xn(EX?T8MtjU0TMTWyqMWY53(MbZQYDKa_B8wp`%@mcc zU6fqj%y}B|vliGgaYZK1iz@>Kzv8agW2W?}2!(`#Me5)fKbW~swxq?bp9T%HlnUfPO=It`BTLos#0NkOPQVo35Jk!#T<(= zJYh(~&e!r=&E0rr;toCr@=R9=gmh`-@nUovE(rR-FLX#1oOx;rS2;^p8@dD#HNkWzH zPEx1_Ge=Xdeluj}%1J_GGpUiP8n_|JW)iA8)YAYHsrzv~F3SOPj5CSW1-Sy**AfzKH1oLS9KrwMZw2N>dAZ9`g+_QRqbh7|pN?dT>LSZYlFP zD{EhGHLBUlDg>J5MF!>skpTn~m<&x38Eyqh-JJx1mkB8LCZcK$+Gu2%+K_~2ED4nb z1Wsxq5J-sqynEm}Qd6<`Vo--)Q{7ol?52?vXQenaRB!4w$KvQwOXV>hW_u}TPU}gt-2u13$^#2v#Aaen>Vm}Gc%;{ZA|WvaEarQTSJ8Y@31jpCGlG%{;5D9nBYn8;Tb-0C zLv>#%#fdas5AHN6Ul39u8P#Bo&6y#GKN8zf)N8`xWwnEjf12fPM?kf zU<=&>vnB*Wl!-X=S);g`2-Y9~Z4?nv+Zg_kSRAKy2^4HLT=E1mj;DK*2GCR373u+n za#+Z;*^Igqme@(Jlbeyg{^4jx+MH%0J{JH`hC&?lkje}2`&($+A$So+cKE-`XXlj! z0f+;?_^v7ZjS9E=4O|b45-=4|H5~6`W-n+sm{9U9<$ER28$6|7itOD%5K+o0<_tkX z`H{ND5TFWm(T0Sqn(!`2$bt|fPa?FV8B93PPT=Gs`}6_j(s?-{cckhbIwGGjc(157 z*PrPYNMeRTec#>eD?bh+s8Ur%L}r=>n#&J{iXYDVpei#>sF;*h!*Xv$fVeSY<|Bs` zO9VU5&x5hzxf!{#`H;LpzEXd?dI))wX%4iJJ(;?1lI_UgQd(|SX+#cP4AyhYb6v_% zx;7jWhsOULzU|tg++!;do~C>ZFc$;m0&=K|TI9sLKBl2yrm47d%pXpQWGYKVu8fhL zsCmOAtFghgRIwq0s>0ne?PD0$V<2f)`Ln1=ny7uKIhc9C*SWVSFKHu9L2I1SOn;Iw%SdYFY0hfWXw*-AIvebJBgBLQt#)H^DXbS61c2h{?I(UXsL0orX zNf^2YQ1>x+%V@FX);WEO`1|k(Kqd>Akm^|w6g8D7?glZO+`)<=VG(k_oa;$2Bn=Fn zWAIcd2n$RJHQ!knl#s+5MY4gwuHIv0pKDkt;gDqRM96v=%F84v&I54D&|XfcOG5W)^R8;ooXq+Xbz+2)Ypu^D$1`?j=K0Xo|1g-c))=%0$qh{b&+O?ZA!(` zj7q|EH$$VlRHt=h_&a@#BoZ!^0b1PnaCfOQ4&jOS0&^AvK571PgGA$QnrH#kp#}nFq3AU^DwxBPhFDIRx(dFsG9?1W{Zd^z z%7tV~27$~h3BNH!4Le-Os}zaH3G1Sa7>d9k(jI>6i$f6O= zY>yNmhGUreR1AqDQW|9Qz4r8@L-PfHlG4>4iLqMjVukaojUZ5hfKzO5+q<*GA8Llh~ z4|7EEq;8Q^Z$tzsQ|C)KysX6FSmDXlFc?3cJ7x9FntKuu>CI%p#aPCkjC_@L5LZlj zk!1nbY<;h4;yK76T?3sJFXWE;U%m??l4?aXOty^e$YqQvGj`2QPltb@b^w%T(FIiu zv2x8=DkODY+yo)4$`lyOofKfP@m2<|!39)qR%FuG4!F_|+RFYRQFK3cCq|MPVNBA` z_a{avBR+JcyKw^mBl3LU8VHDNV;6!JPDcaOBsD0hU4*VA$WBKr)|;4`EAr-zmBP)b z0&9XMn!+Gb7t?rTWal`cHF}N$U=*mT11tBCHvx$5mH%?#VrK33Nj1*bHr8P{W}VGsJFUO7BM=lYF@ER~?IA)h?PLq`313?Rbu3PQvn zSuC4_52Y23XN%~ZRJP0miXJ>>?BepzZP)PP8X9FcHPEO+QJ8r)p162WF}27? zLPJxCV$WSRt1%b=iV4DU842idu|@l}q2_(5HIPiwFM;M6coeG68Q{4sMZ=tha4Cbl z&}PK&ihdFt00S;o+oVa6mzH6wYhFKb{F0jA#o7!MFfmA~dJl^yZEGkOi)e_F;!w%r zI}$?P05zKA)?SMZ^JY`Dh&l&Fz$g^0TBM9Nx*N>n!!4V_`TWF_iwp+jrY3`N-UEZ} z0_nX2+DN~+QFh%3h(Ku4n}1he(waDzTh)od$7lv;gDn)0>m?Y<5lvNhO`(fyW@^wZ z6b}4rFoUWGtf7WYYz#gGb$=Da)Ab2fAD2R9opFQ0wyVnEo7ZX&4WVK1kE6Bi6+7!Vg3l4 z$^qdfia#>^nCfEyLoo>*U*aR=pUw}~2rd=x zg^$M)^`_T#QeZZQ66p|`4HbJhgn#VFP{Oklp_bhe*A3;G51O7`U~FtuJR%}9wd@>V zF3l_}jevB_F#jLV%t;~lEZn1G_Vb>NT_j6N)r&z;9{B2$IP_b%KVtyTX|nYzaXp}E zpkw*`TFK3x1@8d4tT1YDNRcN@J^TBDhYP`uNjp*O;;Ibvz$B8`U&fi7l@ z?0gfCOv9*}rgLaIL|X=ms%>h3iTf`|g1i6eT1;K%G2|y*VnIzULK=a|0yVP24lJJW zYdmzZ`woXM-qkT}*U0<;65n_+T;ZKTkWoenz!RZkr!~odgCoeY5jaXzk@xWe33p%( zR>34Y1{UHljSX)3N%9ZVt7L;5gKCgy1Cc9&>_KNGl2Ry^6ib;JzYR$eYdn-F2L1&A zN{T}(N=#`h?wr=(?G$zc(D>(3MpjxR4F-RCd>|@pky4{H8i}Y;L<_$@5<8UUr2#5z zERN41mR}T19nbNf^AG8DFU>iKK%*#>&L|+@s4LP`x}ZgR23mMbvp7I0VIhkHdw-Pb z*f$FaXKE0W-ISnp$I|dD7ZIv8hqwx;=jteTPnLD1?YZCK2^8}pVv(#e7MkZ5!IX)i z;b>C5nX1;lM3jILaT3X>p@xfuF!l8dfeGgVcQ=;N%~GJTfu*%3{lWs`pM5mzZA$`0 z?3Wu@Ow3~}2P3zvlX7` z8(`4gAQZ7^5Un<35Td|(7Jju?Ois+bC`z*M?F9v>qzaJXou!+jUS$9aMQ3TlIF_BAK~kx<;%g)O(%H(=vF-oVK;t1!$xPXtTr;`2vsGC=cNnEtJjwKS z7K0|rub1VhQ1^nbP%GB4=q`(Ly7DT`vB#{ZJlh15?*=_yw-~os>7r`Nwzvc_O7^HbHKs!59Wg&uvE`nrF zDmJ4@_f-Fd&CFf$3Zs&nx-ZkDvn!#XA>shiKn*0Y)64_JCekeYJ{3HIcTpCD<9M}n zJXl28ysWZ$PX&q^8oQc;Vx|(}M@bDyJh~imRBtD?$>caUwwQiUCG7bqys& zhZvrlp}6SvX_T-~n7620nFJ_IPg5ZZ(u>|~upcJ* zWj|i(qDmDgv%d;6ssXQQbsgfa67}XV5g8s)rVIuGbVdtKNwDp%$s%BHHIvlzN{N*v zmst?SD4~d%o}IkNFHy6Um4lorhu3RwTy$nVT`qxx!DOhqqZM%RKN154oOJ$i9LXW? z;a}G|)g}mu7-4LPd*#TVH!?(gh?sMWDUeap}ED7h4;a)_a%MwS^5}vaZge1dIL|UjKWd(rnzzqdA zu;V02LYBB1$mN?<+_>A77WIWl_%;;L)m-BLEM>RrkMhTWD{-?NZ0>OR4vL0qdn_Y4 zund7%D*SQR0TJro@(5Ti`!p)T&~q#pz+(bzVkz1+I`u(x4%qB#F;PFTt9d52Gsp7<-MG^>YH(2J{*cfizdX>DilC{HG8<#|E!CizE zhjq3~w8b&8QQ&1vy&fwTO0~2Y23OXa*ywDX#zV{~Z5;PieR1Ti&*?7OPkuH4N45z}1( zrU@Y2boXgPqitou{SbbULJ-kBRw5#6n6VCDAf?WI}o-5ME{M*$l zm6)X_E|n4LJ+`{ELQ$_lBVBAVLh>o8+3Wfg%=LCs?C^<3%15O6n7Mfqs%^vl960z*FRF@6j4vuh-M!L70T*D+ADvf|aC8UREuq0Noyk84Ga_|~KCxBuEdJ8*S zghDo9qroSZ8<ljkQ6DDjyM0)VQac`zTQ-fbx0VR?yzWrDAP zD(m<3prX=rn9UXo5?5F@>6%bK3pwtb%ay`qh|&=XLh1M~jQ8`w`_Hs( z{qICp8Q$<^4mZZ@XN+r$-e5M(C!IaM@inKk8Cc#Sz$$%QVsqV1AE^ZfFmsDy#7LXb zGQkGZLgXY%C5DouWI?>4M_Rn7?MWq$gpzV=2Z{7UFvN|^GFQUsp-rALM(ZBQf&^=6 z0Qi!MsxUK>$wruJGi%~<--Oyl`qD(KH|atFI<+_Ly>-~S$SSt3tu-Ex?YQuCI`zIB zPs_T<;n1bGzAnP31q7Pyo0(oVeYx1>Wq-^EL(H zWEhVid_uGmgMfB zJ-|^G7;3ao(OvUQ55gOO4lSBgr8H^H^Vw;vOrn{i;Dz{u3d@olkho|xaR%Tu`kvoD zDV~gx4br)SDTx5g5(o{UUe^Fo+@dV;9L__VzrhF;xD^nmSO^(;5z)nA8cVt4W0=7z zp0aKs5psDVdUb_2Ld>`vNQX8C2aH}0Ok)OVDvPi{*zhh@EiWLA$RZ>GJFR)%4~>-8 zrIU@eHXKBb4d=7mAN$R@-JI6Pr{mqB-#?vBH@Bzz)9H9P+??8>wPg|MY|-`Fn^J6- z%?!5hdU@7;yF71~kC)4bj~_mK`sK%$zy9^(k3T;D@Z-x5KVN?MY5RDwZPUwU=B;u> z(CN(@I%L~72WL%Whc()!qQIu^$Trt@Om|MAB_iuB0@lI6XlOwO1%-))NV9g{SJm08 zQuovZ(oUJawT;|$55bV32ODvqx&_W{jNv^RhR9x_N930aV^^y3|v|B!?E#LpszZ zGiK;l%b@s|*t>o*jyaxgI^_8qE)-P4xs0QB%#;Fm!c>ld$s9-P{Pg(Z67P(Rc_H381Oimd$LAkASG#T zHDo@wK6C$>zS}`45oP08UMtF&Pn5AlBS;=+93*6=4sxdM&321Jq-8Y}%#d9Kl@yBR zhDvO!S0co6MU!$Lf`Fv86h)vk5vsN(UjDlXg|_tG#RR%8(i=Lk66?~Iz8nrb9oyZl z+}|EF1yS`r$8s`S9akKm7H_Pe1+i^5I!OT%^getM&z|s%E-2HR|pO4AF#i zqxOg_M?E1uGa-E4)*Cd=4FHKfv|y~4@3z5$CR2zKu&7W~4wKY9H0q2uf&7%2;b#dV z4^Me?3$YX#-rL84uL8(wkO}=KD->r0DV78oOsWP(=Qv?>lur>16o{VjQHm*_e9&Ab zLaJ+qXi4E>z)EDt^zuNN8pV)2$rT~PG!iXQ^Ob_}A^9dR1gEFWoCu<7UlBn82w`I7 zGh&uKNbo+Hl{vnG=}`#?=6Id3Ch3K;C>-|%Y(fw%JoO}07%9O*&re{S24E^)Kq(@` z+GJfsU~5BfICfb@4#(xNw8M(Sx~{7ndh1K;U6$V6lNyJHH7U=LQsKw!d6-+iX+a6`QK6?wjr#6so)J8+N5qcExT`)yt;aF1zZkwu`7jciW;hW;VJf zOuem0Ml}U=Uhv&fN3W5H$DnnPX)ywTXq{geKEJ@RssB?gmkM#d6Hi4@bMM`(z6W;# z$U`Ku>w>C9!rBzgN73!{7J+oIrSQL2CrJPrwW_9BsDRW%3EN9~hWkc5IUeqbW@gk4 z-33QkK4eCy(N{(($f!>)kGESu!ma@(9N8$+*F!;C1b}6m!$cSi^+*QfjdC?tX>RB? zk6!_#NwX$s#s+Hwn065nti(z?A9;IP?oa*cwmrRF9`A1M9!?LB_cu>(jyLD^_P!s_ z?QkP~Veb%97Kj)LI_M=}t7mDM+J1rUy6@OkFTd#JL>3{qW~M|NNI9UjF()KfF-UcZD@7-30BntC|WxgCC-|n+gcb9N@TW@f4fcr*%^D{sxC+~USA)23nN4n&4+;c7|f=B zuQ@3!fy9!T8dQ2Do}5k@Mm}mxR6=eH068OD))!yDEHv8qv1ZMP;za}g}mqM|j5h0&^<;ryb7YeO;MIdE*uJLN2 zhV_*rN!qF!fHCR`Nz^_iet6|~`SQ!If63*!3XM23)bBQyMiN0JMsRz_%9gXrE(q&q z0^B1e;&4I2jCTcaHToo(VL7l4ez8O(OyyBaHmR0W zbo??!&)m;+ocWhBBOP>yWq>_XMt^ZK&v@|@UnHEU(c)96OJ{EolcjNOSX=Llh*%fv zT^6y^($1^i-1OV?^6=*L_;mmD{^92F{^sdvy}e)Vp89&~hcnuZ$q`;SEFIQ{0EtFN zyX_mzU>kYCez9%iNoy!k;ts2ffAnBeb7|aph3R@-UQgcei_jgN(Pj z`br*QvBEVc7DvUAiu)U>gzSx@x)iEBsqj7 zOc7Yi654y|Ix7J{+RzeeVLDR+ks5~~SD16-Qzjb07!#}TMo)SW^vbS0$0 z{nx$8@woPeg}RC?g2Q5`liaM!X_fP7Js+0iu^ra+cseevCRvYV>zyKY-p1=|J%E*sV+ zmkpO4YS>gag-R4!isGVqs4&93F|({2_5hC2yu+jOnKdsL3NrF4d`s08wAIA@^ln#& zOdR^D$YWNDWSHzJ-T@$Ux{!e-W94rQN^FLdKerz{SvK+W9kOl0z^}{Qi$W{^^e&{`%ANPapY<;{4YOrM2*kHIq zoOWgam}}=wEOcm!>XJ zkv5QUw@!2)FbfsxPnCjMdC%~_(NAhLF))uQ^^ApFTr^4}8X?4?vC)<$$BwmeZMLq~ zJJ;S0s~kISZ`G7-EH}Buvef{?K%{Rx#Z~N(?AJ5Vc+K-eK(tt)bS1q6b4cL7r z#X7J@-yL%_xx>64_h$CI+Q#j2*gl;vKR;bwzWe2;=b!%a;fFu}@gINxumA9efBKgn z{`7A@{`|q7U)mzg!@Zrl?W!h?+)Myb=uR_h; zJ9f0ZYh`j7^3-temyr9P3Lc&tq5#v2hjQ;(sGAZEPE8Z?wkQs^S%4<+F0qpiZYOcg zQ$NR*uU<190+t;=2Co3_@QpXb3)4rn3nJi{F;5P2*33$Pms`S9K1!yg9!b5F5rEG} zT5qNV7%5!|K7HdFBR`1&_(hpvjhR7HpF%itq6gwf^g)NANi46Blt!p`?G!*7x}b{) zO(^c~?_GKqAv)!_im0AMPDi=BUGE?I&3U~!9qw+{(`h*#kB7A{hu)X9wS)A|<{h;p zdQ7J#auKAuSK?`8`FLW_ZZF%mZ~M#3%k%U8>G|c;r|t55`Sj`HR>8LGw(HA_Zkuk~ zzNu}ST?D(b!@6)cZi-E@@79#Nim18MUX5F-;(a$ENKqIY8f4ovQE_`4jC{t?;7k_hA20~qroB!Li8!3^1q ze1ciDj*adNPLVoIuB=!O9KkrMe4@5~BI)Tj37>PVgLh@-q*eu{8o@&ttCNyH*UVH% zM*pB{1r%z-N17>Vo-ay^^Bie$Dg|f|TK8K3B52XT3Z2r3waKCLxU^%J+oQaB>R&yb zzWe&_{nz&oUp?KueYZZm#o@jk&X6Pe(WIe~Y<@qA!Nf54hI1yR#o4ERf~>Ekk~@~& zc9r91`rPhb?B)6J_QT=R&$k~xzWwQ!?|=O1r$7JYhd=!Ir$7Dt`yW33@R#SGpZV$8 zcJk6>7qdp)Y*UCpHje}A<_>aq#KZ=Ma*vf!Ew4oGN;(a_E+z<`^lF>BEd4|Y2GVd3 z4mQUaloWk5HjE?@A453rn-@30DDXotUGw-o!@-yIT0ep>U}*^mxWg=33+Ppz1puaB zD%{M@FnK{>oDlDsXP1+gfx!ls3@c@VloL5#zYGbXV!?H~SE$`#Wam^T^k01U#oMFw z(d^?7Xg>^%@x@5Ean5ULCV2vR>a}vsPb;4(7f7L95fMX1Qe!wMWH7Hx49l*E;TOkk z-N+l4TXq_-`CF6kBSA~mMUOSVdh)A33e|&e-~6U-_t|w<3x1U+_9d`!{k_j0hxBA$ z4=jT&19I53E^)QeqTcTM@m2=dkT=Zz$n&oP9i^{$bdf5kZ^*M_#{d(v5K&81HbU;= zM3G7*kgc+9PznGdV(Yn}oIV={4?kZYcU?Jpp%gNVy>%Cz3LWK+o3NP-Bq_egslQp+-EF_37HY_q;5bV;}Gj% zqGpjw9<`P8Ek+$q`{1d<`V&)JM+l3|2@8!DnXO*Ln1`uV9fD;84MxwNbSKgCa0zei zQ|8eMtdO|IY-^5UrxA~b76>l8TqI4iL?9@gH74jdO*7oOLCduR{$$2X71yW92V z{(N(PyY{}G4*jr5@4c^VP1@pKsjyX~!C|2RnXTc4iR!7$wwsylrnc#}>AtJ#5E zw##MPcD+1rpFX{O`uOtv>E+X>=jZ3``DNR-ecN-0BDAN*WLTB6T5fo^{fMRHt{}QPSpydTgW7~u}S_9vF|adkY!fQI08l4s@sh5Z3UyG3L^Gb$Y-;}&FOE=3RWbVNBxCd&o0 zXK%E0+-TseG?PQWIBsJeu97L3C~k_Z!f5dZ;7n)*gz z^wLFfeyAsH3^n2>T;!@i2udjhuk`glDP}_Li*&K$($}RQJ8nB3@0Ry(m*0GQ`}MbP zzWz_|PY-XGhqrRN!*W798-(ouY2iaOC09iwc2ac2lXhliIzX5hGy_YHm3AQex?!X4 zwmoxy*)N}Ld)}Xa-amfWKm79i^Upv3@Y4@}`r(iN{O3RX;pad8>G`i8_m3Mk^oxhQ zJfCQTnr^Bu+pbyVqndqRuAIk85b9_DIYFAYKaEZRMWe%#31MQ{&Bv#pW^p`20*sPm znozbpyxEY)%v%O|7kC=X_TW4*tFsTXGYOpA7%XXKKm@=^Kvew{KJFgH2n7+=g(b+a zR3WK;yub(`j2d~4&eUzAfx4J^G9c=!vz#7w&Ft|0(#wG8hzJ`#TT9OX;D>`H%@%nV z=S*n4bj)lcOt!plcIePi7JOwqm|g#g5Ia6jS1PZbSgLut$+e!#0nI09O%q7j@!2w! zq?_hD!|@Rx0EU%vb)nBFMc0_`uVVESJa6CrcBtzwuG2u_lOg*bzFwi|ym4CK3DTzp z7)DnPaI7MY(MyD37vzehM#+VOLmgiv?{w_NY?#EprLUKXTRAA(a`ZJ4|399LieY(u zJR&r%G!t!MByctuf|4g2G-ooy4FqfPlGoLAW1f9fVQXx_5v>> z+;fSw$lZVCdPcM2w9wyJqZ+T{P>68|LfoV!~gjI{?kAHk01a1!}jTg8``G(WjBRwmy4?SF)`l#5GGrcMVfje zps*02sv3tP!Nb#g#E60H8S3p3!zCU(DptD)?4}6+xrL_$vfFNvXe99wqBZ6aUvQ`r zjXnt&;b*_&s)vaw!`qz{RD$qY7-kF?iQ-e5B@(KZL@1%=ga!8$P2$huoHwwNY+gx> z5+fFs0F`0sep0tUt`hk^f4!F8ujSn&k|T$zviW@}=ye~g2x~0Vl+j_3=Op{5ro&fdaC-noQyL+* zyOtesuSpX|hH;8*|9*7GBxlm_>!UAGHqi*DzZH|P5{mds+NlaHukk5#5yry4M zFJ%U4GAY>b7AH^$4V{8LS@#vh2tp+RuO6L7fznN3nS@mnSrcwdxvrpjSy5_s!$`Zw`0& zINb4ggVxcGhEC}aFCbhpKG?$*NwGk9;SmoZ@0$if)RQtTWNze5nOqq_cqMKpcy}dL(Z5K7X zY`X2bshJn$`56bQ>DY1qL6C@K5?IC&VTfZYv=FLAlO>c&51e2gi=K1Y^VVW1c|&j| zY6mT={0AdA%146!JnS!}00D?szPX*CDG*^)vwinuCs%WBb;Q>+GH1*HhRV97l_V<;9z8ZaAt zG2Pe0E1@vozO*W9b`vxmXg@JSUYk0yNP4_;wMNeC?(B=1X=2(hymHFZJ527I5x6ao{>z}h0CvuQk@`>>W6axXop&u_u$+cO$dQ$?{h}6DaPeAWhuUQd%gu7QTW)TSclY;i@4x+#69C0dGVU-zht2>Ig!H1~Tr#HOz$UU@DQt1M)Ud^ zg9w-AO?0&ri%9Eej4YtTDr@JuaBaLjE|2Hs&BNj0@%;GqbaOu5KHjaz^Ljo>Kg#0H zQ5r?08_-*$bdctK&fUrjBm(cbA5#HgDll^rb|ZD)Y~Qf$x?NN+wk_M`WxHJLV0!MC zm*wSo{rG(P`0?e#hv$zUo!~}>qINY$iC{9em$bk$F1`)@g zj3c*fBA3uhQa2d#CYv^Z@HCkK(g2J=bHCXf(;cx6qba$xOq*p&Oe;iza3Q&Rbp8); z;7n=oBJ#M2ORG>`2fPIWZXiVY+CSw*gRf(d7M zuLL7ih-L!Q-b4hgL7>Yb$7SiJ=ibgIdAODT^qYtO^xfU}zk754{_E4zH|_Mm^^D~P z(hQAU$rdMFw+{9zOu!7k)yzp+&wX%#AZgk;Qc0;Q2eTdKM6zMOm~LnnS?Tnh9I* zSS0xbp2&uHBt+9Q8{`UQ_Bmw8bd1ffmFZ*rh4lm+bJXQ0zoN`X>rLlh|TR1P31 zFP1hlQeu0P$4S=*-Y-0SNbQLARt{esPw)G~ACGU}zWvR+um0bE|9|}L@Bi@6A3wd| zMX|MRsw7Gy&58}k(w?y>}t&Qn*6fKVe>4f3%@8WoYbxSpqX(IP6skEU7`8NTb)9jK&t3T z7JA4xf&dBhT!>0m0+E>u5t+*>WEb&>S#K1I(3Zw!kwyC2*=cK+)7o#2a(}bFd2@Wa zJ3T$!+&!Mp_vhoyt(?x%4$HEzEueU&hKR%=qZCT>EQ9P$Ee<0<@Af@)-3wP``2el8 z9cDDu?PB}I_O0)ieY@zk?JrwzYhPa0lOAvO6Z`BpFUkKo3^$s zU0#Hn>Scp83bi-B$R-;oBAaeItW!j`U13x;fr|9a)cjN?!uMvlZ?Vvj%bH=rVJf5{OIC zdJqsj$sAxe3FmxoUfd#!v~}faZRg(3E5CYo`0e*kzx&T$ef8~|+c!_^-8(tmaXGOa zAPb}kJq47cfPLDe+^et`!z-2u2taxet^`C*!rO666vqNC&q*$n6+GH< z8so5H1uT-L!VOcVSL{X6hB-kX`tcPYag)vrS)Wcz>mrWSiiUC=z|M3ea`*~*bD+4R z`DuJ~L)+qHK!(H2t6Uq*{4~vSoYugLZ(LNHSI<|AP&4Ei1_lE&XV8w{j|hwDY#|;SXwD?e8&1yNzX8S((&;`XO_~=2|-aeo}MbqLGPM892s} zX;$%m5^gnhH)9ue(_PWf7RcWEsU2>X)A4*d-QFz^Z~Nc=zy9^_{_e*gf2ML_V`IbS zOiU$SZ13U$SpgtMjZF-#>WO3;l{86dMX0H;om%O&U>xof0Ss~I41$5xyDP!1^lL%M zN;t+4DpPz2SN3}>ibN54ep7}#;{!IgieKt4kFc<`*$vtit5 zX4zNojB7DFVqU$j1nwkZEp(OCJ>qhv$|B$iBv!&RW=#e;rTi}W=pi31?fskk_4al> zpW5N5?I6pMghALk*`zfQX6YTl=BIWCC=xHAcSS6vKnw;pO2p|+qNJK;X3D;C-(Wko z7i^pCoAj5yZSvCkzV&50EbVey&!_d5(|Y{*<1asb{P^ibKB0?t)0a&J5H&H|+=&V0 z=G`onO)il>X8?AHnzlHho4YAs4c;o(GC(=%cRfoX<(<`LPMaS-kAREz%7{#@9*d#i zIL~AvK`YpF4GtzUorrzZ@q`gEJUgCDkyfGQQNh!#A~J7RPl}BDV)bHe4s$yl#g434 z{;k34h$GU799?pydk-8&JN3BlUqJ>9!{cwsKgIV}iAn|A03{b|2gm`-Q4Y_C!(lm{4>zat?fLfp_;~-vfBL)q-~ROR zr=PW>Z;c!H0+Xdvw~Oj-YLMva<&1xku)#x1*laHb-g;j^|fo=2Hk>%j6 z5}iUc)xbD2%Yb{9ypcZBSGfOkqjw^K()P1PNpyM*DN_Mmi3rpxCT<*0okuw-^_Uhz z`3~0?lCM)VOqe#y%qE~4nMEACuK_=c0LCn>P-=>XAyp(F7Z0i~uK4_V4oZAA-?3yi zEm203W{rO(J!_GCk7fhl^-iV;E3lA~_f^U9s~5+L%wU28YRg%&{-)~tz^ns3pXTEviqO;KuF||N5;BDOA=3RW*2ngxbVF4xbX3=J)U`g zlBb(?b85$f99LaCm!`|AZSDQg5BIm{H}6kxzgr%^>bLK4d_X^2JA)lu!4;&}^ku-T z<@Q-=ncB(N>7tNV{Q3Hi;9*0{sXDiKVVwbDN_H{neZ7sY8*Snkk?qU7v?d{{6 z^UcHg?(O4xzUjxCz8+MhwN+Xt05-N}Y-F4q5aB6$M^s6{pMZ&;YB>Bv!9PJkp&itA z0FcJMz_c%`?YnLlTQ0h7JRJ7ROM7`<`qulU_1>1{a6TMPhoAoXp)D_;KE3n?z6dXd z7Zu<#YK~El>x(Ju?d<26t>1$hWti5_hs`i1BbH8%GcqbY_br@&N0(jsfx`K4rL`X1SGQ z0?-p<4f>Cc=)N^bSQ(P3D4UFi-_BDu?lF8rr8uls@il2&L4|jxTZd{H)xU zyq%TkBNCp()v6j%KVE&&691=Lh2aU~^sEr_!~1EbNE^UhttnuUfuAvk;N$9sV%!BS zL^A-Z6F^34l0R3iGMUVvjs?xKv9CJC1o{spU_hJ(C;?p}QCne?5I0MFsL@O(*E7!Z ziri1{f)buvA(zH3TpG>`ZdX3sw1?aFc+01o_ITUQ2RW|2FVfq#tmsWB+78R%?*8=t zo5R!h{r)T7ys^U*`WZYzjsU>s`Teujt!P>H+m_nI;+LXW*r2rSs))|Mtfpcf7C@wnKIp)C3gTW`KCXlgR|8UA3sHi0s!K zB{@!~q%6Gxp(5U6pdHid*kXL zmsCa9LTb`e5ry~vytAHP9aHf2vgiuWafhg~HZU9exa4wcnTFm#potp*jA#JS8+wxt z>y4du=v)^$E^_F&JuP>K{`TGJt2f6tPv`T)?aiAv%gwE@LI#D{$R9M zF^cW`?g#W*KvGk4s-x!NtQ<PNAF_8^H_e=&7OBDaYYZTe*yo~BE zHD$T~O)1f!F7bUx16LLT^V~Lxl6Moq(kR5bw4>l=wa2^m^_%r~zkB%hx9{%XzBxX< z>$mq@Z_rPa1KL6^Ak%zp#)c%!%n0S@K+Q?nvPWT2^_5&eWlzc2+0r2aIy}O$8GzP^ z4(-NfvPxT*wzg&Ihjl%)(|Nr;|NbBU@y9>^`O}AOYfV0?VG+Y(`@SncKMUR{#2!{C zL&F+pdSdR1MaN86%HHtJ41nfd`z~wPJ{ALa-_R=pAgMuV31b3O6zM1Q$?{7=1h#tz zeFVxDhVPIX?qmyPl;&dMHjRcw?mrq~H_HXKE5c6H5(|^Uwm?oYssY~9FOrlpv_Cs> zocV`UkgMx)z;?56l;$sU2StUPAX(x3NcnT&~eJfJW35$ zH+GuK_a@!yk%8QS@TfRb8Q~|5X}M<-VF1%0d|`s$rs=ByXq>F#no4>@{Ob}pCi9Zn zEK_5Hdpyrp%pKTX+UuwARJ`49=kIr6?>uxaz|DfETY0*Xw-4*ngFM{e?%WTHEen@b zq=~dnYG0OqJoNj=GW-I5=sg47+duTL}@T4iQpyHvl~H_`bEn&GGb`!_gqWZ2#x~_}f3~v*&VkQVF`d zxKmZ4gmU*b9&|UWGpa-xy7M9pTD?uu*-S)LcR1Bh0n?&h%jsk`Q<(z6dlR%B(cbc! z*P1ccBJpkF@sU&PpGQ%!hJ!#|-D3@;{sAD^Q){TF~cq>+Fhw%Y({ z6rgwS!HM3`nC0@sp>YwB#zSjM$0Bm>{r1=%ZkDIVwY2|-XJqo>g;s}#^NOGvq?_$vV_#%hY`?H~ zwpIIz%cp(2te4)}(%Prids{xX-rA4&@bQzhO@J4Ws_dF#*22Z#KxpLhckCv4m=bMQ zHHc8vwTkD*$!V8a8%