Skip to content

[Unity][CUTLASS] Fixed stacked attention offload when QKV reshape uses the same shape expression#14728

Merged
masahi merged 2 commits intoapache:unityfrom
masahi:cutlass-stacked-attention-fix
Apr 27, 2023
Merged

[Unity][CUTLASS] Fixed stacked attention offload when QKV reshape uses the same shape expression#14728
masahi merged 2 commits intoapache:unityfrom
masahi:cutlass-stacked-attention-fix

Conversation

@masahi
Copy link
Copy Markdown
Member

@masahi masahi commented Apr 26, 2023

When a single expression, such as R.shape([2, 4096, 8, 40]) is used by all reshape for QKV, the following composite function that only gets one shape parameter is generated. This is currently not handled properly by codegen.

@R.function
def fused_relax_split_relax_reshape_relax_reshape_relax_reshape_relax_nn_attention_cutlass(qkv: R.Tensor((4, 8, 6144), dtype="float32"), param_0: R.Shape([4, 8, 32, 64])) -> R.Tensor((4, 8, 32, 64), dtype="float32"):
    R.func_attr({"Codegen": "cutlass", "global_symbol": "fused_relax_split_relax_reshape_relax_reshape_relax_reshape_relax_nn_attention_cutlass"})
    # from tvm.script import relax as R
    
    @R.function
    def gv_1(qkv_1: R.Tensor((4, 8, 6144), dtype="float32"), param_0_1: R.Shape([4, 8, 32, 64])) -> R.Tensor((4, 8, 32, 64), dtype="float32"):
        R.func_attr({"Composite": "cutlass.stacked_attention", "Primitive": 1})
        with R.dataflow():
            lv: R.Tuple(R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 2048), dtype="float32"), R.Tensor((4, 8, 2048), dtype="float32")) = R.split(qkv_1, indices_or_sections=[2048, 4096], axis=2)
            lv1: R.Tensor((4, 8, 2048), dtype="float32") = lv[0]
            lv2: R.Tensor((4, 8, 32, 64), dtype="float32") = R.reshape(lv1, param_0_1)
            lv3: R.Tensor((4, 8, 2048), dtype="float32") = lv[1]
            lv4: R.Tensor((4, 8, 32, 64), dtype="float32") = R.reshape(lv3, param_0_1)
            lv5: R.Tensor((4, 8, 2048), dtype="float32") = lv[2]
            lv6: R.Tensor((4, 8, 32, 64), dtype="float32") = R.reshape(lv5, param_0_1)
            gv_2: R.Tensor((4, 8, 32, 64), dtype="float32") = R.nn.attention(lv2, lv4, lv6, scale=None)
            R.output(gv_2)
        return gv_2

    gv1: R.Tensor((4, 8, 32, 64), dtype="float32") = gv(qkv, param_0)
    return gv1

Apparently, it is due to EliminateCommonSubexpr() pass that I'm using which turns the original three reshape ops below into ones that share the same R.shape([2, 4096, 8, 40]).

lv_3: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.reshape(lv_2, R.shape([2, 4096, 8, 40]))
lv1_2: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.reshape(lv1_1, R.shape([2, 4096, 8, 40]))
lv2_3: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.reshape(lv2_2, R.shape([2, 4096, 8, 40]))
lv_4: R.Tensor((2, 4096, 8, 40), dtype="float32") = cls.fused_relax_nn_attention_cutlass(lv_3, lv1_2, lv2_3)

@cyx-6

@tvm-bot
Copy link
Copy Markdown
Collaborator

tvm-bot commented Apr 26, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

Copy link
Copy Markdown
Contributor

@cyx-6 cyx-6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reporting and fixing!

@masahi masahi merged commit 94e3d51 into apache:unity Apr 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants