Skip to content

Commit

Permalink
Support size_per_head=112 (#660)
Browse files Browse the repository at this point in the history
* fix multi-gpu build

* add support for size_per_head=112 for gpt decoder
  • Loading branch information
dskhudia committed Jun 29, 2023
1 parent eb9b81b commit 7777ff1
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t&
case 96:
mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 112:
mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 128:
mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "decoder_masked_multihead_attention_template.hpp"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>

////////////////////////////////////////////////////////////////////////////////////////////////////

#define MMHA_LAUNCH_KERNEL( \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
DO_CROSS_ATTENTION, \
HAS_BEAMS><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)

////////////////////////////////////////////////////////////////////////////////////////////////////

// !!! Specialize the launcher for Cross attention
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value;
constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
if (params.cache_indir == nullptr) {
if (tlength < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream);
}
else if (tlength < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream);
}
else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream);
}
}
else {
if (tlength < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream);
}
else if (tlength < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream);
}
else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream);
}
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template void mmha_launch_kernel<float, 112, 128, Masked_multihead_attention_params<float>>(
const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
template void mmha_launch_kernel<uint16_t, 112, 128, Masked_multihead_attention_params<uint16_t>>(
const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
template void mmha_launch_kernel<__nv_bfloat16, 112, 128, Masked_multihead_attention_params<__nv_bfloat16>>(
const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream);
#endif
#ifdef ENABLE_FP8
template void mmha_launch_kernel<__nv_fp8_e4m3, 112, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>(
const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream);
#endif

template void mmha_launch_kernel<float, 112, 128, Cross_multihead_attention_params<float>>(
const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
template void mmha_launch_kernel<uint16_t, 112, 128, Cross_multihead_attention_params<uint16_t>>(
const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
template void mmha_launch_kernel<__nv_bfloat16, 112, 128, Cross_multihead_attention_params<__nv_bfloat16>>(
const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream);
#endif
#ifdef ENABLE_FP8
template void mmha_launch_kernel<__nv_fp8_e4m3, 112, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>(
const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream);
#endif

#undef MMHA_LAUNCH_KERNEL
Original file line number Diff line number Diff line change
Expand Up @@ -796,8 +796,8 @@ DecoderCrossAttentionLayer<T>::DecoderCrossAttentionLayer(size_t max_b
q_scaling_(q_scaling)
{
FT_CHECK(size_per_head_ == 32 || size_per_head_ == 48 || size_per_head_ == 64 || size_per_head_ == 80
|| size_per_head_ == 96 || size_per_head_ == 128 || size_per_head_ == 144 || size_per_head_ == 160
|| size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
|| size_per_head_ == 96 || size_per_head_ == 112 || size_per_head_ == 128 || size_per_head_ == 144
|| size_per_head_ == 160 || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
}

template<typename T>
Expand Down Expand Up @@ -1030,4 +1030,4 @@ template class DecoderCrossAttentionLayer<half>;
template class DecoderCrossAttentionLayer<__nv_bfloat16>;
#endif

} // namespace fastertransformer
} // namespace fastertransformer
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ DecoderSelfAttentionLayer<T>::DecoderSelfAttentionLayer(size_t max_bat
int8_mode_(int8_mode)
{
FT_CHECK(size_per_head_ == 32 || size_per_head_ == 48 || size_per_head_ == 64 || size_per_head_ == 80
|| size_per_head_ == 96 || size_per_head_ == 128 || size_per_head_ == 144 || size_per_head_ == 160
|| size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
|| size_per_head_ == 96 || size_per_head_ == 112 || size_per_head_ == 128 || size_per_head_ == 144
|| size_per_head_ == 160 || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
if (int8_mode_ == 1) {
FT_CHECK_WITH_INFO(!(std::is_same<T, float>::value), "Weight only quant not supported for fp32.");
weight_only_int8_fc_runner_ = std::make_shared<CutlassFpAIntBGemmRunner<T, uint8_t>>();
Expand Down
2 changes: 1 addition & 1 deletion src/fastertransformer/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ add_library(mpi_utils STATIC mpi_utils.cc)
set_property(TARGET mpi_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET mpi_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
if (BUILD_MULTI_GPU)
target_link_libraries(mpi_utils PUBLIC -lmpi logger)
target_link_libraries(mpi_utils PUBLIC -lmpi -lmpi_cxx logger)
endif()

add_library(nccl_utils STATIC nccl_utils.cc)
Expand Down
6 changes: 3 additions & 3 deletions tests/decoding/tf_fused_self_multihead_attention_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def test_attn_head_fp16(self):
self.run_attn(4, 128, head, 64, tf.float16)

def test_attn_size_fp32(self):
for size in [32, 64, 80, 96, 128, 144, 160, 192, 224, 256]:
for size in [32, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256]:
tf.reset_default_graph()
self.run_attn(4, 128, 12, size, tf.float32)

def test_attn_size_fp16(self):
for size in [32, 64, 80, 96, 128, 144, 160, 192, 224, 256]:
for size in [32, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256]:
tf.reset_default_graph()
self.run_attn(4, 128, 12, size, tf.float16)

Expand Down Expand Up @@ -171,4 +171,4 @@ def run_attn(self, batch_size, seq_len, head_num, size_per_head, data_type):
assert(v_cache_max_diff < threshold)

if __name__ == "__main__":
unittest.main()
unittest.main()

0 comments on commit 7777ff1

Please sign in to comment.