Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for FlashAttention (flash-attn) optimization #11902

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions modules/cmd_args.py
Expand Up @@ -53,6 +53,7 @@
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--flash-attn", action='store_true', help="use FlashAttention(-2) for cross attention layers (flash_attn package required)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
Expand Down
86 changes: 86 additions & 0 deletions modules/sd_hijack_optimizations.py
Expand Up @@ -91,6 +91,20 @@ def apply(self):
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward


class SdOptimizationFlashAttn(SdOptimization):
name = "flash_attn"
label = "use flash_attn package"
cmd_opt = "flash_attn"
priority = 70

def is_available(self):
return flash_attn_available

def apply(self):
ldm.modules.attention.CrossAttention.forward = flash_attn_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = flash_attn_attnblock_forward


class SdOptimizationSubQuad(SdOptimization):
name = "sub-quadratic"
cmd_opt = "opt_sub_quad_attention"
Expand Down Expand Up @@ -144,6 +158,7 @@ def list_optimizers(res):
SdOptimizationXformers(),
SdOptimizationSdpNoMem(),
SdOptimizationSdp(),
SdOptimizationFlashAttn(),
SdOptimizationSubQuad(),
SdOptimizationV1(),
SdOptimizationInvokeAI(),
Expand All @@ -158,6 +173,16 @@ def list_optimizers(res):
except Exception:
errors.report("Cannot import xformers", exc_info=True)

flash_attn_available = False
if shared.cmd_opts.flash_attn:
try:
from flash_attn import flash_attn_func
if shared.cmd_opts.no_half or shared.cmd_opts.upcast_sampling:
print("Warning: flash_attn only support fp16 or bf16, has no help on fp32")
flash_attn_available = True
except Exception:
errors.report("Cannot import flash_attn", exc_info=True)


def get_available_vram():
if shared.device.type == 'cuda':
Expand Down Expand Up @@ -542,6 +567,44 @@ def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None
return scaled_dot_product_attention_forward(self, x, context, mask)


# provide compatibility when flash_attn is not support some format
def flash_attn_func_compat(q, k, v):
try:
return flash_attn_func(q, k, v)
except RuntimeError:
q, k, v = (rearrange(t, 'b n h d -> b h n d') for t in (q, k, v))
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
else:
out = torch.softmax((q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))), dim=-1) @ v
out = rearrange(out, 'b h n d -> b n h d')
return out


def flash_attn_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)

context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)

q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in

dtype = q.dtype
if shared.opts.upcast_attn:
q, k, v = q.float(), k.float(), v.float()

out = flash_attn_func_compat(q, k, v)

out = out.to(dtype)

out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)


def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
Expand Down Expand Up @@ -651,6 +714,29 @@ def sdp_no_mem_attnblock_forward(self, x):
return sdp_attnblock_forward(self, x)


def flash_attn_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = (rearrange(t, 'b c h w -> b (h w) 1 c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k, v = q.float(), k.float(), v.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()

out = flash_attn_func_compat(q, k, v)

out = out.to(dtype)
out = rearrange(out, 'b (h w) 1 c -> b c h w', h=h)
out = self.proj_out(out)
return x + out


def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
Expand Down