From b6d3eeabf2ad1f87440f93d5b0b450e7ec86b6d3 Mon Sep 17 00:00:00 2001 From: SiYu Wu Date: Fri, 21 Jul 2023 02:47:16 +0800 Subject: [PATCH] Adding support for FlashAttention(-2) (flash-attn) optimization --- modules/cmd_args.py | 1 + modules/sd_hijack_optimizations.py | 86 ++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index e401f6413a4..5523ad07947 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -53,6 +53,7 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") +parser.add_argument("--flash-attn", action='store_true', help="use FlashAttention(-2) for cross attention layers (flash_attn package required)") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization") diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index b5f85ba519e..732915aa17d 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -91,6 +91,20 @@ def apply(self): sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward +class SdOptimizationFlashAttn(SdOptimization): + name = "flash_attn" + label = "use flash_attn package" + cmd_opt = "flash_attn" + priority = 70 + + def is_available(self): + return flash_attn_available + + def apply(self): + ldm.modules.attention.CrossAttention.forward = flash_attn_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = flash_attn_attnblock_forward + + class SdOptimizationSubQuad(SdOptimization): name = "sub-quadratic" cmd_opt = "opt_sub_quad_attention" @@ -144,6 +158,7 @@ def list_optimizers(res): SdOptimizationXformers(), SdOptimizationSdpNoMem(), SdOptimizationSdp(), + SdOptimizationFlashAttn(), SdOptimizationSubQuad(), SdOptimizationV1(), SdOptimizationInvokeAI(), @@ -158,6 +173,16 @@ def list_optimizers(res): except Exception: errors.report("Cannot import xformers", exc_info=True) +flash_attn_available = False +if shared.cmd_opts.flash_attn: + try: + from flash_attn import flash_attn_func + if shared.cmd_opts.no_half or shared.cmd_opts.upcast_sampling: + print("Warning: flash_attn only support fp16 or bf16, has no help on fp32") + flash_attn_available = True + except Exception: + errors.report("Cannot import flash_attn", exc_info=True) + def get_available_vram(): if shared.device.type == 'cuda': @@ -542,6 +567,44 @@ def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None return scaled_dot_product_attention_forward(self, x, context, mask) +# provide compatibility when flash_attn is not support some format +def flash_attn_func_compat(q, k, v): + try: + return flash_attn_func(q, k, v) + except RuntimeError: + q, k, v = (rearrange(t, 'b n h d -> b h n d') for t in (q, k, v)) + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + out = torch.nn.functional.scaled_dot_product_attention(q, k, v) + else: + out = torch.softmax((q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))), dim=-1) @ v + out = rearrange(out, 'b h n d -> b n h d') + return out + + +def flash_attn_attention_forward(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in)) + del q_in, k_in, v_in + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() + + out = flash_attn_func_compat(q, k, v) + + out = out.to(dtype) + + out = rearrange(out, 'b n h d -> b n (h d)', h=h) + return self.to_out(out) + + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -651,6 +714,29 @@ def sdp_no_mem_attnblock_forward(self, x): return sdp_attnblock_forward(self, x) +def flash_attn_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = (rearrange(t, 'b c h w -> b (h w) 1 c') for t in (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + out = flash_attn_func_compat(q, k, v) + + out = out.to(dtype) + out = rearrange(out, 'b (h w) 1 c -> b c h w', h=h) + out = self.proj_out(out) + return x + out + + def sub_quad_attnblock_forward(self, x): h_ = x h_ = self.norm(h_)