diff --git a/csrc/flash_attn/src/fmha_bwd_hdim128.cu b/csrc/flash_attn/src/fmha_bwd_hdim128.cu index d171b3c2d..138dcaafa 100644 --- a/csrc/flash_attn/src/fmha_bwd_hdim128.cu +++ b/csrc/flash_attn/src/fmha_bwd_hdim128.cu @@ -5,7 +5,7 @@ #include "fmha_bwd_launch_template.h" void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - FP16_SWITCH(params.is_bf16, ({ + FP16_SWITCH(params.is_bf16, ([&] { using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; run_fmha_bwd_loop(params, stream, configure); })); diff --git a/csrc/flash_attn/src/fmha_bwd_hdim32.cu b/csrc/flash_attn/src/fmha_bwd_hdim32.cu index 06c6e4846..a09ebac2b 100644 --- a/csrc/flash_attn/src/fmha_bwd_hdim32.cu +++ b/csrc/flash_attn/src/fmha_bwd_hdim32.cu @@ -5,7 +5,7 @@ #include "fmha_bwd_launch_template.h" void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - FP16_SWITCH(params.is_bf16, ({ + FP16_SWITCH(params.is_bf16, ([&] { if (params.seqlen_k == 128) { using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; run_fmha_bwd_loop(params, stream, configure); diff --git a/csrc/flash_attn/src/fmha_bwd_hdim64.cu b/csrc/flash_attn/src/fmha_bwd_hdim64.cu index 7dd8650bf..918b98cab 100644 --- a/csrc/flash_attn/src/fmha_bwd_hdim64.cu +++ b/csrc/flash_attn/src/fmha_bwd_hdim64.cu @@ -5,7 +5,7 @@ #include "fmha_bwd_launch_template.h" void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - FP16_SWITCH(params.is_bf16, ({ + FP16_SWITCH(params.is_bf16, ([&] { auto dprops = at::cuda::getCurrentDeviceProperties(); if (params.seqlen_k == 128) { using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; diff --git a/csrc/flash_attn/src/fmha_bwd_launch_template.h b/csrc/flash_attn/src/fmha_bwd_launch_template.h index 1e5be0207..032c4a11d 100644 --- a/csrc/flash_attn/src/fmha_bwd_launch_template.h +++ b/csrc/flash_attn/src/fmha_bwd_launch_template.h @@ -61,7 +61,7 @@ void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const boo bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - BOOL_SWITCH(is_dropout, IsDropoutConst, ({ + BOOL_SWITCH(is_dropout, IsDropoutConst, ([&] { auto kernel = params.is_causal ? &fmha_bwd_dq_dk_dv_loop_kernel : &fmha_bwd_dq_dk_dv_loop_kernel; diff --git a/csrc/flash_attn/src/fmha_fwd_hdim128.cu b/csrc/flash_attn/src/fmha_fwd_hdim128.cu index 8d4477fcb..66532e651 100644 --- a/csrc/flash_attn/src/fmha_fwd_hdim128.cu +++ b/csrc/flash_attn/src/fmha_fwd_hdim128.cu @@ -5,7 +5,7 @@ #include "fmha_fwd_launch_template.h" void run_fmha_fwd_hdim128(Launch_params &launch_params) { - FP16_SWITCH(launch_params.params.is_bf16, ({ + FP16_SWITCH(launch_params.params.is_bf16, ([&] { using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; run_fmha_fwd_loop(launch_params); })); diff --git a/csrc/flash_attn/src/fmha_fwd_hdim32.cu b/csrc/flash_attn/src/fmha_fwd_hdim32.cu index 5fa48eb17..f569ca5f6 100644 --- a/csrc/flash_attn/src/fmha_fwd_hdim32.cu +++ b/csrc/flash_attn/src/fmha_fwd_hdim32.cu @@ -5,7 +5,7 @@ #include "fmha_fwd_launch_template.h" void run_fmha_fwd_hdim32(Launch_params &launch_params) { - FP16_SWITCH(launch_params.params.is_bf16, ({ + FP16_SWITCH(launch_params.params.is_bf16, ([&] { if (launch_params.params.seqlen_k == 128) { using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; run_fmha_fwd_loop(launch_params); diff --git a/csrc/flash_attn/src/fmha_fwd_hdim64.cu b/csrc/flash_attn/src/fmha_fwd_hdim64.cu index 9776c6d17..134efa63b 100644 --- a/csrc/flash_attn/src/fmha_fwd_hdim64.cu +++ b/csrc/flash_attn/src/fmha_fwd_hdim64.cu @@ -5,7 +5,7 @@ #include "fmha_fwd_launch_template.h" void run_fmha_fwd_hdim64(Launch_params &launch_params) { - FP16_SWITCH(launch_params.params.is_bf16, ({ + FP16_SWITCH(launch_params.params.is_bf16, ([&] { if (launch_params.params.seqlen_k == 128) { using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; run_fmha_fwd_loop(launch_params); diff --git a/csrc/flash_attn/src/fmha_fwd_launch_template.h b/csrc/flash_attn/src/fmha_fwd_launch_template.h index 1b013753e..ec1d3df0a 100644 --- a/csrc/flash_attn/src/fmha_fwd_launch_template.h +++ b/csrc/flash_attn/src/fmha_fwd_launch_template.h @@ -56,7 +56,7 @@ void run_fmha_fwd_loop(Launch_params &launch_params) { // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. // https://github.com/kokkos/kokkos-kernels/issues/349 // https://github.com/HazyResearch/flash-attention/issues/21 - BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ({ + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ([&] { auto kernel = launch_params.params.is_causal ? (launch_params.return_softmax ? &fmha_fwd_loop_kernel diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index a77bae6fb..53bcf35d6 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -1,5 +1,6 @@ // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +// and https://github.com/facebookresearch/xformers/blob/main/xformers/csrc/attention/cuda/fmha/gemm_kernel_utils.h#L8 #pragma once @@ -9,27 +10,31 @@ /// /// Usage: /// ``` -/// BOOL_SWITCH(flag, BoolConst, ({ +/// BOOL_SWITCH(flag, BoolConst, ([&] { /// some_function(...); /// })); /// ``` /// We need "({" and "})" to make sure that the code is a single argument being passed to the macro. -#define BOOL_SWITCH(COND, CONST_NAME, CODE) \ - if (COND) { \ - constexpr bool CONST_NAME = true; \ - CODE; \ - } else { \ - constexpr bool CONST_NAME = false; \ - CODE; \ +#define BOOL_SWITCH(COND, CONST_NAME, F) \ + { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + F(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + F(); \ + } \ } // modified from BOOL_SWITCH // because MSVC cannot handle std::conditional with constexpr variable -#define FP16_SWITCH(COND, CODE) \ - if (COND) { \ - using elem_type = __nv_bfloat16; \ - CODE; \ - } else { \ - using elem_type = __half; \ - CODE; \ - } \ +#define FP16_SWITCH(COND, F) \ + { \ + if (COND) { \ + using elem_type = __nv_bfloat16; \ + F(); \ + } else { \ + using elem_type = __half; \ + F(); \ + } \ + }