make sure dout is contiguous #1539
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Loss curve of CP>1 and CP=1 does not match with cuDNN fused attention while batch size > 1, this is due to the non-contiguous dout. This PR fixed it.
cuDNN Fused Attn and Tri Dao' Flash Attn has different requirements on tensor memory format. Tri Dao's Attn only requires the last dim to be contiguous. TE integration of cuDNN Attn specifies qkv_format (bshd, sbhd, thd), tensors need to be contiguous in the specified format, this is a stronger requirement than Tri Dao's Attn. This is why it only shows issue with cuDNN Attn.
Type of change
Checklist: