Skip to content

Commit

Permalink
Make sure we also change the config when setting `encoder_hid_dim_typ…
Browse files Browse the repository at this point in the history
…e=="text_proj"` and allow xformers (huggingface#3615)

* fix if

* make style

* make style

* add tests for xformers

* make style

* update
  • Loading branch information
patrickvonplaten authored and Jimmy committed Apr 26, 2024
1 parent afce020 commit 225e140
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 25 deletions.
7 changes: 2 additions & 5 deletions examples/community/mixture.py
Expand Up @@ -215,11 +215,8 @@ def __call__(
raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
if isinstance(seed_tiles_mode, str):
seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
if any(
mode not in (modes := [mode.value for mode in self.SeedTilesMode])
for row in seed_tiles_mode
for mode in row
):
modes = [mode.value for mode in self.SeedTilesMode]
if any(mode not in modes for row in seed_tiles_mode for mode in row):
raise ValueError(f"Seed tiles mode must be one of {modes}")
if seed_reroll_regions is None:
seed_reroll_regions = []
Expand Down
Binary file added frog.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
97 changes: 89 additions & 8 deletions src/diffusers/models/attention_processor.py
Expand Up @@ -166,22 +166,28 @@ def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
is_lora = hasattr(self, "processor") and isinstance(
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor)
)
is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
)
is_added_kv_processor = hasattr(self, "processor") and isinstance(
self.processor,
(
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
),
)

if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
if is_added_kv_processor and (is_lora or is_custom_diffusion):
raise NotImplementedError(
"Memory efficient attention with `xformers` is currently not supported when"
" `self.added_kv_proj_dim` is defined."
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
)
elif not is_xformers_available():
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
Expand Down Expand Up @@ -233,6 +239,15 @@ def set_use_memory_efficient_attention_xformers(
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
elif is_added_kv_processor:
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
# throw warning
logger.info(
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
Expand Down Expand Up @@ -889,6 +904,71 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
return hidden_states


class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""

def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)

encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)

if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj

hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual

return hidden_states


class XFormersAttnProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Expand Down Expand Up @@ -1428,6 +1508,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
XFormersAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/unet_2d_condition.py
Expand Up @@ -261,6 +261,7 @@ def __init__(

if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")

if encoder_hid_dim is None and encoder_hid_dim_type is not None:
Expand Down
Expand Up @@ -364,6 +364,7 @@ def __init__(

if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")

if encoder_hid_dim is None and encoder_hid_dim_type is not None:
Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/deepfloyd_if/test_if.py
Expand Up @@ -28,6 +28,7 @@
IFSuperResolutionPipeline,
)
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
Expand All @@ -42,8 +43,6 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_dummy_components()

Expand Down Expand Up @@ -81,6 +80,13 @@ def test_inference_batch_single_identical(self):
expected_max_diff=1e-2,
)

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)


@slow
@require_torch_gpu
Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/deepfloyd_if/test_if_img2img.py
Expand Up @@ -20,6 +20,7 @@

from diffusers import IFImg2ImgPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device

from ..pipeline_params import (
Expand All @@ -37,8 +38,6 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_dummy_components()

Expand All @@ -63,6 +62,13 @@ def get_dummy_inputs(self, device, seed=0):
def test_save_load_optional_components(self):
self._test_save_load_optional_components()

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)

@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_save_load_float16(self):
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py
Expand Up @@ -20,6 +20,7 @@

from diffusers import IFImg2ImgSuperResolutionPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device

from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
Expand All @@ -34,8 +35,6 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"original_image"})
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_superresolution_dummy_components()

Expand All @@ -59,6 +58,13 @@ def get_dummy_inputs(self, device, seed=0):

return inputs

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)

def test_save_load_optional_components(self):
self._test_save_load_optional_components()

Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/deepfloyd_if/test_if_inpainting.py
Expand Up @@ -20,6 +20,7 @@

from diffusers import IFInpaintingPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device

from ..pipeline_params import (
Expand All @@ -37,8 +38,6 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin,
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_dummy_components()

Expand All @@ -62,6 +61,13 @@ def get_dummy_inputs(self, device, seed=0):

return inputs

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)

def test_save_load_optional_components(self):
self._test_save_load_optional_components()

Expand Down
Expand Up @@ -20,6 +20,7 @@

from diffusers import IFInpaintingSuperResolutionPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device

from ..pipeline_params import (
Expand All @@ -37,8 +38,6 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS.union({"original_image"})
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_superresolution_dummy_components()

Expand All @@ -64,6 +63,13 @@ def get_dummy_inputs(self, device, seed=0):

return inputs

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)

def test_save_load_optional_components(self):
self._test_save_load_optional_components()

Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/deepfloyd_if/test_if_superresolution.py
Expand Up @@ -20,6 +20,7 @@

from diffusers import IFSuperResolutionPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device

from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
Expand All @@ -34,8 +35,6 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_superresolution_dummy_components()

Expand All @@ -57,6 +56,13 @@ def get_dummy_inputs(self, device, seed=0):

return inputs

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)

def test_save_load_optional_components(self):
self._test_save_load_optional_components()

Expand Down

0 comments on commit 225e140

Please sign in to comment.