-
Notifications
You must be signed in to change notification settings - Fork 23
Current scaling: two-stage amax kernel #369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
c15d93b
51fab36
ae35e4c
77a68a7
c0d8e73
6c3507d
3c9de07
91249cc
be0e0c8
bce34da
8c388cc
6388604
9e6586f
18292bf
a389455
d87ab8a
fd5dead
16d3bf9
50b34aa
ef532b1
f933ef3
7d4054e
63cff98
c7d44a7
f92b926
eba552e
0d6a177
8eda427
6990928
9ee618f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,12 +26,39 @@ using bf16__ = __nv_bfloat16; | |
| using bf16__ = __hip_bfloat16; | ||
| #endif //__HIP_PLATFORM_AMD__ | ||
|
|
||
| constexpr int amax_kernel_threads = 512; | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's guard our rocm specific code changes by macro HIP_PLATFORM_AMD
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in c7d44a7 |
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
|
|
||
| template <int BLOCK_THREADS> | ||
| __global__ void amax_final_reduce(const float* __restrict__ block_amax, | ||
| float* __restrict__ global_amax, | ||
| int num_blocks) { | ||
| float val = 0.f; | ||
|
|
||
| for (int i = threadIdx.x; i < num_blocks; i += BLOCK_THREADS) { | ||
| val = fmaxf(val, block_amax[i]); | ||
| } | ||
|
|
||
| const int warp_id = threadIdx.x / THREADS_PER_WARP; | ||
| const float block_max = | ||
| reduce_max<BLOCK_THREADS / THREADS_PER_WARP>(val, warp_id); | ||
|
|
||
| if (threadIdx.x == 0) { | ||
| *global_amax = block_max; | ||
| } | ||
| } | ||
|
|
||
| #endif | ||
|
|
||
| template <int nvec, bool aligned, typename InputType> | ||
| __launch_bounds__(amax_kernel_threads) __global__ | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard the api change so NV upstream can remain their flow
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in c7d44a7 |
||
| const size_t num_aligned_elements) { | ||
| #else | ||
| void amax_kernel(const InputType *input, float *amax, const size_t N, | ||
| const size_t num_aligned_elements) { | ||
| #endif | ||
| VectorizedLoader<InputType, nvec, aligned> loader(input, N); | ||
| InputType max{0.f}; | ||
| const int warp_id = threadIdx.x / THREADS_PER_WARP; | ||
|
|
@@ -65,12 +92,23 @@ __launch_bounds__(amax_kernel_threads) __global__ | |
| // Reduce amax over block | ||
| max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id); | ||
| if (threadIdx.x == 0) { | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| if (block_amax != nullptr) { | ||
| // 2-stage: write per-block result | ||
| block_amax[blockIdx.x] = max; | ||
| } else { | ||
| // Atomic path: directly update global amax | ||
| atomicMaxFloat(amax, max); | ||
| } | ||
| #else | ||
| atomicMaxFloat(amax, max); | ||
| #endif | ||
| } | ||
| } | ||
|
|
||
| template <int nvec, typename InputType> | ||
| void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { | ||
| void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax, | ||
| size_t block_capacity, cudaStream_t stream) { | ||
| // Zero out amax so we can update with atomic max | ||
| (void)cudaMemsetAsync(amax, 0, sizeof(float), stream); | ||
|
|
||
|
|
@@ -89,24 +127,54 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud | |
| constexpr size_t max_blocks = 65535; | ||
| num_blocks = std::min(num_blocks, max_blocks); | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| if (block_capacity < num_blocks) | ||
| block_amax = nullptr; | ||
| #endif | ||
|
|
||
| // Launch kernel | ||
| switch (align) { | ||
| case Alignment::SAME_ALIGNED: | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| amax_kernel<nvec, true, InputType> | ||
| <<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements); | ||
| #else | ||
| amax_kernel<nvec, true, InputType> | ||
| <<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements); | ||
| #endif | ||
| break; | ||
| case Alignment::SAME_UNALIGNED: | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| amax_kernel<nvec, false, InputType> | ||
| <<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements); | ||
| #else | ||
| amax_kernel<nvec, false, InputType> | ||
| <<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements); | ||
| #endif | ||
| break; | ||
| case Alignment::DIFFERENT: { | ||
| // This case is a logic error, since there is only one pointer (input) | ||
| // in the alignment check. Still safe to process without vectorization. | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, N); | ||
| #else | ||
| amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N); | ||
| #endif | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| if (block_amax != nullptr) { | ||
| constexpr int FINAL_REDUCE_THREADS = 256; | ||
| dim3 fr_block(FINAL_REDUCE_THREADS); | ||
| dim3 fr_grid(1); | ||
|
|
||
| amax_final_reduce<FINAL_REDUCE_THREADS> | ||
| <<<fr_grid, fr_block, 0, stream>>>(block_amax, amax, static_cast<int>(num_blocks)); | ||
| } | ||
| #endif | ||
|
|
||
| // Check results | ||
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||
| } | ||
|
|
@@ -115,6 +183,12 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud | |
| } // namespace transformer_engine | ||
|
|
||
| void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream); | ||
| } | ||
|
|
||
| void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) { | ||
| #endif | ||
| NVTE_API_CALL(nvte_compute_amax); | ||
| using namespace transformer_engine; | ||
|
|
||
|
|
@@ -150,11 +224,31 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt | |
| to_string(output.amax.dtype), ")"); | ||
| CheckOutputTensor(output, "output_compute_amax", true); | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| // Optional workspace | ||
| float* block_amax = nullptr; | ||
| size_t block_capacity = 0; | ||
|
|
||
| if (workspace_ != nullptr) { | ||
| auto &workspace = *reinterpret_cast<Tensor *>(workspace_); | ||
| if (workspace.data.dptr != nullptr) { | ||
| NVTE_CHECK(workspace.data.dtype == DType::kFloat32, | ||
| "Workspace tensor for amax computation must be FP32, got dtype=", | ||
| to_string(workspace.data.dtype)); | ||
| block_amax = reinterpret_cast<float*>(workspace.data.dptr); | ||
| block_capacity = workspace.data.numel(); | ||
| } | ||
| } | ||
| #endif | ||
|
|
||
| // Compute amax | ||
| TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( | ||
| input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); | ||
| launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr), | ||
| reinterpret_cast<float *>(output.amax.dptr), input.data.numel(), | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| block_amax, block_capacity, | ||
| #endif | ||
| stream);); // NOLINT(*) | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have a brief doc just like the nvte_compute_amax above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in c7d44a7