Skip to content

Commit 203b9b3

Browse files
authored
[FA3] Allow returning LSE via kwarg (#1851)
* lse output * style * style * revert test changes, introduce optional kwarg to output lse
1 parent d0ed097 commit 203b9b3

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

hopper/flash_attn_interface.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def forward(
168168
deterministic=False,
169169
num_heads_q=None,
170170
sm_margin=0,
171+
return_softmax=False,
171172
):
172173
if softmax_scale is None:
173174
softmax_scale = qkv.shape[-1] ** (-0.5)
@@ -210,8 +211,7 @@ def forward(
210211
ctx.deterministic = deterministic
211212
ctx.ndim = qkv.dim()
212213
ctx.sm_margin = sm_margin
213-
# return out, softmax_lse
214-
return out
214+
return (out, softmax_lse) if return_softmax else out
215215

216216
@staticmethod
217217
def backward(ctx, dout, *args):
@@ -270,6 +270,7 @@ def forward(
270270
pack_gqa=None,
271271
deterministic=False,
272272
sm_margin=0,
273+
return_softmax=False,
273274
):
274275
if softmax_scale is None:
275276
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
@@ -305,7 +306,7 @@ def forward(
305306
ctx.softcap = softcap
306307
ctx.deterministic = deterministic
307308
ctx.sm_margin = sm_margin
308-
return out
309+
return (out, softmax_lse) if return_softmax else out
309310

310311
@staticmethod
311312
def backward(ctx, dout, *args):
@@ -363,6 +364,7 @@ def forward(
363364
pack_gqa=None,
364365
deterministic=False,
365366
sm_margin=0,
367+
return_softmax=False,
366368
):
367369
if softmax_scale is None:
368370
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
@@ -404,7 +406,7 @@ def forward(
404406
ctx.softcap = softcap
405407
ctx.deterministic = deterministic
406408
ctx.sm_margin = sm_margin
407-
return out
409+
return (out, softmax_lse) if return_softmax else out
408410

409411
@staticmethod
410412
def backward(ctx, dout, *args):
@@ -451,6 +453,7 @@ def flash_attn_qkvpacked_func(
451453
deterministic=False,
452454
num_heads_q=None,
453455
sm_margin=0,
456+
return_attn_probs=False,
454457
):
455458
"""dropout_p should be set to 0.0 during evaluation
456459
If Q, K, V are already stacked into 1 tensor, this function will be faster than
@@ -497,6 +500,7 @@ def flash_attn_qkvpacked_func(
497500
deterministic,
498501
num_heads_q,
499502
sm_margin,
503+
return_attn_probs,
500504
)
501505

502506

@@ -515,6 +519,7 @@ def flash_attn_func(
515519
pack_gqa=None,
516520
deterministic=False,
517521
sm_margin=0,
522+
return_attn_probs=False,
518523
):
519524
"""dropout_p should be set to 0.0 during evaluation
520525
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@@ -576,6 +581,7 @@ def flash_attn_func(
576581
pack_gqa,
577582
deterministic,
578583
sm_margin,
584+
return_attn_probs,
579585
)
580586

581587

@@ -600,6 +606,7 @@ def flash_attn_varlen_func(
600606
pack_gqa=None,
601607
deterministic=False,
602608
sm_margin=0,
609+
return_attn_probs=False,
603610
):
604611
return FlashAttnVarlenFunc.apply(
605612
q,
@@ -622,6 +629,7 @@ def flash_attn_varlen_func(
622629
pack_gqa,
623630
deterministic,
624631
sm_margin,
632+
return_attn_probs,
625633
)
626634

627635

0 commit comments

Comments
 (0)