@@ -168,6 +168,7 @@ def forward(
168168 deterministic = False ,
169169 num_heads_q = None ,
170170 sm_margin = 0 ,
171+ return_softmax = False ,
171172 ):
172173 if softmax_scale is None :
173174 softmax_scale = qkv .shape [- 1 ] ** (- 0.5 )
@@ -210,8 +211,7 @@ def forward(
210211 ctx .deterministic = deterministic
211212 ctx .ndim = qkv .dim ()
212213 ctx .sm_margin = sm_margin
213- # return out, softmax_lse
214- return out
214+ return (out , softmax_lse ) if return_softmax else out
215215
216216 @staticmethod
217217 def backward (ctx , dout , * args ):
@@ -270,6 +270,7 @@ def forward(
270270 pack_gqa = None ,
271271 deterministic = False ,
272272 sm_margin = 0 ,
273+ return_softmax = False ,
273274 ):
274275 if softmax_scale is None :
275276 softmax_scale = (q .shape [- 1 ] + (qv .shape [- 1 ] if qv is not None else 0 )) ** (- 0.5 )
@@ -305,7 +306,7 @@ def forward(
305306 ctx .softcap = softcap
306307 ctx .deterministic = deterministic
307308 ctx .sm_margin = sm_margin
308- return out
309+ return ( out , softmax_lse ) if return_softmax else out
309310
310311 @staticmethod
311312 def backward (ctx , dout , * args ):
@@ -363,6 +364,7 @@ def forward(
363364 pack_gqa = None ,
364365 deterministic = False ,
365366 sm_margin = 0 ,
367+ return_softmax = False ,
366368 ):
367369 if softmax_scale is None :
368370 softmax_scale = (q .shape [- 1 ] + (qv .shape [- 1 ] if qv is not None else 0 )) ** (- 0.5 )
@@ -404,7 +406,7 @@ def forward(
404406 ctx .softcap = softcap
405407 ctx .deterministic = deterministic
406408 ctx .sm_margin = sm_margin
407- return out
409+ return ( out , softmax_lse ) if return_softmax else out
408410
409411 @staticmethod
410412 def backward (ctx , dout , * args ):
@@ -451,6 +453,7 @@ def flash_attn_qkvpacked_func(
451453 deterministic = False ,
452454 num_heads_q = None ,
453455 sm_margin = 0 ,
456+ return_attn_probs = False ,
454457):
455458 """dropout_p should be set to 0.0 during evaluation
456459 If Q, K, V are already stacked into 1 tensor, this function will be faster than
@@ -497,6 +500,7 @@ def flash_attn_qkvpacked_func(
497500 deterministic ,
498501 num_heads_q ,
499502 sm_margin ,
503+ return_attn_probs ,
500504 )
501505
502506
@@ -515,6 +519,7 @@ def flash_attn_func(
515519 pack_gqa = None ,
516520 deterministic = False ,
517521 sm_margin = 0 ,
522+ return_attn_probs = False ,
518523):
519524 """dropout_p should be set to 0.0 during evaluation
520525 Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@@ -576,6 +581,7 @@ def flash_attn_func(
576581 pack_gqa ,
577582 deterministic ,
578583 sm_margin ,
584+ return_attn_probs ,
579585 )
580586
581587
@@ -600,6 +606,7 @@ def flash_attn_varlen_func(
600606 pack_gqa = None ,
601607 deterministic = False ,
602608 sm_margin = 0 ,
609+ return_attn_probs = False ,
603610):
604611 return FlashAttnVarlenFunc .apply (
605612 q ,
@@ -622,6 +629,7 @@ def flash_attn_varlen_func(
622629 pack_gqa ,
623630 deterministic ,
624631 sm_margin ,
632+ return_attn_probs ,
625633 )
626634
627635
0 commit comments