From ae99ce8d9c9bc27edf60a98422cf64d139d3cf9a Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Tue, 18 Apr 2023 16:53:04 -0600 Subject: [PATCH 1/3] Check for xformers install when using flash attention in diffusion unet --- generative/networks/nets/diffusion_model_unet.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index aafde6b4..e0361c80 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1702,6 +1702,9 @@ def __init__( "`num_channels`." ) + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + if use_flash_attention is True and not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." From 8a8d0ace15d445997217f4b873f041eba5ac96f7 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 19 Apr 2023 16:28:37 -0600 Subject: [PATCH 2/3] Ensure tensors are contiguous --- generative/networks/nets/diffusion_model_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index e0361c80..85cf5c26 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -334,9 +334,9 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch x = block(x, context=context) if self.spatial_dims == 2: - x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2) + x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() if self.spatial_dims == 3: - x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3) + x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() x = self.proj_out(x) return x + residual From 5bb28f7ab6ce164bd6b2fefa68d72ba35ae4b785 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Wed, 19 Apr 2023 16:29:19 -0600 Subject: [PATCH 3/3] Formatting fix --- generative/networks/nets/diffusion_model_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 85cf5c26..216a6cc3 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1704,7 +1704,7 @@ def __init__( if use_flash_attention and not has_xformers: raise ValueError("use_flash_attention is True but xformers is not installed.") - + if use_flash_attention is True and not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."