Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary concatenation using the zipping approach. #1768

Closed
wujingyue opened this issue Feb 15, 2024 · 1 comment
Closed

Remove unnecessary concatenation using the zipping approach. #1768

wujingyue opened this issue Feb 15, 2024 · 1 comment
Assignees

Comments

@wujingyue
Copy link
Collaborator

wujingyue commented Feb 15, 2024

A spin-off from #1502 (comment). Created for tracking progress.

Problem

Below is a common pattern in nanoGPT's backprop.

dQ, dK, dV = scaled_dot_product_attention_backprop(...)  # bf16[16,12,128,64], bf16[16,12,128,64], bf16[16,12,128,64]
dQ = transpose(dQ, [0, 2, 1, 3])  # [16, 128, 12, 64]
dQ = reshape(dQ, [16, 128, 768])  # [16, 128, 768]
dK = the product of the same transpose and reshape on dK
dV = the product of the same transpose and reshape on dV

concatenated = cat([dQ, dK, dV], axis=-1)

dQKV_sum = sum(concatenated, ...)  # omitting a round trip to float
dQKV_view = reshape(concatenated, [B*S, H*D*3])
dQKV_permute = transpose(dQKV_view, [1, 0])

return dQKV_sum, dQKV_view, dQKV_permute

Because nvFuser doesn't take sdpa_backward and therefore sees three unconnected input tensors (dQ, dK, and dV), it has to materialize dQKV_view and dQKV_permute.

Solution

TL;DR: change Thunder's cudnnex to feed nvFuser a concatenated tensor that contains dQ, dK and dV, so nvFuser realizes that the existing cat is unnecessary and removes it.

  1. cudnnex will convert the SDPA backward op into a cudnn spda_backward kernel (which outputs one dQKV tensor) followed by a split.
  2. cudnnex will give that split to nvFuser, so nvFuser will see the following pattern:
    dQKV = fd.ops.define_tensor([B, S, H, D*3])
    dQ = fd.ops.slice(dQKV, ...)
    dQ = fd.ops.view(dQ, ...)
    dQ = fd.ops.permute(dQ, ...)
    dK = ...the same slice-view-permute pattern
    dV = ...the same slice-view-permute pattern
    
    concatenated = fd.ops.cat([dQ, dK, dV], axis=-1)
    
    dQKV_sum = fd.ops.sum(concatenated, ...)  # omitting a round trip to float
    dQKV_view = fd.ops.view(concatenated, [B*S, H*D*3])
    dQKV_permute = fd.ops.permute(dQKV_view, [1, 0])
    
  3. nvFuser will cancel the slices and the cat and merge all view and permute between them, so the above will become:
    dQKV = fd.ops.define_tensor([B, S, H, D*3])
    concatenated = fd.ops.permute(fd.ops.view(dQKV, ...), ...)
    dQKV_sum = fd.ops.sum(concatenated, ...)  # omitting a round trip to float
    dQKV_view = fd.ops.view(concatenated, [B*S, H*D*3])
    dQKV_permute = fd.ops.permute(dQKV_view, [1, 0])
    
    As a result, dQKV_view and dQKV_permute will become aliases of dQKV. The fusion will boil down to a ReduceSum kernel that sums [B,S,H,D*3] to [H*D*3].
wujingyue added a commit that referenced this issue Feb 15, 2024
wujingyue added a commit that referenced this issue Feb 15, 2024
For #1768.

ghstack-source-id: 61a20eef5efa80ac2726fe5ca1c059e47ba55d58
Pull Request resolved: #1771
wujingyue added a commit that referenced this issue Feb 15, 2024
wujingyue added a commit that referenced this issue Feb 15, 2024
For #1768.

ghstack-source-id: 3e24dbf6e795801bb57ae562ea1133678907dceb
Pull Request resolved: #1771
wujingyue added a commit that referenced this issue Feb 15, 2024
wujingyue added a commit that referenced this issue Feb 15, 2024
wujingyue added a commit that referenced this issue Feb 15, 2024
wujingyue added a commit that referenced this issue Feb 15, 2024
wujingyue added a commit that referenced this issue Feb 15, 2024
wujingyue added a commit that referenced this issue Feb 15, 2024
wujingyue added a commit that referenced this issue Feb 15, 2024
wujingyue added a commit that referenced this issue Feb 15, 2024
wujingyue added a commit that referenced this issue Feb 16, 2024
For #1768.

`ghstack land https://github.com/NVIDIA/Fuser/pull/1771` failed for
reasons that I don't understand. I'm trying to land it again without
`ghstack`. See #1771 for review comments.
wujingyue added a commit that referenced this issue Feb 20, 2024
With this PR, MoveSplitCatPass can cancel the <split,cat> pair with
`permute`s in between and horizontally merge those `permute`s. See code
comments for details.

For #1768.
wujingyue added a commit that referenced this issue Feb 26, 2024
wujingyue added a commit that referenced this issue Feb 26, 2024
wujingyue added a commit that referenced this issue Feb 27, 2024
@wujingyue wujingyue self-assigned this Mar 1, 2024
wujingyue added a commit that referenced this issue Mar 1, 2024
This makes it convenient to use an IdModel as a class member without
having to pass it through many functions.

I examined the NVFUSER_TRACE. FusionKernelRuntime::FusionKernelRuntime
is bottlenecked by
"Finding valid fusion segment solutions" not pre-segmenter passes. I
added the FUSER_PERF_SCOPE for pre-segmenter passes anyway.

For #1768
@wujingyue
Copy link
Collaborator Author

This optimization has been implemented but not turned on by default. The part in nvFuser is turned on unconditionally. The part in Thunder is behind a flag. However, even with this flag on, this optimization won't kick in because Thunder gives cat to a special executor by default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant