From 2989e5cac673172f36338c0e75f5842b415606d7 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Sun, 16 Jun 2024 14:27:37 +0000 Subject: [PATCH 01/33] Add submodule cutlass v3.5.0 Co-authored-by: Qi Zhang Signed-off-by: Jiang Shao --- .gitmodules | 3 +++ 3rdparty/cutlass | 1 + 2 files changed, 4 insertions(+) create mode 160000 3rdparty/cutlass diff --git a/.gitmodules b/.gitmodules index 21492db5ef..4b188d6bb1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend url = https://github.com/NVIDIA/cudnn-frontend.git +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 0000000000..7d49e6c7e2 --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc From 0dc63b200cda21fe20f9927a6b8e739038d03cee Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Mon, 17 Jun 2024 08:52:46 +0000 Subject: [PATCH 02/33] Add permutation functions Signed-off-by: Jiang Shao --- .../include/transformer_engine/permutation.h | 26 + .../common/permutation/permutation.cu | 413 +++++++++++++++ transformer_engine/pytorch/csrc/common.h | 7 + transformer_engine/pytorch/csrc/extensions.h | 25 + .../pytorch/csrc/extensions/permutation.cu | 500 ++++++++++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 5 + 6 files changed, 976 insertions(+) create mode 100644 transformer_engine/common/include/transformer_engine/permutation.h create mode 100644 transformer_engine/common/permutation/permutation.cu create mode 100644 transformer_engine/pytorch/csrc/extensions/permutation.cu diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h new file mode 100644 index 0000000000..74ba6c4036 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -0,0 +1,26 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PERMUTATION_H_ +#define TRANSFORMER_ENGINE_PERMUTATION_H_ + +#include "transformer_engine.h" + +template +void moe_permute_topK_kernel_launcher(const T *input, + T *output, + const int *sorted_row_id, + int *row_id_map, + const float *prob, + const int num_rows, + const int num_topK, + const int num_cols, + const int num_out_tokens, + cudaStream_t stream, + float *prob_grad = nullptr, + const T *input_fwd = nullptr); + +#endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu new file mode 100644 index 0000000000..cac3d225c1 --- /dev/null +++ b/transformer_engine/common/permutation/permutation.cu @@ -0,0 +1,413 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include + +static __global__ void moe_permute_topK_row_map( + const int *sorted_row_id, + int *row_id_map, + const int num_rows, + const int num_topK, + const int num_out_tokens) +{ + // Each block corresponds to one source token + // row_id_map[num_topK][num_rows] + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int idx = bid * blockDim.x + tid; + + if (idx >= num_rows * num_topK) + return; + + int source_row = sorted_row_id[idx]; + int source_token_id = source_row / num_topK; + int source_topK_id = source_row % num_topK; + + if (idx >= num_out_tokens) + { + row_id_map[source_topK_id * num_rows + source_token_id] = -1; + } + else + { + row_id_map[source_topK_id * num_rows + source_token_id] = idx; + } +} + +template +__global__ void moe_recover_topK_kernel(const T *input, + T *unpermuted_output, + const int *row_id_map, + const float *prob, + const int num_rows, + const int num_topK, + const int num_cols) +{ + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + using FragmentLoadStore = cutlass::Array; + using FragmentCompute = cutlass::Array; + + cutlass::NumericArrayConverter src_converter; + cutlass::NumericArrayConverter dst_converter; + + // each block corresponds to one source token + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) + { + for (int i = tid; i < num_topK; i += blockDim.x * blockDim.y) + { + s_prob[i] = TCompute(prob[source_token * num_topK + i]); + } + __syncthreads(); + } + + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) + { + FragmentLoadStore frag_load_store; + FragmentCompute frag_elem; + FragmentCompute frag_sum; + + int source_row = row_id_map[source_token]; + + if (source_row != -1) + { + const T *source_row_ptr = input + source_row * num_cols; + + cutlass::arch::global_load( + frag_load_store, (source_row_ptr + i), true); + frag_sum = src_converter(frag_load_store); + + if (hasProb) + { + frag_sum = frag_sum * s_prob[0]; + } + } + else + { + frag_sum.clear(); + } + + for (int k = 1; k < num_topK; k++) + { + source_row = row_id_map[k * num_rows + source_token]; + + if (source_row == -1) + continue; + + const T *source_row_ptr = input + source_row * num_cols; + + cutlass::arch::global_load( + frag_load_store, (source_row_ptr + i), true); + frag_elem = src_converter(frag_load_store); + + if (hasProb) + { + frag_elem = frag_elem * s_prob[k]; + } + + for (int e = 0; e < kElementsPerAccess; e++) + { + frag_sum.at(e) = frag_sum.at(e) + frag_elem.at(e); + } + } + + T *dest_row_ptr = unpermuted_output + source_token * num_cols; + frag_load_store = dst_converter(frag_sum); + *(float4 *)(dest_row_ptr + i) = *(float4 *)(frag_load_store.data()); + } +} + +template +__global__ void moe_permute_topK_kernel(const T *input_bwd, + const T *input_fwd, + T *act_grad, + const float *prob, + float *prob_grad, + const int *row_id_map, + const int num_rows, + const int num_topK, + const int num_cols) +{ + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + using FragmentLoadStore = cutlass::Array; + using FragmentCompute = cutlass::Array; + + cutlass::NumericArrayConverter src_converter; + cutlass::NumericArrayConverter dst_converter; + + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) + { + for (int i = tid; i < num_topK; i += blockDim.x) + { + s_prob[i] = TCompute(prob[source_token * num_topK + i]); + } + __syncthreads(); + } + + float accum[topKTile] = {0.0f}; + FragmentLoadStore frag_load_store; + + const T *source_row_ptr = input_bwd + source_token * num_cols; + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) + { + cutlass::arch::global_load( + frag_load_store, (source_row_ptr + i), true); + FragmentCompute frag_src = src_converter(frag_load_store); + + int index = source_token; + + for (int k = 0; k < topKTile; k++) + { + if (k == num_topK) break; + + int dest_row = row_id_map[index]; + index += num_rows; + + if (dest_row != -1) + { + if (hasProb) + { + frag_load_store = dst_converter(frag_src * s_prob[k]); + } + else + { + frag_load_store = dst_converter(frag_src); + } + + T *dest_row_ptr = act_grad + dest_row * num_cols; + *(float4 *)(dest_row_ptr + i) = *(float4 *)(frag_load_store.data()); + + if (hasProb) + { + const T *input_fwd_ptr = input_fwd + dest_row * num_cols; + cutlass::arch::global_load( + frag_load_store, (input_fwd_ptr + i), true); + FragmentCompute frag_input_fwd = src_converter(frag_load_store); + + for (int e = 0; e < kElementsPerAccess; e++) + { + accum[k] += float(frag_src.at(e) * frag_input_fwd.at(e)); + } + } + } + } + } + + if (hasProb) + { + for (int k = 0; k < topKTile; k++) + { + if (k == num_topK) break; + + for (int mask = 16; mask > 0; mask /= 2) + { + accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); + } + } + + if (tid == 0) + { + for (int k = 0; k < topKTile; k++) + { + if (k == num_topK) break; + prob_grad[source_token * num_topK + k] = accum[k]; + } + } + } +} + +template +void moe_permute_topK_kernel_launcher( + const T *input, + T *output, + const int *sorted_row_id, + int *row_id_map, + const float *prob, + const int num_rows, + const int num_topK, + const int num_cols, + const int num_out_tokens, + cudaStream_t stream, + float *prob_grad = nullptr, + const T *input_fwd = nullptr) +{ + if (FWD) + { + if (prob == nullptr) + { + if (input_fwd == nullptr) + { + // permute_topK fwd + int threads = 64; + int blocks = (num_rows * num_topK + threads - 1) / threads; + moe_permute_topK_row_map<<>>( + sorted_row_id, + row_id_map, + num_rows, + num_topK, + num_out_tokens); + + blocks = num_rows; + threads = std::min(num_cols / kElementsPerAccess, 1024); + moe_permute_topK_kernel<<>>( + input, + nullptr, + output, + nullptr, + nullptr, + row_id_map, + num_rows, + num_topK, + num_cols); + } + else + { + // unpermute_topK bwd without probs for topK == 1 + int blocks = num_rows; + int threads = 32; + + moe_permute_topK_kernel<<>>( + input, + input_fwd, + output, + prob, + prob_grad, + row_id_map, + num_rows, + num_topK, + num_cols); + } + } + else + { + // unpermute_topK bwd with probs + int blocks = num_rows; + int threads = 32; + size_t smem_bytes = num_topK * sizeof(TCompute); + + if (num_topK <= 8) + { + moe_permute_topK_kernel<<>>( + input, + input_fwd, + output, + prob, + prob_grad, + row_id_map, + num_rows, + num_topK, + num_cols); + } + else if (num_topK <= 16) + { + moe_permute_topK_kernel<<>>( + input, + input_fwd, + output, + prob, + prob_grad, + row_id_map, + num_rows, + num_topK, + num_cols); + } + else if (num_topK <= 32) + { + moe_permute_topK_kernel<<>>( + input, + input_fwd, + output, + prob, + prob_grad, + row_id_map, + num_rows, + num_topK, + num_cols); + } + else if (num_topK <= 64) + { + moe_permute_topK_kernel<<>>( + input, + input_fwd, + output, + prob, + prob_grad, + row_id_map, + num_rows, + num_topK, + num_cols); + } + else if (num_topK <= 128) + { + moe_permute_topK_kernel<<>>( + input, + input_fwd, + output, + prob, + prob_grad, + row_id_map, + num_rows, + num_topK, + num_cols); + } + else + { + throw std::runtime_error("num_topK cannot exceed 128."); + } + } + } + else + { + int blocks = num_rows; + int threads = std::min(num_cols / kElementsPerAccess, 1024); + size_t smem_bytes = num_topK * sizeof(TCompute); + + if (prob == nullptr) + { + // permute_topK bwd + // unpermute_topK fwd without probs + moe_recover_topK_kernel<<>>( + input, + output, + row_id_map, + prob, + num_rows, + num_topK, + num_cols); + } + else + { + // unpermute_topK fwd with probs + moe_recover_topK_kernel<<>>( + input, + output, + row_id_map, + prob, + num_rows, + num_topK, + num_cols); + } + } +} diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index aac693a430..e3e9957357 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -166,4 +167,10 @@ at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype); void* getDataPtr(at::Tensor tensor, int offset = 0); +template +inline T *get_ptr(torch::Tensor &t) +{ + return reinterpret_cast(t.data_ptr()); +} + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 445271598c..9a670e16ab 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -10,6 +10,31 @@ #include "common.h" #include "common/common.h" + +/*************************************************************************************************** + * permute + **************************************************************************************************/ + +std::tuple> moe_permute_topK_op( + at::Tensor input, + at::Tensor indices, + int64_t num_out_tokens, + std::vector workspace, + int64_t max_expanded_token_num); + +at::Tensor moe_recover_topK_op( + at::Tensor input, + at::Tensor row_id_map, + at::Tensor prob, + int64_t num_tokens, + int64_t num_topK); + +std::tuple moe_recover_topK_bwd_op( + at::Tensor input_bwd, + at::Tensor input_fwd, + at::Tensor row_id_map, + at::Tensor prob); + /*************************************************************************************************** * Attention **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu new file mode 100644 index 0000000000..ba449cf07f --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -0,0 +1,500 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "extensions.h" + +using torch::Tensor; + +std::tuple> moe_permute_topK_op( + Tensor input, + Tensor indices, + int64_t num_out_tokens, + std::vector workspace, + int64_t max_expanded_token_num) +{ + const int num_tokens = input.size(0); + const int num_cols = input.size(1); + const int num_topK = indices.size(1); + + // initialize the workspace on the first run + if (workspace.empty()) { + auto options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false); + + Tensor sorted_indices = torch::empty(max_expanded_token_num, options); + Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options); + Tensor sorted_row_id = + torch::empty(max_expanded_token_num, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + + size_t temp_storage_bytes = 0; + int *temp_ptr = nullptr; + cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, + temp_ptr, temp_ptr, + temp_ptr, temp_ptr, max_expanded_token_num); + Tensor temp_storage = + torch::empty(temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + + workspace.push_back(sorted_indices); + workspace.push_back(row_id); + workspace.push_back(sorted_row_id); + workspace.push_back(temp_storage); + } + + int *indices_ptr = get_ptr(indices); + int *sorted_indices_ptr = get_ptr(workspace[0]); + int *row_id_ptr = get_ptr(workspace[1]); + int *sorted_row_id_ptr = get_ptr(workspace[2]); + + void *d_temp_storage = get_ptr(workspace[3]); + size_t temp_storage_bytes = std::numeric_limits::max(); + + cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, + indices_ptr, sorted_indices_ptr, + row_id_ptr, sorted_row_id_ptr, num_tokens * num_topK); + + // activations type + const at::ScalarType _st = input.scalar_type(); + + // Output buffer alloc + num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * num_topK; + Tensor permuted_output = + torch::empty({num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + Tensor row_id_map = + torch::empty({num_tokens * num_topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + + int *row_id_map_ptr = get_ptr(row_id_map); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + switch (_st) + { + case at::ScalarType::Float: + { + using dType = float; + using dTypeCompute = float; + + dType *input_ptr = get_ptr(input); + dType *permuted_output_ptr = get_ptr(permuted_output); + + moe_permute_topK_kernel_launcher( + input_ptr, + permuted_output_ptr, + sorted_row_id_ptr, + row_id_map_ptr, + nullptr, + num_tokens, + num_topK, + num_cols, + num_out_tokens, + stream); + + break; + } + case at::ScalarType::Half: + { + using dType = cutlass::half_t; + using dTypeCompute = cutlass::half_t; + + dType *input_ptr = get_ptr(input); + dType *permuted_output_ptr = get_ptr(permuted_output); + + moe_permute_topK_kernel_launcher( + input_ptr, + permuted_output_ptr, + sorted_row_id_ptr, + row_id_map_ptr, + nullptr, + num_tokens, + num_topK, + num_cols, + num_out_tokens, + stream); + + break; + } +#ifdef ENABLE_BF16 + case at::ScalarType::BFloat16: + { + using dType = cutlass::bfloat16_t; + using dTypeCompute = cutlass::bfloat16_t; + + dType *input_ptr = get_ptr(input); + dType *permuted_output_ptr = get_ptr(permuted_output); + + moe_permute_topK_kernel_launcher( + input_ptr, + permuted_output_ptr, + sorted_row_id_ptr, + row_id_map_ptr, + nullptr, + num_tokens, + num_topK, + num_cols, + num_out_tokens, + stream); + + break; + } +#endif +#ifdef ENABLE_FP8 + case at::ScalarType::Float8_e5m2: + { + using dType = cutlass::float_e5m2_t; + using dTypeCompute = cutlass::half_t; + + dType *input_ptr = get_ptr(input); + dType *permuted_output_ptr = get_ptr(permuted_output); + + moe_permute_topK_kernel_launcher( + input_ptr, + permuted_output_ptr, + sorted_row_id_ptr, + row_id_map_ptr, + nullptr, + num_tokens, + num_topK, + num_cols, + num_out_tokens, + stream); + + break; + } + case at::ScalarType::Float8_e4m3fn: + { + using dType = cutlass::float_e4m3_t; + using dTypeCompute = cutlass::half_t; + + dType *input_ptr = get_ptr(input); + dType *permuted_output_ptr = get_ptr(permuted_output); + + moe_permute_topK_kernel_launcher( + input_ptr, + permuted_output_ptr, + sorted_row_id_ptr, + row_id_map_ptr, + nullptr, + num_tokens, + num_topK, + num_cols, + num_out_tokens, + stream); + + break; + } +#endif + default: + throw std::runtime_error("Wrong activation tensor type."); + } + + return std::make_tuple(permuted_output, row_id_map, workspace); +} + + +Tensor moe_recover_topK_op( + Tensor input, + Tensor row_id_map, + Tensor prob, + int64_t num_tokens, + int64_t num_topK) +{ + const int num_cols = input.size(1); + + // activations type + const at::ScalarType _st = input.scalar_type(); + + // Output buffer alloc + Tensor unpermuted_output = + torch::empty({num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + + int *row_id_map_ptr = get_ptr(row_id_map); + float *prob_ptr = (prob.defined()) ? get_ptr(prob) : nullptr; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + switch (_st) + { + case at::ScalarType::Float: + { + using dType = float; + using dTypeCompute = float; + + dType *input_ptr = get_ptr(input); + dType *unpermuted_output_ptr = get_ptr(unpermuted_output); + + moe_permute_topK_kernel_launcher( + input_ptr, + unpermuted_output_ptr, + nullptr, + row_id_map_ptr, + prob_ptr, + num_tokens, + num_topK, + num_cols, + 0, + stream); + + break; + } + case at::ScalarType::Half: + { + using dType = cutlass::half_t; + using dTypeCompute = cutlass::half_t; + + dType *input_ptr = get_ptr(input); + dType *unpermuted_output_ptr = get_ptr(unpermuted_output); + + moe_permute_topK_kernel_launcher( + input_ptr, + unpermuted_output_ptr, + nullptr, + row_id_map_ptr, + prob_ptr, + num_tokens, + num_topK, + num_cols, + 0, + stream); + + break; + } +#ifdef ENABLE_BF16 + case at::ScalarType::BFloat16: + { + using dType = cutlass::bfloat16_t; + using dTypeCompute = cutlass::bfloat16_t; + + dType *input_ptr = get_ptr(input); + dType *unpermuted_output_ptr = get_ptr(unpermuted_output); + + moe_permute_topK_kernel_launcher( + input_ptr, + unpermuted_output_ptr, + nullptr, + row_id_map_ptr, + prob_ptr, + num_tokens, + num_topK, + num_cols, + 0, + stream); + + break; + } +#endif +#ifdef ENABLE_FP8 + case at::ScalarType::Float8_e5m2: + { + using dType = cutlass::float_e5m2_t; + using dTypeCompute = cutlass::half_t; + + dType *input_ptr = get_ptr(input); + dType *unpermuted_output_ptr = get_ptr(unpermuted_output); + + moe_permute_topK_kernel_launcher( + input_ptr, + unpermuted_output_ptr, + nullptr, + row_id_map_ptr, + prob_ptr, + num_tokens, + num_topK, + num_cols, + 0, + stream); + + break; + } + case at::ScalarType::Float8_e4m3fn: + { + using dType = cutlass::float_e4m3_t; + using dTypeCompute = cutlass::half_t; + + dType *input_ptr = get_ptr(input); + dType *unpermuted_output_ptr = get_ptr(unpermuted_output); + + moe_permute_topK_kernel_launcher( + input_ptr, + unpermuted_output_ptr, + nullptr, + row_id_map_ptr, + prob_ptr, + num_tokens, + num_topK, + num_cols, + 0, + stream); + + break; + } +#endif + default: + throw std::runtime_error("Wrong activation tensor type."); + } + + return unpermuted_output; +} + +std::tuple moe_recover_topK_bwd_op( + Tensor input_bwd, + Tensor input_fwd, + Tensor row_id_map, + Tensor prob) +{ + const int num_topK = (prob.defined()) ? prob.size(1) : 1; + const int num_tokens = (prob.defined()) ? prob.size(0) : row_id_map.size(0); + const int num_cols = input_bwd.size(1); + + int *row_id_map_ptr = get_ptr(row_id_map); + float *prob_ptr = (prob.defined()) ? get_ptr(prob) : nullptr; + + // activations type + const at::ScalarType _st = input_bwd.scalar_type(); + + // Output buffer alloc + Tensor act_grad = + torch::empty({input_fwd.size(0), num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + Tensor prob_grad = + torch::empty({num_tokens, num_topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); + float *prob_grad_ptr = get_ptr(prob_grad); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + switch (_st) + { + case at::ScalarType::Float: + { + using dType = float; + using dTypeCompute = float; + + dType *input_bwd_ptr = get_ptr(input_bwd); + dType *input_fwd_ptr = get_ptr(input_fwd); + dType *act_grad_ptr = get_ptr(act_grad); + + moe_permute_topK_kernel_launcher( + input_bwd_ptr, + act_grad_ptr, + nullptr, + row_id_map_ptr, + prob_ptr, + num_tokens, + num_topK, + num_cols, + 0, + stream, + prob_grad_ptr, + input_fwd_ptr); + + break; + } + case at::ScalarType::Half: + { + using dType = cutlass::half_t; + using dTypeCompute = cutlass::half_t; + + dType *input_bwd_ptr = get_ptr(input_bwd); + dType *input_fwd_ptr = get_ptr(input_fwd); + dType *act_grad_ptr = get_ptr(act_grad); + + moe_permute_topK_kernel_launcher( + input_bwd_ptr, + act_grad_ptr, + nullptr, + row_id_map_ptr, + prob_ptr, + num_tokens, + num_topK, + num_cols, + 0, + stream, + prob_grad_ptr, + input_fwd_ptr); + + break; + } +#ifdef ENABLE_BF16 + case at::ScalarType::BFloat16: + { + using dType = cutlass::bfloat16_t; + using dTypeCompute = cutlass::bfloat16_t; + + dType *input_bwd_ptr = get_ptr(input_bwd); + dType *input_fwd_ptr = get_ptr(input_fwd); + dType *act_grad_ptr = get_ptr(act_grad); + + moe_permute_topK_kernel_launcher( + input_bwd_ptr, + act_grad_ptr, + nullptr, + row_id_map_ptr, + prob_ptr, + num_tokens, + num_topK, + num_cols, + 0, + stream, + prob_grad_ptr, + input_fwd_ptr); + + break; + } +#endif +#ifdef ENABLE_FP8 + case at::ScalarType::Float8_e5m2: + { + using dType = cutlass::float_e5m2_t; + using dTypeCompute = cutlass::half_t; + + dType *input_bwd_ptr = get_ptr(input_bwd); + dType *input_fwd_ptr = get_ptr(input_fwd); + dType *act_grad_ptr = get_ptr(act_grad); + + moe_permute_topK_kernel_launcher( + input_bwd_ptr, + act_grad_ptr, + nullptr, + row_id_map_ptr, + prob_ptr, + num_tokens, + num_topK, + num_cols, + 0, + stream, + prob_grad_ptr, + input_fwd_ptr); + + break; + } + case at::ScalarType::Float8_e4m3fn: + { + using dType = cutlass::float_e4m3_t; + using dTypeCompute = cutlass::half_t; + + dType *input_bwd_ptr = get_ptr(input_bwd); + dType *input_fwd_ptr = get_ptr(input_fwd); + dType *act_grad_ptr = get_ptr(act_grad); + + moe_permute_topK_kernel_launcher( + input_bwd_ptr, + act_grad_ptr, + nullptr, + row_id_map_ptr, + prob_ptr, + num_tokens, + num_topK, + num_cols, + 0, + stream, + prob_grad_ptr, + input_fwd_ptr); + + break; + } +#endif + default: + throw std::runtime_error("Wrong activation tensor type."); + } + + return std::make_tuple(act_grad, prob_grad); +} \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 86641ea0a7..d85a4cb5a2 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -10,6 +10,11 @@ #include "../extensions.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // Permutation functions + m.def("moe_permute_topK_op", moe_permute_topK_op); + m.def("moe_recover_topK_op", moe_recover_topK_op); + m.def("moe_recover_topK_bwd_op", moe_recover_topK_bwd_op); + // Softmax functions m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD", py::call_guard()); From 4b4930a9ec10d2c3df29c1678d12c5bcde350fd6 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Mon, 17 Jun 2024 13:41:52 +0000 Subject: [PATCH 03/33] Building pass Signed-off-by: Jiang Shao --- transformer_engine/common/CMakeLists.txt | 1 + .../include/transformer_engine/permutation.h | 8 +- .../common/permutation/permutation.cu | 32 ++++- .../pytorch/csrc/extensions/permutation.cu | 135 ++++-------------- 4 files changed, 55 insertions(+), 121 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index c0d18141b4..e99e9c1912 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -57,6 +57,7 @@ list(APPEND transformer_engine_SOURCES layer_norm/ln_api.cpp layer_norm/ln_bwd_semi_cuda_kernel.cu layer_norm/ln_fwd_cuda_kernel.cu + permutation/permutation.cu rmsnorm/rmsnorm_api.cpp rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu rmsnorm/rmsnorm_fwd_cuda_kernel.cu diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h index 74ba6c4036..574333b4c5 100644 --- a/transformer_engine/common/include/transformer_engine/permutation.h +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -9,9 +9,9 @@ #include "transformer_engine.h" -template -void moe_permute_topK_kernel_launcher(const T *input, - T *output, +template +void moe_permute_topK_kernel_launcher(const void *input, + void *output, const int *sorted_row_id, int *row_id_map, const float *prob, @@ -21,6 +21,6 @@ void moe_permute_topK_kernel_launcher(const T *input, const int num_out_tokens, cudaStream_t stream, float *prob_grad = nullptr, - const T *input_fwd = nullptr); + const void *input_fwd = nullptr); #endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index cac3d225c1..121bc2c219 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -4,8 +4,6 @@ * See LICENSE for license information. ************************************************************************/ -#pragma once - #include "cutlass/arch/memory.h" #include "cutlass/arch/cache_operation.h" #include "cutlass/array.h" @@ -238,10 +236,10 @@ __global__ void moe_permute_topK_kernel(const T *input_bwd, } } -template +template void moe_permute_topK_kernel_launcher( - const T *input, - T *output, + const void *input, + void *output, const int *sorted_row_id, int *row_id_map, const float *prob, @@ -250,9 +248,29 @@ void moe_permute_topK_kernel_launcher( const int num_cols, const int num_out_tokens, cudaStream_t stream, - float *prob_grad = nullptr, - const T *input_fwd = nullptr) + float *prob_grad, + const void *input_fwd) { + // Convert to cutlass type + using T_fp16 = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::half_t, TInput>::type; + using T_bf16 = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::bfloat16_t, T_fp16>::type; + using T_fp8e5m2 = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::float_e5m2_t, T_bf16>::type; + using T_fp8e4m3 = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::float_e4m3_t, T_fp8e5m2>::type; + using T = T_fp8e4m3; + + using TCompute = typename cutlass::platform::conditional< + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value), + cutlass::half_t, T>::type; + if (FWD) { if (prob == nullptr) diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index ba449cf07f..16651594dc 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -69,17 +69,14 @@ std::tuple> moe_permute_topK_op( int *row_id_map_ptr = get_ptr(row_id_map); auto stream = at::cuda::getCurrentCUDAStream().stream(); + void *input_ptr = getDataPtr(input, 0); + void *permuted_output_ptr = getDataPtr(permuted_output, 0); + switch (_st) { case at::ScalarType::Float: { - using dType = float; - using dTypeCompute = float; - - dType *input_ptr = get_ptr(input); - dType *permuted_output_ptr = get_ptr(permuted_output); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -95,13 +92,7 @@ std::tuple> moe_permute_topK_op( } case at::ScalarType::Half: { - using dType = cutlass::half_t; - using dTypeCompute = cutlass::half_t; - - dType *input_ptr = get_ptr(input); - dType *permuted_output_ptr = get_ptr(permuted_output); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -118,13 +109,7 @@ std::tuple> moe_permute_topK_op( #ifdef ENABLE_BF16 case at::ScalarType::BFloat16: { - using dType = cutlass::bfloat16_t; - using dTypeCompute = cutlass::bfloat16_t; - - dType *input_ptr = get_ptr(input); - dType *permuted_output_ptr = get_ptr(permuted_output); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher<__nv_bfloat16, true, 8>( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -142,13 +127,7 @@ std::tuple> moe_permute_topK_op( #ifdef ENABLE_FP8 case at::ScalarType::Float8_e5m2: { - using dType = cutlass::float_e5m2_t; - using dTypeCompute = cutlass::half_t; - - dType *input_ptr = get_ptr(input); - dType *permuted_output_ptr = get_ptr(permuted_output); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, true, 16>( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -164,13 +143,7 @@ std::tuple> moe_permute_topK_op( } case at::ScalarType::Float8_e4m3fn: { - using dType = cutlass::float_e4m3_t; - using dTypeCompute = cutlass::half_t; - - dType *input_ptr = get_ptr(input); - dType *permuted_output_ptr = get_ptr(permuted_output); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, true, 16>( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -213,17 +186,14 @@ Tensor moe_recover_topK_op( float *prob_ptr = (prob.defined()) ? get_ptr(prob) : nullptr; auto stream = at::cuda::getCurrentCUDAStream().stream(); + void *input_ptr = getDataPtr(input, 0); + void *unpermuted_output_ptr = getDataPtr(unpermuted_output, 0); + switch (_st) { case at::ScalarType::Float: { - using dType = float; - using dTypeCompute = float; - - dType *input_ptr = get_ptr(input); - dType *unpermuted_output_ptr = get_ptr(unpermuted_output); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_ptr, unpermuted_output_ptr, nullptr, @@ -239,13 +209,7 @@ Tensor moe_recover_topK_op( } case at::ScalarType::Half: { - using dType = cutlass::half_t; - using dTypeCompute = cutlass::half_t; - - dType *input_ptr = get_ptr(input); - dType *unpermuted_output_ptr = get_ptr(unpermuted_output); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_ptr, unpermuted_output_ptr, nullptr, @@ -262,13 +226,7 @@ Tensor moe_recover_topK_op( #ifdef ENABLE_BF16 case at::ScalarType::BFloat16: { - using dType = cutlass::bfloat16_t; - using dTypeCompute = cutlass::bfloat16_t; - - dType *input_ptr = get_ptr(input); - dType *unpermuted_output_ptr = get_ptr(unpermuted_output); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher<__nv_bfloat16, false, 8>( input_ptr, unpermuted_output_ptr, nullptr, @@ -286,13 +244,7 @@ Tensor moe_recover_topK_op( #ifdef ENABLE_FP8 case at::ScalarType::Float8_e5m2: { - using dType = cutlass::float_e5m2_t; - using dTypeCompute = cutlass::half_t; - - dType *input_ptr = get_ptr(input); - dType *unpermuted_output_ptr = get_ptr(unpermuted_output); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, false, 16>( input_ptr, unpermuted_output_ptr, nullptr, @@ -308,13 +260,7 @@ Tensor moe_recover_topK_op( } case at::ScalarType::Float8_e4m3fn: { - using dType = cutlass::float_e4m3_t; - using dTypeCompute = cutlass::half_t; - - dType *input_ptr = get_ptr(input); - dType *unpermuted_output_ptr = get_ptr(unpermuted_output); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, false, 16>( input_ptr, unpermuted_output_ptr, nullptr, @@ -361,18 +307,15 @@ std::tuple moe_recover_topK_bwd_op( auto stream = at::cuda::getCurrentCUDAStream().stream(); + void *input_bwd_ptr = getDataPtr(input_bwd, 0); + void *input_fwd_ptr = getDataPtr(input_fwd, 0); + void *act_grad_ptr = getDataPtr(act_grad, 0); + switch (_st) { case at::ScalarType::Float: { - using dType = float; - using dTypeCompute = float; - - dType *input_bwd_ptr = get_ptr(input_bwd); - dType *input_fwd_ptr = get_ptr(input_fwd); - dType *act_grad_ptr = get_ptr(act_grad); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_bwd_ptr, act_grad_ptr, nullptr, @@ -390,14 +333,7 @@ std::tuple moe_recover_topK_bwd_op( } case at::ScalarType::Half: { - using dType = cutlass::half_t; - using dTypeCompute = cutlass::half_t; - - dType *input_bwd_ptr = get_ptr(input_bwd); - dType *input_fwd_ptr = get_ptr(input_fwd); - dType *act_grad_ptr = get_ptr(act_grad); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_bwd_ptr, act_grad_ptr, nullptr, @@ -416,14 +352,7 @@ std::tuple moe_recover_topK_bwd_op( #ifdef ENABLE_BF16 case at::ScalarType::BFloat16: { - using dType = cutlass::bfloat16_t; - using dTypeCompute = cutlass::bfloat16_t; - - dType *input_bwd_ptr = get_ptr(input_bwd); - dType *input_fwd_ptr = get_ptr(input_fwd); - dType *act_grad_ptr = get_ptr(act_grad); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher<__nv_bfloat16, true, 8>( input_bwd_ptr, act_grad_ptr, nullptr, @@ -443,14 +372,7 @@ std::tuple moe_recover_topK_bwd_op( #ifdef ENABLE_FP8 case at::ScalarType::Float8_e5m2: { - using dType = cutlass::float_e5m2_t; - using dTypeCompute = cutlass::half_t; - - dType *input_bwd_ptr = get_ptr(input_bwd); - dType *input_fwd_ptr = get_ptr(input_fwd); - dType *act_grad_ptr = get_ptr(act_grad); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, true, 16>( input_bwd_ptr, act_grad_ptr, nullptr, @@ -468,14 +390,7 @@ std::tuple moe_recover_topK_bwd_op( } case at::ScalarType::Float8_e4m3fn: { - using dType = cutlass::float_e4m3_t; - using dTypeCompute = cutlass::half_t; - - dType *input_bwd_ptr = get_ptr(input_bwd); - dType *input_fwd_ptr = get_ptr(input_fwd); - dType *act_grad_ptr = get_ptr(act_grad); - - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, true, 16>( input_bwd_ptr, act_grad_ptr, nullptr, From 5f881d9f1d5cd1239fc4a7ea7c537d7508f52f65 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Mon, 17 Jun 2024 13:42:33 +0000 Subject: [PATCH 04/33] Add permutation ops Signed-off-by: Jiang Shao --- .../pytorch/module/permutation.py | 203 ++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 transformer_engine/pytorch/module/permutation.py diff --git a/transformer_engine/pytorch/module/permutation.py b/transformer_engine/pytorch/module/permutation.py new file mode 100644 index 0000000000..e11c084e20 --- /dev/null +++ b/transformer_engine/pytorch/module/permutation.py @@ -0,0 +1,203 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear API""" +import os +import torch +import warnings + +# TODO by Jiang Shao, add parameter `out` which can be optionally given to be used as output buffers. + +################################################################################################ +## +## PermuteMoE topK +## +################################################################################################ + +class PermuteMoE_topK(torch.autograd.Function): + + workspace_fw=None + dtype=None + max_expanded_token_num=0 + + @staticmethod + def forward(ctx, + input_act: torch.Tensor, + indices: torch.Tensor, + num_out_tokens: int, + max_token_num: int): + # Empty input check + if not input_act.numel(): + return input_act, None + + # Device check + if input_act.is_cpu: + raise RuntimeError("[Error] The input `input_act` of permute_topK op is on the device: CPU!") + if indices.is_cpu: + warnings.warn("[Warning] The input `indices` of permute_topK op is on the device: CPU!") + expert_for_rows = expert_for_rows.cuda() + + # Shape check + if input_act.size(0) != indices.size(0): + raise RuntimeError(f"[Error] permute_topK op input `indices` shape mismatch! " + f"Expect {input_act.size(0)}, but got {indices.size(0)}.") + + # Data type check + if indices.dtype != torch.int32: + warnings.warn(f"[Warning] The data type of the input `indices` of permute_topK op is {indices.dtype}! " + "The recommended type is torch.int32.") + indices = indices.to(torch.int32) + + # Contiguous check + if not input_act.is_contiguous(): + warnings.warn("[Warning] The input `input_act` of permute_topK op is discontiguous!") + input_act = input_act.contiguous() + if not indices.is_contiguous(): + warnings.warn("[Warning] The input `indices` of permute_topK op is discontiguous!") + indices = indices.contiguous() + + num_topK = indices.size(1) + + input_max_expanded_token_num = max(max_token_num, input_act.size(0)) * num_topK + if PermuteMoE_topK.max_expanded_token_num < input_max_expanded_token_num: + PermuteMoE_topK.max_expanded_token_num = input_max_expanded_token_num + PermuteMoE_topK.workspace_fw = [] + + if PermuteMoE_topK.dtype != input_act.dtype: + PermuteMoE_topK.dtype = input_act.dtype + PermuteMoE_topK.workspace_fw = [] + + permuted_act, row_id_map, PermuteMoE_topK.workspace_fw = torch.ops.moe_unit_ops.moe_permute_topK_op( + input_act, + indices, + num_out_tokens, + PermuteMoE_topK.workspace_fw, + PermuteMoE_topK.max_expanded_token_num) + + ctx.row_id_map = row_id_map + ctx.num_tokens = indices.size(0) + ctx.num_topK = indices.size(1) + return permuted_act, row_id_map + + + @staticmethod + def backward(ctx, permuted_act_grad, _): + # Empty input check + if not permuted_act_grad.numel(): + return permuted_act_grad, None, None, None + + if not permuted_act_grad.is_contiguous(): + permuted_act_grad = permuted_act_grad.contiguous() + + row_id_map = ctx.row_id_map + num_tokens = ctx.num_tokens + num_topK = ctx.num_topK + + unpermuted_act_grad = torch.ops.moe_unit_ops.moe_recover_topK_op( + permuted_act_grad, + row_id_map, + None, + num_tokens, + num_topK) + return unpermuted_act_grad, None, None, None + +################################################################################################ +## +## UnpermuteMoE topK +## +################################################################################################ + +class UnpermuteMoE_topK(torch.autograd.Function): + + @staticmethod + def forward(ctx, + input_act: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor): + # Empty input check + if not input_act.numel(): + ctx.probs = probs + return input_act + + # None probs check + if probs is not None: + if probs.is_cpu: + warnings.warn("[Warning] The input `probs` of unpermute_topK op is on the device: CPU!") + probs = probs.cuda() + if probs.dtype != torch.float32: + warnings.warn(f"[Warning] The data type of the input `probs` of unpermute_topK op is {probs.dtype}! " + "The recommended type is torch.float32.") + probs = probs.to(torch.float32) + if not probs.is_contiguous(): + warnings.warn("[Warning] The input `probs` of unpermute_topK op is discontiguous!") + probs = probs.contiguous() + + # Device check + if input_act.is_cpu: + raise RuntimeError("[Error] The input `input_act` of unpermute_topK op is on the device: CPU!") + if row_id_map.is_cpu: + warnings.warn("[Warning] The input `row_id_map` of unpermute_topK op is on the device: CPU!") + row_id_map = row_id_map.cuda() + + # Data type check + if row_id_map.dtype != torch.int32: + warnings.warn(f"[Warning] The data type of the input `row_id_map` of unpermute_topK op is {row_id_map.dtype}! " + "The recommended type is torch.int32.") + row_id_map = row_id_map.to(torch.int32) + + # Contiguous check + if not input_act.is_contiguous(): + warnings.warn("[Warning] The input `input_act` of unpermute_topK op is discontiguous!") + input_act = input_act.contiguous() + if not row_id_map.is_contiguous(): + warnings.warn("[Warning] The input `row_id_map` of unpermute_topK op is discontiguous!") + row_id_map = row_id_map.contiguous() + + num_topK = probs.size(1) if probs is not None else 1 + num_tokens = probs.size(0) if probs is not None else row_id_map.size(0) + + unpermuted_output = torch.ops.moe_unit_ops.moe_recover_topK_op( + input_act, + row_id_map, + probs, + num_tokens, + num_topK) + + ctx.save_for_backward(input_act, row_id_map, probs) + return unpermuted_output + + @staticmethod + def backward(ctx, unpermuted_act_grad): + # Empty input check + if not unpermuted_act_grad.numel(): + return unpermuted_act_grad, None, ctx.probs + + if not unpermuted_act_grad.is_contiguous(): + unpermuted_act_grad = unpermuted_act_grad.contiguous() + + input_act, row_id_map, probs = ctx.saved_tensors + + act_grad = None + if ctx.needs_input_grad[0]: + act_grad, prob_grad = torch.ops.moe_unit_ops.moe_recover_topK_bwd_op( + unpermuted_act_grad, + input_act, + row_id_map, + probs) + + if not ctx.needs_input_grad[2]: + prob_grad = None + return act_grad, None, prob_grad + +################################################################################################ +## +## Ops Wrapper +## +################################################################################################ + +def permute(input_act, indices, num_out_tokens=-1, max_token_num=-1): + return PermuteMoE_topK.apply(input_act, indices, num_out_tokens, max_token_num) + +def unpermute(input_act, row_id_map, probs): + return UnpermuteMoE_topK.apply(input_act, row_id_map, probs) From 1ce7fc11256a4590135769f22df3e1b996983987 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Mon, 17 Jun 2024 14:59:52 +0000 Subject: [PATCH 05/33] Everything works fine Signed-off-by: Jiang Shao --- tests/pytorch/test_permutation.py | 348 ++++++++++++++++++ .../include/transformer_engine/permutation.h | 2 +- .../common/permutation/permutation.cu | 42 ++- transformer_engine/pytorch/__init__.py | 1 + .../pytorch/csrc/extensions/permutation.cu | 50 +-- transformer_engine/pytorch/module/__init__.py | 1 + .../pytorch/module/permutation.py | 18 +- 7 files changed, 418 insertions(+), 44 deletions(-) create mode 100644 tests/pytorch/test_permutation.py diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py new file mode 100644 index 0000000000..e6367d733b --- /dev/null +++ b/tests/pytorch/test_permutation.py @@ -0,0 +1,348 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import triton +import torch.cuda.nvtx as nvtx + +from transformer_engine.pytorch import permute as permute_topK, unpermute as unpermute_topK + +def permute(tokens, indices, num_out_tokens: int = 0): + """Permute the tokens based on the indices. Token with the same index will be grouped together. + The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately. + Args: + tokens (torch.Tensor): The input token tensor. + indices (torch.Tensor): The token to expert indices tensor, should have a shape of [num_tokens, topk]. + topk (int, optional): The topk value. Defaults to 1. + num_out_tokens (int, optional): The effective token count, when enabling the capacity factor, should equal the number of tokens not dropped. By default, set to None, meaning no tokens are dropped. + + Returns: + torch.Tensor: The permuted tensor. + torch.Tensor: The sorted_indices corresponding permuted tensor. + """ + + topk = indices.size(1) + flatten_indices = indices.view(-1) + sorted_indices = torch.argsort(flatten_indices, stable=True) + if num_out_tokens > 0: + sorted_indices = sorted_indices[:num_out_tokens] + permuted_tokens = tokens.index_select(0, sorted_indices // topk) + return permuted_tokens, sorted_indices + + +def unpermute( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + probs: torch.Tensor = torch.empty(0), +): + """Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities. + + Args: + permuted_tokens (torch.Tensor): The tensor of permuted tokens to be unpermuted. + sorted_indices (torch.Tensor): The tensor of sorted indices used to unpermute the tokens. + probs (torch.Tensor, optional): The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. + + Returns: + torch.Tensor: The unpermuted tokens, optionally merged with probabilities. + """ + num_unpermuted_tokens = probs.numel() + topk = probs.size(1) + + unpermuted_tokens = torch.zeros( + [num_unpermuted_tokens, permuted_tokens.shape[-1]], + dtype=permuted_tokens.dtype, + device=permuted_tokens.device, + ) + unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens) + unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1)) + + if probs is not None: + unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1) + unpermuted_tokens = unpermuted_tokens.sum(dim=1) + + return unpermuted_tokens + + +def permute_topK_test( + dtype, + num_token, + num_expert, + hidden_size, + num_topK, + num_out_tokens = None, + PRINT = False, + BENCHMARK = False): + + if num_out_tokens == None: + num_out_tokens = num_token * num_topK + + print(f"{dtype} token:{num_token} hidden_size:{hidden_size} expert:{num_expert} topK:{num_topK}") + + is_fp8 = dtype in [torch.float8_e5m2, torch.float8_e4m3fn] + + permute_input = torch.rand((num_token, hidden_size), dtype=torch.float32).cuda() + # for i in range(num_token): + # for j in range(hidden_size): + # permute_input[i][j] = i * 100 + j + permute_input = permute_input.to(dtype) + if is_fp8: + permute_input = permute_input.half() + + permute_input.requires_grad_(True) + + if num_token > 0: + indices = torch.stack([torch.randperm(num_expert)[:num_topK] for _ in range(num_token)]) + else: + indices = torch.empty((num_token, num_topK)) + indices = indices.to(torch.int32).cuda() + + # probs = torch.tensor([[0.1, 0.9], + # [0.2, 0.8], + # [0.3, 0.7]]) + # 0.5 + # probs = torch.ones_like(indices) / 2 + # rand + probs = torch.rand(num_token, num_topK).cuda() + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs.requires_grad_(True) + + if PRINT: + print(permute_input) + print(indices) + print(probs) + + ################################################################################################################################### + # + # PyTorch + # + ################################################################################################################################### + nvtx.range_push("PyTorch permute forward") + permute_output, sorted_indices = permute(permute_input, indices, num_out_tokens) + nvtx.range_pop() + + permute_bwd_input = torch.rand_like(permute_output) + # for i in range(num_token * num_topK): + # for j in range(hidden_size): + # permute_bwd_input[i][j] = i * 100 + j + + nvtx.range_push("PyTorch permute backward") + permute_output.backward(permute_bwd_input, retain_graph=True) + nvtx.range_pop() + + unpermute_input = permute_output.detach() + unpermute_input.requires_grad_(True) + + unpermute_output = unpermute( + unpermute_input, sorted_indices, probs=probs) + + if PRINT: + print("--------------unpermute fwd permute_input--------------") + print(unpermute_input) + print("--------------unpermute fwd output--------------") + print(unpermute_output) + + unpermute_bwd_input = torch.rand_like(unpermute_output) + # for i in range(num_token): + # for j in range(hidden_size): + # unpermute_bwd_input[i][j] = i * 2000 + j * 20 + + if PRINT: + print("--------------unpermute bwd permute_input--------------") + print(unpermute_bwd_input) + + unpermute_output.backward(unpermute_bwd_input, retain_graph=True) + if PRINT: + print("--------------unpermute bwd output act grad--------------") + print(permute_output.grad) + print("--------------unpermute bwd output probs grad--------------") + print(probs.grad) + + ################################################################################################################################### + # + # Mine + # + ################################################################################################################################### + new_permute_input = permute_input.detach().to(dtype) + new_permute_bwd_input = permute_bwd_input.detach().to(dtype) + new_unpermute_bwd_input = unpermute_bwd_input.detach().to(dtype) + new_permute_input.requires_grad_(True) + + new_permute_output, row_id_map = permute_topK(new_permute_input, indices, num_out_tokens) + + assert torch.allclose(permute_output.float(), new_permute_output.float()) + + if PRINT: + print("--------------row_id_map--------------") + print(row_id_map) + print("--------------new_permute_input--------------") + print(new_permute_input) + print("--------------new_permute_output--------------") + print(new_permute_output) + + new_permute_output.backward(new_permute_bwd_input, retain_graph=True) + + if torch.allclose(permute_input.grad.float(), new_permute_input.grad.float()) == False: + original_inputs = new_permute_input.grad.float().cpu().numpy().flatten() + original_output = permute_input.grad.float().cpu().numpy().flatten() + max_abs_error = abs(original_inputs - original_output).max() + print(f"permute_topK bwd max error (mine vs pytorch): \t\t\t{max_abs_error:.3e} ({dtype})") + + if PRINT: + print(permute_input.grad) + print(new_permute_input.grad) + + new_probs = probs.detach() + new_probs.requires_grad_(True) + if num_topK == 1: + new_probs = torch.empty(0) + new_unpermute_input = new_permute_output.detach() + new_unpermute_input.requires_grad_(True) + + new_unpermute_output = unpermute_topK(new_unpermute_input, row_id_map, new_probs) + + if torch.allclose(unpermute_output.float(), new_unpermute_output.float()) == False: + original_inputs = unpermute_output.float().cpu().detach().numpy().flatten() + original_output = new_unpermute_output.float().cpu().detach().numpy().flatten() + max_abs_error = abs(original_inputs - original_output).max() + print(f"unpermute_topK fwd max error (mine vs pytorch): \t\t{max_abs_error:.3e} ({dtype})") + + if PRINT: + print(unpermute_output) + print(new_unpermute_output) + + new_unpermute_output.backward(new_unpermute_bwd_input, retain_graph=True) + + if torch.allclose(unpermute_input.grad.float(), new_unpermute_input.grad.float()) == False: + original_inputs = unpermute_input.grad.float().cpu().detach().numpy().flatten() + original_output = new_unpermute_input.grad.float().cpu().detach().numpy().flatten() + max_abs_error = abs(original_inputs - original_output).max() + print(f"unpermute_topK bwd act_grad max error (mine vs pytorch): \t{max_abs_error:.3e} ({dtype})") + if PRINT: + print(new_unpermute_input.grad) + print(unpermute_input.grad) + + if num_topK > 1 and torch.allclose(new_probs.grad, probs.grad) == False: + original_inputs = new_probs.grad.float().cpu().detach().numpy().flatten() + original_output = probs.grad.float().cpu().detach().numpy().flatten() + max_abs_error = abs(original_inputs - original_output).max() + print(f"unpermute_topK bwd prob_grad max error (mine vs pytorch): \t{max_abs_error:.3e} ({dtype})") + if PRINT: + print(new_probs.grad) + print(probs.grad) + + if not permute_input.numel(): + print("Empty permute_input activation test passed.") + return + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + def backward_wrapper(act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False): + # Set forward_input.grad to None to avoid grad accumulation. + if accumulate_grad == False: + for i in forward_input: + i.grad = None + return act.backward(backward_input, retain_graph=retain_graph) + + if BENCHMARK: + print(f"----permute topK----") + t = perf_test_cuda_kernel(lambda: permute(permute_input, indices, num_out_tokens)) + print(f"pytorch fwd: {t:.3f} ms") + t = perf_test_cuda_kernel(lambda: permute_topK(new_permute_input, indices, num_out_tokens)) + print(f"new fwd: {t:.3f} ms") + + t = perf_test_cuda_kernel( + lambda: backward_wrapper(permute_output, permute_bwd_input, forward_input=[permute_input], retain_graph=True, accumulate_grad=False)) + print(f"pytorch bwd: {t:.3f} ms") + t = perf_test_cuda_kernel( + lambda: backward_wrapper(new_permute_output, new_permute_bwd_input, forward_input=[new_permute_input], retain_graph=True, accumulate_grad=False)) + print(f"new bwd: {t:.3f} ms") + + print(f"----unpermute topK----") + t = perf_test_cuda_kernel( + lambda: unpermute(unpermute_input, sorted_indices, probs=probs)) + print(f"pytorch fwd: {t:.3f} ms") + t = perf_test_cuda_kernel( + lambda: unpermute_topK(new_unpermute_input, row_id_map, new_probs)) + print(f"new fwd: {t:.3f} ms") + + t = perf_test_cuda_kernel( + lambda: backward_wrapper(unpermute_output, unpermute_bwd_input, forward_input=[unpermute_input, probs], retain_graph=True, accumulate_grad=False)) + print(f"pytorch bwd: {t:.3f} ms") + t = perf_test_cuda_kernel( + lambda: backward_wrapper(new_unpermute_output, new_unpermute_bwd_input, forward_input=[new_unpermute_input, new_probs], retain_graph=True, accumulate_grad=False)) + print(f"new bwd: {t:.3f} ms") + + +def perf_test_cuda_kernel(cuda_kernel_fn): + if torch.cuda.is_available(): + # create CUDA event + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # warmup + for _ in range(50): + cuda_kernel_fn() + + start_event.record() + for _ in range(100): + cuda_kernel_fn() + end_event.record() + torch.cuda.synchronize() + + elapsed_time_ms = start_event.elapsed_time(end_event) + # print(f"Elapsed Time: {elapsed_time_ms / 100} ms") + return elapsed_time_ms / 100 + else: + print("CUDA is not available.") + +def test_permute_topK(): + + torch.manual_seed(1) + + num_token = 4096 * 2 + num_expert = 8 + hidden_size = 4096 + num_topK = 1 + + num_out_tokens = num_token * num_topK - 20 + # num_out_tokens = 0 + + Benchmark = False + print("GPU:", torch.cuda.get_device_name(0)) + + dtype = torch.float32 + permute_topK_test(dtype, num_token, num_expert, + hidden_size, num_topK, num_out_tokens, + False, Benchmark) + dtype = torch.float16 + permute_topK_test(dtype, num_token, num_expert, + hidden_size, num_topK, num_out_tokens, + False, Benchmark) + dtype = torch.bfloat16 + permute_topK_test(dtype, num_token, num_expert, + hidden_size, num_topK, num_out_tokens, + False, Benchmark) + dtype = torch.float8_e5m2 + permute_topK_test(dtype, num_token, num_expert, + hidden_size, num_topK, num_out_tokens, + False, Benchmark) + dtype = torch.float8_e4m3fn + permute_topK_test(dtype, num_token, num_expert, + hidden_size, num_topK, num_out_tokens, + False, Benchmark) + dtype = torch.bfloat16 + permute_topK_test(dtype, num_token, 4, hidden_size, 1, None, False, Benchmark) + permute_topK_test(dtype, num_token, 5, hidden_size, 2, None, False, Benchmark) + permute_topK_test(dtype, num_token, 6, hidden_size, 3, None, False, Benchmark) + permute_topK_test(dtype, num_token, 7, hidden_size, 4, None, False, Benchmark) + permute_topK_test(dtype, num_token, 8, hidden_size, 5, None, False, Benchmark) + num_token = 0 + permute_topK_test(dtype, num_token, 8, hidden_size, 5, None, False, Benchmark) + +if __name__ == "__main__": + test_permute_topK() \ No newline at end of file diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h index 574333b4c5..387bb0817f 100644 --- a/transformer_engine/common/include/transformer_engine/permutation.h +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -9,7 +9,7 @@ #include "transformer_engine.h" -template +template void moe_permute_topK_kernel_launcher(const void *input, void *output, const int *sorted_row_id, diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 121bc2c219..50e579bd8d 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -236,10 +236,10 @@ __global__ void moe_permute_topK_kernel(const T *input_bwd, } } -template +template void moe_permute_topK_kernel_launcher( - const void *input, - void *output, + const void *input_, + void *output_, const int *sorted_row_id, int *row_id_map, const float *prob, @@ -249,7 +249,7 @@ void moe_permute_topK_kernel_launcher( const int num_out_tokens, cudaStream_t stream, float *prob_grad, - const void *input_fwd) + const void *input_fwd_) { // Convert to cutlass type using T_fp16 = typename cutlass::platform::conditional< @@ -271,6 +271,12 @@ void moe_permute_topK_kernel_launcher( cutlass::platform::is_same::value), cutlass::half_t, T>::type; + static constexpr int kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + + const T* input = reinterpret_cast(input_); + T* output = reinterpret_cast(output_); + const T* input_fwd = reinterpret_cast(input_fwd_); + if (FWD) { if (prob == nullptr) @@ -429,3 +435,31 @@ void moe_permute_topK_kernel_launcher( } } } + + + +#define FUNCTION_INSTANTIATION(T, FWD) \ +template void moe_permute_topK_kernel_launcher( \ + const void *input, \ + void *output, \ + const int *sorted_row_id, \ + int *row_id_map, \ + const float *prob, \ + const int num_rows, \ + const int num_topK, \ + const int num_cols, \ + const int num_out_tokens, \ + cudaStream_t stream, \ + float *prob_grad, \ + const void *input_fwd); + +FUNCTION_INSTANTIATION(float, true) +FUNCTION_INSTANTIATION(float, false) +FUNCTION_INSTANTIATION(half, true) +FUNCTION_INSTANTIATION(half, false) +FUNCTION_INSTANTIATION(__nv_bfloat16, true) +FUNCTION_INSTANTIATION(__nv_bfloat16, false) +FUNCTION_INSTANTIATION(__nv_fp8_e5m2, true) +FUNCTION_INSTANTIATION(__nv_fp8_e5m2, false) +FUNCTION_INSTANTIATION(__nv_fp8_e4m3, true) +FUNCTION_INSTANTIATION(__nv_fp8_e4m3, false) \ No newline at end of file diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 30cb450976..d93e5ff389 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -36,6 +36,7 @@ def _load_library(): from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import LayerNormMLP from transformer_engine.pytorch.module import LayerNorm +from transformer_engine.pytorch.module import permute, unpermute from transformer_engine.pytorch.module import RMSNorm from transformer_engine.pytorch.module import GroupedLinear from transformer_engine.pytorch.attention import DotProductAttention diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index 16651594dc..d0deefd8e8 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -76,7 +76,7 @@ std::tuple> moe_permute_topK_op( { case at::ScalarType::Float: { - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -92,7 +92,7 @@ std::tuple> moe_permute_topK_op( } case at::ScalarType::Half: { - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -106,10 +106,9 @@ std::tuple> moe_permute_topK_op( break; } -#ifdef ENABLE_BF16 case at::ScalarType::BFloat16: { - moe_permute_topK_kernel_launcher<__nv_bfloat16, true, 8>( + moe_permute_topK_kernel_launcher<__nv_bfloat16, true>( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -123,11 +122,9 @@ std::tuple> moe_permute_topK_op( break; } -#endif -#ifdef ENABLE_FP8 case at::ScalarType::Float8_e5m2: { - moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, true, 16>( + moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, true>( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -143,7 +140,7 @@ std::tuple> moe_permute_topK_op( } case at::ScalarType::Float8_e4m3fn: { - moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, true, 16>( + moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, true>( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -157,7 +154,6 @@ std::tuple> moe_permute_topK_op( break; } -#endif default: throw std::runtime_error("Wrong activation tensor type."); } @@ -183,7 +179,7 @@ Tensor moe_recover_topK_op( torch::empty({num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); int *row_id_map_ptr = get_ptr(row_id_map); - float *prob_ptr = (prob.defined()) ? get_ptr(prob) : nullptr; + float *prob_ptr = (prob.numel() > 0) ? get_ptr(prob) : nullptr; auto stream = at::cuda::getCurrentCUDAStream().stream(); void *input_ptr = getDataPtr(input, 0); @@ -193,7 +189,7 @@ Tensor moe_recover_topK_op( { case at::ScalarType::Float: { - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_ptr, unpermuted_output_ptr, nullptr, @@ -209,7 +205,7 @@ Tensor moe_recover_topK_op( } case at::ScalarType::Half: { - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_ptr, unpermuted_output_ptr, nullptr, @@ -223,10 +219,9 @@ Tensor moe_recover_topK_op( break; } -#ifdef ENABLE_BF16 case at::ScalarType::BFloat16: { - moe_permute_topK_kernel_launcher<__nv_bfloat16, false, 8>( + moe_permute_topK_kernel_launcher<__nv_bfloat16, false>( input_ptr, unpermuted_output_ptr, nullptr, @@ -240,11 +235,9 @@ Tensor moe_recover_topK_op( break; } -#endif -#ifdef ENABLE_FP8 case at::ScalarType::Float8_e5m2: { - moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, false, 16>( + moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, false>( input_ptr, unpermuted_output_ptr, nullptr, @@ -260,7 +253,7 @@ Tensor moe_recover_topK_op( } case at::ScalarType::Float8_e4m3fn: { - moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, false, 16>( + moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, false>( input_ptr, unpermuted_output_ptr, nullptr, @@ -274,7 +267,6 @@ Tensor moe_recover_topK_op( break; } -#endif default: throw std::runtime_error("Wrong activation tensor type."); } @@ -288,12 +280,12 @@ std::tuple moe_recover_topK_bwd_op( Tensor row_id_map, Tensor prob) { - const int num_topK = (prob.defined()) ? prob.size(1) : 1; - const int num_tokens = (prob.defined()) ? prob.size(0) : row_id_map.size(0); + const int num_topK = (prob.numel() > 0) ? prob.size(1) : 1; + const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); const int num_cols = input_bwd.size(1); int *row_id_map_ptr = get_ptr(row_id_map); - float *prob_ptr = (prob.defined()) ? get_ptr(prob) : nullptr; + float *prob_ptr = (prob.numel() > 0) ? get_ptr(prob) : nullptr; // activations type const at::ScalarType _st = input_bwd.scalar_type(); @@ -315,7 +307,7 @@ std::tuple moe_recover_topK_bwd_op( { case at::ScalarType::Float: { - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_bwd_ptr, act_grad_ptr, nullptr, @@ -333,7 +325,7 @@ std::tuple moe_recover_topK_bwd_op( } case at::ScalarType::Half: { - moe_permute_topK_kernel_launcher( + moe_permute_topK_kernel_launcher( input_bwd_ptr, act_grad_ptr, nullptr, @@ -349,10 +341,9 @@ std::tuple moe_recover_topK_bwd_op( break; } -#ifdef ENABLE_BF16 case at::ScalarType::BFloat16: { - moe_permute_topK_kernel_launcher<__nv_bfloat16, true, 8>( + moe_permute_topK_kernel_launcher<__nv_bfloat16, true>( input_bwd_ptr, act_grad_ptr, nullptr, @@ -368,11 +359,9 @@ std::tuple moe_recover_topK_bwd_op( break; } -#endif -#ifdef ENABLE_FP8 case at::ScalarType::Float8_e5m2: { - moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, true, 16>( + moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, true>( input_bwd_ptr, act_grad_ptr, nullptr, @@ -390,7 +379,7 @@ std::tuple moe_recover_topK_bwd_op( } case at::ScalarType::Float8_e4m3fn: { - moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, true, 16>( + moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, true>( input_bwd_ptr, act_grad_ptr, nullptr, @@ -406,7 +395,6 @@ std::tuple moe_recover_topK_bwd_op( break; } -#endif default: throw std::runtime_error("Wrong activation tensor type."); } diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 6994f586b1..74af24348a 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -8,5 +8,6 @@ from .grouped_linear import GroupedLinear from .layernorm_mlp import LayerNormMLP from .layernorm import LayerNorm +from .permutation import permute, unpermute from .rmsnorm import RMSNorm from .base import initialize_ub, destroy_ub diff --git a/transformer_engine/pytorch/module/permutation.py b/transformer_engine/pytorch/module/permutation.py index e11c084e20..ad69883560 100644 --- a/transformer_engine/pytorch/module/permutation.py +++ b/transformer_engine/pytorch/module/permutation.py @@ -7,6 +7,8 @@ import torch import warnings +import transformer_engine_torch as tex + # TODO by Jiang Shao, add parameter `out` which can be optionally given to be used as output buffers. ################################################################################################ @@ -68,7 +70,7 @@ def forward(ctx, PermuteMoE_topK.dtype = input_act.dtype PermuteMoE_topK.workspace_fw = [] - permuted_act, row_id_map, PermuteMoE_topK.workspace_fw = torch.ops.moe_unit_ops.moe_permute_topK_op( + permuted_act, row_id_map, PermuteMoE_topK.workspace_fw = tex.moe_permute_topK_op( input_act, indices, num_out_tokens, @@ -94,10 +96,10 @@ def backward(ctx, permuted_act_grad, _): num_tokens = ctx.num_tokens num_topK = ctx.num_topK - unpermuted_act_grad = torch.ops.moe_unit_ops.moe_recover_topK_op( + unpermuted_act_grad = tex.moe_recover_topK_op( permuted_act_grad, row_id_map, - None, + torch.empty(0), num_tokens, num_topK) return unpermuted_act_grad, None, None, None @@ -121,7 +123,7 @@ def forward(ctx, return input_act # None probs check - if probs is not None: + if probs.numel(): if probs.is_cpu: warnings.warn("[Warning] The input `probs` of unpermute_topK op is on the device: CPU!") probs = probs.cuda() @@ -154,10 +156,10 @@ def forward(ctx, warnings.warn("[Warning] The input `row_id_map` of unpermute_topK op is discontiguous!") row_id_map = row_id_map.contiguous() - num_topK = probs.size(1) if probs is not None else 1 - num_tokens = probs.size(0) if probs is not None else row_id_map.size(0) + num_topK = probs.size(1) if probs.numel() else 1 + num_tokens = probs.size(0) if probs.numel() else row_id_map.size(0) - unpermuted_output = torch.ops.moe_unit_ops.moe_recover_topK_op( + unpermuted_output = tex.moe_recover_topK_op( input_act, row_id_map, probs, @@ -180,7 +182,7 @@ def backward(ctx, unpermuted_act_grad): act_grad = None if ctx.needs_input_grad[0]: - act_grad, prob_grad = torch.ops.moe_unit_ops.moe_recover_topK_bwd_op( + act_grad, prob_grad = tex.moe_recover_topK_bwd_op( unpermuted_act_grad, input_act, row_id_map, From fb483f8b1c4133f50544d329f0e74b9379f375e2 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Mon, 24 Jun 2024 11:44:33 +0000 Subject: [PATCH 06/33] Replace get_ptr with getDataPtr Signed-off-by: Jiang Shao --- transformer_engine/pytorch/csrc/common.h | 6 ----- .../pytorch/csrc/extensions/permutation.cu | 22 +++++++++---------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index e3e9957357..7fb9953f94 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -167,10 +167,4 @@ at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype); void* getDataPtr(at::Tensor tensor, int offset = 0); -template -inline T *get_ptr(torch::Tensor &t) -{ - return reinterpret_cast(t.data_ptr()); -} - #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index d0deefd8e8..752ea76a62 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -44,12 +44,12 @@ std::tuple> moe_permute_topK_op( workspace.push_back(temp_storage); } - int *indices_ptr = get_ptr(indices); - int *sorted_indices_ptr = get_ptr(workspace[0]); - int *row_id_ptr = get_ptr(workspace[1]); - int *sorted_row_id_ptr = get_ptr(workspace[2]); + int *indices_ptr = reinterpret_cast(getDataPtr(indices, 0)); + int *sorted_indices_ptr = reinterpret_cast(getDataPtr(workspace[0], 0)); + int *row_id_ptr = reinterpret_cast(getDataPtr(workspace[1], 0)); + int *sorted_row_id_ptr = reinterpret_cast(getDataPtr(workspace[2], 0)); - void *d_temp_storage = get_ptr(workspace[3]); + void *d_temp_storage = getDataPtr(workspace[3], 0); size_t temp_storage_bytes = std::numeric_limits::max(); cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, @@ -66,7 +66,7 @@ std::tuple> moe_permute_topK_op( Tensor row_id_map = torch::empty({num_tokens * num_topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); - int *row_id_map_ptr = get_ptr(row_id_map); + int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); auto stream = at::cuda::getCurrentCUDAStream().stream(); void *input_ptr = getDataPtr(input, 0); @@ -178,8 +178,8 @@ Tensor moe_recover_topK_op( Tensor unpermuted_output = torch::empty({num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); - int *row_id_map_ptr = get_ptr(row_id_map); - float *prob_ptr = (prob.numel() > 0) ? get_ptr(prob) : nullptr; + int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); + float *prob_ptr = (prob.numel() > 0) ? reinterpret_cast(getDataPtr(prob, 0)) : nullptr; auto stream = at::cuda::getCurrentCUDAStream().stream(); void *input_ptr = getDataPtr(input, 0); @@ -284,8 +284,8 @@ std::tuple moe_recover_topK_bwd_op( const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); const int num_cols = input_bwd.size(1); - int *row_id_map_ptr = get_ptr(row_id_map); - float *prob_ptr = (prob.numel() > 0) ? get_ptr(prob) : nullptr; + int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); + float *prob_ptr = (prob.numel() > 0) ? reinterpret_cast(getDataPtr(prob, 0)) : nullptr; // activations type const at::ScalarType _st = input_bwd.scalar_type(); @@ -295,7 +295,7 @@ std::tuple moe_recover_topK_bwd_op( torch::empty({input_fwd.size(0), num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); Tensor prob_grad = torch::empty({num_tokens, num_topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); - float *prob_grad_ptr = get_ptr(prob_grad); + float *prob_grad_ptr = reinterpret_cast(getDataPtr(prob_grad, 0)); auto stream = at::cuda::getCurrentCUDAStream().stream(); From c8f7991d377d874c16eb1650db148b089f13de3d Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Mon, 24 Jun 2024 13:53:24 +0000 Subject: [PATCH 07/33] Rename some functions and kernels Signed-off-by: Jiang Shao --- .../include/transformer_engine/permutation.h | 24 ++++---- .../common/permutation/permutation.cu | 58 +++++++++---------- transformer_engine/pytorch/csrc/extensions.h | 6 +- .../pytorch/csrc/extensions/permutation.cu | 36 ++++++------ .../pytorch/csrc/extensions/pybind.cpp | 6 +- .../pytorch/module/permutation.py | 25 ++------ 6 files changed, 69 insertions(+), 86 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h index 387bb0817f..ddd10f84db 100644 --- a/transformer_engine/common/include/transformer_engine/permutation.h +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -10,17 +10,17 @@ #include "transformer_engine.h" template -void moe_permute_topK_kernel_launcher(const void *input, - void *output, - const int *sorted_row_id, - int *row_id_map, - const float *prob, - const int num_rows, - const int num_topK, - const int num_cols, - const int num_out_tokens, - cudaStream_t stream, - float *prob_grad = nullptr, - const void *input_fwd = nullptr); +void moe_permutation_launcher(const void *input, + void *output, + const int *sorted_row_id, + int *row_id_map, + const float *prob, + const int num_rows, + const int num_topK, + const int num_cols, + const int num_out_tokens, + cudaStream_t stream, + float *prob_grad = nullptr, + const void *input_fwd = nullptr); #endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 50e579bd8d..01a5e5765b 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -11,7 +11,7 @@ #include -static __global__ void moe_permute_topK_row_map( +static __global__ void moe_permute_row_map( const int *sorted_row_id, int *row_id_map, const int num_rows, @@ -42,13 +42,13 @@ static __global__ void moe_permute_topK_row_map( } template -__global__ void moe_recover_topK_kernel(const T *input, - T *unpermuted_output, - const int *row_id_map, - const float *prob, - const int num_rows, - const int num_topK, - const int num_cols) +__global__ void moe_unpermute_kernel(const T *input, + T *unpermuted_output, + const int *row_id_map, + const float *prob, + const int num_rows, + const int num_topK, + const int num_cols) { extern __shared__ int8_t s_mem[]; TCompute *s_prob = reinterpret_cast(s_mem); @@ -133,15 +133,15 @@ template -__global__ void moe_permute_topK_kernel(const T *input_bwd, - const T *input_fwd, - T *act_grad, - const float *prob, - float *prob_grad, - const int *row_id_map, - const int num_rows, - const int num_topK, - const int num_cols) +__global__ void moe_permute_kernel(const T *input_bwd, + const T *input_fwd, + T *act_grad, + const float *prob, + float *prob_grad, + const int *row_id_map, + const int num_rows, + const int num_topK, + const int num_cols) { extern __shared__ int8_t s_mem[]; TCompute *s_prob = reinterpret_cast(s_mem); @@ -237,7 +237,7 @@ __global__ void moe_permute_topK_kernel(const T *input_bwd, } template -void moe_permute_topK_kernel_launcher( +void moe_permutation_launcher( const void *input_, void *output_, const int *sorted_row_id, @@ -286,7 +286,7 @@ void moe_permute_topK_kernel_launcher( // permute_topK fwd int threads = 64; int blocks = (num_rows * num_topK + threads - 1) / threads; - moe_permute_topK_row_map<<>>( + moe_permute_row_map<<>>( sorted_row_id, row_id_map, num_rows, @@ -295,7 +295,7 @@ void moe_permute_topK_kernel_launcher( blocks = num_rows; threads = std::min(num_cols / kElementsPerAccess, 1024); - moe_permute_topK_kernel<<>>( + moe_permute_kernel<<>>( input, nullptr, output, @@ -312,7 +312,7 @@ void moe_permute_topK_kernel_launcher( int blocks = num_rows; int threads = 32; - moe_permute_topK_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -333,7 +333,7 @@ void moe_permute_topK_kernel_launcher( if (num_topK <= 8) { - moe_permute_topK_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -346,7 +346,7 @@ void moe_permute_topK_kernel_launcher( } else if (num_topK <= 16) { - moe_permute_topK_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -359,7 +359,7 @@ void moe_permute_topK_kernel_launcher( } else if (num_topK <= 32) { - moe_permute_topK_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -372,7 +372,7 @@ void moe_permute_topK_kernel_launcher( } else if (num_topK <= 64) { - moe_permute_topK_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -385,7 +385,7 @@ void moe_permute_topK_kernel_launcher( } else if (num_topK <= 128) { - moe_permute_topK_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -412,7 +412,7 @@ void moe_permute_topK_kernel_launcher( { // permute_topK bwd // unpermute_topK fwd without probs - moe_recover_topK_kernel<<>>( + moe_unpermute_kernel<<>>( input, output, row_id_map, @@ -424,7 +424,7 @@ void moe_permute_topK_kernel_launcher( else { // unpermute_topK fwd with probs - moe_recover_topK_kernel<<>>( + moe_unpermute_kernel<<>>( input, output, row_id_map, @@ -439,7 +439,7 @@ void moe_permute_topK_kernel_launcher( #define FUNCTION_INSTANTIATION(T, FWD) \ -template void moe_permute_topK_kernel_launcher( \ +template void moe_permutation_launcher( \ const void *input, \ void *output, \ const int *sorted_row_id, \ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9a670e16ab..5257971ebb 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -15,21 +15,21 @@ * permute **************************************************************************************************/ -std::tuple> moe_permute_topK_op( +std::tuple> moe_permute( at::Tensor input, at::Tensor indices, int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num); -at::Tensor moe_recover_topK_op( +at::Tensor moe_unpermute_fwd( at::Tensor input, at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, int64_t num_topK); -std::tuple moe_recover_topK_bwd_op( +std::tuple moe_unpermute_bwd( at::Tensor input_bwd, at::Tensor input_fwd, at::Tensor row_id_map, diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index 752ea76a62..cf43c145d2 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -10,7 +10,7 @@ using torch::Tensor; -std::tuple> moe_permute_topK_op( +std::tuple> moe_permute( Tensor input, Tensor indices, int64_t num_out_tokens, @@ -76,7 +76,7 @@ std::tuple> moe_permute_topK_op( { case at::ScalarType::Float: { - moe_permute_topK_kernel_launcher( + moe_permutation_launcher( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -92,7 +92,7 @@ std::tuple> moe_permute_topK_op( } case at::ScalarType::Half: { - moe_permute_topK_kernel_launcher( + moe_permutation_launcher( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -108,7 +108,7 @@ std::tuple> moe_permute_topK_op( } case at::ScalarType::BFloat16: { - moe_permute_topK_kernel_launcher<__nv_bfloat16, true>( + moe_permutation_launcher<__nv_bfloat16, true>( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -124,7 +124,7 @@ std::tuple> moe_permute_topK_op( } case at::ScalarType::Float8_e5m2: { - moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, true>( + moe_permutation_launcher<__nv_fp8_e5m2, true>( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -140,7 +140,7 @@ std::tuple> moe_permute_topK_op( } case at::ScalarType::Float8_e4m3fn: { - moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, true>( + moe_permutation_launcher<__nv_fp8_e4m3, true>( input_ptr, permuted_output_ptr, sorted_row_id_ptr, @@ -162,7 +162,7 @@ std::tuple> moe_permute_topK_op( } -Tensor moe_recover_topK_op( +Tensor moe_unpermute_fwd( Tensor input, Tensor row_id_map, Tensor prob, @@ -189,7 +189,7 @@ Tensor moe_recover_topK_op( { case at::ScalarType::Float: { - moe_permute_topK_kernel_launcher( + moe_permutation_launcher( input_ptr, unpermuted_output_ptr, nullptr, @@ -205,7 +205,7 @@ Tensor moe_recover_topK_op( } case at::ScalarType::Half: { - moe_permute_topK_kernel_launcher( + moe_permutation_launcher( input_ptr, unpermuted_output_ptr, nullptr, @@ -221,7 +221,7 @@ Tensor moe_recover_topK_op( } case at::ScalarType::BFloat16: { - moe_permute_topK_kernel_launcher<__nv_bfloat16, false>( + moe_permutation_launcher<__nv_bfloat16, false>( input_ptr, unpermuted_output_ptr, nullptr, @@ -237,7 +237,7 @@ Tensor moe_recover_topK_op( } case at::ScalarType::Float8_e5m2: { - moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, false>( + moe_permutation_launcher<__nv_fp8_e5m2, false>( input_ptr, unpermuted_output_ptr, nullptr, @@ -253,7 +253,7 @@ Tensor moe_recover_topK_op( } case at::ScalarType::Float8_e4m3fn: { - moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, false>( + moe_permutation_launcher<__nv_fp8_e4m3, false>( input_ptr, unpermuted_output_ptr, nullptr, @@ -274,7 +274,7 @@ Tensor moe_recover_topK_op( return unpermuted_output; } -std::tuple moe_recover_topK_bwd_op( +std::tuple moe_unpermute_bwd( Tensor input_bwd, Tensor input_fwd, Tensor row_id_map, @@ -307,7 +307,7 @@ std::tuple moe_recover_topK_bwd_op( { case at::ScalarType::Float: { - moe_permute_topK_kernel_launcher( + moe_permutation_launcher( input_bwd_ptr, act_grad_ptr, nullptr, @@ -325,7 +325,7 @@ std::tuple moe_recover_topK_bwd_op( } case at::ScalarType::Half: { - moe_permute_topK_kernel_launcher( + moe_permutation_launcher( input_bwd_ptr, act_grad_ptr, nullptr, @@ -343,7 +343,7 @@ std::tuple moe_recover_topK_bwd_op( } case at::ScalarType::BFloat16: { - moe_permute_topK_kernel_launcher<__nv_bfloat16, true>( + moe_permutation_launcher<__nv_bfloat16, true>( input_bwd_ptr, act_grad_ptr, nullptr, @@ -361,7 +361,7 @@ std::tuple moe_recover_topK_bwd_op( } case at::ScalarType::Float8_e5m2: { - moe_permute_topK_kernel_launcher<__nv_fp8_e5m2, true>( + moe_permutation_launcher<__nv_fp8_e5m2, true>( input_bwd_ptr, act_grad_ptr, nullptr, @@ -379,7 +379,7 @@ std::tuple moe_recover_topK_bwd_op( } case at::ScalarType::Float8_e4m3fn: { - moe_permute_topK_kernel_launcher<__nv_fp8_e4m3, true>( + moe_permutation_launcher<__nv_fp8_e4m3, true>( input_bwd_ptr, act_grad_ptr, nullptr, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d85a4cb5a2..aae323c5f6 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -11,9 +11,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Permutation functions - m.def("moe_permute_topK_op", moe_permute_topK_op); - m.def("moe_recover_topK_op", moe_recover_topK_op); - m.def("moe_recover_topK_bwd_op", moe_recover_topK_bwd_op); + m.def("moe_permute", moe_permute); + m.def("moe_unpermute_fwd", moe_unpermute_fwd); + m.def("moe_unpermute_bwd", moe_unpermute_bwd); // Softmax functions m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD", diff --git a/transformer_engine/pytorch/module/permutation.py b/transformer_engine/pytorch/module/permutation.py index ad69883560..ffff097222 100644 --- a/transformer_engine/pytorch/module/permutation.py +++ b/transformer_engine/pytorch/module/permutation.py @@ -9,13 +9,6 @@ import transformer_engine_torch as tex -# TODO by Jiang Shao, add parameter `out` which can be optionally given to be used as output buffers. - -################################################################################################ -## -## PermuteMoE topK -## -################################################################################################ class PermuteMoE_topK(torch.autograd.Function): @@ -70,7 +63,7 @@ def forward(ctx, PermuteMoE_topK.dtype = input_act.dtype PermuteMoE_topK.workspace_fw = [] - permuted_act, row_id_map, PermuteMoE_topK.workspace_fw = tex.moe_permute_topK_op( + permuted_act, row_id_map, PermuteMoE_topK.workspace_fw = tex.moe_permute( input_act, indices, num_out_tokens, @@ -96,7 +89,7 @@ def backward(ctx, permuted_act_grad, _): num_tokens = ctx.num_tokens num_topK = ctx.num_topK - unpermuted_act_grad = tex.moe_recover_topK_op( + unpermuted_act_grad = tex.moe_unpermute_fwd( permuted_act_grad, row_id_map, torch.empty(0), @@ -104,11 +97,6 @@ def backward(ctx, permuted_act_grad, _): num_topK) return unpermuted_act_grad, None, None, None -################################################################################################ -## -## UnpermuteMoE topK -## -################################################################################################ class UnpermuteMoE_topK(torch.autograd.Function): @@ -159,7 +147,7 @@ def forward(ctx, num_topK = probs.size(1) if probs.numel() else 1 num_tokens = probs.size(0) if probs.numel() else row_id_map.size(0) - unpermuted_output = tex.moe_recover_topK_op( + unpermuted_output = tex.moe_unpermute_fwd( input_act, row_id_map, probs, @@ -182,7 +170,7 @@ def backward(ctx, unpermuted_act_grad): act_grad = None if ctx.needs_input_grad[0]: - act_grad, prob_grad = tex.moe_recover_topK_bwd_op( + act_grad, prob_grad = tex.moe_unpermute_bwd( unpermuted_act_grad, input_act, row_id_map, @@ -192,11 +180,6 @@ def backward(ctx, unpermuted_act_grad): prob_grad = None return act_grad, None, prob_grad -################################################################################################ -## -## Ops Wrapper -## -################################################################################################ def permute(input_act, indices, num_out_tokens=-1, max_token_num=-1): return PermuteMoE_topK.apply(input_act, indices, num_out_tokens, max_token_num) From 6b6eb39fc6f63501f52cbe35922d985f75d96aab Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Mon, 24 Jun 2024 15:02:10 +0000 Subject: [PATCH 08/33] Refactor to fit the TE style. Part I Signed-off-by: Jiang Shao --- tests/pytorch/test_permutation.py | 2 +- transformer_engine/pytorch/__init__.py | 2 +- transformer_engine/pytorch/module/__init__.py | 2 +- .../pytorch/module/permutation.py | 86 ++++++++++--------- 4 files changed, 49 insertions(+), 43 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index e6367d733b..6a00fb5736 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -6,7 +6,7 @@ import triton import torch.cuda.nvtx as nvtx -from transformer_engine.pytorch import permute as permute_topK, unpermute as unpermute_topK +from transformer_engine.pytorch import Permute as permute_topK, Unpermute as unpermute_topK def permute(tokens, indices, num_out_tokens: int = 0): """Permute the tokens based on the indices. Token with the same index will be grouped together. diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index d93e5ff389..5d43ddbcd5 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -36,7 +36,7 @@ def _load_library(): from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import LayerNormMLP from transformer_engine.pytorch.module import LayerNorm -from transformer_engine.pytorch.module import permute, unpermute +from transformer_engine.pytorch.module import Permute, Unpermute from transformer_engine.pytorch.module import RMSNorm from transformer_engine.pytorch.module import GroupedLinear from transformer_engine.pytorch.attention import DotProductAttention diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 74af24348a..253ff48f66 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -8,6 +8,6 @@ from .grouped_linear import GroupedLinear from .layernorm_mlp import LayerNormMLP from .layernorm import LayerNorm -from .permutation import permute, unpermute +from .permutation import Permute, Unpermute from .rmsnorm import RMSNorm from .base import initialize_ub, destroy_ub diff --git a/transformer_engine/pytorch/module/permutation.py b/transformer_engine/pytorch/module/permutation.py index ffff097222..6c7fe5dab3 100644 --- a/transformer_engine/pytorch/module/permutation.py +++ b/transformer_engine/pytorch/module/permutation.py @@ -10,33 +10,39 @@ import transformer_engine_torch as tex -class PermuteMoE_topK(torch.autograd.Function): +__all__ = [ + 'Permute', + 'Unpermute', +] - workspace_fw=None + +class _Permute(torch.autograd.Function): + + workspace=None dtype=None max_expanded_token_num=0 @staticmethod def forward(ctx, - input_act: torch.Tensor, + inp: torch.Tensor, indices: torch.Tensor, num_out_tokens: int, max_token_num: int): # Empty input check - if not input_act.numel(): - return input_act, None + if not inp.numel(): + return inp, None # Device check - if input_act.is_cpu: - raise RuntimeError("[Error] The input `input_act` of permute_topK op is on the device: CPU!") + if inp.is_cpu: + raise RuntimeError("[Error] The input `inp` of permute_topK op is on the device: CPU!") if indices.is_cpu: warnings.warn("[Warning] The input `indices` of permute_topK op is on the device: CPU!") expert_for_rows = expert_for_rows.cuda() # Shape check - if input_act.size(0) != indices.size(0): + if inp.size(0) != indices.size(0): raise RuntimeError(f"[Error] permute_topK op input `indices` shape mismatch! " - f"Expect {input_act.size(0)}, but got {indices.size(0)}.") + f"Expect {inp.size(0)}, but got {indices.size(0)}.") # Data type check if indices.dtype != torch.int32: @@ -45,30 +51,30 @@ def forward(ctx, indices = indices.to(torch.int32) # Contiguous check - if not input_act.is_contiguous(): - warnings.warn("[Warning] The input `input_act` of permute_topK op is discontiguous!") - input_act = input_act.contiguous() + if not inp.is_contiguous(): + warnings.warn("[Warning] The input `inp` of permute_topK op is discontiguous!") + inp = inp.contiguous() if not indices.is_contiguous(): warnings.warn("[Warning] The input `indices` of permute_topK op is discontiguous!") indices = indices.contiguous() num_topK = indices.size(1) - input_max_expanded_token_num = max(max_token_num, input_act.size(0)) * num_topK - if PermuteMoE_topK.max_expanded_token_num < input_max_expanded_token_num: - PermuteMoE_topK.max_expanded_token_num = input_max_expanded_token_num - PermuteMoE_topK.workspace_fw = [] + input_max_expanded_token_num = max(max_token_num, inp.size(0)) * num_topK + if _Permute.max_expanded_token_num < input_max_expanded_token_num: + _Permute.max_expanded_token_num = input_max_expanded_token_num + _Permute.workspace = [] - if PermuteMoE_topK.dtype != input_act.dtype: - PermuteMoE_topK.dtype = input_act.dtype - PermuteMoE_topK.workspace_fw = [] + if _Permute.dtype != inp.dtype: + _Permute.dtype = inp.dtype + _Permute.workspace = [] - permuted_act, row_id_map, PermuteMoE_topK.workspace_fw = tex.moe_permute( - input_act, + permuted_act, row_id_map, _Permute.workspace = tex.moe_permute( + inp, indices, num_out_tokens, - PermuteMoE_topK.workspace_fw, - PermuteMoE_topK.max_expanded_token_num) + _Permute.workspace, + _Permute.max_expanded_token_num) ctx.row_id_map = row_id_map ctx.num_tokens = indices.size(0) @@ -98,17 +104,17 @@ def backward(ctx, permuted_act_grad, _): return unpermuted_act_grad, None, None, None -class UnpermuteMoE_topK(torch.autograd.Function): +class _Unpermute(torch.autograd.Function): @staticmethod def forward(ctx, - input_act: torch.Tensor, + inp: torch.Tensor, row_id_map: torch.Tensor, probs: torch.Tensor): # Empty input check - if not input_act.numel(): + if not inp.numel(): ctx.probs = probs - return input_act + return inp # None probs check if probs.numel(): @@ -124,8 +130,8 @@ def forward(ctx, probs = probs.contiguous() # Device check - if input_act.is_cpu: - raise RuntimeError("[Error] The input `input_act` of unpermute_topK op is on the device: CPU!") + if inp.is_cpu: + raise RuntimeError("[Error] The input `inp` of unpermute_topK op is on the device: CPU!") if row_id_map.is_cpu: warnings.warn("[Warning] The input `row_id_map` of unpermute_topK op is on the device: CPU!") row_id_map = row_id_map.cuda() @@ -137,9 +143,9 @@ def forward(ctx, row_id_map = row_id_map.to(torch.int32) # Contiguous check - if not input_act.is_contiguous(): - warnings.warn("[Warning] The input `input_act` of unpermute_topK op is discontiguous!") - input_act = input_act.contiguous() + if not inp.is_contiguous(): + warnings.warn("[Warning] The input `inp` of unpermute_topK op is discontiguous!") + inp = inp.contiguous() if not row_id_map.is_contiguous(): warnings.warn("[Warning] The input `row_id_map` of unpermute_topK op is discontiguous!") row_id_map = row_id_map.contiguous() @@ -148,13 +154,13 @@ def forward(ctx, num_tokens = probs.size(0) if probs.numel() else row_id_map.size(0) unpermuted_output = tex.moe_unpermute_fwd( - input_act, + inp, row_id_map, probs, num_tokens, num_topK) - ctx.save_for_backward(input_act, row_id_map, probs) + ctx.save_for_backward(inp, row_id_map, probs) return unpermuted_output @staticmethod @@ -166,13 +172,13 @@ def backward(ctx, unpermuted_act_grad): if not unpermuted_act_grad.is_contiguous(): unpermuted_act_grad = unpermuted_act_grad.contiguous() - input_act, row_id_map, probs = ctx.saved_tensors + inp, row_id_map, probs = ctx.saved_tensors act_grad = None if ctx.needs_input_grad[0]: act_grad, prob_grad = tex.moe_unpermute_bwd( unpermuted_act_grad, - input_act, + inp, row_id_map, probs) @@ -181,8 +187,8 @@ def backward(ctx, unpermuted_act_grad): return act_grad, None, prob_grad -def permute(input_act, indices, num_out_tokens=-1, max_token_num=-1): - return PermuteMoE_topK.apply(input_act, indices, num_out_tokens, max_token_num) +def Permute(inp, indices, num_out_tokens=-1, max_token_num=-1): + return _Permute.apply(inp, indices, num_out_tokens, max_token_num) -def unpermute(input_act, row_id_map, probs): - return UnpermuteMoE_topK.apply(input_act, row_id_map, probs) +def Unpermute(inp, row_id_map, probs): + return _Unpermute.apply(inp, row_id_map, probs) From 7135612e5ae7a7b360f8cabfda88ef8f5a4d75ff Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Tue, 25 Jun 2024 08:18:05 +0000 Subject: [PATCH 09/33] Remove the dependency on cutlass Signed-off-by: Jiang Shao --- .gitmodules | 3 - 3rdparty/cutlass | 1 - .../include/transformer_engine/permutation.h | 2 +- .../common/permutation/permutation.cu | 160 ++++++++---------- 4 files changed, 76 insertions(+), 90 deletions(-) delete mode 160000 3rdparty/cutlass diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..21492db5ef 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,6 +4,3 @@ [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend url = https://github.com/NVIDIA/cudnn-frontend.git -[submodule "3rdparty/cutlass"] - path = 3rdparty/cutlass - url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cutlass b/3rdparty/cutlass deleted file mode 160000 index 7d49e6c7e2..0000000000 --- a/3rdparty/cutlass +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h index ddd10f84db..c1c33900e6 100644 --- a/transformer_engine/common/include/transformer_engine/permutation.h +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -9,7 +9,7 @@ #include "transformer_engine.h" -template +template void moe_permutation_launcher(const void *input, void *output, const int *sorted_row_id, diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 01a5e5765b..e08025f3d3 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -4,13 +4,10 @@ * See LICENSE for license information. ************************************************************************/ -#include "cutlass/arch/memory.h" -#include "cutlass/arch/cache_operation.h" -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" - #include +#include "../common.h" + static __global__ void moe_permute_row_map( const int *sorted_row_id, int *row_id_map, @@ -41,7 +38,9 @@ static __global__ void moe_permute_row_map( } } -template +template __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const int *row_id_map, @@ -53,12 +52,6 @@ __global__ void moe_unpermute_kernel(const T *input, extern __shared__ int8_t s_mem[]; TCompute *s_prob = reinterpret_cast(s_mem); - using FragmentLoadStore = cutlass::Array; - using FragmentCompute = cutlass::Array; - - cutlass::NumericArrayConverter src_converter; - cutlass::NumericArrayConverter dst_converter; - // each block corresponds to one source token const int source_token = blockIdx.x; const int tid = threadIdx.x; @@ -72,11 +65,16 @@ __global__ void moe_unpermute_kernel(const T *input, __syncthreads(); } + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + + static constexpr int kElementsPerAccess = 16 / sizeof(T); + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { - FragmentLoadStore frag_load_store; - FragmentCompute frag_elem; - FragmentCompute frag_sum; + + TCompute frag_elem[kElementsPerAccess]; + TCompute frag_sum[kElementsPerAccess]; int source_row = row_id_map[source_token]; @@ -84,18 +82,20 @@ __global__ void moe_unpermute_kernel(const T *input, { const T *source_row_ptr = input + source_row * num_cols; - cutlass::arch::global_load( - frag_load_store, (source_row_ptr + i), true); - frag_sum = src_converter(frag_load_store); + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); - if (hasProb) - { - frag_sum = frag_sum * s_prob[0]; + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(frag_load_store_ptr[e]); } + + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] * s_prob[0]; } } } else { - frag_sum.clear(); + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(0.0f); } } for (int k = 1; k < num_topK; k++) @@ -107,30 +107,31 @@ __global__ void moe_unpermute_kernel(const T *input, const T *source_row_ptr = input + source_row * num_cols; - cutlass::arch::global_load( - frag_load_store, (source_row_ptr + i), true); - frag_elem = src_converter(frag_load_store); + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); - if (hasProb) - { - frag_elem = frag_elem * s_prob[k]; + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = TCompute(frag_load_store_ptr[e]); } + + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = frag_elem[e] * s_prob[k]; } } - for (int e = 0; e < kElementsPerAccess; e++) - { - frag_sum.at(e) = frag_sum.at(e) + frag_elem.at(e); - } + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] + frag_elem[e]; } } T *dest_row_ptr = unpermuted_output + source_token * num_cols; - frag_load_store = dst_converter(frag_sum); - *(float4 *)(dest_row_ptr + i) = *(float4 *)(frag_load_store.data()); + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_load_store_ptr[e] = T(frag_sum[e]); } + + *(float4 *)(dest_row_ptr + i) = frag_load_store; } } template __global__ void moe_permute_kernel(const T *input_bwd, @@ -146,12 +147,6 @@ __global__ void moe_permute_kernel(const T *input_bwd, extern __shared__ int8_t s_mem[]; TCompute *s_prob = reinterpret_cast(s_mem); - using FragmentLoadStore = cutlass::Array; - using FragmentCompute = cutlass::Array; - - cutlass::NumericArrayConverter src_converter; - cutlass::NumericArrayConverter dst_converter; - const int source_token = blockIdx.x; const int tid = threadIdx.x; @@ -165,14 +160,21 @@ __global__ void moe_permute_kernel(const T *input_bwd, } float accum[topKTile] = {0.0f}; - FragmentLoadStore frag_load_store; + + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + + static constexpr int kElementsPerAccess = 16 / sizeof(T); const T *source_row_ptr = input_bwd + source_token * num_cols; for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { - cutlass::arch::global_load( - frag_load_store, (source_row_ptr + i), true); - FragmentCompute frag_src = src_converter(frag_load_store); + TCompute frag_src[kElementsPerAccess]; + + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) + frag_src[e] = TCompute(frag_load_store_ptr[e]); int index = source_token; @@ -187,27 +189,30 @@ __global__ void moe_permute_kernel(const T *input_bwd, { if (hasProb) { - frag_load_store = dst_converter(frag_src * s_prob[k]); + for (int e = 0; e < kElementsPerAccess; e++) + frag_load_store_ptr[e] = T(frag_src[e] * s_prob[k]); } else { - frag_load_store = dst_converter(frag_src); + for (int e = 0; e < kElementsPerAccess; e++) + frag_load_store_ptr[e] = T(frag_src[e]); } T *dest_row_ptr = act_grad + dest_row * num_cols; - *(float4 *)(dest_row_ptr + i) = *(float4 *)(frag_load_store.data()); + *(float4 *)(dest_row_ptr + i) = frag_load_store; if (hasProb) { const T *input_fwd_ptr = input_fwd + dest_row * num_cols; - cutlass::arch::global_load( - frag_load_store, (input_fwd_ptr + i), true); - FragmentCompute frag_input_fwd = src_converter(frag_load_store); + frag_load_store = __ldlu(reinterpret_cast(input_fwd_ptr + i)); + + TCompute frag_input_fwd[kElementsPerAccess]; for (int e = 0; e < kElementsPerAccess; e++) - { - accum[k] += float(frag_src.at(e) * frag_input_fwd.at(e)); - } + frag_input_fwd[e] = TCompute(frag_load_store_ptr[e]); + + for (int e = 0; e < kElementsPerAccess; e++) { + accum[k] += float(frag_src[e] * frag_input_fwd[e]); } } } } @@ -236,7 +241,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, } } -template +template void moe_permutation_launcher( const void *input_, void *output_, @@ -251,27 +256,12 @@ void moe_permutation_launcher( float *prob_grad, const void *input_fwd_) { - // Convert to cutlass type - using T_fp16 = typename cutlass::platform::conditional< - cutlass::platform::is_same::value, - cutlass::half_t, TInput>::type; - using T_bf16 = typename cutlass::platform::conditional< - cutlass::platform::is_same::value, - cutlass::bfloat16_t, T_fp16>::type; - using T_fp8e5m2 = typename cutlass::platform::conditional< - cutlass::platform::is_same::value, - cutlass::float_e5m2_t, T_bf16>::type; - using T_fp8e4m3 = typename cutlass::platform::conditional< - cutlass::platform::is_same::value, - cutlass::float_e4m3_t, T_fp8e5m2>::type; - using T = T_fp8e4m3; - - using TCompute = typename cutlass::platform::conditional< - (cutlass::platform::is_same::value || - cutlass::platform::is_same::value), - cutlass::half_t, T>::type; - - static constexpr int kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using TCompute = typename std::conditional< + (std::is_same::value || + std::is_same::value), + half, T>::type; + + static constexpr int kElementsPerAccess = 16 / sizeof(T); const T* input = reinterpret_cast(input_); T* output = reinterpret_cast(output_); @@ -295,7 +285,7 @@ void moe_permutation_launcher( blocks = num_rows; threads = std::min(num_cols / kElementsPerAccess, 1024); - moe_permute_kernel<<>>( + moe_permute_kernel<<>>( input, nullptr, output, @@ -312,7 +302,7 @@ void moe_permutation_launcher( int blocks = num_rows; int threads = 32; - moe_permute_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -333,7 +323,7 @@ void moe_permutation_launcher( if (num_topK <= 8) { - moe_permute_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -346,7 +336,7 @@ void moe_permutation_launcher( } else if (num_topK <= 16) { - moe_permute_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -359,7 +349,7 @@ void moe_permutation_launcher( } else if (num_topK <= 32) { - moe_permute_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -372,7 +362,7 @@ void moe_permutation_launcher( } else if (num_topK <= 64) { - moe_permute_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -385,7 +375,7 @@ void moe_permutation_launcher( } else if (num_topK <= 128) { - moe_permute_kernel<<>>( + moe_permute_kernel<<>>( input, input_fwd, output, @@ -412,7 +402,7 @@ void moe_permutation_launcher( { // permute_topK bwd // unpermute_topK fwd without probs - moe_unpermute_kernel<<>>( + moe_unpermute_kernel<<>>( input, output, row_id_map, @@ -424,7 +414,7 @@ void moe_permutation_launcher( else { // unpermute_topK fwd with probs - moe_unpermute_kernel<<>>( + moe_unpermute_kernel<<>>( input, output, row_id_map, From 1f92e83e9deaa9d7c02f67c96b02c7beacdcdeb3 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Tue, 25 Jun 2024 19:20:33 +0000 Subject: [PATCH 10/33] Refactor to fit the TE style. Part II Signed-off-by: Jiang Shao --- .../pytorch/module/permutation.py | 117 +++++++++++------- 1 file changed, 69 insertions(+), 48 deletions(-) diff --git a/transformer_engine/pytorch/module/permutation.py b/transformer_engine/pytorch/module/permutation.py index 6c7fe5dab3..beb5f0cbb4 100644 --- a/transformer_engine/pytorch/module/permutation.py +++ b/transformer_engine/pytorch/module/permutation.py @@ -6,6 +6,7 @@ import os import torch import warnings +from typing import Tuple import transformer_engine_torch as tex @@ -17,6 +18,7 @@ class _Permute(torch.autograd.Function): + """functional Permute""" workspace=None dtype=None @@ -27,37 +29,24 @@ def forward(ctx, inp: torch.Tensor, indices: torch.Tensor, num_out_tokens: int, - max_token_num: int): + max_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: # Empty input check if not inp.numel(): return inp, None # Device check - if inp.is_cpu: - raise RuntimeError("[Error] The input `inp` of permute_topK op is on the device: CPU!") - if indices.is_cpu: - warnings.warn("[Warning] The input `indices` of permute_topK op is on the device: CPU!") - expert_for_rows = expert_for_rows.cuda() - + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert indices.is_cuda, "TransformerEngine needs CUDA." # Shape check - if inp.size(0) != indices.size(0): - raise RuntimeError(f"[Error] permute_topK op input `indices` shape mismatch! " - f"Expect {inp.size(0)}, but got {indices.size(0)}.") + assert inp.size(0) == indices.size(0), "Permute not possible" # Data type check if indices.dtype != torch.int32: - warnings.warn(f"[Warning] The data type of the input `indices` of permute_topK op is {indices.dtype}! " + warnings.warn(f"The data type of the input `indices` of Permute is {indices.dtype}! " "The recommended type is torch.int32.") indices = indices.to(torch.int32) - # Contiguous check - if not inp.is_contiguous(): - warnings.warn("[Warning] The input `inp` of permute_topK op is discontiguous!") - inp = inp.contiguous() - if not indices.is_contiguous(): - warnings.warn("[Warning] The input `indices` of permute_topK op is discontiguous!") - indices = indices.contiguous() - num_topK = indices.size(1) input_max_expanded_token_num = max(max_token_num, inp.size(0)) * num_topK @@ -83,7 +72,10 @@ def forward(ctx, @staticmethod - def backward(ctx, permuted_act_grad, _): + def backward(ctx, + permuted_act_grad: torch.Tensor, + _, + ) -> Tuple[torch.Tensor, ...]: # Empty input check if not permuted_act_grad.numel(): return permuted_act_grad, None, None, None @@ -105,12 +97,14 @@ def backward(ctx, permuted_act_grad, _): class _Unpermute(torch.autograd.Function): + """functional Unpermute""" @staticmethod def forward(ctx, inp: torch.Tensor, row_id_map: torch.Tensor, - probs: torch.Tensor): + probs: torch.Tensor = torch.empty(0), + ) -> torch.Tensor: # Empty input check if not inp.numel(): ctx.probs = probs @@ -118,38 +112,23 @@ def forward(ctx, # None probs check if probs.numel(): - if probs.is_cpu: - warnings.warn("[Warning] The input `probs` of unpermute_topK op is on the device: CPU!") - probs = probs.cuda() + assert probs.is_cuda, "TransformerEngine needs CUDA." + if probs.dtype != torch.float32: - warnings.warn(f"[Warning] The data type of the input `probs` of unpermute_topK op is {probs.dtype}! " + warnings.warn(f"The data type of the input `probs` of Unpermute is {probs.dtype}! " "The recommended type is torch.float32.") probs = probs.to(torch.float32) - if not probs.is_contiguous(): - warnings.warn("[Warning] The input `probs` of unpermute_topK op is discontiguous!") - probs = probs.contiguous() # Device check - if inp.is_cpu: - raise RuntimeError("[Error] The input `inp` of unpermute_topK op is on the device: CPU!") - if row_id_map.is_cpu: - warnings.warn("[Warning] The input `row_id_map` of unpermute_topK op is on the device: CPU!") - row_id_map = row_id_map.cuda() + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." # Data type check if row_id_map.dtype != torch.int32: - warnings.warn(f"[Warning] The data type of the input `row_id_map` of unpermute_topK op is {row_id_map.dtype}! " + warnings.warn(f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " "The recommended type is torch.int32.") row_id_map = row_id_map.to(torch.int32) - # Contiguous check - if not inp.is_contiguous(): - warnings.warn("[Warning] The input `inp` of unpermute_topK op is discontiguous!") - inp = inp.contiguous() - if not row_id_map.is_contiguous(): - warnings.warn("[Warning] The input `row_id_map` of unpermute_topK op is discontiguous!") - row_id_map = row_id_map.contiguous() - num_topK = probs.size(1) if probs.numel() else 1 num_tokens = probs.size(0) if probs.numel() else row_id_map.size(0) @@ -164,7 +143,9 @@ def forward(ctx, return unpermuted_output @staticmethod - def backward(ctx, unpermuted_act_grad): + def backward(ctx, + unpermuted_act_grad: torch.Tensor, + ) -> Tuple[torch.Tensor, None, torch.Tensor]: # Empty input check if not unpermuted_act_grad.numel(): return unpermuted_act_grad, None, ctx.probs @@ -187,8 +168,48 @@ def backward(ctx, unpermuted_act_grad): return act_grad, None, prob_grad -def Permute(inp, indices, num_out_tokens=-1, max_token_num=-1): - return _Permute.apply(inp, indices, num_out_tokens, max_token_num) - -def Unpermute(inp, row_id_map, probs): - return _Unpermute.apply(inp, row_id_map, probs) +def Permute(inp: torch.Tensor, + indices: torch.Tensor, + num_out_tokens: int = -1, + max_token_num: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Permute the tokens based on the indices. Token with the same index will be grouped together. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + indices: torch.Tensor + The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'. + num_out_tokens: int, default = -1 + The effective output token count, representing the number of tokens not dropped. + By default, set to '-1', meaning no tokens are dropped. + max_token_num: int, default = -1 + The maximum number of tokens, used for workspace allocation. + By default, set to '-1', meaning the calculation of the size of workspace is + automatically taken over by the operator. + """ + return _Permute.apply(inp, indices, num_out_tokens, max_token_num) + +def Unpermute(inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor = torch.empty(0), +) -> torch.Tensor: + """ + Unpermute a tensor with permuted tokens, and optionally merge the tokens with their + corresponding probabilities. + + Parameters + ---------- + inp: torch.Tensor + Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted. + row_id_map: torch.Tensor + The tensor of a mapping table for sorted indices used to unpermute the tokens, + which is the second output tensor of `Permute`. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens. If provided, + the unpermuted tokens will be merged with their respective probabilities. + By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. + """ + return _Unpermute.apply(inp, row_id_map, probs) From 64986fb82789b009fc102942295d51fd1607a4d0 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Tue, 25 Jun 2024 19:28:47 +0000 Subject: [PATCH 11/33] Move permutation.py out of module dir Signed-off-by: Jiang Shao --- transformer_engine/pytorch/__init__.py | 2 +- transformer_engine/pytorch/module/__init__.py | 1 - transformer_engine/pytorch/{module => }/permutation.py | 0 3 files changed, 1 insertion(+), 2 deletions(-) rename transformer_engine/pytorch/{module => }/permutation.py (100%) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 5d43ddbcd5..f89b9ed722 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -36,12 +36,12 @@ def _load_library(): from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import LayerNormMLP from transformer_engine.pytorch.module import LayerNorm -from transformer_engine.pytorch.module import Permute, Unpermute from transformer_engine.pytorch.module import RMSNorm from transformer_engine.pytorch.module import GroupedLinear from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import MultiheadAttention +from transformer_engine.pytorch.permutation import Permute, Unpermute from transformer_engine.pytorch.transformer import TransformerLayer from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 253ff48f66..6994f586b1 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -8,6 +8,5 @@ from .grouped_linear import GroupedLinear from .layernorm_mlp import LayerNormMLP from .layernorm import LayerNorm -from .permutation import Permute, Unpermute from .rmsnorm import RMSNorm from .base import initialize_ub, destroy_ub diff --git a/transformer_engine/pytorch/module/permutation.py b/transformer_engine/pytorch/permutation.py similarity index 100% rename from transformer_engine/pytorch/module/permutation.py rename to transformer_engine/pytorch/permutation.py From 115931c66aeec755658d2da9185079551436c31e Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Tue, 25 Jun 2024 20:02:10 +0000 Subject: [PATCH 12/33] pre-commit reformat Signed-off-by: Jiang Shao --- tests/pytorch/test_permutation.py | 113 ++-- .../include/transformer_engine/permutation.h | 15 +- .../common/permutation/permutation.cu | 603 +++++++----------- transformer_engine/pytorch/csrc/extensions.h | 27 +- .../pytorch/csrc/extensions/permutation.cu | 580 ++++++----------- transformer_engine/pytorch/permutation.py | 311 +++++---- 6 files changed, 676 insertions(+), 973 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 6a00fb5736..82924e74ac 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -8,9 +8,10 @@ from transformer_engine.pytorch import Permute as permute_topK, Unpermute as unpermute_topK + def permute(tokens, indices, num_out_tokens: int = 0): """Permute the tokens based on the indices. Token with the same index will be grouped together. - The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately. + The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately. Args: tokens (torch.Tensor): The input token tensor. indices (torch.Tensor): The token to expert indices tensor, should have a shape of [num_tokens, topk]. @@ -70,14 +71,17 @@ def permute_topK_test( num_expert, hidden_size, num_topK, - num_out_tokens = None, - PRINT = False, - BENCHMARK = False): + num_out_tokens=None, + PRINT=False, + BENCHMARK=False, +): if num_out_tokens == None: num_out_tokens = num_token * num_topK - print(f"{dtype} token:{num_token} hidden_size:{hidden_size} expert:{num_expert} topK:{num_topK}") + print( + f"{dtype} token:{num_token} hidden_size:{hidden_size} expert:{num_expert} topK:{num_topK}" + ) is_fp8 = dtype in [torch.float8_e5m2, torch.float8_e4m3fn] @@ -90,7 +94,7 @@ def permute_topK_test( permute_input = permute_input.half() permute_input.requires_grad_(True) - + if num_token > 0: indices = torch.stack([torch.randperm(num_expert)[:num_topK] for _ in range(num_token)]) else: @@ -134,8 +138,7 @@ def permute_topK_test( unpermute_input = permute_output.detach() unpermute_input.requires_grad_(True) - unpermute_output = unpermute( - unpermute_input, sorted_indices, probs=probs) + unpermute_output = unpermute(unpermute_input, sorted_indices, probs=probs) if PRINT: print("--------------unpermute fwd permute_input--------------") @@ -218,7 +221,10 @@ def permute_topK_test( original_inputs = unpermute_input.grad.float().cpu().detach().numpy().flatten() original_output = new_unpermute_input.grad.float().cpu().detach().numpy().flatten() max_abs_error = abs(original_inputs - original_output).max() - print(f"unpermute_topK bwd act_grad max error (mine vs pytorch): \t{max_abs_error:.3e} ({dtype})") + print( + "unpermute_topK bwd act_grad max error (mine vs pytorch):" + f" \t{max_abs_error:.3e} ({dtype})" + ) if PRINT: print(new_unpermute_input.grad) print(unpermute_input.grad) @@ -227,21 +233,26 @@ def permute_topK_test( original_inputs = new_probs.grad.float().cpu().detach().numpy().flatten() original_output = probs.grad.float().cpu().detach().numpy().flatten() max_abs_error = abs(original_inputs - original_output).max() - print(f"unpermute_topK bwd prob_grad max error (mine vs pytorch): \t{max_abs_error:.3e} ({dtype})") + print( + "unpermute_topK bwd prob_grad max error (mine vs pytorch):" + f" \t{max_abs_error:.3e} ({dtype})" + ) if PRINT: print(new_probs.grad) print(probs.grad) if not permute_input.numel(): - print("Empty permute_input activation test passed.") - return + print("Empty permute_input activation test passed.") + return ################################################################################################################################### # # Benchmark # ################################################################################################################################### - def backward_wrapper(act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False): + def backward_wrapper( + act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False + ): # Set forward_input.grad to None to avoid grad accumulation. if accumulate_grad == False: for i in forward_input: @@ -256,25 +267,53 @@ def backward_wrapper(act, backward_input, forward_input=[], retain_graph=True, a print(f"new fwd: {t:.3f} ms") t = perf_test_cuda_kernel( - lambda: backward_wrapper(permute_output, permute_bwd_input, forward_input=[permute_input], retain_graph=True, accumulate_grad=False)) + lambda: backward_wrapper( + permute_output, + permute_bwd_input, + forward_input=[permute_input], + retain_graph=True, + accumulate_grad=False, + ) + ) print(f"pytorch bwd: {t:.3f} ms") t = perf_test_cuda_kernel( - lambda: backward_wrapper(new_permute_output, new_permute_bwd_input, forward_input=[new_permute_input], retain_graph=True, accumulate_grad=False)) + lambda: backward_wrapper( + new_permute_output, + new_permute_bwd_input, + forward_input=[new_permute_input], + retain_graph=True, + accumulate_grad=False, + ) + ) print(f"new bwd: {t:.3f} ms") print(f"----unpermute topK----") - t = perf_test_cuda_kernel( - lambda: unpermute(unpermute_input, sorted_indices, probs=probs)) + t = perf_test_cuda_kernel(lambda: unpermute(unpermute_input, sorted_indices, probs=probs)) print(f"pytorch fwd: {t:.3f} ms") t = perf_test_cuda_kernel( - lambda: unpermute_topK(new_unpermute_input, row_id_map, new_probs)) + lambda: unpermute_topK(new_unpermute_input, row_id_map, new_probs) + ) print(f"new fwd: {t:.3f} ms") t = perf_test_cuda_kernel( - lambda: backward_wrapper(unpermute_output, unpermute_bwd_input, forward_input=[unpermute_input, probs], retain_graph=True, accumulate_grad=False)) + lambda: backward_wrapper( + unpermute_output, + unpermute_bwd_input, + forward_input=[unpermute_input, probs], + retain_graph=True, + accumulate_grad=False, + ) + ) print(f"pytorch bwd: {t:.3f} ms") t = perf_test_cuda_kernel( - lambda: backward_wrapper(new_unpermute_output, new_unpermute_bwd_input, forward_input=[new_unpermute_input, new_probs], retain_graph=True, accumulate_grad=False)) + lambda: backward_wrapper( + new_unpermute_output, + new_unpermute_bwd_input, + forward_input=[new_unpermute_input, new_probs], + retain_graph=True, + accumulate_grad=False, + ) + ) print(f"new bwd: {t:.3f} ms") @@ -300,6 +339,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn): else: print("CUDA is not available.") + def test_permute_topK(): torch.manual_seed(1) @@ -316,25 +356,25 @@ def test_permute_topK(): print("GPU:", torch.cuda.get_device_name(0)) dtype = torch.float32 - permute_topK_test(dtype, num_token, num_expert, - hidden_size, num_topK, num_out_tokens, - False, Benchmark) + permute_topK_test( + dtype, num_token, num_expert, hidden_size, num_topK, num_out_tokens, False, Benchmark + ) dtype = torch.float16 - permute_topK_test(dtype, num_token, num_expert, - hidden_size, num_topK, num_out_tokens, - False, Benchmark) + permute_topK_test( + dtype, num_token, num_expert, hidden_size, num_topK, num_out_tokens, False, Benchmark + ) dtype = torch.bfloat16 - permute_topK_test(dtype, num_token, num_expert, - hidden_size, num_topK, num_out_tokens, - False, Benchmark) + permute_topK_test( + dtype, num_token, num_expert, hidden_size, num_topK, num_out_tokens, False, Benchmark + ) dtype = torch.float8_e5m2 - permute_topK_test(dtype, num_token, num_expert, - hidden_size, num_topK, num_out_tokens, - False, Benchmark) + permute_topK_test( + dtype, num_token, num_expert, hidden_size, num_topK, num_out_tokens, False, Benchmark + ) dtype = torch.float8_e4m3fn - permute_topK_test(dtype, num_token, num_expert, - hidden_size, num_topK, num_out_tokens, - False, Benchmark) + permute_topK_test( + dtype, num_token, num_expert, hidden_size, num_topK, num_out_tokens, False, Benchmark + ) dtype = torch.bfloat16 permute_topK_test(dtype, num_token, 4, hidden_size, 1, None, False, Benchmark) permute_topK_test(dtype, num_token, 5, hidden_size, 2, None, False, Benchmark) @@ -344,5 +384,6 @@ def test_permute_topK(): num_token = 0 permute_topK_test(dtype, num_token, 8, hidden_size, 5, None, False, Benchmark) + if __name__ == "__main__": - test_permute_topK() \ No newline at end of file + test_permute_topK() diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h index c1c33900e6..a2e4661883 100644 --- a/transformer_engine/common/include/transformer_engine/permutation.h +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -10,17 +10,10 @@ #include "transformer_engine.h" template -void moe_permutation_launcher(const void *input, - void *output, - const int *sorted_row_id, - int *row_id_map, - const float *prob, - const int num_rows, - const int num_topK, - const int num_cols, - const int num_out_tokens, - cudaStream_t stream, - float *prob_grad = nullptr, +void moe_permutation_launcher(const void *input, void *output, const int *sorted_row_id, + int *row_id_map, const float *prob, const int num_rows, + const int num_topK, const int num_cols, const int num_out_tokens, + cudaStream_t stream, float *prob_grad = nullptr, const void *input_fwd = nullptr); #endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index e08025f3d3..d0a63433c3 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -8,440 +8,281 @@ #include "../common.h" -static __global__ void moe_permute_row_map( - const int *sorted_row_id, - int *row_id_map, - const int num_rows, - const int num_topK, - const int num_out_tokens) -{ - // Each block corresponds to one source token - // row_id_map[num_topK][num_rows] - const int bid = blockIdx.x; - const int tid = threadIdx.x; - const int idx = bid * blockDim.x + tid; - - if (idx >= num_rows * num_topK) - return; - - int source_row = sorted_row_id[idx]; - int source_token_id = source_row / num_topK; - int source_topK_id = source_row % num_topK; - - if (idx >= num_out_tokens) - { - row_id_map[source_topK_id * num_rows + source_token_id] = -1; - } - else - { - row_id_map[source_topK_id * num_rows + source_token_id] = idx; - } +static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map, + const int num_rows, const int num_topK, + const int num_out_tokens) { + // Each block corresponds to one source token + // row_id_map[num_topK][num_rows] + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int idx = bid * blockDim.x + tid; + + if (idx >= num_rows * num_topK) return; + + int source_row = sorted_row_id[idx]; + int source_token_id = source_row / num_topK; + int source_topK_id = source_row % num_topK; + + if (idx >= num_out_tokens) { + row_id_map[source_topK_id * num_rows + source_token_id] = -1; + } else { + row_id_map[source_topK_id * num_rows + source_token_id] = idx; + } } -template -__global__ void moe_unpermute_kernel(const T *input, - T *unpermuted_output, - const int *row_id_map, - const float *prob, - const int num_rows, - const int num_topK, - const int num_cols) -{ - extern __shared__ int8_t s_mem[]; - TCompute *s_prob = reinterpret_cast(s_mem); - - // each block corresponds to one source token - const int source_token = blockIdx.x; - const int tid = threadIdx.x; - - if (hasProb) - { - for (int i = tid; i < num_topK; i += blockDim.x * blockDim.y) - { - s_prob[i] = TCompute(prob[source_token * num_topK + i]); - } - __syncthreads(); +template +__global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const int *row_id_map, + const float *prob, const int num_rows, const int num_topK, + const int num_cols) { + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + // each block corresponds to one source token + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) { + for (int i = tid; i < num_topK; i += blockDim.x * blockDim.y) { + s_prob[i] = TCompute(prob[source_token * num_topK + i]); } + __syncthreads(); + } - float4 frag_load_store; - T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); - static constexpr int kElementsPerAccess = 16 / sizeof(T); + static constexpr int kElementsPerAccess = 16 / sizeof(T); - for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) - { + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { + TCompute frag_elem[kElementsPerAccess]; + TCompute frag_sum[kElementsPerAccess]; - TCompute frag_elem[kElementsPerAccess]; - TCompute frag_sum[kElementsPerAccess]; - - int source_row = row_id_map[source_token]; + int source_row = row_id_map[source_token]; - if (source_row != -1) - { - const T *source_row_ptr = input + source_row * num_cols; + if (source_row != -1) { + const T *source_row_ptr = input + source_row * num_cols; - frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); - for (int e = 0; e < kElementsPerAccess; e++) { - frag_sum[e] = TCompute(frag_load_store_ptr[e]); } + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(frag_load_store_ptr[e]); + } - if (hasProb) { - for (int e = 0; e < kElementsPerAccess; e++) { - frag_sum[e] = frag_sum[e] * s_prob[0]; } - } - } - else - { - for (int e = 0; e < kElementsPerAccess; e++) { - frag_sum[e] = TCompute(0.0f); } + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] * s_prob[0]; } + } + } else { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(0.0f); + } + } - for (int k = 1; k < num_topK; k++) - { - source_row = row_id_map[k * num_rows + source_token]; + for (int k = 1; k < num_topK; k++) { + source_row = row_id_map[k * num_rows + source_token]; - if (source_row == -1) - continue; + if (source_row == -1) continue; - const T *source_row_ptr = input + source_row * num_cols; + const T *source_row_ptr = input + source_row * num_cols; - frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); - for (int e = 0; e < kElementsPerAccess; e++) { - frag_elem[e] = TCompute(frag_load_store_ptr[e]); } + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = TCompute(frag_load_store_ptr[e]); + } - if (hasProb) { - for (int e = 0; e < kElementsPerAccess; e++) { - frag_elem[e] = frag_elem[e] * s_prob[k]; } - } - - for (int e = 0; e < kElementsPerAccess; e++) { - frag_sum[e] = frag_sum[e] + frag_elem[e]; } + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = frag_elem[e] * s_prob[k]; } + } - T *dest_row_ptr = unpermuted_output + source_token * num_cols; + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] + frag_elem[e]; + } + } - for (int e = 0; e < kElementsPerAccess; e++) { - frag_load_store_ptr[e] = T(frag_sum[e]); } + T *dest_row_ptr = unpermuted_output + source_token * num_cols; - *(float4 *)(dest_row_ptr + i) = frag_load_store; + for (int e = 0; e < kElementsPerAccess; e++) { + frag_load_store_ptr[e] = T(frag_sum[e]); } + + *(float4 *)(dest_row_ptr + i) = frag_load_store; + } } -template -__global__ void moe_permute_kernel(const T *input_bwd, - const T *input_fwd, - T *act_grad, - const float *prob, - float *prob_grad, - const int *row_id_map, - const int num_rows, - const int num_topK, - const int num_cols) -{ - extern __shared__ int8_t s_mem[]; - TCompute *s_prob = reinterpret_cast(s_mem); - - const int source_token = blockIdx.x; - const int tid = threadIdx.x; - - if (hasProb) - { - for (int i = tid; i < num_topK; i += blockDim.x) - { - s_prob[i] = TCompute(prob[source_token * num_topK + i]); - } - __syncthreads(); +template +__global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *act_grad, + const float *prob, float *prob_grad, const int *row_id_map, + const int num_rows, const int num_topK, const int num_cols) { + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) { + for (int i = tid; i < num_topK; i += blockDim.x) { + s_prob[i] = TCompute(prob[source_token * num_topK + i]); } + __syncthreads(); + } - float accum[topKTile] = {0.0f}; + float accum[topKTile] = {0.0f}; - float4 frag_load_store; - T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); - static constexpr int kElementsPerAccess = 16 / sizeof(T); + static constexpr int kElementsPerAccess = 16 / sizeof(T); - const T *source_row_ptr = input_bwd + source_token * num_cols; - for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) - { - TCompute frag_src[kElementsPerAccess]; + const T *source_row_ptr = input_bwd + source_token * num_cols; + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { + TCompute frag_src[kElementsPerAccess]; - frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); - for (int e = 0; e < kElementsPerAccess; e++) - frag_src[e] = TCompute(frag_load_store_ptr[e]); + for (int e = 0; e < kElementsPerAccess; e++) frag_src[e] = TCompute(frag_load_store_ptr[e]); - int index = source_token; + int index = source_token; - for (int k = 0; k < topKTile; k++) - { - if (k == num_topK) break; + for (int k = 0; k < topKTile; k++) { + if (k == num_topK) break; - int dest_row = row_id_map[index]; - index += num_rows; + int dest_row = row_id_map[index]; + index += num_rows; - if (dest_row != -1) - { - if (hasProb) - { - for (int e = 0; e < kElementsPerAccess; e++) - frag_load_store_ptr[e] = T(frag_src[e] * s_prob[k]); - } - else - { - for (int e = 0; e < kElementsPerAccess; e++) - frag_load_store_ptr[e] = T(frag_src[e]); - } + if (dest_row != -1) { + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) + frag_load_store_ptr[e] = T(frag_src[e] * s_prob[k]); + } else { + for (int e = 0; e < kElementsPerAccess; e++) frag_load_store_ptr[e] = T(frag_src[e]); + } - T *dest_row_ptr = act_grad + dest_row * num_cols; - *(float4 *)(dest_row_ptr + i) = frag_load_store; + T *dest_row_ptr = act_grad + dest_row * num_cols; + *(float4 *)(dest_row_ptr + i) = frag_load_store; - if (hasProb) - { - const T *input_fwd_ptr = input_fwd + dest_row * num_cols; + if (hasProb) { + const T *input_fwd_ptr = input_fwd + dest_row * num_cols; - frag_load_store = __ldlu(reinterpret_cast(input_fwd_ptr + i)); + frag_load_store = __ldlu(reinterpret_cast(input_fwd_ptr + i)); - TCompute frag_input_fwd[kElementsPerAccess]; - for (int e = 0; e < kElementsPerAccess; e++) - frag_input_fwd[e] = TCompute(frag_load_store_ptr[e]); + TCompute frag_input_fwd[kElementsPerAccess]; + for (int e = 0; e < kElementsPerAccess; e++) + frag_input_fwd[e] = TCompute(frag_load_store_ptr[e]); - for (int e = 0; e < kElementsPerAccess; e++) { - accum[k] += float(frag_src[e] * frag_input_fwd[e]); } - } - } + for (int e = 0; e < kElementsPerAccess; e++) { + accum[k] += float(frag_src[e] * frag_input_fwd[e]); + } } + } } + } - if (hasProb) - { - for (int k = 0; k < topKTile; k++) - { - if (k == num_topK) break; + if (hasProb) { + for (int k = 0; k < topKTile; k++) { + if (k == num_topK) break; - for (int mask = 16; mask > 0; mask /= 2) - { - accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); - } - } + for (int mask = 16; mask > 0; mask /= 2) { + accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); + } + } - if (tid == 0) - { - for (int k = 0; k < topKTile; k++) - { - if (k == num_topK) break; - prob_grad[source_token * num_topK + k] = accum[k]; - } - } + if (tid == 0) { + for (int k = 0; k < topKTile; k++) { + if (k == num_topK) break; + prob_grad[source_token * num_topK + k] = accum[k]; + } } + } } template -void moe_permutation_launcher( - const void *input_, - void *output_, - const int *sorted_row_id, - int *row_id_map, - const float *prob, - const int num_rows, - const int num_topK, - const int num_cols, - const int num_out_tokens, - cudaStream_t stream, - float *prob_grad, - const void *input_fwd_) -{ - using TCompute = typename std::conditional< - (std::is_same::value || - std::is_same::value), - half, T>::type; - - static constexpr int kElementsPerAccess = 16 / sizeof(T); - - const T* input = reinterpret_cast(input_); - T* output = reinterpret_cast(output_); - const T* input_fwd = reinterpret_cast(input_fwd_); - - if (FWD) - { - if (prob == nullptr) - { - if (input_fwd == nullptr) - { - // permute_topK fwd - int threads = 64; - int blocks = (num_rows * num_topK + threads - 1) / threads; - moe_permute_row_map<<>>( - sorted_row_id, - row_id_map, - num_rows, - num_topK, - num_out_tokens); - - blocks = num_rows; - threads = std::min(num_cols / kElementsPerAccess, 1024); - moe_permute_kernel<<>>( - input, - nullptr, - output, - nullptr, - nullptr, - row_id_map, - num_rows, - num_topK, - num_cols); - } - else - { - // unpermute_topK bwd without probs for topK == 1 - int blocks = num_rows; - int threads = 32; - - moe_permute_kernel<<>>( - input, - input_fwd, - output, - prob, - prob_grad, - row_id_map, - num_rows, - num_topK, - num_cols); - } - } - else - { - // unpermute_topK bwd with probs - int blocks = num_rows; - int threads = 32; - size_t smem_bytes = num_topK * sizeof(TCompute); - - if (num_topK <= 8) - { - moe_permute_kernel<<>>( - input, - input_fwd, - output, - prob, - prob_grad, - row_id_map, - num_rows, - num_topK, - num_cols); - } - else if (num_topK <= 16) - { - moe_permute_kernel<<>>( - input, - input_fwd, - output, - prob, - prob_grad, - row_id_map, - num_rows, - num_topK, - num_cols); - } - else if (num_topK <= 32) - { - moe_permute_kernel<<>>( - input, - input_fwd, - output, - prob, - prob_grad, - row_id_map, - num_rows, - num_topK, - num_cols); - } - else if (num_topK <= 64) - { - moe_permute_kernel<<>>( - input, - input_fwd, - output, - prob, - prob_grad, - row_id_map, - num_rows, - num_topK, - num_cols); - } - else if (num_topK <= 128) - { - moe_permute_kernel<<>>( - input, - input_fwd, - output, - prob, - prob_grad, - row_id_map, - num_rows, - num_topK, - num_cols); - } - else - { - throw std::runtime_error("num_topK cannot exceed 128."); - } - } - } - else - { +void moe_permutation_launcher(const void *input_, void *output_, const int *sorted_row_id, + int *row_id_map, const float *prob, const int num_rows, + const int num_topK, const int num_cols, const int num_out_tokens, + cudaStream_t stream, float *prob_grad, const void *input_fwd_) { + using TCompute = typename std::conditional<(std::is_same::value || + std::is_same::value), + half, T>::type; + + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + const T *input = reinterpret_cast(input_); + T *output = reinterpret_cast(output_); + const T *input_fwd = reinterpret_cast(input_fwd_); + + if (FWD) { + if (prob == nullptr) { + if (input_fwd == nullptr) { + // Permute fwd + int threads = 64; + int blocks = (num_rows * num_topK + threads - 1) / threads; + moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, + num_topK, num_out_tokens); + + blocks = num_rows; + threads = std::min(num_cols / kElementsPerAccess, 1024); + moe_permute_kernel<<>>( + input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, num_topK, num_cols); + } else { + // Unpermute bwd without probs for topK == 1 int blocks = num_rows; - int threads = std::min(num_cols / kElementsPerAccess, 1024); - size_t smem_bytes = num_topK * sizeof(TCompute); - - if (prob == nullptr) - { - // permute_topK bwd - // unpermute_topK fwd without probs - moe_unpermute_kernel<<>>( - input, - output, - row_id_map, - prob, - num_rows, - num_topK, - num_cols); - } - else - { - // unpermute_topK fwd with probs - moe_unpermute_kernel<<>>( - input, - output, - row_id_map, - prob, - num_rows, - num_topK, - num_cols); - } + int threads = 32; + + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + } + } else { + // Unpermute bwd with probs + int blocks = num_rows; + int threads = 32; + size_t smem_bytes = num_topK * sizeof(TCompute); + + if (num_topK <= 8) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + } else if (num_topK <= 16) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + } else if (num_topK <= 32) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + } else if (num_topK <= 64) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + } else if (num_topK <= 128) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + } else { + throw std::runtime_error("num_topK cannot exceed 128."); + } } + } else { + int blocks = num_rows; + int threads = std::min(num_cols / kElementsPerAccess, 1024); + size_t smem_bytes = num_topK * sizeof(TCompute); + + if (prob == nullptr) { + // Permute bwd + // Unpermute fwd without probs + moe_unpermute_kernel<<>>( + input, output, row_id_map, prob, num_rows, num_topK, num_cols); + } else { + // Unpermute fwd with probs + moe_unpermute_kernel<<>>( + input, output, row_id_map, prob, num_rows, num_topK, num_cols); + } + } } - - -#define FUNCTION_INSTANTIATION(T, FWD) \ -template void moe_permutation_launcher( \ - const void *input, \ - void *output, \ - const int *sorted_row_id, \ - int *row_id_map, \ - const float *prob, \ - const int num_rows, \ - const int num_topK, \ - const int num_cols, \ - const int num_out_tokens, \ - cudaStream_t stream, \ - float *prob_grad, \ - const void *input_fwd); +#define FUNCTION_INSTANTIATION(T, FWD) \ + template void moe_permutation_launcher( \ + const void *input, void *output, const int *sorted_row_id, int *row_id_map, \ + const float *prob, const int num_rows, const int num_topK, const int num_cols, \ + const int num_out_tokens, cudaStream_t stream, float *prob_grad, const void *input_fwd); FUNCTION_INSTANTIATION(float, true) FUNCTION_INSTANTIATION(float, false) @@ -452,4 +293,4 @@ FUNCTION_INSTANTIATION(__nv_bfloat16, false) FUNCTION_INSTANTIATION(__nv_fp8_e5m2, true) FUNCTION_INSTANTIATION(__nv_fp8_e5m2, false) FUNCTION_INSTANTIATION(__nv_fp8_e4m3, true) -FUNCTION_INSTANTIATION(__nv_fp8_e4m3, false) \ No newline at end of file +FUNCTION_INSTANTIATION(__nv_fp8_e4m3, false) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 5257971ebb..e11c77b56b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -10,30 +10,19 @@ #include "common.h" #include "common/common.h" - /*************************************************************************************************** * permute **************************************************************************************************/ std::tuple> moe_permute( - at::Tensor input, - at::Tensor indices, - int64_t num_out_tokens, - std::vector workspace, - int64_t max_expanded_token_num); - -at::Tensor moe_unpermute_fwd( - at::Tensor input, - at::Tensor row_id_map, - at::Tensor prob, - int64_t num_tokens, - int64_t num_topK); - -std::tuple moe_unpermute_bwd( - at::Tensor input_bwd, - at::Tensor input_fwd, - at::Tensor row_id_map, - at::Tensor prob); + at::Tensor input, at::Tensor indices, int64_t num_out_tokens, std::vector workspace, + int64_t max_expanded_token_num); + +at::Tensor moe_unpermute_fwd(at::Tensor input, at::Tensor row_id_map, at::Tensor prob, + int64_t num_tokens, int64_t num_topK); + +std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, + at::Tensor row_id_map, at::Tensor prob); /*************************************************************************************************** * Attention diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index cf43c145d2..f7d8663751 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -10,394 +10,236 @@ using torch::Tensor; -std::tuple> moe_permute( - Tensor input, - Tensor indices, - int64_t num_out_tokens, - std::vector workspace, - int64_t max_expanded_token_num) -{ - const int num_tokens = input.size(0); - const int num_cols = input.size(1); - const int num_topK = indices.size(1); - - // initialize the workspace on the first run - if (workspace.empty()) { - auto options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false); - - Tensor sorted_indices = torch::empty(max_expanded_token_num, options); - Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options); - Tensor sorted_row_id = - torch::empty(max_expanded_token_num, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); - - size_t temp_storage_bytes = 0; - int *temp_ptr = nullptr; - cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, - temp_ptr, temp_ptr, - temp_ptr, temp_ptr, max_expanded_token_num); - Tensor temp_storage = - torch::empty(temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); - - workspace.push_back(sorted_indices); - workspace.push_back(row_id); - workspace.push_back(sorted_row_id); - workspace.push_back(temp_storage); - } - - int *indices_ptr = reinterpret_cast(getDataPtr(indices, 0)); - int *sorted_indices_ptr = reinterpret_cast(getDataPtr(workspace[0], 0)); - int *row_id_ptr = reinterpret_cast(getDataPtr(workspace[1], 0)); - int *sorted_row_id_ptr = reinterpret_cast(getDataPtr(workspace[2], 0)); - - void *d_temp_storage = getDataPtr(workspace[3], 0); - size_t temp_storage_bytes = std::numeric_limits::max(); - - cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, - indices_ptr, sorted_indices_ptr, - row_id_ptr, sorted_row_id_ptr, num_tokens * num_topK); - - // activations type - const at::ScalarType _st = input.scalar_type(); - - // Output buffer alloc - num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * num_topK; - Tensor permuted_output = - torch::empty({num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); - Tensor row_id_map = - torch::empty({num_tokens * num_topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); - - int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - void *input_ptr = getDataPtr(input, 0); - void *permuted_output_ptr = getDataPtr(permuted_output, 0); - - switch (_st) - { - case at::ScalarType::Float: - { - moe_permutation_launcher( - input_ptr, - permuted_output_ptr, - sorted_row_id_ptr, - row_id_map_ptr, - nullptr, - num_tokens, - num_topK, - num_cols, - num_out_tokens, - stream); - - break; - } - case at::ScalarType::Half: - { - moe_permutation_launcher( - input_ptr, - permuted_output_ptr, - sorted_row_id_ptr, - row_id_map_ptr, - nullptr, - num_tokens, - num_topK, - num_cols, - num_out_tokens, - stream); - - break; - } - case at::ScalarType::BFloat16: - { - moe_permutation_launcher<__nv_bfloat16, true>( - input_ptr, - permuted_output_ptr, - sorted_row_id_ptr, - row_id_map_ptr, - nullptr, - num_tokens, - num_topK, - num_cols, - num_out_tokens, - stream); - - break; - } - case at::ScalarType::Float8_e5m2: - { - moe_permutation_launcher<__nv_fp8_e5m2, true>( - input_ptr, - permuted_output_ptr, - sorted_row_id_ptr, - row_id_map_ptr, - nullptr, - num_tokens, - num_topK, - num_cols, - num_out_tokens, - stream); - - break; - } - case at::ScalarType::Float8_e4m3fn: - { - moe_permutation_launcher<__nv_fp8_e4m3, true>( - input_ptr, - permuted_output_ptr, - sorted_row_id_ptr, - row_id_map_ptr, - nullptr, - num_tokens, - num_topK, - num_cols, - num_out_tokens, - stream); - - break; +std::tuple> moe_permute(Tensor input, Tensor indices, + int64_t num_out_tokens, + std::vector workspace, + int64_t max_expanded_token_num) { + const int num_tokens = input.size(0); + const int num_cols = input.size(1); + const int num_topK = indices.size(1); + + // initialize the workspace on the first run + if (workspace.empty()) { + auto options = + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false); + + Tensor sorted_indices = torch::empty(max_expanded_token_num, options); + Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options); + Tensor sorted_row_id = + torch::empty(max_expanded_token_num, + torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + + size_t temp_storage_bytes = 0; + int *temp_ptr = nullptr; + cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_ptr, temp_ptr, temp_ptr, + temp_ptr, max_expanded_token_num); + Tensor temp_storage = torch::empty( + temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + + workspace.push_back(sorted_indices); + workspace.push_back(row_id); + workspace.push_back(sorted_row_id); + workspace.push_back(temp_storage); + } + + int *indices_ptr = reinterpret_cast(getDataPtr(indices, 0)); + int *sorted_indices_ptr = reinterpret_cast(getDataPtr(workspace[0], 0)); + int *row_id_ptr = reinterpret_cast(getDataPtr(workspace[1], 0)); + int *sorted_row_id_ptr = reinterpret_cast(getDataPtr(workspace[2], 0)); + + void *d_temp_storage = getDataPtr(workspace[3], 0); + size_t temp_storage_bytes = std::numeric_limits::max(); + + cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, indices_ptr, + sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, + num_tokens * num_topK); + + // activations type + const at::ScalarType _st = input.scalar_type(); + + // Output buffer alloc + num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * num_topK; + Tensor permuted_output = torch::empty( + {num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + Tensor row_id_map = + torch::empty({num_tokens * num_topK}, + torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + + int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + void *input_ptr = getDataPtr(input, 0); + void *permuted_output_ptr = getDataPtr(permuted_output, 0); + + switch (_st) { + case at::ScalarType::Float: { + moe_permutation_launcher(input_ptr, permuted_output_ptr, sorted_row_id_ptr, + row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, + num_out_tokens, stream); + + break; + } + case at::ScalarType::Half: { + moe_permutation_launcher(input_ptr, permuted_output_ptr, sorted_row_id_ptr, + row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, + num_out_tokens, stream); + + break; + } + case at::ScalarType::BFloat16: { + moe_permutation_launcher<__nv_bfloat16, true>( + input_ptr, permuted_output_ptr, sorted_row_id_ptr, row_id_map_ptr, nullptr, num_tokens, + num_topK, num_cols, num_out_tokens, stream); + + break; + } + case at::ScalarType::Float8_e5m2: { + moe_permutation_launcher<__nv_fp8_e5m2, true>( + input_ptr, permuted_output_ptr, sorted_row_id_ptr, row_id_map_ptr, nullptr, num_tokens, + num_topK, num_cols, num_out_tokens, stream); + + break; + } + case at::ScalarType::Float8_e4m3fn: { + moe_permutation_launcher<__nv_fp8_e4m3, true>( + input_ptr, permuted_output_ptr, sorted_row_id_ptr, row_id_map_ptr, nullptr, num_tokens, + num_topK, num_cols, num_out_tokens, stream); + + break; } default: - throw std::runtime_error("Wrong activation tensor type."); - } + throw std::runtime_error("Wrong activation tensor type."); + } - return std::make_tuple(permuted_output, row_id_map, workspace); + return std::make_tuple(permuted_output, row_id_map, workspace); } +Tensor moe_unpermute_fwd(Tensor input, Tensor row_id_map, Tensor prob, int64_t num_tokens, + int64_t num_topK) { + const int num_cols = input.size(1); + + // activations type + const at::ScalarType _st = input.scalar_type(); -Tensor moe_unpermute_fwd( - Tensor input, - Tensor row_id_map, - Tensor prob, - int64_t num_tokens, - int64_t num_topK) -{ - const int num_cols = input.size(1); - - // activations type - const at::ScalarType _st = input.scalar_type(); - - // Output buffer alloc - Tensor unpermuted_output = - torch::empty({num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); - - int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); - float *prob_ptr = (prob.numel() > 0) ? reinterpret_cast(getDataPtr(prob, 0)) : nullptr; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - void *input_ptr = getDataPtr(input, 0); - void *unpermuted_output_ptr = getDataPtr(unpermuted_output, 0); - - switch (_st) - { - case at::ScalarType::Float: - { - moe_permutation_launcher( - input_ptr, - unpermuted_output_ptr, - nullptr, - row_id_map_ptr, - prob_ptr, - num_tokens, - num_topK, - num_cols, - 0, - stream); - - break; + // Output buffer alloc + Tensor unpermuted_output = torch::empty( + {num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + + int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); + float *prob_ptr = (prob.numel() > 0) ? reinterpret_cast(getDataPtr(prob, 0)) : nullptr; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + void *input_ptr = getDataPtr(input, 0); + void *unpermuted_output_ptr = getDataPtr(unpermuted_output, 0); + + switch (_st) { + case at::ScalarType::Float: { + moe_permutation_launcher(input_ptr, unpermuted_output_ptr, nullptr, + row_id_map_ptr, prob_ptr, num_tokens, num_topK, + num_cols, 0, stream); + + break; } - case at::ScalarType::Half: - { - moe_permutation_launcher( - input_ptr, - unpermuted_output_ptr, - nullptr, - row_id_map_ptr, - prob_ptr, - num_tokens, - num_topK, - num_cols, - 0, - stream); - - break; + case at::ScalarType::Half: { + moe_permutation_launcher(input_ptr, unpermuted_output_ptr, nullptr, + row_id_map_ptr, prob_ptr, num_tokens, num_topK, + num_cols, 0, stream); + + break; } - case at::ScalarType::BFloat16: - { - moe_permutation_launcher<__nv_bfloat16, false>( - input_ptr, - unpermuted_output_ptr, - nullptr, - row_id_map_ptr, - prob_ptr, - num_tokens, - num_topK, - num_cols, - 0, - stream); - - break; + case at::ScalarType::BFloat16: { + moe_permutation_launcher<__nv_bfloat16, false>(input_ptr, unpermuted_output_ptr, nullptr, + row_id_map_ptr, prob_ptr, num_tokens, num_topK, + num_cols, 0, stream); + + break; } - case at::ScalarType::Float8_e5m2: - { - moe_permutation_launcher<__nv_fp8_e5m2, false>( - input_ptr, - unpermuted_output_ptr, - nullptr, - row_id_map_ptr, - prob_ptr, - num_tokens, - num_topK, - num_cols, - 0, - stream); - - break; + case at::ScalarType::Float8_e5m2: { + moe_permutation_launcher<__nv_fp8_e5m2, false>(input_ptr, unpermuted_output_ptr, nullptr, + row_id_map_ptr, prob_ptr, num_tokens, num_topK, + num_cols, 0, stream); + + break; } - case at::ScalarType::Float8_e4m3fn: - { - moe_permutation_launcher<__nv_fp8_e4m3, false>( - input_ptr, - unpermuted_output_ptr, - nullptr, - row_id_map_ptr, - prob_ptr, - num_tokens, - num_topK, - num_cols, - 0, - stream); - - break; + case at::ScalarType::Float8_e4m3fn: { + moe_permutation_launcher<__nv_fp8_e4m3, false>(input_ptr, unpermuted_output_ptr, nullptr, + row_id_map_ptr, prob_ptr, num_tokens, num_topK, + num_cols, 0, stream); + + break; } default: - throw std::runtime_error("Wrong activation tensor type."); - } + throw std::runtime_error("Wrong activation tensor type."); + } - return unpermuted_output; + return unpermuted_output; } -std::tuple moe_unpermute_bwd( - Tensor input_bwd, - Tensor input_fwd, - Tensor row_id_map, - Tensor prob) -{ - const int num_topK = (prob.numel() > 0) ? prob.size(1) : 1; - const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); - const int num_cols = input_bwd.size(1); - - int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); - float *prob_ptr = (prob.numel() > 0) ? reinterpret_cast(getDataPtr(prob, 0)) : nullptr; - - // activations type - const at::ScalarType _st = input_bwd.scalar_type(); - - // Output buffer alloc - Tensor act_grad = - torch::empty({input_fwd.size(0), num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); - Tensor prob_grad = - torch::empty({num_tokens, num_topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); - float *prob_grad_ptr = reinterpret_cast(getDataPtr(prob_grad, 0)); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - void *input_bwd_ptr = getDataPtr(input_bwd, 0); - void *input_fwd_ptr = getDataPtr(input_fwd, 0); - void *act_grad_ptr = getDataPtr(act_grad, 0); - - switch (_st) - { - case at::ScalarType::Float: - { - moe_permutation_launcher( - input_bwd_ptr, - act_grad_ptr, - nullptr, - row_id_map_ptr, - prob_ptr, - num_tokens, - num_topK, - num_cols, - 0, - stream, - prob_grad_ptr, - input_fwd_ptr); - - break; +std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, Tensor row_id_map, + Tensor prob) { + const int num_topK = (prob.numel() > 0) ? prob.size(1) : 1; + const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); + const int num_cols = input_bwd.size(1); + + int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); + float *prob_ptr = (prob.numel() > 0) ? reinterpret_cast(getDataPtr(prob, 0)) : nullptr; + + // activations type + const at::ScalarType _st = input_bwd.scalar_type(); + + // Output buffer alloc + Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, + torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + Tensor prob_grad = + torch::empty({num_tokens, num_topK}, + torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); + float *prob_grad_ptr = reinterpret_cast(getDataPtr(prob_grad, 0)); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + void *input_bwd_ptr = getDataPtr(input_bwd, 0); + void *input_fwd_ptr = getDataPtr(input_fwd, 0); + void *act_grad_ptr = getDataPtr(act_grad, 0); + + switch (_st) { + case at::ScalarType::Float: { + moe_permutation_launcher(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, + prob_ptr, num_tokens, num_topK, num_cols, 0, stream, + prob_grad_ptr, input_fwd_ptr); + + break; } - case at::ScalarType::Half: - { - moe_permutation_launcher( - input_bwd_ptr, - act_grad_ptr, - nullptr, - row_id_map_ptr, - prob_ptr, - num_tokens, - num_topK, - num_cols, - 0, - stream, - prob_grad_ptr, - input_fwd_ptr); - - break; + case at::ScalarType::Half: { + moe_permutation_launcher(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, + prob_ptr, num_tokens, num_topK, num_cols, 0, stream, + prob_grad_ptr, input_fwd_ptr); + + break; } - case at::ScalarType::BFloat16: - { - moe_permutation_launcher<__nv_bfloat16, true>( - input_bwd_ptr, - act_grad_ptr, - nullptr, - row_id_map_ptr, - prob_ptr, - num_tokens, - num_topK, - num_cols, - 0, - stream, - prob_grad_ptr, - input_fwd_ptr); - - break; + case at::ScalarType::BFloat16: { + moe_permutation_launcher<__nv_bfloat16, true>( + input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, + num_cols, 0, stream, prob_grad_ptr, input_fwd_ptr); + + break; } - case at::ScalarType::Float8_e5m2: - { - moe_permutation_launcher<__nv_fp8_e5m2, true>( - input_bwd_ptr, - act_grad_ptr, - nullptr, - row_id_map_ptr, - prob_ptr, - num_tokens, - num_topK, - num_cols, - 0, - stream, - prob_grad_ptr, - input_fwd_ptr); - - break; + case at::ScalarType::Float8_e5m2: { + moe_permutation_launcher<__nv_fp8_e5m2, true>( + input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, + num_cols, 0, stream, prob_grad_ptr, input_fwd_ptr); + + break; } - case at::ScalarType::Float8_e4m3fn: - { - moe_permutation_launcher<__nv_fp8_e4m3, true>( - input_bwd_ptr, - act_grad_ptr, - nullptr, - row_id_map_ptr, - prob_ptr, - num_tokens, - num_topK, - num_cols, - 0, - stream, - prob_grad_ptr, - input_fwd_ptr); - - break; + case at::ScalarType::Float8_e4m3fn: { + moe_permutation_launcher<__nv_fp8_e4m3, true>( + input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, + num_cols, 0, stream, prob_grad_ptr, input_fwd_ptr); + + break; } default: - throw std::runtime_error("Wrong activation tensor type."); - } + throw std::runtime_error("Wrong activation tensor type."); + } - return std::make_tuple(act_grad, prob_grad); -} \ No newline at end of file + return std::make_tuple(act_grad, prob_grad); +} diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index beb5f0cbb4..612d0d95f0 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -12,166 +12,161 @@ __all__ = [ - 'Permute', - 'Unpermute', + "Permute", + "Unpermute", ] class _Permute(torch.autograd.Function): - """functional Permute""" - - workspace=None - dtype=None - max_expanded_token_num=0 - - @staticmethod - def forward(ctx, - inp: torch.Tensor, - indices: torch.Tensor, - num_out_tokens: int, - max_token_num: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Empty input check - if not inp.numel(): - return inp, None - - # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert indices.is_cuda, "TransformerEngine needs CUDA." - # Shape check - assert inp.size(0) == indices.size(0), "Permute not possible" - - # Data type check - if indices.dtype != torch.int32: - warnings.warn(f"The data type of the input `indices` of Permute is {indices.dtype}! " - "The recommended type is torch.int32.") - indices = indices.to(torch.int32) - - num_topK = indices.size(1) - - input_max_expanded_token_num = max(max_token_num, inp.size(0)) * num_topK - if _Permute.max_expanded_token_num < input_max_expanded_token_num: - _Permute.max_expanded_token_num = input_max_expanded_token_num - _Permute.workspace = [] - - if _Permute.dtype != inp.dtype: - _Permute.dtype = inp.dtype - _Permute.workspace = [] - - permuted_act, row_id_map, _Permute.workspace = tex.moe_permute( - inp, - indices, - num_out_tokens, - _Permute.workspace, - _Permute.max_expanded_token_num) - - ctx.row_id_map = row_id_map - ctx.num_tokens = indices.size(0) - ctx.num_topK = indices.size(1) - return permuted_act, row_id_map - - - @staticmethod - def backward(ctx, - permuted_act_grad: torch.Tensor, - _, - ) -> Tuple[torch.Tensor, ...]: - # Empty input check - if not permuted_act_grad.numel(): - return permuted_act_grad, None, None, None - - if not permuted_act_grad.is_contiguous(): - permuted_act_grad = permuted_act_grad.contiguous() - - row_id_map = ctx.row_id_map - num_tokens = ctx.num_tokens - num_topK = ctx.num_topK - - unpermuted_act_grad = tex.moe_unpermute_fwd( - permuted_act_grad, - row_id_map, - torch.empty(0), - num_tokens, - num_topK) - return unpermuted_act_grad, None, None, None + """functional Permute""" + + workspace = None + dtype = None + max_expanded_token_num = 0 + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + indices: torch.Tensor, + num_out_tokens: int, + max_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Empty input check + if not inp.numel(): + return inp, None + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert indices.is_cuda, "TransformerEngine needs CUDA." + # Shape check + assert inp.size(0) == indices.size(0), "Permute not possible" + + # Data type check + if indices.dtype != torch.int32: + warnings.warn( + f"The data type of the input `indices` of Permute is {indices.dtype}! " + "The recommended type is torch.int32." + ) + indices = indices.to(torch.int32) + + num_topK = indices.size(1) + + input_max_expanded_token_num = max(max_token_num, inp.size(0)) * num_topK + if _Permute.max_expanded_token_num < input_max_expanded_token_num: + _Permute.max_expanded_token_num = input_max_expanded_token_num + _Permute.workspace = [] + + if _Permute.dtype != inp.dtype: + _Permute.dtype = inp.dtype + _Permute.workspace = [] + + permuted_act, row_id_map, _Permute.workspace = tex.moe_permute( + inp, indices, num_out_tokens, _Permute.workspace, _Permute.max_expanded_token_num + ) + + ctx.row_id_map = row_id_map + ctx.num_tokens = indices.size(0) + ctx.num_topK = indices.size(1) + return permuted_act, row_id_map + + @staticmethod + def backward( + ctx, + permuted_act_grad: torch.Tensor, + _, + ) -> Tuple[torch.Tensor, ...]: + # Empty input check + if not permuted_act_grad.numel(): + return permuted_act_grad, None, None, None + + if not permuted_act_grad.is_contiguous(): + permuted_act_grad = permuted_act_grad.contiguous() + + row_id_map = ctx.row_id_map + num_tokens = ctx.num_tokens + num_topK = ctx.num_topK + + unpermuted_act_grad = tex.moe_unpermute_fwd( + permuted_act_grad, row_id_map, torch.empty(0), num_tokens, num_topK + ) + return unpermuted_act_grad, None, None, None class _Unpermute(torch.autograd.Function): - """functional Unpermute""" - - @staticmethod - def forward(ctx, - inp: torch.Tensor, - row_id_map: torch.Tensor, - probs: torch.Tensor = torch.empty(0), - ) -> torch.Tensor: - # Empty input check - if not inp.numel(): - ctx.probs = probs - return inp - - # None probs check - if probs.numel(): - assert probs.is_cuda, "TransformerEngine needs CUDA." - - if probs.dtype != torch.float32: - warnings.warn(f"The data type of the input `probs` of Unpermute is {probs.dtype}! " - "The recommended type is torch.float32.") - probs = probs.to(torch.float32) - - # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." - - # Data type check - if row_id_map.dtype != torch.int32: - warnings.warn(f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " - "The recommended type is torch.int32.") - row_id_map = row_id_map.to(torch.int32) - - num_topK = probs.size(1) if probs.numel() else 1 - num_tokens = probs.size(0) if probs.numel() else row_id_map.size(0) - - unpermuted_output = tex.moe_unpermute_fwd( - inp, - row_id_map, - probs, - num_tokens, - num_topK) - - ctx.save_for_backward(inp, row_id_map, probs) - return unpermuted_output - - @staticmethod - def backward(ctx, - unpermuted_act_grad: torch.Tensor, - ) -> Tuple[torch.Tensor, None, torch.Tensor]: - # Empty input check - if not unpermuted_act_grad.numel(): - return unpermuted_act_grad, None, ctx.probs - - if not unpermuted_act_grad.is_contiguous(): - unpermuted_act_grad = unpermuted_act_grad.contiguous() - - inp, row_id_map, probs = ctx.saved_tensors - - act_grad = None - if ctx.needs_input_grad[0]: - act_grad, prob_grad = tex.moe_unpermute_bwd( - unpermuted_act_grad, - inp, - row_id_map, - probs) - - if not ctx.needs_input_grad[2]: - prob_grad = None - return act_grad, None, prob_grad - - -def Permute(inp: torch.Tensor, - indices: torch.Tensor, - num_out_tokens: int = -1, - max_token_num: int = -1, + """functional Unpermute""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor = torch.empty(0), + ) -> torch.Tensor: + # Empty input check + if not inp.numel(): + ctx.probs = probs + return inp + + # None probs check + if probs.numel(): + assert probs.is_cuda, "TransformerEngine needs CUDA." + + if probs.dtype != torch.float32: + warnings.warn( + f"The data type of the input `probs` of Unpermute is {probs.dtype}! " + "The recommended type is torch.float32." + ) + probs = probs.to(torch.float32) + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + + # Data type check + if row_id_map.dtype != torch.int32: + warnings.warn( + f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " + "The recommended type is torch.int32." + ) + row_id_map = row_id_map.to(torch.int32) + + num_topK = probs.size(1) if probs.numel() else 1 + num_tokens = probs.size(0) if probs.numel() else row_id_map.size(0) + + unpermuted_output = tex.moe_unpermute_fwd(inp, row_id_map, probs, num_tokens, num_topK) + + ctx.save_for_backward(inp, row_id_map, probs) + return unpermuted_output + + @staticmethod + def backward( + ctx, + unpermuted_act_grad: torch.Tensor, + ) -> Tuple[torch.Tensor, None, torch.Tensor]: + # Empty input check + if not unpermuted_act_grad.numel(): + return unpermuted_act_grad, None, ctx.probs + + if not unpermuted_act_grad.is_contiguous(): + unpermuted_act_grad = unpermuted_act_grad.contiguous() + + inp, row_id_map, probs = ctx.saved_tensors + + act_grad = None + if ctx.needs_input_grad[0]: + act_grad, prob_grad = tex.moe_unpermute_bwd(unpermuted_act_grad, inp, row_id_map, probs) + + if not ctx.needs_input_grad[2]: + prob_grad = None + return act_grad, None, prob_grad + + +def Permute( + inp: torch.Tensor, + indices: torch.Tensor, + num_out_tokens: int = -1, + max_token_num: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Permute the tokens based on the indices. Token with the same index will be grouped together. @@ -192,9 +187,11 @@ def Permute(inp: torch.Tensor, """ return _Permute.apply(inp, indices, num_out_tokens, max_token_num) -def Unpermute(inp: torch.Tensor, - row_id_map: torch.Tensor, - probs: torch.Tensor = torch.empty(0), + +def Unpermute( + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor = torch.empty(0), ) -> torch.Tensor: """ Unpermute a tensor with permuted tokens, and optionally merge the tokens with their @@ -208,7 +205,7 @@ def Unpermute(inp: torch.Tensor, The tensor of a mapping table for sorted indices used to unpermute the tokens, which is the second output tensor of `Permute`. probs: torch.Tensor - The tensor of probabilities corresponding to the permuted tokens. If provided, + The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. """ From 816c8f6f59c19c56056413fc651cdb706052abd6 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Tue, 9 Jul 2024 08:46:53 +0000 Subject: [PATCH 13/33] Rewrite the unit test Signed-off-by: Jiang Shao --- tests/pytorch/test_permutation.py | 421 +++++++++------------- transformer_engine/pytorch/permutation.py | 28 +- transformer_engine/pytorch/utils.py | 2 + 3 files changed, 195 insertions(+), 256 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 82924e74ac..3020c3cdfc 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -3,243 +3,232 @@ # See LICENSE for license information. import torch -import triton -import torch.cuda.nvtx as nvtx +import pytest +from typing import Dict -from transformer_engine.pytorch import Permute as permute_topK, Unpermute as unpermute_topK +from transformer_engine.pytorch import Permute as te_permute, Unpermute as te_unpermute +from transformer_engine.pytorch.utils import ( + is_bf16_compatible, + is_fp8_compatible, +) -def permute(tokens, indices, num_out_tokens: int = 0): - """Permute the tokens based on the indices. Token with the same index will be grouped together. - The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately. +def pytorch_permute(tokens, indices, num_out_tokens: int = None): + """ + Permute the tokens based on the indices. Token with the same index will be grouped together. + The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately. + Args: - tokens (torch.Tensor): The input token tensor. - indices (torch.Tensor): The token to expert indices tensor, should have a shape of [num_tokens, topk]. - topk (int, optional): The topk value. Defaults to 1. - num_out_tokens (int, optional): The effective token count, when enabling the capacity factor, should equal the number of tokens not dropped. By default, set to None, meaning no tokens are dropped. + tokens: torch.Tensor + The input token tensor. + indices: torch.Tensor + The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk]. + num_out_tokens: int, optional + The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. + By default, set to None, meaning no tokens are dropped. Returns: - torch.Tensor: The permuted tensor. - torch.Tensor: The sorted_indices corresponding permuted tensor. + torch.Tensor: + The permuted tensor. + torch.Tensor: + The sorted_indices corresponding permuted tensor. """ - - topk = indices.size(1) + if indices.dim() == 1: + topk = 1 + else: + topk = indices.size(1) flatten_indices = indices.view(-1) sorted_indices = torch.argsort(flatten_indices, stable=True) - if num_out_tokens > 0: - sorted_indices = sorted_indices[:num_out_tokens] - permuted_tokens = tokens.index_select(0, sorted_indices // topk) + num_out_tokens = num_out_tokens if num_out_tokens is not None else flatten_indices.size(0) + + permuted_tokens = tokens.index_select(0, sorted_indices[:num_out_tokens] // topk) return permuted_tokens, sorted_indices -def unpermute( +def pytorch_unpermute( permuted_tokens: torch.Tensor, sorted_indices: torch.Tensor, - probs: torch.Tensor = torch.empty(0), + probs: torch.Tensor = None, ): - """Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities. + """ + Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their + corresponding probabilities. Args: - permuted_tokens (torch.Tensor): The tensor of permuted tokens to be unpermuted. - sorted_indices (torch.Tensor): The tensor of sorted indices used to unpermute the tokens. - probs (torch.Tensor, optional): The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. + permuted_tokens: torch.Tensor + The tensor of permuted tokens to be unpermuted. + sorted_indices: torch.Tensor + The tensor of sorted indices used to unpermute the tokens. + probs: torch.Tensor, optional + The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will + be merged with their respective probabilities. Returns: - torch.Tensor: The unpermuted tokens, optionally merged with probabilities. + torch.Tensor: + The unpermuted tokens, optionally merged with probabilities. """ - num_unpermuted_tokens = probs.numel() - topk = probs.size(1) + if probs is not None: + # Unpermute and merge the tokens with their probabilities + num_unpermuted_tokens = probs.numel() + topk = probs.size(1) + else: + # Unpermute the tokens without merge + num_unpermuted_tokens = sorted_indices.size(0) + topk = 1 unpermuted_tokens = torch.zeros( [num_unpermuted_tokens, permuted_tokens.shape[-1]], dtype=permuted_tokens.dtype, device=permuted_tokens.device, ) - unpermuted_tokens.index_copy_(0, sorted_indices, permuted_tokens) - unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1)) + unpermuted_tokens.index_copy_(0, sorted_indices[:permuted_tokens.size(0)], permuted_tokens) + unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1)) if probs is not None: unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1) unpermuted_tokens = unpermuted_tokens.sum(dim=1) - return unpermuted_tokens -def permute_topK_test( +def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: + """Estimated tolerances for a datatype + + Based on tolerances for torch.testing.assert_close. + + """ + if dtype == torch.float32: + return dict(rtol=1.0e-6, atol=1.0e-6) + if dtype == torch.float16: + return dict(rtol=3.0e-3, atol=1.0e-5) + if dtype == torch.bfloat16: + return dict(rtol=2.0e-2, atol=1.0e-5) + if dtype == torch.float8_e5m2 or dtype == torch.float8_e4m3fn: + return dict(rtol=2.0e-1, atol=1.0e-1) + raise ValueError(f"Unsuppored dtype ({dtype})") + + +param_dtypes = [torch.float32, torch.float16] +if is_bf16_compatible(): + param_dtypes.append(torch.bfloat16) +if is_fp8_compatible(): + param_dtypes.extend([torch.float8_e5m2, torch.float8_e4m3fn]) + +@pytest.mark.parametrize("dtype", param_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 4050]) +@pytest.mark.parametrize("with_probs", [True, False]) +def test_permutation( dtype, - num_token, + num_tokens, num_expert, hidden_size, - num_topK, - num_out_tokens=None, - PRINT=False, + topK, + num_out_tokens, + with_probs, BENCHMARK=False, ): + if not with_probs and topK > 1: + print("Only permutations with topK=1 and without probabilities are supported.") + return + + if topK > num_expert: + print("topK should be smaller than the number of experts.") + return if num_out_tokens == None: - num_out_tokens = num_token * num_topK + num_out_tokens = num_tokens * topK print( - f"{dtype} token:{num_token} hidden_size:{hidden_size} expert:{num_expert} topK:{num_topK}" + f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {dtype}" ) - is_fp8 = dtype in [torch.float8_e5m2, torch.float8_e4m3fn] - - permute_input = torch.rand((num_token, hidden_size), dtype=torch.float32).cuda() - # for i in range(num_token): - # for j in range(hidden_size): - # permute_input[i][j] = i * 100 + j - permute_input = permute_input.to(dtype) - if is_fp8: + permute_input = torch.rand((num_tokens, hidden_size), dtype=torch.float32).cuda().to(dtype) + if dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: permute_input = permute_input.half() - permute_input.requires_grad_(True) - if num_token > 0: - indices = torch.stack([torch.randperm(num_expert)[:num_topK] for _ in range(num_token)]) + if num_tokens > 0: + indices = torch.stack([torch.randperm(num_expert)[:topK] for _ in range(num_tokens)]) else: - indices = torch.empty((num_token, num_topK)) + indices = torch.empty((num_tokens, topK)) indices = indices.to(torch.int32).cuda() - # probs = torch.tensor([[0.1, 0.9], - # [0.2, 0.8], - # [0.3, 0.7]]) - # 0.5 - # probs = torch.ones_like(indices) / 2 - # rand - probs = torch.rand(num_token, num_topK).cuda() - row_sums = probs.sum(dim=1, keepdim=True) - probs = probs / row_sums - probs.requires_grad_(True) - - if PRINT: - print(permute_input) - print(indices) - print(probs) + probs = None + if with_probs: + probs = torch.rand(num_tokens, topK).cuda() + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs.requires_grad_(True) ################################################################################################################################### # # PyTorch # ################################################################################################################################### - nvtx.range_push("PyTorch permute forward") - permute_output, sorted_indices = permute(permute_input, indices, num_out_tokens) - nvtx.range_pop() - + permute_output, sorted_indices = pytorch_permute(permute_input, indices, num_out_tokens) permute_bwd_input = torch.rand_like(permute_output) - # for i in range(num_token * num_topK): - # for j in range(hidden_size): - # permute_bwd_input[i][j] = i * 100 + j - - nvtx.range_push("PyTorch permute backward") permute_output.backward(permute_bwd_input, retain_graph=True) - nvtx.range_pop() unpermute_input = permute_output.detach() unpermute_input.requires_grad_(True) - unpermute_output = unpermute(unpermute_input, sorted_indices, probs=probs) - - if PRINT: - print("--------------unpermute fwd permute_input--------------") - print(unpermute_input) - print("--------------unpermute fwd output--------------") - print(unpermute_output) - + unpermute_output = pytorch_unpermute(unpermute_input, sorted_indices, probs=probs) unpermute_bwd_input = torch.rand_like(unpermute_output) - # for i in range(num_token): - # for j in range(hidden_size): - # unpermute_bwd_input[i][j] = i * 2000 + j * 20 - - if PRINT: - print("--------------unpermute bwd permute_input--------------") - print(unpermute_bwd_input) - unpermute_output.backward(unpermute_bwd_input, retain_graph=True) - if PRINT: - print("--------------unpermute bwd output act grad--------------") - print(permute_output.grad) - print("--------------unpermute bwd output probs grad--------------") - print(probs.grad) ################################################################################################################################### # - # Mine + # TE # ################################################################################################################################### - new_permute_input = permute_input.detach().to(dtype) - new_permute_bwd_input = permute_bwd_input.detach().to(dtype) - new_unpermute_bwd_input = unpermute_bwd_input.detach().to(dtype) - new_permute_input.requires_grad_(True) - - new_permute_output, row_id_map = permute_topK(new_permute_input, indices, num_out_tokens) - - assert torch.allclose(permute_output.float(), new_permute_output.float()) - - if PRINT: - print("--------------row_id_map--------------") - print(row_id_map) - print("--------------new_permute_input--------------") - print(new_permute_input) - print("--------------new_permute_output--------------") - print(new_permute_output) - - new_permute_output.backward(new_permute_bwd_input, retain_graph=True) - - if torch.allclose(permute_input.grad.float(), new_permute_input.grad.float()) == False: - original_inputs = new_permute_input.grad.float().cpu().numpy().flatten() - original_output = permute_input.grad.float().cpu().numpy().flatten() - max_abs_error = abs(original_inputs - original_output).max() - print(f"permute_topK bwd max error (mine vs pytorch): \t\t\t{max_abs_error:.3e} ({dtype})") - - if PRINT: - print(permute_input.grad) - print(new_permute_input.grad) - - new_probs = probs.detach() - new_probs.requires_grad_(True) - if num_topK == 1: - new_probs = torch.empty(0) - new_unpermute_input = new_permute_output.detach() - new_unpermute_input.requires_grad_(True) - - new_unpermute_output = unpermute_topK(new_unpermute_input, row_id_map, new_probs) - - if torch.allclose(unpermute_output.float(), new_unpermute_output.float()) == False: - original_inputs = unpermute_output.float().cpu().detach().numpy().flatten() - original_output = new_unpermute_output.float().cpu().detach().numpy().flatten() - max_abs_error = abs(original_inputs - original_output).max() - print(f"unpermute_topK fwd max error (mine vs pytorch): \t\t{max_abs_error:.3e} ({dtype})") - - if PRINT: - print(unpermute_output) - print(new_unpermute_output) - - new_unpermute_output.backward(new_unpermute_bwd_input, retain_graph=True) - - if torch.allclose(unpermute_input.grad.float(), new_unpermute_input.grad.float()) == False: - original_inputs = unpermute_input.grad.float().cpu().detach().numpy().flatten() - original_output = new_unpermute_input.grad.float().cpu().detach().numpy().flatten() - max_abs_error = abs(original_inputs - original_output).max() - print( - "unpermute_topK bwd act_grad max error (mine vs pytorch):" - f" \t{max_abs_error:.3e} ({dtype})" - ) - if PRINT: - print(new_unpermute_input.grad) - print(unpermute_input.grad) - - if num_topK > 1 and torch.allclose(new_probs.grad, probs.grad) == False: - original_inputs = new_probs.grad.float().cpu().detach().numpy().flatten() - original_output = probs.grad.float().cpu().detach().numpy().flatten() - max_abs_error = abs(original_inputs - original_output).max() - print( - "unpermute_topK bwd prob_grad max error (mine vs pytorch):" - f" \t{max_abs_error:.3e} ({dtype})" - ) - if PRINT: - print(new_probs.grad) - print(probs.grad) + te_permute_input = permute_input.detach().to(dtype) + te_permute_bwd_input = permute_bwd_input.detach().to(dtype) + te_unpermute_bwd_input = unpermute_bwd_input.detach().to(dtype) + te_permute_input.requires_grad_(True) + + te_permute_output, row_id_map = te_permute(te_permute_input, indices, num_out_tokens) + te_permute_output.backward(te_permute_bwd_input, retain_graph=True) + + te_probs = None + if with_probs: + te_probs = probs.detach() + te_probs.requires_grad_(True) + te_unpermute_input = te_permute_output.detach() + te_unpermute_input.requires_grad_(True) + + te_unpermute_output = te_unpermute(te_unpermute_input, row_id_map, te_probs) + te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) + + tols = dtype_tols(dtype) + + torch.testing.assert_close( + permute_output.float(), + te_permute_output.float(), + msg=f"Mismatch in te_permute fwd") + torch.testing.assert_close( + permute_input.grad.float(), + te_permute_input.grad.float(), + msg=f"Mismatch in te_permute bwd", + **tols) + torch.testing.assert_close( + unpermute_output.float(), + te_unpermute_output.float(), + msg=f"Mismatch in te_unpermute fwd", + **tols) + torch.testing.assert_close( + unpermute_input.grad.float(), + te_unpermute_input.grad.float(), + msg=f"Mismatch in te_unpermute bwd", + **tols) + if with_probs: + torch.testing.assert_close( + probs.grad.float(), + te_probs.grad.float(), + msg=f"Mismatch in te_unpermute bwd", + **tols) if not permute_input.numel(): print("Empty permute_input activation test passed.") @@ -260,13 +249,11 @@ def backward_wrapper( return act.backward(backward_input, retain_graph=retain_graph) if BENCHMARK: - print(f"----permute topK----") - t = perf_test_cuda_kernel(lambda: permute(permute_input, indices, num_out_tokens)) - print(f"pytorch fwd: {t:.3f} ms") - t = perf_test_cuda_kernel(lambda: permute_topK(new_permute_input, indices, num_out_tokens)) - print(f"new fwd: {t:.3f} ms") + t1 = perf_test_cuda_kernel(lambda: pytorch_permute(permute_input, indices, num_out_tokens)) + t2 = perf_test_cuda_kernel(lambda: te_permute(te_permute_input, indices, num_out_tokens)) + print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") - t = perf_test_cuda_kernel( + t1 = perf_test_cuda_kernel( lambda: backward_wrapper( permute_output, permute_bwd_input, @@ -275,46 +262,42 @@ def backward_wrapper( accumulate_grad=False, ) ) - print(f"pytorch bwd: {t:.3f} ms") - t = perf_test_cuda_kernel( + t2 = perf_test_cuda_kernel( lambda: backward_wrapper( - new_permute_output, - new_permute_bwd_input, - forward_input=[new_permute_input], + te_permute_output, + te_permute_bwd_input, + forward_input=[te_permute_input], retain_graph=True, accumulate_grad=False, ) ) - print(f"new bwd: {t:.3f} ms") + print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") - print(f"----unpermute topK----") - t = perf_test_cuda_kernel(lambda: unpermute(unpermute_input, sorted_indices, probs=probs)) - print(f"pytorch fwd: {t:.3f} ms") - t = perf_test_cuda_kernel( - lambda: unpermute_topK(new_unpermute_input, row_id_map, new_probs) + t1 = perf_test_cuda_kernel(lambda: pytorch_unpermute(unpermute_input, sorted_indices, probs=probs)) + t2 = perf_test_cuda_kernel( + lambda: te_unpermute(te_unpermute_input, row_id_map, te_probs) ) - print(f"new fwd: {t:.3f} ms") + print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") - t = perf_test_cuda_kernel( + t1 = perf_test_cuda_kernel( lambda: backward_wrapper( unpermute_output, unpermute_bwd_input, - forward_input=[unpermute_input, probs], + forward_input=[unpermute_input, probs] if with_probs else [unpermute_input], retain_graph=True, accumulate_grad=False, ) ) - print(f"pytorch bwd: {t:.3f} ms") - t = perf_test_cuda_kernel( + t2 = perf_test_cuda_kernel( lambda: backward_wrapper( - new_unpermute_output, - new_unpermute_bwd_input, - forward_input=[new_unpermute_input, new_probs], + te_unpermute_output, + te_unpermute_bwd_input, + forward_input=[te_unpermute_input, te_probs] if with_probs else [te_unpermute_input], retain_graph=True, accumulate_grad=False, ) ) - print(f"new bwd: {t:.3f} ms") + print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") def perf_test_cuda_kernel(cuda_kernel_fn): @@ -334,56 +317,6 @@ def perf_test_cuda_kernel(cuda_kernel_fn): torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) - # print(f"Elapsed Time: {elapsed_time_ms / 100} ms") return elapsed_time_ms / 100 else: print("CUDA is not available.") - - -def test_permute_topK(): - - torch.manual_seed(1) - - num_token = 4096 * 2 - num_expert = 8 - hidden_size = 4096 - num_topK = 1 - - num_out_tokens = num_token * num_topK - 20 - # num_out_tokens = 0 - - Benchmark = False - print("GPU:", torch.cuda.get_device_name(0)) - - dtype = torch.float32 - permute_topK_test( - dtype, num_token, num_expert, hidden_size, num_topK, num_out_tokens, False, Benchmark - ) - dtype = torch.float16 - permute_topK_test( - dtype, num_token, num_expert, hidden_size, num_topK, num_out_tokens, False, Benchmark - ) - dtype = torch.bfloat16 - permute_topK_test( - dtype, num_token, num_expert, hidden_size, num_topK, num_out_tokens, False, Benchmark - ) - dtype = torch.float8_e5m2 - permute_topK_test( - dtype, num_token, num_expert, hidden_size, num_topK, num_out_tokens, False, Benchmark - ) - dtype = torch.float8_e4m3fn - permute_topK_test( - dtype, num_token, num_expert, hidden_size, num_topK, num_out_tokens, False, Benchmark - ) - dtype = torch.bfloat16 - permute_topK_test(dtype, num_token, 4, hidden_size, 1, None, False, Benchmark) - permute_topK_test(dtype, num_token, 5, hidden_size, 2, None, False, Benchmark) - permute_topK_test(dtype, num_token, 6, hidden_size, 3, None, False, Benchmark) - permute_topK_test(dtype, num_token, 7, hidden_size, 4, None, False, Benchmark) - permute_topK_test(dtype, num_token, 8, hidden_size, 5, None, False, Benchmark) - num_token = 0 - permute_topK_test(dtype, num_token, 8, hidden_size, 5, None, False, Benchmark) - - -if __name__ == "__main__": - test_permute_topK() diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 612d0d95f0..d6c93c1295 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -50,9 +50,9 @@ def forward( ) indices = indices.to(torch.int32) - num_topK = indices.size(1) + topK = indices.size(1) - input_max_expanded_token_num = max(max_token_num, inp.size(0)) * num_topK + input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK if _Permute.max_expanded_token_num < input_max_expanded_token_num: _Permute.max_expanded_token_num = input_max_expanded_token_num _Permute.workspace = [] @@ -67,7 +67,7 @@ def forward( ctx.row_id_map = row_id_map ctx.num_tokens = indices.size(0) - ctx.num_topK = indices.size(1) + ctx.topK = indices.size(1) return permuted_act, row_id_map @staticmethod @@ -85,10 +85,10 @@ def backward( row_id_map = ctx.row_id_map num_tokens = ctx.num_tokens - num_topK = ctx.num_topK + topK = ctx.topK unpermuted_act_grad = tex.moe_unpermute_fwd( - permuted_act_grad, row_id_map, torch.empty(0), num_tokens, num_topK + permuted_act_grad, row_id_map, torch.empty(0), num_tokens, topK ) return unpermuted_act_grad, None, None, None @@ -101,7 +101,7 @@ def forward( ctx, inp: torch.Tensor, row_id_map: torch.Tensor, - probs: torch.Tensor = torch.empty(0), + probs: torch.Tensor, ) -> torch.Tensor: # Empty input check if not inp.numel(): @@ -109,7 +109,7 @@ def forward( return inp # None probs check - if probs.numel(): + if probs is not None: assert probs.is_cuda, "TransformerEngine needs CUDA." if probs.dtype != torch.float32: @@ -118,6 +118,13 @@ def forward( "The recommended type is torch.float32." ) probs = probs.to(torch.float32) + + num_tokens = probs.size(0) + topK = probs.size(1) + else: + num_tokens = row_id_map.size(0) + topK = 1 + probs = torch.empty(0) # Device check assert inp.is_cuda, "TransformerEngine needs CUDA." @@ -131,10 +138,7 @@ def forward( ) row_id_map = row_id_map.to(torch.int32) - num_topK = probs.size(1) if probs.numel() else 1 - num_tokens = probs.size(0) if probs.numel() else row_id_map.size(0) - - unpermuted_output = tex.moe_unpermute_fwd(inp, row_id_map, probs, num_tokens, num_topK) + unpermuted_output = tex.moe_unpermute_fwd(inp, row_id_map, probs, num_tokens, topK) ctx.save_for_backward(inp, row_id_map, probs) return unpermuted_output @@ -191,7 +195,7 @@ def Permute( def Unpermute( inp: torch.Tensor, row_id_map: torch.Tensor, - probs: torch.Tensor = torch.empty(0), + probs: torch.Tensor = None, ) -> torch.Tensor: """ Unpermute a tensor with permuted tokens, and optionally merge the tokens with their diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index e83369c671..92db009de3 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -242,6 +242,8 @@ def is_bf16_compatible() -> None: """ return torch.cuda.get_device_capability()[0] >= 8 +def is_fp8_compatible() -> None: + return float(torch.__version__[0:3]) > 2.2 @functools.cache def get_cudnn_version() -> Tuple[int, int, int]: From 67c4764ec81f23634f46db7f16a66208c4ce3507 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Tue, 9 Jul 2024 18:52:14 +0000 Subject: [PATCH 14/33] Enable skipping if FP8 is unavailable Signed-off-by: Jiang Shao --- tests/pytorch/test_permutation.py | 17 +++++++++++------ transformer_engine/pytorch/utils.py | 2 -- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 3020c3cdfc..8f7ea5a0d7 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -7,10 +7,15 @@ from typing import Dict from transformer_engine.pytorch import Permute as te_permute, Unpermute as te_unpermute -from transformer_engine.pytorch.utils import ( - is_bf16_compatible, - is_fp8_compatible, -) +from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +# Only run FP8 tests on H100. +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) def pytorch_permute(tokens, indices, num_out_tokens: int = None): @@ -110,7 +115,7 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: param_dtypes = [torch.float32, torch.float16] if is_bf16_compatible(): param_dtypes.append(torch.bfloat16) -if is_fp8_compatible(): +if fp8_available: param_dtypes.extend([torch.float8_e5m2, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", param_dtypes) @@ -146,7 +151,7 @@ def test_permutation( ) permute_input = torch.rand((num_tokens, hidden_size), dtype=torch.float32).cuda().to(dtype) - if dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: + if fp8_available and dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: permute_input = permute_input.half() permute_input.requires_grad_(True) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 92db009de3..e83369c671 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -242,8 +242,6 @@ def is_bf16_compatible() -> None: """ return torch.cuda.get_device_capability()[0] >= 8 -def is_fp8_compatible() -> None: - return float(torch.__version__[0:3]) > 2.2 @functools.cache def get_cudnn_version() -> Tuple[int, int, int]: From 76c563583c50e457a38926af928d044d804542c4 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Tue, 9 Jul 2024 19:05:55 +0000 Subject: [PATCH 15/33] Rename exposed C++ api and reorder its parameters Signed-off-by: Jiang Shao --- tests/pytorch/test_permutation.py | 43 ++++----- .../include/transformer_engine/permutation.h | 9 +- .../common/permutation/permutation.cu | 12 +-- .../pytorch/csrc/extensions/permutation.cu | 90 +++++++++---------- transformer_engine/pytorch/permutation.py | 2 +- 5 files changed, 78 insertions(+), 78 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 8f7ea5a0d7..48de0f5b94 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -29,13 +29,13 @@ def pytorch_permute(tokens, indices, num_out_tokens: int = None): indices: torch.Tensor The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk]. num_out_tokens: int, optional - The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. + The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. By default, set to None, meaning no tokens are dropped. Returns: - torch.Tensor: + torch.Tensor: The permuted tensor. - torch.Tensor: + torch.Tensor: The sorted_indices corresponding permuted tensor. """ if indices.dim() == 1: @@ -87,7 +87,7 @@ def pytorch_unpermute( device=permuted_tokens.device, ) - unpermuted_tokens.index_copy_(0, sorted_indices[:permuted_tokens.size(0)], permuted_tokens) + unpermuted_tokens.index_copy_(0, sorted_indices[: permuted_tokens.size(0)], permuted_tokens) unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1)) if probs is not None: unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1) @@ -118,6 +118,7 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: if fp8_available: param_dtypes.extend([torch.float8_e5m2, torch.float8_e4m3fn]) + @pytest.mark.parametrize("dtype", param_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [8, 16]) @@ -146,9 +147,7 @@ def test_permutation( if num_out_tokens == None: num_out_tokens = num_tokens * topK - print( - f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {dtype}" - ) + print(f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {dtype}") permute_input = torch.rand((num_tokens, hidden_size), dtype=torch.float32).cuda().to(dtype) if fp8_available and dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: @@ -210,30 +209,30 @@ def test_permutation( tols = dtype_tols(dtype) torch.testing.assert_close( - permute_output.float(), - te_permute_output.float(), - msg=f"Mismatch in te_permute fwd") + permute_output.float(), te_permute_output.float(), msg=f"Mismatch in te_permute fwd" + ) torch.testing.assert_close( permute_input.grad.float(), te_permute_input.grad.float(), msg=f"Mismatch in te_permute bwd", - **tols) + **tols, + ) torch.testing.assert_close( unpermute_output.float(), te_unpermute_output.float(), msg=f"Mismatch in te_unpermute fwd", - **tols) + **tols, + ) torch.testing.assert_close( unpermute_input.grad.float(), te_unpermute_input.grad.float(), msg=f"Mismatch in te_unpermute bwd", - **tols) + **tols, + ) if with_probs: torch.testing.assert_close( - probs.grad.float(), - te_probs.grad.float(), - msg=f"Mismatch in te_unpermute bwd", - **tols) + probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols + ) if not permute_input.numel(): print("Empty permute_input activation test passed.") @@ -278,10 +277,10 @@ def backward_wrapper( ) print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") - t1 = perf_test_cuda_kernel(lambda: pytorch_unpermute(unpermute_input, sorted_indices, probs=probs)) - t2 = perf_test_cuda_kernel( - lambda: te_unpermute(te_unpermute_input, row_id_map, te_probs) + t1 = perf_test_cuda_kernel( + lambda: pytorch_unpermute(unpermute_input, sorted_indices, probs=probs) ) + t2 = perf_test_cuda_kernel(lambda: te_unpermute(te_unpermute_input, row_id_map, te_probs)) print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") t1 = perf_test_cuda_kernel( @@ -297,7 +296,9 @@ def backward_wrapper( lambda: backward_wrapper( te_unpermute_output, te_unpermute_bwd_input, - forward_input=[te_unpermute_input, te_probs] if with_probs else [te_unpermute_input], + forward_input=( + [te_unpermute_input, te_probs] if with_probs else [te_unpermute_input] + ), retain_graph=True, accumulate_grad=False, ) diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h index a2e4661883..fd18669008 100644 --- a/transformer_engine/common/include/transformer_engine/permutation.h +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -10,10 +10,9 @@ #include "transformer_engine.h" template -void moe_permutation_launcher(const void *input, void *output, const int *sorted_row_id, - int *row_id_map, const float *prob, const int num_rows, - const int num_topK, const int num_cols, const int num_out_tokens, - cudaStream_t stream, float *prob_grad = nullptr, - const void *input_fwd = nullptr); +void nvte_permutation(const void *input, void *output, const int *sorted_row_id, int *row_id_map, + const float *prob, const int num_rows, const int num_topK, const int num_cols, + const int num_out_tokens, float *prob_grad = nullptr, + const void *input_fwd = nullptr, cudaStream_t stream = nullptr); #endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index d0a63433c3..beccd3ec01 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -200,10 +200,10 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac } template -void moe_permutation_launcher(const void *input_, void *output_, const int *sorted_row_id, - int *row_id_map, const float *prob, const int num_rows, - const int num_topK, const int num_cols, const int num_out_tokens, - cudaStream_t stream, float *prob_grad, const void *input_fwd_) { +void nvte_permutation(const void *input_, void *output_, const int *sorted_row_id, int *row_id_map, + const float *prob, const int num_rows, const int num_topK, const int num_cols, + const int num_out_tokens, float *prob_grad, const void *input_fwd_, + cudaStream_t stream) { using TCompute = typename std::conditional<(std::is_same::value || std::is_same::value), half, T>::type; @@ -279,10 +279,10 @@ void moe_permutation_launcher(const void *input_, void *output_, const int *sort } #define FUNCTION_INSTANTIATION(T, FWD) \ - template void moe_permutation_launcher( \ + template void nvte_permutation( \ const void *input, void *output, const int *sorted_row_id, int *row_id_map, \ const float *prob, const int num_rows, const int num_topK, const int num_cols, \ - const int num_out_tokens, cudaStream_t stream, float *prob_grad, const void *input_fwd); + const int num_out_tokens, float *prob_grad, const void *input_fwd, cudaStream_t stream); FUNCTION_INSTANTIATION(float, true) FUNCTION_INSTANTIATION(float, false) diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index f7d8663751..bb1ed265b1 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -73,37 +73,37 @@ std::tuple> moe_permute(Tensor input, Tensor switch (_st) { case at::ScalarType::Float: { - moe_permutation_launcher(input_ptr, permuted_output_ptr, sorted_row_id_ptr, - row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, - num_out_tokens, stream); + nvte_permutation(input_ptr, permuted_output_ptr, sorted_row_id_ptr, + row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, + num_out_tokens, nullptr, nullptr, stream); break; } case at::ScalarType::Half: { - moe_permutation_launcher(input_ptr, permuted_output_ptr, sorted_row_id_ptr, - row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, - num_out_tokens, stream); + nvte_permutation(input_ptr, permuted_output_ptr, sorted_row_id_ptr, + row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, + num_out_tokens, nullptr, nullptr, stream); break; } case at::ScalarType::BFloat16: { - moe_permutation_launcher<__nv_bfloat16, true>( - input_ptr, permuted_output_ptr, sorted_row_id_ptr, row_id_map_ptr, nullptr, num_tokens, - num_topK, num_cols, num_out_tokens, stream); + nvte_permutation<__nv_bfloat16, true>(input_ptr, permuted_output_ptr, sorted_row_id_ptr, + row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, + num_out_tokens, nullptr, nullptr, stream); break; } case at::ScalarType::Float8_e5m2: { - moe_permutation_launcher<__nv_fp8_e5m2, true>( - input_ptr, permuted_output_ptr, sorted_row_id_ptr, row_id_map_ptr, nullptr, num_tokens, - num_topK, num_cols, num_out_tokens, stream); + nvte_permutation<__nv_fp8_e5m2, true>(input_ptr, permuted_output_ptr, sorted_row_id_ptr, + row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, + num_out_tokens, nullptr, nullptr, stream); break; } case at::ScalarType::Float8_e4m3fn: { - moe_permutation_launcher<__nv_fp8_e4m3, true>( - input_ptr, permuted_output_ptr, sorted_row_id_ptr, row_id_map_ptr, nullptr, num_tokens, - num_topK, num_cols, num_out_tokens, stream); + nvte_permutation<__nv_fp8_e4m3, true>(input_ptr, permuted_output_ptr, sorted_row_id_ptr, + row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, + num_out_tokens, nullptr, nullptr, stream); break; } @@ -134,37 +134,37 @@ Tensor moe_unpermute_fwd(Tensor input, Tensor row_id_map, Tensor prob, int64_t n switch (_st) { case at::ScalarType::Float: { - moe_permutation_launcher(input_ptr, unpermuted_output_ptr, nullptr, - row_id_map_ptr, prob_ptr, num_tokens, num_topK, - num_cols, 0, stream); + nvte_permutation(input_ptr, unpermuted_output_ptr, nullptr, row_id_map_ptr, + prob_ptr, num_tokens, num_topK, num_cols, 0, nullptr, nullptr, + stream); break; } case at::ScalarType::Half: { - moe_permutation_launcher(input_ptr, unpermuted_output_ptr, nullptr, - row_id_map_ptr, prob_ptr, num_tokens, num_topK, - num_cols, 0, stream); + nvte_permutation(input_ptr, unpermuted_output_ptr, nullptr, row_id_map_ptr, + prob_ptr, num_tokens, num_topK, num_cols, 0, nullptr, nullptr, + stream); break; } case at::ScalarType::BFloat16: { - moe_permutation_launcher<__nv_bfloat16, false>(input_ptr, unpermuted_output_ptr, nullptr, - row_id_map_ptr, prob_ptr, num_tokens, num_topK, - num_cols, 0, stream); + nvte_permutation<__nv_bfloat16, false>(input_ptr, unpermuted_output_ptr, nullptr, + row_id_map_ptr, prob_ptr, num_tokens, num_topK, + num_cols, 0, nullptr, nullptr, stream); break; } case at::ScalarType::Float8_e5m2: { - moe_permutation_launcher<__nv_fp8_e5m2, false>(input_ptr, unpermuted_output_ptr, nullptr, - row_id_map_ptr, prob_ptr, num_tokens, num_topK, - num_cols, 0, stream); + nvte_permutation<__nv_fp8_e5m2, false>(input_ptr, unpermuted_output_ptr, nullptr, + row_id_map_ptr, prob_ptr, num_tokens, num_topK, + num_cols, 0, nullptr, nullptr, stream); break; } case at::ScalarType::Float8_e4m3fn: { - moe_permutation_launcher<__nv_fp8_e4m3, false>(input_ptr, unpermuted_output_ptr, nullptr, - row_id_map_ptr, prob_ptr, num_tokens, num_topK, - num_cols, 0, stream); + nvte_permutation<__nv_fp8_e4m3, false>(input_ptr, unpermuted_output_ptr, nullptr, + row_id_map_ptr, prob_ptr, num_tokens, num_topK, + num_cols, 0, nullptr, nullptr, stream); break; } @@ -203,37 +203,37 @@ std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, switch (_st) { case at::ScalarType::Float: { - moe_permutation_launcher(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, - prob_ptr, num_tokens, num_topK, num_cols, 0, stream, - prob_grad_ptr, input_fwd_ptr); + nvte_permutation(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, + num_tokens, num_topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, + stream); break; } case at::ScalarType::Half: { - moe_permutation_launcher(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, - prob_ptr, num_tokens, num_topK, num_cols, 0, stream, - prob_grad_ptr, input_fwd_ptr); + nvte_permutation(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, + num_tokens, num_topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, + stream); break; } case at::ScalarType::BFloat16: { - moe_permutation_launcher<__nv_bfloat16, true>( - input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, - num_cols, 0, stream, prob_grad_ptr, input_fwd_ptr); + nvte_permutation<__nv_bfloat16, true>(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, + prob_ptr, num_tokens, num_topK, num_cols, 0, + prob_grad_ptr, input_fwd_ptr, stream); break; } case at::ScalarType::Float8_e5m2: { - moe_permutation_launcher<__nv_fp8_e5m2, true>( - input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, - num_cols, 0, stream, prob_grad_ptr, input_fwd_ptr); + nvte_permutation<__nv_fp8_e5m2, true>(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, + prob_ptr, num_tokens, num_topK, num_cols, 0, + prob_grad_ptr, input_fwd_ptr, stream); break; } case at::ScalarType::Float8_e4m3fn: { - moe_permutation_launcher<__nv_fp8_e4m3, true>( - input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, - num_cols, 0, stream, prob_grad_ptr, input_fwd_ptr); + nvte_permutation<__nv_fp8_e4m3, true>(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, + prob_ptr, num_tokens, num_topK, num_cols, 0, + prob_grad_ptr, input_fwd_ptr, stream); break; } diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index d6c93c1295..6468494720 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -118,7 +118,7 @@ def forward( "The recommended type is torch.float32." ) probs = probs.to(torch.float32) - + num_tokens = probs.size(0) topK = probs.size(1) else: From 25276a84d1e081c7c1dbc14fb46e74bdede4b12e Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Tue, 9 Jul 2024 19:32:09 +0000 Subject: [PATCH 16/33] Minor changes Signed-off-by: Jiang Shao --- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- transformer_engine/common/permutation/permutation.cu | 2 +- transformer_engine/pytorch/csrc/extensions/permutation.cu | 6 +++--- transformer_engine/pytorch/permutation.py | 1 - 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 30161b68c0..0eb2d84629 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -255,7 +255,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "Unable to find suitable cuBLAS GEMM algorithm"); NVTE_CHECK_CUBLAS(status); - if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); + if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index beccd3ec01..12262c805f 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -257,7 +257,7 @@ void nvte_permutation(const void *input_, void *output_, const int *sorted_row_i moe_permute_kernel<<>>( input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); } else { - throw std::runtime_error("num_topK cannot exceed 128."); + NVTE_ERROR("num_topK cannot exceed 128."); } } } else { diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index bb1ed265b1..c460cb5c6a 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -108,7 +108,7 @@ std::tuple> moe_permute(Tensor input, Tensor break; } default: - throw std::runtime_error("Wrong activation tensor type."); + NVTE_ERROR("Wrong activation tensor type."); } return std::make_tuple(permuted_output, row_id_map, workspace); @@ -169,7 +169,7 @@ Tensor moe_unpermute_fwd(Tensor input, Tensor row_id_map, Tensor prob, int64_t n break; } default: - throw std::runtime_error("Wrong activation tensor type."); + NVTE_ERROR("Wrong activation tensor type."); } return unpermuted_output; @@ -238,7 +238,7 @@ std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, break; } default: - throw std::runtime_error("Wrong activation tensor type."); + NVTE_ERROR("Wrong activation tensor type."); } return std::make_tuple(act_grad, prob_grad); diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 6468494720..216c345863 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -3,7 +3,6 @@ # See LICENSE for license information. """Linear API""" -import os import torch import warnings from typing import Tuple From 57ce3d088d213b5d53c8c10daeb3123aa11acef2 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Wed, 10 Jul 2024 10:36:33 +0000 Subject: [PATCH 17/33] Minor changes Signed-off-by: Jiang Shao --- transformer_engine/pytorch/csrc/extensions.h | 5 ++++- .../pytorch/csrc/extensions/permutation.cu | 13 +++++++++---- .../pytorch/csrc/extensions/pybind.cpp | 3 ++- transformer_engine/pytorch/permutation.py | 4 ++-- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e11c77b56b..c534dc92a1 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -14,10 +14,13 @@ * permute **************************************************************************************************/ -std::tuple> moe_permute( +std::tuple> moe_permute_fwd( at::Tensor input, at::Tensor indices, int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num); +at::Tensor moe_permute_bwd(at::Tensor input, at::Tensor row_id_map, at::Tensor prob, + int64_t num_tokens, int64_t num_topK); + at::Tensor moe_unpermute_fwd(at::Tensor input, at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, int64_t num_topK); diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index c460cb5c6a..493d3a2359 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -10,10 +10,10 @@ using torch::Tensor; -std::tuple> moe_permute(Tensor input, Tensor indices, - int64_t num_out_tokens, - std::vector workspace, - int64_t max_expanded_token_num) { +std::tuple> moe_permute_fwd(Tensor input, Tensor indices, + int64_t num_out_tokens, + std::vector workspace, + int64_t max_expanded_token_num) { const int num_tokens = input.size(0); const int num_cols = input.size(1); const int num_topK = indices.size(1); @@ -114,6 +114,11 @@ std::tuple> moe_permute(Tensor input, Tensor return std::make_tuple(permuted_output, row_id_map, workspace); } +Tensor moe_permute_bwd(Tensor input, Tensor row_id_map, Tensor prob, int64_t num_tokens, + int64_t num_topK) { + return moe_unpermute_fwd(input, row_id_map, prob, num_tokens, num_topK); +} + Tensor moe_unpermute_fwd(Tensor input, Tensor row_id_map, Tensor prob, int64_t num_tokens, int64_t num_topK) { const int num_cols = input.size(1); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index aae323c5f6..5f779bad12 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -11,7 +11,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Permutation functions - m.def("moe_permute", moe_permute); + m.def("moe_permute_fwd", moe_permute_fwd); + m.def("moe_permute_bwd", moe_permute_bwd); m.def("moe_unpermute_fwd", moe_unpermute_fwd); m.def("moe_unpermute_bwd", moe_unpermute_bwd); diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 216c345863..8b12fb274a 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -60,7 +60,7 @@ def forward( _Permute.dtype = inp.dtype _Permute.workspace = [] - permuted_act, row_id_map, _Permute.workspace = tex.moe_permute( + permuted_act, row_id_map, _Permute.workspace = tex.moe_permute_fwd( inp, indices, num_out_tokens, _Permute.workspace, _Permute.max_expanded_token_num ) @@ -86,7 +86,7 @@ def backward( num_tokens = ctx.num_tokens topK = ctx.topK - unpermuted_act_grad = tex.moe_unpermute_fwd( + unpermuted_act_grad = tex.moe_permute_bwd( permuted_act_grad, row_id_map, torch.empty(0), num_tokens, topK ) return unpermuted_act_grad, None, None, None From 385741fde738ccc941c7abf62d9c4d1536262ecf Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Wed, 24 Jul 2024 17:50:02 +0000 Subject: [PATCH 18/33] Remove the dependency on pytorch fp8 data type Signed-off-by: Jiang Shao --- tests/pytorch/test_permutation.py | 282 ++++++++++++++---- transformer_engine/pytorch/__init__.py | 2 +- transformer_engine/pytorch/csrc/extensions.h | 15 +- .../pytorch/csrc/extensions/permutation.cu | 97 +++--- transformer_engine/pytorch/permutation.py | 56 ++-- 5 files changed, 322 insertions(+), 130 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 48de0f5b94..719e1ca61b 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -4,11 +4,12 @@ import torch import pytest -from typing import Dict +from typing import Dict, List -from transformer_engine.pytorch import Permute as te_permute, Unpermute as te_unpermute +from transformer_engine.pytorch import permute as te_permute, unpermute as te_unpermute from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine_torch as tex # Only run FP8 tests on H100. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -17,6 +18,13 @@ torch.manual_seed(seed) torch.cuda.manual_seed(seed) +# TE tensor dtypes +_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] +if is_bf16_compatible(): + _te_dtypes.append(tex.DType.kBFloat16) +if fp8_available: + _te_dtypes.extend([tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) + def pytorch_permute(tokens, indices, num_out_tokens: int = None): """ @@ -95,31 +103,60 @@ def pytorch_unpermute( return unpermuted_tokens -def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: +def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: """Estimated tolerances for a datatype Based on tolerances for torch.testing.assert_close. """ - if dtype == torch.float32: + if te_dtype == tex.DType.kFloat32: return dict(rtol=1.0e-6, atol=1.0e-6) - if dtype == torch.float16: + if te_dtype == tex.DType.kFloat16: return dict(rtol=3.0e-3, atol=1.0e-5) - if dtype == torch.bfloat16: + if te_dtype == tex.DType.kBFloat16: return dict(rtol=2.0e-2, atol=1.0e-5) - if dtype == torch.float8_e5m2 or dtype == torch.float8_e4m3fn: + if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3: return dict(rtol=2.0e-1, atol=1.0e-1) - raise ValueError(f"Unsuppored dtype ({dtype})") + raise ValueError(f"Unsuppored dtype ({te_dtype})") -param_dtypes = [torch.float32, torch.float16] -if is_bf16_compatible(): - param_dtypes.append(torch.bfloat16) -if fp8_available: - param_dtypes.extend([torch.float8_e5m2, torch.float8_e4m3fn]) +def fp8_to_fp16(uint8_tensor, e4m3: bool = True): + assert uint8_tensor.dtype == torch.uint8, "Input tensor must be uint8" + + float16_tensor = torch.zeros_like(uint8_tensor, dtype=torch.float16) + + sign = (uint8_tensor >> 7) & 1 + exponent_mask = 0xF if e4m3 else 0x1F + if e4m3: + exponent = (uint8_tensor >> 3) & exponent_mask + mantissa = uint8_tensor & 0x7 + else: + exponent = (uint8_tensor >> 2) & exponent_mask + mantissa = uint8_tensor & 0x3 + + exponent_bias = 7 if e4m3 else 15 + mantissa_max = 8.0 if e4m3 else 4.0 + + normal_mask = (exponent != 0) & ~(exponent == exponent_mask) + actual_exponent = exponent[normal_mask].to(torch.float16) - exponent_bias + actual_mantissa = (mantissa[normal_mask].to(torch.float16) + mantissa_max) / mantissa_max + float16_tensor[normal_mask] = ( + ((-1) ** sign[normal_mask].to(torch.float16)) * (2**actual_exponent) * actual_mantissa + ) + + subnormal_mask = (exponent == 0) & (mantissa != 0) + subnormal_exponent = 1 - exponent_bias + subnormal_mantissa = mantissa[subnormal_mask].to(torch.float16) / mantissa_max + float16_tensor[subnormal_mask] = ( + ((-1) ** sign[subnormal_mask].to(torch.float16)) + * (2**subnormal_exponent) + * subnormal_mantissa + ) + return float16_tensor -@pytest.mark.parametrize("dtype", param_dtypes) + +@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("hidden_size", [4096]) @@ -127,7 +164,7 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: @pytest.mark.parametrize("num_out_tokens", [None, 4050]) @pytest.mark.parametrize("with_probs", [True, False]) def test_permutation( - dtype, + te_dtype, num_tokens, num_expert, hidden_size, @@ -137,22 +174,58 @@ def test_permutation( BENCHMARK=False, ): if not with_probs and topK > 1: - print("Only permutations with topK=1 and without probabilities are supported.") - return + pytest.skip("Only permutations with topK=1 and without probabilities are supported.") if topK > num_expert: - print("topK should be smaller than the number of experts.") - return + pytest.skip("topK should be smaller than the number of experts.") if num_out_tokens == None: num_out_tokens = num_tokens * topK - print(f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {dtype}") + print( + f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" + ) + + fp8 = False + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): + dtype = torch.uint8 + fp8 = True + else: + pytest.skip("Invalid dtype.") + + if fp8: + N = 56 if te_dtype == tex.DType.kFloat8E4M3 else 60 + permute_fwd_input = torch.randint( + low=0, high=N + 1, size=(num_tokens, hidden_size), dtype=torch.uint8 + ).cuda() + permute_bwd_input = torch.randint( + low=0, high=N + 1, size=(num_out_tokens, hidden_size), dtype=torch.uint8 + ).cuda() + unpermute_bwd_input = torch.randint( + low=0, high=N + 1, size=(num_tokens, hidden_size), dtype=torch.uint8 + ).cuda() + pytorch_permute_fwd_input = fp8_to_fp16( + permute_fwd_input, te_dtype == tex.DType.kFloat8E4M3 + ) + pytorch_permute_bwd_input = fp8_to_fp16( + permute_bwd_input, te_dtype == tex.DType.kFloat8E4M3 + ) + pytorch_unpermute_bwd_input = fp8_to_fp16( + unpermute_bwd_input, te_dtype == tex.DType.kFloat8E4M3 + ) + else: + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() - permute_input = torch.rand((num_tokens, hidden_size), dtype=torch.float32).cuda().to(dtype) - if fp8_available and dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: - permute_input = permute_input.half() - permute_input.requires_grad_(True) + pytorch_permute_fwd_input.requires_grad_(True) if num_tokens > 0: indices = torch.stack([torch.randperm(num_expert)[:topK] for _ in range(num_tokens)]) @@ -169,63 +242,99 @@ def test_permutation( ################################################################################################################################### # - # PyTorch + # PyTorch Permutation # ################################################################################################################################### - permute_output, sorted_indices = pytorch_permute(permute_input, indices, num_out_tokens) - permute_bwd_input = torch.rand_like(permute_output) - permute_output.backward(permute_bwd_input, retain_graph=True) + pytorch_permute_output, sorted_indices = pytorch_permute( + pytorch_permute_fwd_input, indices, num_out_tokens + ) + pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True) - unpermute_input = permute_output.detach() - unpermute_input.requires_grad_(True) + pytorch_unpermute_fwd_input = pytorch_permute_output.detach() + pytorch_unpermute_fwd_input.requires_grad_(True) - unpermute_output = pytorch_unpermute(unpermute_input, sorted_indices, probs=probs) - unpermute_bwd_input = torch.rand_like(unpermute_output) - unpermute_output.backward(unpermute_bwd_input, retain_graph=True) + pytorch_unpermute_output = pytorch_unpermute( + pytorch_unpermute_fwd_input, sorted_indices, probs=probs + ) + pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True) ################################################################################################################################### # - # TE + # TE Permutation # ################################################################################################################################### - te_permute_input = permute_input.detach().to(dtype) - te_permute_bwd_input = permute_bwd_input.detach().to(dtype) - te_unpermute_bwd_input = unpermute_bwd_input.detach().to(dtype) - te_permute_input.requires_grad_(True) + te_permute_fwd_input = ( + permute_fwd_input.view(torch.float32) if fp8 else pytorch_permute_fwd_input.detach() + ) + te_permute_fwd_input.requires_grad_(True) + te_permute_bwd_input = ( + permute_bwd_input.view(torch.float32) if fp8 else pytorch_permute_bwd_input.detach() + ) - te_permute_output, row_id_map = te_permute(te_permute_input, indices, num_out_tokens) + te_permute_output, row_id_map = te_permute( + te_permute_fwd_input, te_dtype, indices, num_out_tokens + ) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_probs = None if with_probs: te_probs = probs.detach() te_probs.requires_grad_(True) - te_unpermute_input = te_permute_output.detach() - te_unpermute_input.requires_grad_(True) + te_unpermute_fwd_input = te_permute_output.detach() + te_unpermute_fwd_input.requires_grad_(True) + te_unpermute_bwd_input = ( + unpermute_bwd_input.view(torch.float32) if fp8 else pytorch_unpermute_bwd_input.detach() + ) - te_unpermute_output = te_unpermute(te_unpermute_input, row_id_map, te_probs) + te_unpermute_output = te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) - tols = dtype_tols(dtype) + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + if fp8: + te_permute_output_ = fp8_to_fp16( + te_permute_output.view(torch.uint8), te_dtype == tex.DType.kFloat8E4M3 + ) + te_permute_fwd_input_grad = fp8_to_fp16( + te_permute_fwd_input.grad.view(torch.uint8), te_dtype == tex.DType.kFloat8E4M3 + ) + te_unpermute_output_ = fp8_to_fp16( + te_unpermute_output.view(torch.uint8), te_dtype == tex.DType.kFloat8E4M3 + ) + te_unpermute_fwd_input_grad = fp8_to_fp16( + te_unpermute_fwd_input.grad.view(torch.uint8), te_dtype == tex.DType.kFloat8E4M3 + ) + else: + te_permute_output_ = te_permute_output + te_permute_fwd_input_grad = te_permute_fwd_input.grad + te_unpermute_output_ = te_unpermute_output + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad torch.testing.assert_close( - permute_output.float(), te_permute_output.float(), msg=f"Mismatch in te_permute fwd" + pytorch_permute_output.float(), + te_permute_output_.float(), + msg=f"Mismatch in te_permute fwd", ) torch.testing.assert_close( - permute_input.grad.float(), - te_permute_input.grad.float(), + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad.float(), msg=f"Mismatch in te_permute bwd", **tols, ) torch.testing.assert_close( - unpermute_output.float(), - te_unpermute_output.float(), + pytorch_unpermute_output.float(), + te_unpermute_output_.float(), msg=f"Mismatch in te_unpermute fwd", **tols, ) torch.testing.assert_close( - unpermute_input.grad.float(), - te_unpermute_input.grad.float(), + pytorch_unpermute_fwd_input.grad.float(), + te_unpermute_fwd_input_grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols, ) @@ -234,8 +343,8 @@ def test_permutation( probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols ) - if not permute_input.numel(): - print("Empty permute_input activation test passed.") + if not pytorch_permute_fwd_input.numel(): + print("Empty pytorch_permute_fwd_input activation test passed.") return ################################################################################################################################### @@ -253,15 +362,19 @@ def backward_wrapper( return act.backward(backward_input, retain_graph=retain_graph) if BENCHMARK: - t1 = perf_test_cuda_kernel(lambda: pytorch_permute(permute_input, indices, num_out_tokens)) - t2 = perf_test_cuda_kernel(lambda: te_permute(te_permute_input, indices, num_out_tokens)) + t1 = perf_test_cuda_kernel( + lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens) + ) + t2 = perf_test_cuda_kernel( + lambda: te_permute(te_permute_fwd_input, te_dtype, indices, num_out_tokens) + ) print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") t1 = perf_test_cuda_kernel( lambda: backward_wrapper( - permute_output, - permute_bwd_input, - forward_input=[permute_input], + pytorch_permute_output, + pytorch_permute_bwd_input, + forward_input=[pytorch_permute_fwd_input], retain_graph=True, accumulate_grad=False, ) @@ -270,7 +383,7 @@ def backward_wrapper( lambda: backward_wrapper( te_permute_output, te_permute_bwd_input, - forward_input=[te_permute_input], + forward_input=[te_permute_fwd_input], retain_graph=True, accumulate_grad=False, ) @@ -278,16 +391,22 @@ def backward_wrapper( print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") t1 = perf_test_cuda_kernel( - lambda: pytorch_unpermute(unpermute_input, sorted_indices, probs=probs) + lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs) + ) + t2 = perf_test_cuda_kernel( + lambda: te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) ) - t2 = perf_test_cuda_kernel(lambda: te_unpermute(te_unpermute_input, row_id_map, te_probs)) print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") t1 = perf_test_cuda_kernel( lambda: backward_wrapper( - unpermute_output, - unpermute_bwd_input, - forward_input=[unpermute_input, probs] if with_probs else [unpermute_input], + pytorch_unpermute_output, + pytorch_unpermute_bwd_input, + forward_input=( + [pytorch_unpermute_fwd_input, probs] + if with_probs + else [pytorch_unpermute_fwd_input] + ), retain_graph=True, accumulate_grad=False, ) @@ -297,7 +416,7 @@ def backward_wrapper( te_unpermute_output, te_unpermute_bwd_input, forward_input=( - [te_unpermute_input, te_probs] if with_probs else [te_unpermute_input] + [te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input] ), retain_graph=True, accumulate_grad=False, @@ -325,4 +444,37 @@ def perf_test_cuda_kernel(cuda_kernel_fn): elapsed_time_ms = start_event.elapsed_time(end_event) return elapsed_time_ms / 100 else: - print("CUDA is not available.") + pytest.skip("CUDA is not available.") + + +def test_permute_single_case(): + print("GPU:", torch.cuda.get_device_name(0)) + + # te_dtype = tex.DType.kFloat32 + # te_dtype = tex.DType.kFloat16 + # te_dtype = tex.DType.kBFloat16 + te_dtype = tex.DType.kFloat8E5M2 + # te_dtype = tex.DType.kFloat8E4M3 + + num_tokens = 10 + num_expert = 4 + hidden_size = 16 + topK = 2 + num_out_tokens = num_tokens * topK - 1 + with_probs = True + Benchmark = True + + test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=Benchmark, + ) + + +if __name__ == "__main__": + test_permute_single_case() diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index f89b9ed722..6649699054 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -41,8 +41,8 @@ def _load_library(): from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import MultiheadAttention -from transformer_engine.pytorch.permutation import Permute, Unpermute from transformer_engine.pytorch.transformer import TransformerLayer +from transformer_engine.pytorch.permutation import permute, unpermute from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.graph import make_graphed_callables diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index c534dc92a1..d4343bba6d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -15,16 +15,19 @@ **************************************************************************************************/ std::tuple> moe_permute_fwd( - at::Tensor input, at::Tensor indices, int64_t num_out_tokens, std::vector workspace, - int64_t max_expanded_token_num); + at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, + int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num); -at::Tensor moe_permute_bwd(at::Tensor input, at::Tensor row_id_map, at::Tensor prob, - int64_t num_tokens, int64_t num_topK); +at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t num_topK); -at::Tensor moe_unpermute_fwd(at::Tensor input, at::Tensor row_id_map, at::Tensor prob, - int64_t num_tokens, int64_t num_topK); +at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t num_topK); std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, + const transformer_engine::DType dtype, at::Tensor row_id_map, at::Tensor prob); /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index 493d3a2359..74452d922a 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -10,12 +10,11 @@ using torch::Tensor; -std::tuple> moe_permute_fwd(Tensor input, Tensor indices, - int64_t num_out_tokens, - std::vector workspace, - int64_t max_expanded_token_num) { +std::tuple> moe_permute_fwd( + Tensor input, const transformer_engine::DType dtype, Tensor indices, int64_t num_out_tokens, + std::vector workspace, int64_t max_expanded_token_num) { const int num_tokens = input.size(0); - const int num_cols = input.size(1); + int num_cols = input.size(1); const int num_topK = indices.size(1); // initialize the workspace on the first run @@ -55,7 +54,12 @@ std::tuple> moe_permute_fwd(Tensor input, Te num_tokens * num_topK); // activations type - const at::ScalarType _st = input.scalar_type(); + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Float; + else + _st = input.scalar_type(); // Output buffer alloc num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * num_topK; @@ -71,36 +75,40 @@ std::tuple> moe_permute_fwd(Tensor input, Te void *input_ptr = getDataPtr(input, 0); void *permuted_output_ptr = getDataPtr(permuted_output, 0); - switch (_st) { - case at::ScalarType::Float: { + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + num_cols *= 4; + + switch (dtype) { + case transformer_engine::DType::kFloat32: { nvte_permutation(input_ptr, permuted_output_ptr, sorted_row_id_ptr, row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, num_out_tokens, nullptr, nullptr, stream); break; } - case at::ScalarType::Half: { + case transformer_engine::DType::kFloat16: { nvte_permutation(input_ptr, permuted_output_ptr, sorted_row_id_ptr, row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, num_out_tokens, nullptr, nullptr, stream); break; } - case at::ScalarType::BFloat16: { + case transformer_engine::DType::kBFloat16: { nvte_permutation<__nv_bfloat16, true>(input_ptr, permuted_output_ptr, sorted_row_id_ptr, row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, num_out_tokens, nullptr, nullptr, stream); break; } - case at::ScalarType::Float8_e5m2: { + case transformer_engine::DType::kFloat8E5M2: { nvte_permutation<__nv_fp8_e5m2, true>(input_ptr, permuted_output_ptr, sorted_row_id_ptr, row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, num_out_tokens, nullptr, nullptr, stream); break; } - case at::ScalarType::Float8_e4m3fn: { + case transformer_engine::DType::kFloat8E4M3: { nvte_permutation<__nv_fp8_e4m3, true>(input_ptr, permuted_output_ptr, sorted_row_id_ptr, row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, num_out_tokens, nullptr, nullptr, stream); @@ -114,17 +122,22 @@ std::tuple> moe_permute_fwd(Tensor input, Te return std::make_tuple(permuted_output, row_id_map, workspace); } -Tensor moe_permute_bwd(Tensor input, Tensor row_id_map, Tensor prob, int64_t num_tokens, - int64_t num_topK) { - return moe_unpermute_fwd(input, row_id_map, prob, num_tokens, num_topK); +Tensor moe_permute_bwd(Tensor input, const transformer_engine::DType dtype, Tensor row_id_map, + Tensor prob, int64_t num_tokens, int64_t num_topK) { + return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, num_topK); } -Tensor moe_unpermute_fwd(Tensor input, Tensor row_id_map, Tensor prob, int64_t num_tokens, - int64_t num_topK) { - const int num_cols = input.size(1); +Tensor moe_unpermute_fwd(Tensor input, const transformer_engine::DType dtype, Tensor row_id_map, + Tensor prob, int64_t num_tokens, int64_t num_topK) { + int num_cols = input.size(1); // activations type - const at::ScalarType _st = input.scalar_type(); + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Float; + else + _st = input.scalar_type(); // Output buffer alloc Tensor unpermuted_output = torch::empty( @@ -137,36 +150,40 @@ Tensor moe_unpermute_fwd(Tensor input, Tensor row_id_map, Tensor prob, int64_t n void *input_ptr = getDataPtr(input, 0); void *unpermuted_output_ptr = getDataPtr(unpermuted_output, 0); - switch (_st) { - case at::ScalarType::Float: { + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + num_cols *= 4; + + switch (dtype) { + case transformer_engine::DType::kFloat32: { nvte_permutation(input_ptr, unpermuted_output_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, num_cols, 0, nullptr, nullptr, stream); break; } - case at::ScalarType::Half: { + case transformer_engine::DType::kFloat16: { nvte_permutation(input_ptr, unpermuted_output_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, num_cols, 0, nullptr, nullptr, stream); break; } - case at::ScalarType::BFloat16: { + case transformer_engine::DType::kBFloat16: { nvte_permutation<__nv_bfloat16, false>(input_ptr, unpermuted_output_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, num_cols, 0, nullptr, nullptr, stream); break; } - case at::ScalarType::Float8_e5m2: { + case transformer_engine::DType::kFloat8E5M2: { nvte_permutation<__nv_fp8_e5m2, false>(input_ptr, unpermuted_output_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, num_cols, 0, nullptr, nullptr, stream); break; } - case at::ScalarType::Float8_e4m3fn: { + case transformer_engine::DType::kFloat8E4M3: { nvte_permutation<__nv_fp8_e4m3, false>(input_ptr, unpermuted_output_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, num_cols, 0, nullptr, nullptr, stream); @@ -180,17 +197,23 @@ Tensor moe_unpermute_fwd(Tensor input, Tensor row_id_map, Tensor prob, int64_t n return unpermuted_output; } -std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, Tensor row_id_map, - Tensor prob) { +std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, + const transformer_engine::DType dtype, + Tensor row_id_map, Tensor prob) { const int num_topK = (prob.numel() > 0) ? prob.size(1) : 1; const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); - const int num_cols = input_bwd.size(1); + int num_cols = input_bwd.size(1); int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); float *prob_ptr = (prob.numel() > 0) ? reinterpret_cast(getDataPtr(prob, 0)) : nullptr; // activations type - const at::ScalarType _st = input_bwd.scalar_type(); + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Float; + else + _st = input_bwd.scalar_type(); // Output buffer alloc Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, @@ -206,36 +229,40 @@ std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, void *input_fwd_ptr = getDataPtr(input_fwd, 0); void *act_grad_ptr = getDataPtr(act_grad, 0); - switch (_st) { - case at::ScalarType::Float: { + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + num_cols *= 4; + + switch (dtype) { + case transformer_engine::DType::kFloat32: { nvte_permutation(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, stream); break; } - case at::ScalarType::Half: { + case transformer_engine::DType::kFloat16: { nvte_permutation(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, stream); break; } - case at::ScalarType::BFloat16: { + case transformer_engine::DType::kBFloat16: { nvte_permutation<__nv_bfloat16, true>(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, stream); break; } - case at::ScalarType::Float8_e5m2: { + case transformer_engine::DType::kFloat8E5M2: { nvte_permutation<__nv_fp8_e5m2, true>(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, stream); break; } - case at::ScalarType::Float8_e4m3fn: { + case transformer_engine::DType::kFloat8E4M3: { nvte_permutation<__nv_fp8_e4m3, true>(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, num_tokens, num_topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, stream); diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 8b12fb274a..8c401b0218 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -16,7 +16,7 @@ ] -class _Permute(torch.autograd.Function): +class _permute(torch.autograd.Function): """functional Permute""" workspace = None @@ -27,6 +27,7 @@ class _Permute(torch.autograd.Function): def forward( ctx, inp: torch.Tensor, + dtype: tex.DType, indices: torch.Tensor, num_out_tokens: int, max_token_num: int, @@ -52,16 +53,16 @@ def forward( topK = indices.size(1) input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK - if _Permute.max_expanded_token_num < input_max_expanded_token_num: - _Permute.max_expanded_token_num = input_max_expanded_token_num - _Permute.workspace = [] + if _permute.max_expanded_token_num < input_max_expanded_token_num: + _permute.max_expanded_token_num = input_max_expanded_token_num + _permute.workspace = [] - if _Permute.dtype != inp.dtype: - _Permute.dtype = inp.dtype - _Permute.workspace = [] + if _permute.dtype != dtype: + _permute.dtype = dtype + _permute.workspace = [] - permuted_act, row_id_map, _Permute.workspace = tex.moe_permute_fwd( - inp, indices, num_out_tokens, _Permute.workspace, _Permute.max_expanded_token_num + permuted_act, row_id_map, _permute.workspace = tex.moe_permute_fwd( + inp, dtype, indices, num_out_tokens, _permute.workspace, _permute.max_expanded_token_num ) ctx.row_id_map = row_id_map @@ -86,19 +87,23 @@ def backward( num_tokens = ctx.num_tokens topK = ctx.topK - unpermuted_act_grad = tex.moe_permute_bwd( - permuted_act_grad, row_id_map, torch.empty(0), num_tokens, topK - ) - return unpermuted_act_grad, None, None, None + act_grad = None + if ctx.needs_input_grad[0]: + act_grad = tex.moe_permute_bwd( + permuted_act_grad, _permute.dtype, row_id_map, torch.empty(0), num_tokens, topK + ) + return act_grad, None, None, None, None -class _Unpermute(torch.autograd.Function): + +class _unpermute(torch.autograd.Function): """functional Unpermute""" @staticmethod def forward( ctx, inp: torch.Tensor, + dtype: tex.DType, row_id_map: torch.Tensor, probs: torch.Tensor, ) -> torch.Tensor: @@ -137,8 +142,9 @@ def forward( ) row_id_map = row_id_map.to(torch.int32) - unpermuted_output = tex.moe_unpermute_fwd(inp, row_id_map, probs, num_tokens, topK) + unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) + ctx.dtype = dtype ctx.save_for_backward(inp, row_id_map, probs) return unpermuted_output @@ -158,15 +164,18 @@ def backward( act_grad = None if ctx.needs_input_grad[0]: - act_grad, prob_grad = tex.moe_unpermute_bwd(unpermuted_act_grad, inp, row_id_map, probs) - - if not ctx.needs_input_grad[2]: + act_grad, prob_grad = tex.moe_unpermute_bwd( + unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs + ) + if not ctx.needs_input_grad[3]: prob_grad = None - return act_grad, None, prob_grad + + return act_grad, None, None, prob_grad -def Permute( +def permute( inp: torch.Tensor, + dtype: tex.DType, indices: torch.Tensor, num_out_tokens: int = -1, max_token_num: int = -1, @@ -188,11 +197,12 @@ def Permute( By default, set to '-1', meaning the calculation of the size of workspace is automatically taken over by the operator. """ - return _Permute.apply(inp, indices, num_out_tokens, max_token_num) + return _permute.apply(inp, dtype, indices, num_out_tokens, max_token_num) -def Unpermute( +def unpermute( inp: torch.Tensor, + dtype: tex.DType, row_id_map: torch.Tensor, probs: torch.Tensor = None, ) -> torch.Tensor: @@ -212,4 +222,4 @@ def Unpermute( the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. """ - return _Unpermute.apply(inp, row_id_map, probs) + return _unpermute.apply(inp, dtype, row_id_map, probs) From b9ad0aec2855a63f9cd5c4624b4276c665b4a159 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Wed, 24 Jul 2024 21:01:39 +0000 Subject: [PATCH 19/33] Move dtype dispatch from pytorch dir to common dir Signed-off-by: Jiang Shao --- .../include/transformer_engine/permutation.h | 14 +- .../common/permutation/permutation.cu | 161 +++++++++--------- .../pytorch/csrc/extensions/permutation.cu | 123 +------------ 3 files changed, 100 insertions(+), 198 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h index fd18669008..d5ce692eb4 100644 --- a/transformer_engine/common/include/transformer_engine/permutation.h +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -9,10 +9,14 @@ #include "transformer_engine.h" -template -void nvte_permutation(const void *input, void *output, const int *sorted_row_id, int *row_id_map, - const float *prob, const int num_rows, const int num_topK, const int num_cols, - const int num_out_tokens, float *prob_grad = nullptr, - const void *input_fwd = nullptr, cudaStream_t stream = nullptr); +void nvte_permute(const void *input, void *output, const transformer_engine::DType dtype, + const int *sorted_row_id, int *row_id_map, const float *prob, const int num_rows, + const int num_topK, const int num_cols, const int num_out_tokens, + float *prob_grad = nullptr, const void *input_fwd = nullptr, + cudaStream_t stream = nullptr); + +void nvte_unpermute(const void *input, void *output, const transformer_engine::DType dtype, + int *row_id_map, const float *prob, const int num_rows, const int num_topK, + const int num_cols, cudaStream_t stream = nullptr); #endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 12262c805f..37d8ed191d 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -199,98 +199,107 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac } } -template -void nvte_permutation(const void *input_, void *output_, const int *sorted_row_id, int *row_id_map, - const float *prob, const int num_rows, const int num_topK, const int num_cols, - const int num_out_tokens, float *prob_grad, const void *input_fwd_, - cudaStream_t stream) { +template +void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, int *row_id_map, + const float *prob, const int num_rows, const int num_topK, + const int num_cols, const int num_out_tokens, float *prob_grad, + const T *input_fwd, cudaStream_t stream) { using TCompute = typename std::conditional<(std::is_same::value || std::is_same::value), half, T>::type; static constexpr int kElementsPerAccess = 16 / sizeof(T); - const T *input = reinterpret_cast(input_); - T *output = reinterpret_cast(output_); - const T *input_fwd = reinterpret_cast(input_fwd_); - - if (FWD) { - if (prob == nullptr) { - if (input_fwd == nullptr) { - // Permute fwd - int threads = 64; - int blocks = (num_rows * num_topK + threads - 1) / threads; - moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, - num_topK, num_out_tokens); - - blocks = num_rows; - threads = std::min(num_cols / kElementsPerAccess, 1024); - moe_permute_kernel<<>>( - input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, num_topK, num_cols); - } else { - // Unpermute bwd without probs for topK == 1 - int blocks = num_rows; - int threads = 32; - - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); - } + if (prob == nullptr) { + if (input_fwd == nullptr) { + // Permute fwd + int threads = 64; + int blocks = (num_rows * num_topK + threads - 1) / threads; + moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, + num_topK, num_out_tokens); + + blocks = num_rows; + threads = std::min(num_cols / kElementsPerAccess, 1024); + moe_permute_kernel<<>>( + input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, num_topK, num_cols); } else { - // Unpermute bwd with probs + // Unpermute bwd without probs for topK == 1 int blocks = num_rows; int threads = 32; - size_t smem_bytes = num_topK * sizeof(TCompute); - - if (num_topK <= 8) { - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); - } else if (num_topK <= 16) { - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); - } else if (num_topK <= 32) { - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); - } else if (num_topK <= 64) { - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); - } else if (num_topK <= 128) { - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); - } else { - NVTE_ERROR("num_topK cannot exceed 128."); - } + + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); } } else { + // Unpermute bwd with probs int blocks = num_rows; - int threads = std::min(num_cols / kElementsPerAccess, 1024); + int threads = 32; size_t smem_bytes = num_topK * sizeof(TCompute); - if (prob == nullptr) { - // Permute bwd - // Unpermute fwd without probs - moe_unpermute_kernel<<>>( - input, output, row_id_map, prob, num_rows, num_topK, num_cols); + if (num_topK <= 8) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + } else if (num_topK <= 16) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + } else if (num_topK <= 32) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + } else if (num_topK <= 64) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + } else if (num_topK <= 128) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); } else { - // Unpermute fwd with probs - moe_unpermute_kernel<<>>( - input, output, row_id_map, prob, num_rows, num_topK, num_cols); + NVTE_ERROR("num_topK cannot exceed 128."); } } } -#define FUNCTION_INSTANTIATION(T, FWD) \ - template void nvte_permutation( \ - const void *input, void *output, const int *sorted_row_id, int *row_id_map, \ - const float *prob, const int num_rows, const int num_topK, const int num_cols, \ - const int num_out_tokens, float *prob_grad, const void *input_fwd, cudaStream_t stream); - -FUNCTION_INSTANTIATION(float, true) -FUNCTION_INSTANTIATION(float, false) -FUNCTION_INSTANTIATION(half, true) -FUNCTION_INSTANTIATION(half, false) -FUNCTION_INSTANTIATION(__nv_bfloat16, true) -FUNCTION_INSTANTIATION(__nv_bfloat16, false) -FUNCTION_INSTANTIATION(__nv_fp8_e5m2, true) -FUNCTION_INSTANTIATION(__nv_fp8_e5m2, false) -FUNCTION_INSTANTIATION(__nv_fp8_e4m3, true) -FUNCTION_INSTANTIATION(__nv_fp8_e4m3, false) +template +void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const float *prob, + const int num_rows, const int num_topK, const int num_cols, + cudaStream_t stream) { + using TCompute = typename std::conditional<(std::is_same::value || + std::is_same::value), + half, T>::type; + + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + int blocks = num_rows; + int threads = std::min(num_cols / kElementsPerAccess, 1024); + size_t smem_bytes = num_topK * sizeof(TCompute); + + if (prob == nullptr) { + // Permute bwd + // Unpermute fwd without probs + moe_unpermute_kernel<<>>( + input, output, row_id_map, prob, num_rows, num_topK, num_cols); + } else { + // Unpermute fwd with probs + moe_unpermute_kernel<<>>( + input, output, row_id_map, prob, num_rows, num_topK, num_cols); + } +} + +void nvte_permute(const void *input, void *output, const transformer_engine::DType dtype, + const int *sorted_row_id, int *row_id_map, const float *prob, const int num_rows, + const int num_topK, const int num_cols, const int num_out_tokens, + float *prob_grad, const void *input_fwd, cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + dtype, T, + nvte_permute_launcher(reinterpret_cast(input), reinterpret_cast(output), + sorted_row_id, row_id_map, prob, num_rows, num_topK, num_cols, + num_out_tokens, prob_grad, reinterpret_cast(input_fwd), + stream);); +} + +void nvte_unpermute(const void *input, void *output, const transformer_engine::DType dtype, + int *row_id_map, const float *prob, const int num_rows, const int num_topK, + const int num_cols, cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + dtype, T, + nvte_unpermute_launcher(reinterpret_cast(input), reinterpret_cast(output), + row_id_map, prob, num_rows, num_topK, num_cols, stream);); +} diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index 74452d922a..803d2b4378 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -79,45 +79,8 @@ std::tuple> moe_permute_fwd( dtype == transformer_engine::DType::kFloat8E5M2) num_cols *= 4; - switch (dtype) { - case transformer_engine::DType::kFloat32: { - nvte_permutation(input_ptr, permuted_output_ptr, sorted_row_id_ptr, - row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, - num_out_tokens, nullptr, nullptr, stream); - - break; - } - case transformer_engine::DType::kFloat16: { - nvte_permutation(input_ptr, permuted_output_ptr, sorted_row_id_ptr, - row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, - num_out_tokens, nullptr, nullptr, stream); - - break; - } - case transformer_engine::DType::kBFloat16: { - nvte_permutation<__nv_bfloat16, true>(input_ptr, permuted_output_ptr, sorted_row_id_ptr, - row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, - num_out_tokens, nullptr, nullptr, stream); - - break; - } - case transformer_engine::DType::kFloat8E5M2: { - nvte_permutation<__nv_fp8_e5m2, true>(input_ptr, permuted_output_ptr, sorted_row_id_ptr, - row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, - num_out_tokens, nullptr, nullptr, stream); - - break; - } - case transformer_engine::DType::kFloat8E4M3: { - nvte_permutation<__nv_fp8_e4m3, true>(input_ptr, permuted_output_ptr, sorted_row_id_ptr, - row_id_map_ptr, nullptr, num_tokens, num_topK, num_cols, - num_out_tokens, nullptr, nullptr, stream); - - break; - } - default: - NVTE_ERROR("Wrong activation tensor type."); - } + nvte_permute(input_ptr, permuted_output_ptr, dtype, sorted_row_id_ptr, row_id_map_ptr, nullptr, + num_tokens, num_topK, num_cols, num_out_tokens, nullptr, nullptr, stream); return std::make_tuple(permuted_output, row_id_map, workspace); } @@ -154,45 +117,8 @@ Tensor moe_unpermute_fwd(Tensor input, const transformer_engine::DType dtype, Te dtype == transformer_engine::DType::kFloat8E5M2) num_cols *= 4; - switch (dtype) { - case transformer_engine::DType::kFloat32: { - nvte_permutation(input_ptr, unpermuted_output_ptr, nullptr, row_id_map_ptr, - prob_ptr, num_tokens, num_topK, num_cols, 0, nullptr, nullptr, - stream); - - break; - } - case transformer_engine::DType::kFloat16: { - nvte_permutation(input_ptr, unpermuted_output_ptr, nullptr, row_id_map_ptr, - prob_ptr, num_tokens, num_topK, num_cols, 0, nullptr, nullptr, - stream); - - break; - } - case transformer_engine::DType::kBFloat16: { - nvte_permutation<__nv_bfloat16, false>(input_ptr, unpermuted_output_ptr, nullptr, - row_id_map_ptr, prob_ptr, num_tokens, num_topK, - num_cols, 0, nullptr, nullptr, stream); - - break; - } - case transformer_engine::DType::kFloat8E5M2: { - nvte_permutation<__nv_fp8_e5m2, false>(input_ptr, unpermuted_output_ptr, nullptr, - row_id_map_ptr, prob_ptr, num_tokens, num_topK, - num_cols, 0, nullptr, nullptr, stream); - - break; - } - case transformer_engine::DType::kFloat8E4M3: { - nvte_permutation<__nv_fp8_e4m3, false>(input_ptr, unpermuted_output_ptr, nullptr, - row_id_map_ptr, prob_ptr, num_tokens, num_topK, - num_cols, 0, nullptr, nullptr, stream); - - break; - } - default: - NVTE_ERROR("Wrong activation tensor type."); - } + nvte_unpermute(input_ptr, unpermuted_output_ptr, dtype, row_id_map_ptr, prob_ptr, num_tokens, + num_topK, num_cols, stream); return unpermuted_output; } @@ -233,45 +159,8 @@ std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, dtype == transformer_engine::DType::kFloat8E5M2) num_cols *= 4; - switch (dtype) { - case transformer_engine::DType::kFloat32: { - nvte_permutation(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, - num_tokens, num_topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, - stream); - - break; - } - case transformer_engine::DType::kFloat16: { - nvte_permutation(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, prob_ptr, - num_tokens, num_topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, - stream); - - break; - } - case transformer_engine::DType::kBFloat16: { - nvte_permutation<__nv_bfloat16, true>(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, - prob_ptr, num_tokens, num_topK, num_cols, 0, - prob_grad_ptr, input_fwd_ptr, stream); - - break; - } - case transformer_engine::DType::kFloat8E5M2: { - nvte_permutation<__nv_fp8_e5m2, true>(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, - prob_ptr, num_tokens, num_topK, num_cols, 0, - prob_grad_ptr, input_fwd_ptr, stream); - - break; - } - case transformer_engine::DType::kFloat8E4M3: { - nvte_permutation<__nv_fp8_e4m3, true>(input_bwd_ptr, act_grad_ptr, nullptr, row_id_map_ptr, - prob_ptr, num_tokens, num_topK, num_cols, 0, - prob_grad_ptr, input_fwd_ptr, stream); - - break; - } - default: - NVTE_ERROR("Wrong activation tensor type."); - } + nvte_permute(input_bwd_ptr, act_grad_ptr, dtype, nullptr, row_id_map_ptr, prob_ptr, num_tokens, + num_topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, stream); return std::make_tuple(act_grad, prob_grad); } From 97a583a1786ca21734b3238bebf1e05fe196b1fe Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Wed, 24 Jul 2024 21:06:46 +0000 Subject: [PATCH 20/33] Minor changes Signed-off-by: Jiang Shao --- .../include/transformer_engine/permutation.h | 4 +- .../common/permutation/permutation.cu | 86 +++++++++---------- transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/permutation.cu | 30 +++---- 4 files changed, 61 insertions(+), 63 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h index d5ce692eb4..a450456e1a 100644 --- a/transformer_engine/common/include/transformer_engine/permutation.h +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -11,12 +11,12 @@ void nvte_permute(const void *input, void *output, const transformer_engine::DType dtype, const int *sorted_row_id, int *row_id_map, const float *prob, const int num_rows, - const int num_topK, const int num_cols, const int num_out_tokens, + const int topK, const int num_cols, const int num_out_tokens, float *prob_grad = nullptr, const void *input_fwd = nullptr, cudaStream_t stream = nullptr); void nvte_unpermute(const void *input, void *output, const transformer_engine::DType dtype, - int *row_id_map, const float *prob, const int num_rows, const int num_topK, + int *row_id_map, const float *prob, const int num_rows, const int topK, const int num_cols, cudaStream_t stream = nullptr); #endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 37d8ed191d..4695e3cab9 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -9,19 +9,19 @@ #include "../common.h" static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map, - const int num_rows, const int num_topK, + const int num_rows, const int topK, const int num_out_tokens) { // Each block corresponds to one source token - // row_id_map[num_topK][num_rows] + // row_id_map[topK][num_rows] const int bid = blockIdx.x; const int tid = threadIdx.x; const int idx = bid * blockDim.x + tid; - if (idx >= num_rows * num_topK) return; + if (idx >= num_rows * topK) return; int source_row = sorted_row_id[idx]; - int source_token_id = source_row / num_topK; - int source_topK_id = source_row % num_topK; + int source_token_id = source_row / topK; + int source_topK_id = source_row % topK; if (idx >= num_out_tokens) { row_id_map[source_topK_id * num_rows + source_token_id] = -1; @@ -32,7 +32,7 @@ static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id template __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const int *row_id_map, - const float *prob, const int num_rows, const int num_topK, + const float *prob, const int num_rows, const int topK, const int num_cols) { extern __shared__ int8_t s_mem[]; TCompute *s_prob = reinterpret_cast(s_mem); @@ -42,8 +42,8 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const const int tid = threadIdx.x; if (hasProb) { - for (int i = tid; i < num_topK; i += blockDim.x * blockDim.y) { - s_prob[i] = TCompute(prob[source_token * num_topK + i]); + for (int i = tid; i < topK; i += blockDim.x * blockDim.y) { + s_prob[i] = TCompute(prob[source_token * topK + i]); } __syncthreads(); } @@ -79,7 +79,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const } } - for (int k = 1; k < num_topK; k++) { + for (int k = 1; k < topK; k++) { source_row = row_id_map[k * num_rows + source_token]; if (source_row == -1) continue; @@ -116,7 +116,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const template __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *act_grad, const float *prob, float *prob_grad, const int *row_id_map, - const int num_rows, const int num_topK, const int num_cols) { + const int num_rows, const int topK, const int num_cols) { extern __shared__ int8_t s_mem[]; TCompute *s_prob = reinterpret_cast(s_mem); @@ -124,8 +124,8 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac const int tid = threadIdx.x; if (hasProb) { - for (int i = tid; i < num_topK; i += blockDim.x) { - s_prob[i] = TCompute(prob[source_token * num_topK + i]); + for (int i = tid; i < topK; i += blockDim.x) { + s_prob[i] = TCompute(prob[source_token * topK + i]); } __syncthreads(); } @@ -148,7 +148,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac int index = source_token; for (int k = 0; k < topKTile; k++) { - if (k == num_topK) break; + if (k == topK) break; int dest_row = row_id_map[index]; index += num_rows; @@ -183,7 +183,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac if (hasProb) { for (int k = 0; k < topKTile; k++) { - if (k == num_topK) break; + if (k == topK) break; for (int mask = 16; mask > 0; mask /= 2) { accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); @@ -192,8 +192,8 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac if (tid == 0) { for (int k = 0; k < topKTile; k++) { - if (k == num_topK) break; - prob_grad[source_token * num_topK + k] = accum[k]; + if (k == topK) break; + prob_grad[source_token * topK + k] = accum[k]; } } } @@ -201,7 +201,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac template void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, int *row_id_map, - const float *prob, const int num_rows, const int num_topK, + const float *prob, const int num_rows, const int topK, const int num_cols, const int num_out_tokens, float *prob_grad, const T *input_fwd, cudaStream_t stream) { using TCompute = typename std::conditional<(std::is_same::value || @@ -214,52 +214,52 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, if (input_fwd == nullptr) { // Permute fwd int threads = 64; - int blocks = (num_rows * num_topK + threads - 1) / threads; - moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, - num_topK, num_out_tokens); + int blocks = (num_rows * topK + threads - 1) / threads; + moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, + num_out_tokens); blocks = num_rows; threads = std::min(num_cols / kElementsPerAccess, 1024); moe_permute_kernel<<>>( - input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, num_topK, num_cols); + input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); } else { // Unpermute bwd without probs for topK == 1 int blocks = num_rows; int threads = 32; moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); } } else { // Unpermute bwd with probs int blocks = num_rows; int threads = 32; - size_t smem_bytes = num_topK * sizeof(TCompute); + size_t smem_bytes = topK * sizeof(TCompute); - if (num_topK <= 8) { + if (topK <= 8) { moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); - } else if (num_topK <= 16) { + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 16) { moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); - } else if (num_topK <= 32) { + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 32) { moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); - } else if (num_topK <= 64) { + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 64) { moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); - } else if (num_topK <= 128) { + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 128) { moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, num_topK, num_cols); + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); } else { - NVTE_ERROR("num_topK cannot exceed 128."); + NVTE_ERROR("topK cannot exceed 128."); } } } template void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const float *prob, - const int num_rows, const int num_topK, const int num_cols, + const int num_rows, const int topK, const int num_cols, cudaStream_t stream) { using TCompute = typename std::conditional<(std::is_same::value || std::is_same::value), @@ -269,37 +269,37 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f int blocks = num_rows; int threads = std::min(num_cols / kElementsPerAccess, 1024); - size_t smem_bytes = num_topK * sizeof(TCompute); + size_t smem_bytes = topK * sizeof(TCompute); if (prob == nullptr) { // Permute bwd // Unpermute fwd without probs moe_unpermute_kernel<<>>( - input, output, row_id_map, prob, num_rows, num_topK, num_cols); + input, output, row_id_map, prob, num_rows, topK, num_cols); } else { // Unpermute fwd with probs moe_unpermute_kernel<<>>( - input, output, row_id_map, prob, num_rows, num_topK, num_cols); + input, output, row_id_map, prob, num_rows, topK, num_cols); } } void nvte_permute(const void *input, void *output, const transformer_engine::DType dtype, const int *sorted_row_id, int *row_id_map, const float *prob, const int num_rows, - const int num_topK, const int num_cols, const int num_out_tokens, - float *prob_grad, const void *input_fwd, cudaStream_t stream) { + const int topK, const int num_cols, const int num_out_tokens, float *prob_grad, + const void *input_fwd, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( dtype, T, nvte_permute_launcher(reinterpret_cast(input), reinterpret_cast(output), - sorted_row_id, row_id_map, prob, num_rows, num_topK, num_cols, + sorted_row_id, row_id_map, prob, num_rows, topK, num_cols, num_out_tokens, prob_grad, reinterpret_cast(input_fwd), stream);); } void nvte_unpermute(const void *input, void *output, const transformer_engine::DType dtype, - int *row_id_map, const float *prob, const int num_rows, const int num_topK, + int *row_id_map, const float *prob, const int num_rows, const int topK, const int num_cols, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( dtype, T, nvte_unpermute_launcher(reinterpret_cast(input), reinterpret_cast(output), - row_id_map, prob, num_rows, num_topK, num_cols, stream);); + row_id_map, prob, num_rows, topK, num_cols, stream);); } diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d4343bba6d..5e65c7f482 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -20,11 +20,11 @@ std::tuple> moe_permute_fwd( at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, - int64_t num_topK); + int64_t topK); at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, - int64_t num_topK); + int64_t topK); std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, const transformer_engine::DType dtype, diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index 803d2b4378..b11666a2e4 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -15,7 +15,7 @@ std::tuple> moe_permute_fwd( std::vector workspace, int64_t max_expanded_token_num) { const int num_tokens = input.size(0); int num_cols = input.size(1); - const int num_topK = indices.size(1); + const int topK = indices.size(1); // initialize the workspace on the first run if (workspace.empty()) { @@ -51,7 +51,7 @@ std::tuple> moe_permute_fwd( cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, indices_ptr, sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, - num_tokens * num_topK); + num_tokens * topK); // activations type at::ScalarType _st; @@ -62,12 +62,11 @@ std::tuple> moe_permute_fwd( _st = input.scalar_type(); // Output buffer alloc - num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * num_topK; + num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; Tensor permuted_output = torch::empty( {num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); - Tensor row_id_map = - torch::empty({num_tokens * num_topK}, - torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + Tensor row_id_map = torch::empty( + {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -80,18 +79,18 @@ std::tuple> moe_permute_fwd( num_cols *= 4; nvte_permute(input_ptr, permuted_output_ptr, dtype, sorted_row_id_ptr, row_id_map_ptr, nullptr, - num_tokens, num_topK, num_cols, num_out_tokens, nullptr, nullptr, stream); + num_tokens, topK, num_cols, num_out_tokens, nullptr, nullptr, stream); return std::make_tuple(permuted_output, row_id_map, workspace); } Tensor moe_permute_bwd(Tensor input, const transformer_engine::DType dtype, Tensor row_id_map, - Tensor prob, int64_t num_tokens, int64_t num_topK) { - return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, num_topK); + Tensor prob, int64_t num_tokens, int64_t topK) { + return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK); } Tensor moe_unpermute_fwd(Tensor input, const transformer_engine::DType dtype, Tensor row_id_map, - Tensor prob, int64_t num_tokens, int64_t num_topK) { + Tensor prob, int64_t num_tokens, int64_t topK) { int num_cols = input.size(1); // activations type @@ -118,7 +117,7 @@ Tensor moe_unpermute_fwd(Tensor input, const transformer_engine::DType dtype, Te num_cols *= 4; nvte_unpermute(input_ptr, unpermuted_output_ptr, dtype, row_id_map_ptr, prob_ptr, num_tokens, - num_topK, num_cols, stream); + topK, num_cols, stream); return unpermuted_output; } @@ -126,7 +125,7 @@ Tensor moe_unpermute_fwd(Tensor input, const transformer_engine::DType dtype, Te std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, const transformer_engine::DType dtype, Tensor row_id_map, Tensor prob) { - const int num_topK = (prob.numel() > 0) ? prob.size(1) : 1; + const int topK = (prob.numel() > 0) ? prob.size(1) : 1; const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); int num_cols = input_bwd.size(1); @@ -144,9 +143,8 @@ std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, // Output buffer alloc Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); - Tensor prob_grad = - torch::empty({num_tokens, num_topK}, - torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); + Tensor prob_grad = torch::empty( + {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); float *prob_grad_ptr = reinterpret_cast(getDataPtr(prob_grad, 0)); auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -160,7 +158,7 @@ std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, num_cols *= 4; nvte_permute(input_bwd_ptr, act_grad_ptr, dtype, nullptr, row_id_map_ptr, prob_ptr, num_tokens, - num_topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, stream); + topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, stream); return std::make_tuple(act_grad, prob_grad); } From 31cfa794c2f1e3e4e30074756dbcf47364853460 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Thu, 25 Jul 2024 10:34:49 +0000 Subject: [PATCH 21/33] Clear up the code path Signed-off-by: Jiang Shao --- .../common/permutation/permutation.cu | 92 ++++++++++--------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 4695e3cab9..c6c7b6b0d7 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -210,49 +210,53 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, static constexpr int kElementsPerAccess = 16 / sizeof(T); - if (prob == nullptr) { - if (input_fwd == nullptr) { - // Permute fwd - int threads = 64; - int blocks = (num_rows * topK + threads - 1) / threads; - moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, - num_out_tokens); - - blocks = num_rows; - threads = std::min(num_cols / kElementsPerAccess, 1024); - moe_permute_kernel<<>>( - input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); - } else { - // Unpermute bwd without probs for topK == 1 - int blocks = num_rows; - int threads = 32; + if (input_fwd == nullptr) { + // permute fwd - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); - } + int threads = 64; + int blocks = (num_rows * topK + threads - 1) / threads; + + moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, + num_out_tokens); + + blocks = num_rows; + threads = std::min(num_cols / kElementsPerAccess, 1024); + moe_permute_kernel<<>>( + input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); } else { - // Unpermute bwd with probs - int blocks = num_rows; + // unpermute bwd + int threads = 32; - size_t smem_bytes = topK * sizeof(TCompute); - - if (topK <= 8) { - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); - } else if (topK <= 16) { - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); - } else if (topK <= 32) { - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); - } else if (topK <= 64) { - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); - } else if (topK <= 128) { - moe_permute_kernel<<>>( - input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + int blocks = num_rows; + + if (prob == nullptr) { + // unpermute bwd without probs + + moe_permute_kernel<<>>( + input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); } else { - NVTE_ERROR("topK cannot exceed 128."); + // unpermute bwd with probs + + size_t smem_bytes = topK * sizeof(TCompute); + + if (topK <= 8) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 16) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 32) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 64) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 128) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else { + NVTE_ERROR("topK cannot exceed 128."); + } } } } @@ -272,12 +276,14 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f size_t smem_bytes = topK * sizeof(TCompute); if (prob == nullptr) { - // Permute bwd - // Unpermute fwd without probs + // permute bwd + // unpermute fwd without probs + moe_unpermute_kernel<<>>( - input, output, row_id_map, prob, num_rows, topK, num_cols); + input, output, row_id_map, nullptr, num_rows, topK, num_cols); } else { - // Unpermute fwd with probs + // unpermute fwd with probs + moe_unpermute_kernel<<>>( input, output, row_id_map, prob, num_rows, topK, num_cols); } From b37533e7b5ee9bf332c43e2eddc8cd90e8329c87 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Thu, 25 Jul 2024 12:13:06 +0000 Subject: [PATCH 22/33] Minor Changes Signed-off-by: Jiang Shao --- .../common/permutation/permutation.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index c6c7b6b0d7..cefcdd2656 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -211,7 +211,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, static constexpr int kElementsPerAccess = 16 / sizeof(T); if (input_fwd == nullptr) { - // permute fwd + // moe_permute_fwd int threads = 64; int blocks = (num_rows * topK + threads - 1) / threads; @@ -224,18 +224,18 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, moe_permute_kernel<<>>( input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); } else { - // unpermute bwd + // moe_unpermute_bwd int threads = 32; int blocks = num_rows; if (prob == nullptr) { - // unpermute bwd without probs + // moe_unpermute_bwd without probs moe_permute_kernel<<>>( input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); } else { - // unpermute bwd with probs + // moe_unpermute_bwd with probs size_t smem_bytes = topK * sizeof(TCompute); @@ -276,13 +276,13 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f size_t smem_bytes = topK * sizeof(TCompute); if (prob == nullptr) { - // permute bwd - // unpermute fwd without probs + // moe_permute_bwd + // moe_unpermute_fwd without probs moe_unpermute_kernel<<>>( input, output, row_id_map, nullptr, num_rows, topK, num_cols); } else { - // unpermute fwd with probs + // moe_unpermute_fwd with probs moe_unpermute_kernel<<>>( input, output, row_id_map, prob, num_rows, topK, num_cols); From c8a9977c953601dffd968b83d419aa566f226ecb Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Thu, 25 Jul 2024 20:56:01 +0000 Subject: [PATCH 23/33] Add some comments Signed-off-by: Jiang Shao --- .../common/permutation/permutation.cu | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index cefcdd2656..9ceaf2ae5f 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -24,8 +24,10 @@ static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id int source_topK_id = source_row % topK; if (idx >= num_out_tokens) { + // Set the indices of dropped tokens to -1 row_id_map[source_topK_id * num_rows + source_token_id] = -1; } else { + // Create a row id map for subsequent unpermute operation row_id_map[source_topK_id * num_rows + source_token_id] = idx; } } @@ -37,28 +39,33 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const extern __shared__ int8_t s_mem[]; TCompute *s_prob = reinterpret_cast(s_mem); - // each block corresponds to one source token + // Each block corresponds to one dest token const int source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { for (int i = tid; i < topK; i += blockDim.x * blockDim.y) { + // Load all the topK probs related to the source row into smem s_prob[i] = TCompute(prob[source_token * topK + i]); } __syncthreads(); } + // Register buffers for vector type (float4) memory access float4 frag_load_store; T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + // Number of elemments in frag_load_store static constexpr int kElementsPerAccess = 16 / sizeof(T); + // Traverse along the hidden dimention for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { TCompute frag_elem[kElementsPerAccess]; TCompute frag_sum[kElementsPerAccess]; int source_row = row_id_map[source_token]; + // source_row == -1 represents a dropped token if (source_row != -1) { const T *source_row_ptr = input + source_row * num_cols; @@ -120,24 +127,32 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac extern __shared__ int8_t s_mem[]; TCompute *s_prob = reinterpret_cast(s_mem); + // Each block corresponds to one source token const int source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { for (int i = tid; i < topK; i += blockDim.x) { + // Load all the topK probs related to the source row into smem s_prob[i] = TCompute(prob[source_token * topK + i]); } __syncthreads(); } + // Accumulators for the calculation of prob_grad float accum[topKTile] = {0.0f}; + // Register buffers for vector type (float4) memory access float4 frag_load_store; T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + // Number of elemments in frag_load_store static constexpr int kElementsPerAccess = 16 / sizeof(T); + // The starting address of each source row const T *source_row_ptr = input_bwd + source_token * num_cols; + + // Traverse along the hidden dimention for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { TCompute frag_src[kElementsPerAccess]; @@ -147,6 +162,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac int index = source_token; + // Process each row in the corresponding topK rows for (int k = 0; k < topKTile; k++) { if (k == topK) break; @@ -155,9 +171,11 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac if (dest_row != -1) { if (hasProb) { + // Calculate act_grad in unpermute bwd for (int e = 0; e < kElementsPerAccess; e++) frag_load_store_ptr[e] = T(frag_src[e] * s_prob[k]); } else { + // permute fwd for (int e = 0; e < kElementsPerAccess; e++) frag_load_store_ptr[e] = T(frag_src[e]); } @@ -165,6 +183,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac *(float4 *)(dest_row_ptr + i) = frag_load_store; if (hasProb) { + // Inner product calculation for prob_grad in unpermute bwd const T *input_fwd_ptr = input_fwd + dest_row * num_cols; frag_load_store = __ldlu(reinterpret_cast(input_fwd_ptr + i)); @@ -184,7 +203,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac if (hasProb) { for (int k = 0; k < topKTile; k++) { if (k == topK) break; - + // Warp-level reduction for (int mask = 16; mask > 0; mask /= 2) { accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); } From 83ddfad23cc4709f8c0412c21eb8b99834f4f151 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Thu, 25 Jul 2024 20:59:25 +0000 Subject: [PATCH 24/33] Add some comments Signed-off-by: Jiang Shao --- transformer_engine/pytorch/csrc/extensions.h | 2 +- transformer_engine/pytorch/csrc/extensions/permutation.cu | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 5e65c7f482..44a117d581 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -11,7 +11,7 @@ #include "common/common.h" /*************************************************************************************************** - * permute + * Permutation **************************************************************************************************/ std::tuple> moe_permute_fwd( diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index b11666a2e4..8966189a96 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -17,7 +17,7 @@ std::tuple> moe_permute_fwd( int num_cols = input.size(1); const int topK = indices.size(1); - // initialize the workspace on the first run + // Initialize the workspace on the first run if (workspace.empty()) { auto options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false); @@ -53,7 +53,7 @@ std::tuple> moe_permute_fwd( sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, num_tokens * topK); - // activations type + // Activations type at::ScalarType _st; if (dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2) @@ -93,7 +93,7 @@ Tensor moe_unpermute_fwd(Tensor input, const transformer_engine::DType dtype, Te Tensor prob, int64_t num_tokens, int64_t topK) { int num_cols = input.size(1); - // activations type + // Activations type at::ScalarType _st; if (dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2) @@ -132,7 +132,7 @@ std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); float *prob_ptr = (prob.numel() > 0) ? reinterpret_cast(getDataPtr(prob, 0)) : nullptr; - // activations type + // Activations type at::ScalarType _st; if (dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2) From f049343cfa683ac7bba851bb7a4f61618930d51b Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Thu, 25 Jul 2024 21:04:03 +0000 Subject: [PATCH 25/33] Revise funcion description Signed-off-by: Jiang Shao --- transformer_engine/pytorch/permutation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 8c401b0218..51f1d438c0 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -187,6 +187,8 @@ def permute( ---------- inp: torch.Tensor Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + dtype: tex.DType + Data type of the input tensor. indices: torch.Tensor The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'. num_out_tokens: int, default = -1 @@ -214,6 +216,8 @@ def unpermute( ---------- inp: torch.Tensor Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted. + dtype: tex.DType + Data type of the input tensor. row_id_map: torch.Tensor The tensor of a mapping table for sorted indices used to unpermute the tokens, which is the second output tensor of `Permute`. From 2b331d48b8c77ae331609a0dec20344e0166fa5a Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Thu, 1 Aug 2024 16:55:40 +0000 Subject: [PATCH 26/33] Take NVTETensor as inputs Signed-off-by: Jiang Shao --- .../include/transformer_engine/permutation.h | 15 ++- .../common/permutation/permutation.cu | 71 +++++++++--- .../pytorch/csrc/extensions/permutation.cu | 108 ++++++++++-------- 3 files changed, 123 insertions(+), 71 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h index a450456e1a..c6263bf87e 100644 --- a/transformer_engine/common/include/transformer_engine/permutation.h +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -9,14 +9,13 @@ #include "transformer_engine.h" -void nvte_permute(const void *input, void *output, const transformer_engine::DType dtype, - const int *sorted_row_id, int *row_id_map, const float *prob, const int num_rows, - const int topK, const int num_cols, const int num_out_tokens, - float *prob_grad = nullptr, const void *input_fwd = nullptr, - cudaStream_t stream = nullptr); +void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, + NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, + const NVTETensor input_fwd, const int num_rows, const int topK, + const int num_cols, const int num_out_tokens, cudaStream_t stream = nullptr); -void nvte_unpermute(const void *input, void *output, const transformer_engine::DType dtype, - int *row_id_map, const float *prob, const int num_rows, const int topK, - const int num_cols, cudaStream_t stream = nullptr); +void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor prob, const int num_rows, const int topK, const int num_cols, + cudaStream_t stream = nullptr); #endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 9ceaf2ae5f..ab1108ed5a 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -220,9 +220,9 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac template void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, int *row_id_map, - const float *prob, const int num_rows, const int topK, - const int num_cols, const int num_out_tokens, float *prob_grad, - const T *input_fwd, cudaStream_t stream) { + const float *prob, float *prob_grad, const T *input_fwd, + const int num_rows, const int topK, const int num_cols, + const int num_out_tokens, cudaStream_t stream) { using TCompute = typename std::conditional<(std::is_same::value || std::is_same::value), half, T>::type; @@ -308,23 +308,58 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f } } -void nvte_permute(const void *input, void *output, const transformer_engine::DType dtype, - const int *sorted_row_id, int *row_id_map, const float *prob, const int num_rows, - const int topK, const int num_cols, const int num_out_tokens, float *prob_grad, - const void *input_fwd, cudaStream_t stream) { +void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, + NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, + const NVTETensor input_fwd, const int num_rows, const int topK, + const int num_cols, const int num_out_tokens, cudaStream_t stream) { + NVTE_API_CALL(nvte_permute); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *sorted_row_id_cu = + reinterpret_cast(sorted_row_id); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *prob_cu = + reinterpret_cast(prob); + const transformer_engine::Tensor *prob_grad_cu = + reinterpret_cast(prob_grad); + const transformer_engine::Tensor *input_fwd_cu = + reinterpret_cast(input_fwd); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - dtype, T, - nvte_permute_launcher(reinterpret_cast(input), reinterpret_cast(output), - sorted_row_id, row_id_map, prob, num_rows, topK, num_cols, - num_out_tokens, prob_grad, reinterpret_cast(input_fwd), - stream);); + input_cu->data.dtype, T, + nvte_permute_launcher(reinterpret_cast(input_cu->data.dptr), + reinterpret_cast(output_cu->data.dptr), + reinterpret_cast(sorted_row_id_cu->data.dptr), + reinterpret_cast(row_id_map_cu->data.dptr), + reinterpret_cast(prob_cu->data.dptr), + reinterpret_cast(prob_grad_cu->data.dptr), + reinterpret_cast(input_fwd_cu->data.dptr), num_rows, topK, + num_cols, num_out_tokens, stream);); } -void nvte_unpermute(const void *input, void *output, const transformer_engine::DType dtype, - int *row_id_map, const float *prob, const int num_rows, const int topK, - const int num_cols, cudaStream_t stream) { +void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor prob, const int num_rows, const int topK, const int num_cols, + cudaStream_t stream) { + NVTE_API_CALL(nvte_unpermute); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *prob_cu = + reinterpret_cast(prob); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - dtype, T, - nvte_unpermute_launcher(reinterpret_cast(input), reinterpret_cast(output), - row_id_map, prob, num_rows, topK, num_cols, stream);); + input_cu->data.dtype, T, + nvte_unpermute_launcher(reinterpret_cast(input_cu->data.dptr), + reinterpret_cast(output_cu->data.dptr), + reinterpret_cast(row_id_map_cu->data.dptr), + reinterpret_cast(prob_cu->data.dptr), num_rows, topK, + num_cols, stream);); } diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index 8966189a96..d45ddb03da 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -8,11 +8,9 @@ #include "extensions.h" -using torch::Tensor; - -std::tuple> moe_permute_fwd( - Tensor input, const transformer_engine::DType dtype, Tensor indices, int64_t num_out_tokens, - std::vector workspace, int64_t max_expanded_token_num) { +std::tuple> moe_permute_fwd( + at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, + int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num) { const int num_tokens = input.size(0); int num_cols = input.size(1); const int topK = indices.size(1); @@ -22,9 +20,9 @@ std::tuple> moe_permute_fwd( auto options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false); - Tensor sorted_indices = torch::empty(max_expanded_token_num, options); - Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options); - Tensor sorted_row_id = + at::Tensor sorted_indices = torch::empty(max_expanded_token_num, options); + at::Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options); + at::Tensor sorted_row_id = torch::empty(max_expanded_token_num, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); @@ -32,7 +30,7 @@ std::tuple> moe_permute_fwd( int *temp_ptr = nullptr; cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_ptr, temp_ptr, temp_ptr, temp_ptr, max_expanded_token_num); - Tensor temp_storage = torch::empty( + at::Tensor temp_storage = torch::empty( temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); workspace.push_back(sorted_indices); @@ -63,34 +61,45 @@ std::tuple> moe_permute_fwd( // Output buffer alloc num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; - Tensor permuted_output = torch::empty( + at::Tensor permuted_output = torch::empty( {num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); - Tensor row_id_map = torch::empty( + at::Tensor row_id_map = torch::empty( {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); - int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - void *input_ptr = getDataPtr(input, 0); - void *permuted_output_ptr = getDataPtr(permuted_output, 0); - if (dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2) num_cols *= 4; - nvte_permute(input_ptr, permuted_output_ptr, dtype, sorted_row_id_ptr, row_id_map_ptr, nullptr, - num_tokens, topK, num_cols, num_out_tokens, nullptr, nullptr, stream); + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); + auto permuted_output_cu = makeTransformerEngineTensor( + permuted_output.data_ptr(), + {static_cast(permuted_output.size(0)), static_cast(num_cols)}, dtype); + auto sorted_row_id_cu = + makeTransformerEngineTensor(sorted_row_id_ptr, {static_cast(num_tokens * topK)}, + transformer_engine::DType::kInt32); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + + nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), + row_id_map_cu.data(), transformer_engine::TensorWrapper().data(), + transformer_engine::TensorWrapper().data(), + transformer_engine::TensorWrapper().data(), num_tokens, topK, num_cols, + num_out_tokens, stream); return std::make_tuple(permuted_output, row_id_map, workspace); } -Tensor moe_permute_bwd(Tensor input, const transformer_engine::DType dtype, Tensor row_id_map, - Tensor prob, int64_t num_tokens, int64_t topK) { +at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK) { return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK); } -Tensor moe_unpermute_fwd(Tensor input, const transformer_engine::DType dtype, Tensor row_id_map, - Tensor prob, int64_t num_tokens, int64_t topK) { +at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK) { int num_cols = input.size(1); // Activations type @@ -102,36 +111,36 @@ Tensor moe_unpermute_fwd(Tensor input, const transformer_engine::DType dtype, Te _st = input.scalar_type(); // Output buffer alloc - Tensor unpermuted_output = torch::empty( + at::Tensor unpermuted_output = torch::empty( {num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); - int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); - float *prob_ptr = (prob.numel() > 0) ? reinterpret_cast(getDataPtr(prob, 0)) : nullptr; auto stream = at::cuda::getCurrentCUDAStream().stream(); - void *input_ptr = getDataPtr(input, 0); - void *unpermuted_output_ptr = getDataPtr(unpermuted_output, 0); - if (dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2) num_cols *= 4; - nvte_unpermute(input_ptr, unpermuted_output_ptr, dtype, row_id_map_ptr, prob_ptr, num_tokens, - topK, num_cols, stream); + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); + auto unpermuted_output_cu = makeTransformerEngineTensor( + unpermuted_output.data_ptr(), + {static_cast(unpermuted_output.size(0)), static_cast(num_cols)}, dtype); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + auto prob_cu = makeTransformerEngineTensor(prob); + + nvte_unpermute(input_cu.data(), unpermuted_output_cu.data(), row_id_map_cu.data(), prob_cu.data(), + num_tokens, topK, num_cols, stream); return unpermuted_output; } -std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, - const transformer_engine::DType dtype, - Tensor row_id_map, Tensor prob) { +std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, + const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob) { const int topK = (prob.numel() > 0) ? prob.size(1) : 1; const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); int num_cols = input_bwd.size(1); - int *row_id_map_ptr = reinterpret_cast(getDataPtr(row_id_map, 0)); - float *prob_ptr = (prob.numel() > 0) ? reinterpret_cast(getDataPtr(prob, 0)) : nullptr; - // Activations type at::ScalarType _st; if (dtype == transformer_engine::DType::kFloat8E4M3 || @@ -141,24 +150,33 @@ std::tuple moe_unpermute_bwd(Tensor input_bwd, Tensor input_fwd, _st = input_bwd.scalar_type(); // Output buffer alloc - Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, - torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); - Tensor prob_grad = torch::empty( + at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, + torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor prob_grad = torch::empty( {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); - float *prob_grad_ptr = reinterpret_cast(getDataPtr(prob_grad, 0)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - void *input_bwd_ptr = getDataPtr(input_bwd, 0); - void *input_fwd_ptr = getDataPtr(input_fwd, 0); - void *act_grad_ptr = getDataPtr(act_grad, 0); - if (dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2) num_cols *= 4; - nvte_permute(input_bwd_ptr, act_grad_ptr, dtype, nullptr, row_id_map_ptr, prob_ptr, num_tokens, - topK, num_cols, 0, prob_grad_ptr, input_fwd_ptr, stream); + auto input_bwd_cu = makeTransformerEngineTensor( + input_bwd.data_ptr(), {static_cast(input_bwd.size(0)), static_cast(num_cols)}, + dtype); + auto act_grad_cu = makeTransformerEngineTensor( + act_grad.data_ptr(), {static_cast(act_grad.size(0)), static_cast(num_cols)}, + dtype); + auto input_fwd_cu = makeTransformerEngineTensor( + input_fwd.data_ptr(), {static_cast(input_fwd.size(0)), static_cast(num_cols)}, + dtype); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + auto prob_cu = makeTransformerEngineTensor(prob); + auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); + + nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), transformer_engine::TensorWrapper().data(), + row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(), + num_tokens, topK, num_cols, 0, stream); return std::make_tuple(act_grad, prob_grad); } From 6d23c988ecdb597f8e60f07750b6f3d00014503e Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Thu, 1 Aug 2024 19:04:34 +0000 Subject: [PATCH 27/33] Split unit tests Signed-off-by: Jiang Shao --- tests/pytorch/test_permutation.py | 120 ++++++++++++++++++---- transformer_engine/pytorch/permutation.py | 10 ++ 2 files changed, 110 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 719e1ca61b..be0d10726a 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -11,20 +11,11 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine_torch as tex -# Only run FP8 tests on H100. -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -# TE tensor dtypes -_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] -if is_bf16_compatible(): - _te_dtypes.append(tex.DType.kBFloat16) -if fp8_available: - _te_dtypes.extend([tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def pytorch_permute(tokens, indices, num_out_tokens: int = None): """ @@ -156,14 +147,7 @@ def fp8_to_fp16(uint8_tensor, e4m3: bool = True): return float16_tensor -@pytest.mark.parametrize("te_dtype", _te_dtypes) -@pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) -@pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) -@pytest.mark.parametrize("num_out_tokens", [None, 4050]) -@pytest.mark.parametrize("with_probs", [True, False]) -def test_permutation( +def _test_permutation( te_dtype, num_tokens, num_expert, @@ -447,7 +431,103 @@ def perf_test_cuda_kernel(cuda_kernel_fn): pytest.skip("CUDA is not available.") -def test_permute_single_case(): +# TE tensor dtypes +_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] +if is_bf16_compatible(): + _te_dtypes.append(tex.DType.kBFloat16) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +# Only run FP8 tests on H100. +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +def test_permutation_topk1_no_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, +): + topK = 1 + num_out_tokens = None + with_probs = False + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +def test_permutation_single_case(): print("GPU:", torch.cuda.get_device_name(0)) # te_dtype = tex.DType.kFloat32 @@ -464,7 +544,7 @@ def test_permute_single_case(): with_probs = True Benchmark = True - test_permutation( + _test_permutation( te_dtype=te_dtype, num_tokens=num_tokens, num_expert=num_expert, @@ -477,4 +557,4 @@ def test_permute_single_case(): if __name__ == "__main__": - test_permute_single_case() + test_permutation_single_case() diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 51f1d438c0..f5d8916aeb 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -43,6 +43,11 @@ def forward( assert inp.size(0) == indices.size(0), "Permute not possible" # Data type check + if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: + assert inp.dtype == torch.float32, ( + "When using fp8 data as input, the input should be packed to" + f" {torch.float32} data type first." + ) if indices.dtype != torch.int32: warnings.warn( f"The data type of the input `indices` of Permute is {indices.dtype}! " @@ -135,6 +140,11 @@ def forward( assert row_id_map.is_cuda, "TransformerEngine needs CUDA." # Data type check + if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: + assert inp.dtype == torch.float32, ( + "When using fp8 data as input, the input should be packed to" + f" {torch.float32} data type first." + ) if row_id_map.dtype != torch.int32: warnings.warn( f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " From 0e01f8e650c2eabe65423bc06b028691c4b7a093 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Mon, 19 Aug 2024 09:26:42 +0000 Subject: [PATCH 28/33] Change names Signed-off-by: Jiang Shao --- tests/pytorch/test_permutation.py | 2 +- transformer_engine/pytorch/__init__.py | 2 +- transformer_engine/pytorch/permutation.py | 34 +++++++++++------------ 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index be0d10726a..1700d80184 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -6,7 +6,7 @@ import pytest from typing import Dict, List -from transformer_engine.pytorch import permute as te_permute, unpermute as te_unpermute +from transformer_engine.pytorch import moe_permute as te_permute, moe_unpermute as te_unpermute from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine_torch as tex diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 6649699054..1fbf4f26fe 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -42,7 +42,7 @@ def _load_library(): from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.transformer import TransformerLayer -from transformer_engine.pytorch.permutation import permute, unpermute +from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.graph import make_graphed_callables diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index f5d8916aeb..32efd45ce3 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -11,12 +11,12 @@ __all__ = [ - "Permute", - "Unpermute", + "moe_permute", + "moe_unpermute", ] -class _permute(torch.autograd.Function): +class _moe_permute(torch.autograd.Function): """functional Permute""" workspace = None @@ -58,16 +58,16 @@ def forward( topK = indices.size(1) input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK - if _permute.max_expanded_token_num < input_max_expanded_token_num: - _permute.max_expanded_token_num = input_max_expanded_token_num - _permute.workspace = [] + if _moe_permute.max_expanded_token_num < input_max_expanded_token_num: + _moe_permute.max_expanded_token_num = input_max_expanded_token_num + _moe_permute.workspace = [] - if _permute.dtype != dtype: - _permute.dtype = dtype - _permute.workspace = [] + if _moe_permute.dtype != dtype: + _moe_permute.dtype = dtype + _moe_permute.workspace = [] - permuted_act, row_id_map, _permute.workspace = tex.moe_permute_fwd( - inp, dtype, indices, num_out_tokens, _permute.workspace, _permute.max_expanded_token_num + permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd( + inp, dtype, indices, num_out_tokens, _moe_permute.workspace, _moe_permute.max_expanded_token_num ) ctx.row_id_map = row_id_map @@ -95,13 +95,13 @@ def backward( act_grad = None if ctx.needs_input_grad[0]: act_grad = tex.moe_permute_bwd( - permuted_act_grad, _permute.dtype, row_id_map, torch.empty(0), num_tokens, topK + permuted_act_grad, _moe_permute.dtype, row_id_map, torch.empty(0), num_tokens, topK ) return act_grad, None, None, None, None -class _unpermute(torch.autograd.Function): +class _moe_unpermute(torch.autograd.Function): """functional Unpermute""" @staticmethod @@ -183,7 +183,7 @@ def backward( return act_grad, None, None, prob_grad -def permute( +def moe_permute( inp: torch.Tensor, dtype: tex.DType, indices: torch.Tensor, @@ -209,10 +209,10 @@ def permute( By default, set to '-1', meaning the calculation of the size of workspace is automatically taken over by the operator. """ - return _permute.apply(inp, dtype, indices, num_out_tokens, max_token_num) + return _moe_permute.apply(inp, dtype, indices, num_out_tokens, max_token_num) -def unpermute( +def moe_unpermute( inp: torch.Tensor, dtype: tex.DType, row_id_map: torch.Tensor, @@ -236,4 +236,4 @@ def unpermute( the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. """ - return _unpermute.apply(inp, dtype, row_id_map, probs) + return _moe_unpermute.apply(inp, dtype, row_id_map, probs) From 906dbb8ef989099e47eb7f035aeb32d6e4cad27f Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Mon, 19 Aug 2024 19:50:00 +0000 Subject: [PATCH 29/33] Use Float8Tensor for FP8 input Signed-off-by: Jiang Shao --- tests/pytorch/test_permutation.py | 105 +++++------------- .../pytorch/csrc/extensions/permutation.cu | 18 +-- transformer_engine/pytorch/permutation.py | 61 +++++++++- 3 files changed, 85 insertions(+), 99 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 1700d80184..36cced7d7d 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -9,6 +9,7 @@ from transformer_engine.pytorch import moe_permute as te_permute, moe_unpermute as te_unpermute from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.float8_tensor import Float8Tensor import transformer_engine_torch as tex @@ -111,42 +112,6 @@ def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: raise ValueError(f"Unsuppored dtype ({te_dtype})") -def fp8_to_fp16(uint8_tensor, e4m3: bool = True): - assert uint8_tensor.dtype == torch.uint8, "Input tensor must be uint8" - - float16_tensor = torch.zeros_like(uint8_tensor, dtype=torch.float16) - - sign = (uint8_tensor >> 7) & 1 - exponent_mask = 0xF if e4m3 else 0x1F - if e4m3: - exponent = (uint8_tensor >> 3) & exponent_mask - mantissa = uint8_tensor & 0x7 - else: - exponent = (uint8_tensor >> 2) & exponent_mask - mantissa = uint8_tensor & 0x3 - - exponent_bias = 7 if e4m3 else 15 - mantissa_max = 8.0 if e4m3 else 4.0 - - normal_mask = (exponent != 0) & ~(exponent == exponent_mask) - actual_exponent = exponent[normal_mask].to(torch.float16) - exponent_bias - actual_mantissa = (mantissa[normal_mask].to(torch.float16) + mantissa_max) / mantissa_max - float16_tensor[normal_mask] = ( - ((-1) ** sign[normal_mask].to(torch.float16)) * (2**actual_exponent) * actual_mantissa - ) - - subnormal_mask = (exponent == 0) & (mantissa != 0) - subnormal_exponent = 1 - exponent_bias - subnormal_mantissa = mantissa[subnormal_mask].to(torch.float16) / mantissa_max - float16_tensor[subnormal_mask] = ( - ((-1) ** sign[subnormal_mask].to(torch.float16)) - * (2**subnormal_exponent) - * subnormal_mantissa - ) - - return float16_tensor - - def _test_permutation( te_dtype, num_tokens, @@ -185,25 +150,17 @@ def _test_permutation( pytest.skip("Invalid dtype.") if fp8: - N = 56 if te_dtype == tex.DType.kFloat8E4M3 else 60 - permute_fwd_input = torch.randint( - low=0, high=N + 1, size=(num_tokens, hidden_size), dtype=torch.uint8 - ).cuda() - permute_bwd_input = torch.randint( - low=0, high=N + 1, size=(num_out_tokens, hidden_size), dtype=torch.uint8 - ).cuda() - unpermute_bwd_input = torch.randint( - low=0, high=N + 1, size=(num_tokens, hidden_size), dtype=torch.uint8 - ).cuda() - pytorch_permute_fwd_input = fp8_to_fp16( - permute_fwd_input, te_dtype == tex.DType.kFloat8E4M3 - ) - pytorch_permute_bwd_input = fp8_to_fp16( - permute_bwd_input, te_dtype == tex.DType.kFloat8E4M3 - ) - pytorch_unpermute_bwd_input = fp8_to_fp16( - unpermute_bwd_input, te_dtype == tex.DType.kFloat8E4M3 - ) + permute_fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") + permute_bwd_input = torch.rand(size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda") + unpermute_bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") + + permute_fwd_input = Float8Tensor.to_float8(permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)) + permute_bwd_input = Float8Tensor.to_float8(permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)) + unpermute_bwd_input = Float8Tensor.to_float8(unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)) + + pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16) else: pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() @@ -248,11 +205,11 @@ def _test_permutation( # ################################################################################################################################### te_permute_fwd_input = ( - permute_fwd_input.view(torch.float32) if fp8 else pytorch_permute_fwd_input.detach() + permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() ) te_permute_fwd_input.requires_grad_(True) te_permute_bwd_input = ( - permute_bwd_input.view(torch.float32) if fp8 else pytorch_permute_bwd_input.detach() + permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() ) te_permute_output, row_id_map = te_permute( @@ -267,7 +224,7 @@ def _test_permutation( te_unpermute_fwd_input = te_permute_output.detach() te_unpermute_fwd_input.requires_grad_(True) te_unpermute_bwd_input = ( - unpermute_bwd_input.view(torch.float32) if fp8 else pytorch_unpermute_bwd_input.detach() + unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() ) te_unpermute_output = te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) @@ -281,44 +238,36 @@ def _test_permutation( tols = dtype_tols(te_dtype) if fp8: - te_permute_output_ = fp8_to_fp16( - te_permute_output.view(torch.uint8), te_dtype == tex.DType.kFloat8E4M3 - ) - te_permute_fwd_input_grad = fp8_to_fp16( - te_permute_fwd_input.grad.view(torch.uint8), te_dtype == tex.DType.kFloat8E4M3 - ) - te_unpermute_output_ = fp8_to_fp16( - te_unpermute_output.view(torch.uint8), te_dtype == tex.DType.kFloat8E4M3 - ) - te_unpermute_fwd_input_grad = fp8_to_fp16( - te_unpermute_fwd_input.grad.view(torch.uint8), te_dtype == tex.DType.kFloat8E4M3 - ) + te_permute_output_ = te_permute_output.from_float8(torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32) + te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32) else: - te_permute_output_ = te_permute_output - te_permute_fwd_input_grad = te_permute_fwd_input.grad - te_unpermute_output_ = te_unpermute_output - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad + te_permute_output_ = te_permute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() torch.testing.assert_close( pytorch_permute_output.float(), - te_permute_output_.float(), + te_permute_output_, msg=f"Mismatch in te_permute fwd", ) torch.testing.assert_close( pytorch_permute_fwd_input.grad.float(), - te_permute_fwd_input_grad.float(), + te_permute_fwd_input_grad, msg=f"Mismatch in te_permute bwd", **tols, ) torch.testing.assert_close( pytorch_unpermute_output.float(), - te_unpermute_output_.float(), + te_unpermute_output_, msg=f"Mismatch in te_unpermute fwd", **tols, ) torch.testing.assert_close( pytorch_unpermute_fwd_input.grad.float(), - te_unpermute_fwd_input_grad.float(), + te_unpermute_fwd_input_grad, msg=f"Mismatch in te_unpermute bwd", **tols, ) diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index d45ddb03da..0c9bed45e0 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -55,7 +55,7 @@ std::tuple> moe_permute_fwd( at::ScalarType _st; if (dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2) - _st = at::ScalarType::Float; + _st = at::ScalarType::Byte; else _st = input.scalar_type(); @@ -68,10 +68,6 @@ std::tuple> moe_permute_fwd( auto stream = at::cuda::getCurrentCUDAStream().stream(); - if (dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2) - num_cols *= 4; - auto input_cu = makeTransformerEngineTensor( input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); auto permuted_output_cu = makeTransformerEngineTensor( @@ -106,7 +102,7 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d at::ScalarType _st; if (dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2) - _st = at::ScalarType::Float; + _st = at::ScalarType::Byte; else _st = input.scalar_type(); @@ -116,10 +112,6 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d auto stream = at::cuda::getCurrentCUDAStream().stream(); - if (dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2) - num_cols *= 4; - auto input_cu = makeTransformerEngineTensor( input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); auto unpermuted_output_cu = makeTransformerEngineTensor( @@ -145,7 +137,7 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T at::ScalarType _st; if (dtype == transformer_engine::DType::kFloat8E4M3 || dtype == transformer_engine::DType::kFloat8E5M2) - _st = at::ScalarType::Float; + _st = at::ScalarType::Byte; else _st = input_bwd.scalar_type(); @@ -157,10 +149,6 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T auto stream = at::cuda::getCurrentCUDAStream().stream(); - if (dtype == transformer_engine::DType::kFloat8E4M3 || - dtype == transformer_engine::DType::kFloat8E5M2) - num_cols *= 4; - auto input_bwd_cu = makeTransformerEngineTensor( input_bwd.data_ptr(), {static_cast(input_bwd.size(0)), static_cast(num_cols)}, dtype); diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 32efd45ce3..9aa219e098 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -8,6 +8,7 @@ from typing import Tuple import transformer_engine_torch as tex +from transformer_engine.pytorch.float8_tensor import Float8Tensor __all__ = [ @@ -43,11 +44,16 @@ def forward( assert inp.size(0) == indices.size(0), "Permute not possible" # Data type check + fp8 = False if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert inp.dtype == torch.float32, ( - "When using fp8 data as input, the input should be packed to" - f" {torch.float32} data type first." + fp8 = True + if fp8: + assert isinstance(inp, Float8Tensor), ( + "Input must be in Float8Tensor type for FP8 moe_permute." ) + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + inp = inp._data if indices.dtype != torch.int32: warnings.warn( f"The data type of the input `indices` of Permute is {indices.dtype}! " @@ -69,10 +75,16 @@ def forward( permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd( inp, dtype, indices, num_out_tokens, _moe_permute.workspace, _moe_permute.max_expanded_token_num ) + + if fp8: + permuted_act = Float8Tensor(data=permuted_act, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv) ctx.row_id_map = row_id_map ctx.num_tokens = indices.size(0) ctx.topK = indices.size(1) + ctx.fp8 = fp8 return permuted_act, row_id_map @staticmethod @@ -88,6 +100,15 @@ def backward( if not permuted_act_grad.is_contiguous(): permuted_act_grad = permuted_act_grad.contiguous() + fp8 = ctx.fp8 + if fp8: + assert isinstance(permuted_act_grad, Float8Tensor), ( + "Grad of the output must be in Float8Tensor type for FP8 moe_permute." + ) + fp8_dtype=permuted_act_grad._fp8_dtype + fp8_scale_inv=permuted_act_grad._scale_inv + permuted_act_grad = permuted_act_grad._data + row_id_map = ctx.row_id_map num_tokens = ctx.num_tokens topK = ctx.topK @@ -97,6 +118,10 @@ def backward( act_grad = tex.moe_permute_bwd( permuted_act_grad, _moe_permute.dtype, row_id_map, torch.empty(0), num_tokens, topK ) + if fp8: + act_grad = Float8Tensor(data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv) return act_grad, None, None, None, None @@ -140,11 +165,16 @@ def forward( assert row_id_map.is_cuda, "TransformerEngine needs CUDA." # Data type check + fp8 = False if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert inp.dtype == torch.float32, ( - "When using fp8 data as input, the input should be packed to" - f" {torch.float32} data type first." + fp8 = True + if fp8: + assert isinstance(inp, Float8Tensor), ( + "Input must be in Float8Tensor type for FP8 moe_unpermute." ) + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + inp = inp._data if row_id_map.dtype != torch.int32: warnings.warn( f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " @@ -154,8 +184,14 @@ def forward( unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) + if fp8: + unpermuted_output = Float8Tensor(data=unpermuted_output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv) + ctx.dtype = dtype ctx.save_for_backward(inp, row_id_map, probs) + ctx.fp8 = fp8 return unpermuted_output @staticmethod @@ -170,6 +206,15 @@ def backward( if not unpermuted_act_grad.is_contiguous(): unpermuted_act_grad = unpermuted_act_grad.contiguous() + fp8 = ctx.fp8 + if fp8: + assert isinstance(unpermuted_act_grad, Float8Tensor), ( + "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." + ) + fp8_dtype=unpermuted_act_grad._fp8_dtype + fp8_scale_inv=unpermuted_act_grad._scale_inv + unpermuted_act_grad = unpermuted_act_grad._data + inp, row_id_map, probs = ctx.saved_tensors act_grad = None @@ -177,6 +222,10 @@ def backward( act_grad, prob_grad = tex.moe_unpermute_bwd( unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs ) + if fp8: + act_grad = Float8Tensor(data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv) if not ctx.needs_input_grad[3]: prob_grad = None From 67407816278a05e2b2e942188adb8b056c25cddf Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Mon, 19 Aug 2024 19:56:18 +0000 Subject: [PATCH 30/33] Reformat Signed-off-by: Jiang Shao --- tests/pytorch/test_permutation.py | 36 +++++++------ transformer_engine/pytorch/permutation.py | 65 ++++++++++++----------- 2 files changed, 56 insertions(+), 45 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 36cced7d7d..99bd706b45 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -150,13 +150,25 @@ def _test_permutation( pytest.skip("Invalid dtype.") if fp8: - permute_fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") - permute_bwd_input = torch.rand(size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda") - unpermute_bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") + permute_fwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + permute_bwd_input = torch.rand( + size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + unpermute_bwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) - permute_fwd_input = Float8Tensor.to_float8(permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)) - permute_bwd_input = Float8Tensor.to_float8(permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)) - unpermute_bwd_input = Float8Tensor.to_float8(unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)) + permute_fwd_input = Float8Tensor.to_float8( + permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + permute_bwd_input = Float8Tensor.to_float8( + permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + unpermute_bwd_input = Float8Tensor.to_float8( + unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16) pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16) @@ -204,13 +216,9 @@ def _test_permutation( # TE Permutation # ################################################################################################################################### - te_permute_fwd_input = ( - permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() - ) + te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() te_permute_fwd_input.requires_grad_(True) - te_permute_bwd_input = ( - permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() - ) + te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() te_permute_output, row_id_map = te_permute( te_permute_fwd_input, te_dtype, indices, num_out_tokens @@ -223,9 +231,7 @@ def _test_permutation( te_probs.requires_grad_(True) te_unpermute_fwd_input = te_permute_output.detach() te_unpermute_fwd_input.requires_grad_(True) - te_unpermute_bwd_input = ( - unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() - ) + te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() te_unpermute_output = te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 9aa219e098..422694e1f5 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -48,9 +48,9 @@ def forward( if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: fp8 = True if fp8: - assert isinstance(inp, Float8Tensor), ( - "Input must be in Float8Tensor type for FP8 moe_permute." - ) + assert isinstance( + inp, Float8Tensor + ), "Input must be in Float8Tensor type for FP8 moe_permute." fp8_dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv inp = inp._data @@ -73,13 +73,18 @@ def forward( _moe_permute.workspace = [] permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd( - inp, dtype, indices, num_out_tokens, _moe_permute.workspace, _moe_permute.max_expanded_token_num + inp, + dtype, + indices, + num_out_tokens, + _moe_permute.workspace, + _moe_permute.max_expanded_token_num, ) - + if fp8: - permuted_act = Float8Tensor(data=permuted_act, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv) + permuted_act = Float8Tensor( + data=permuted_act, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) ctx.row_id_map = row_id_map ctx.num_tokens = indices.size(0) @@ -102,11 +107,11 @@ def backward( fp8 = ctx.fp8 if fp8: - assert isinstance(permuted_act_grad, Float8Tensor), ( - "Grad of the output must be in Float8Tensor type for FP8 moe_permute." - ) - fp8_dtype=permuted_act_grad._fp8_dtype - fp8_scale_inv=permuted_act_grad._scale_inv + assert isinstance( + permuted_act_grad, Float8Tensor + ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." + fp8_dtype = permuted_act_grad._fp8_dtype + fp8_scale_inv = permuted_act_grad._scale_inv permuted_act_grad = permuted_act_grad._data row_id_map = ctx.row_id_map @@ -119,9 +124,9 @@ def backward( permuted_act_grad, _moe_permute.dtype, row_id_map, torch.empty(0), num_tokens, topK ) if fp8: - act_grad = Float8Tensor(data=act_grad, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv) + act_grad = Float8Tensor( + data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) return act_grad, None, None, None, None @@ -169,9 +174,9 @@ def forward( if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: fp8 = True if fp8: - assert isinstance(inp, Float8Tensor), ( - "Input must be in Float8Tensor type for FP8 moe_unpermute." - ) + assert isinstance( + inp, Float8Tensor + ), "Input must be in Float8Tensor type for FP8 moe_unpermute." fp8_dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv inp = inp._data @@ -185,9 +190,9 @@ def forward( unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) if fp8: - unpermuted_output = Float8Tensor(data=unpermuted_output, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv) + unpermuted_output = Float8Tensor( + data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) ctx.dtype = dtype ctx.save_for_backward(inp, row_id_map, probs) @@ -208,11 +213,11 @@ def backward( fp8 = ctx.fp8 if fp8: - assert isinstance(unpermuted_act_grad, Float8Tensor), ( - "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." - ) - fp8_dtype=unpermuted_act_grad._fp8_dtype - fp8_scale_inv=unpermuted_act_grad._scale_inv + assert isinstance( + unpermuted_act_grad, Float8Tensor + ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." + fp8_dtype = unpermuted_act_grad._fp8_dtype + fp8_scale_inv = unpermuted_act_grad._scale_inv unpermuted_act_grad = unpermuted_act_grad._data inp, row_id_map, probs = ctx.saved_tensors @@ -223,9 +228,9 @@ def backward( unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs ) if fp8: - act_grad = Float8Tensor(data=act_grad, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv) + act_grad = Float8Tensor( + data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) if not ctx.needs_input_grad[3]: prob_grad = None From f5ad7fc704414febb1834199e167441dd6a42a52 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Tue, 20 Aug 2024 21:19:12 +0000 Subject: [PATCH 31/33] Rescale fp8 for permute backward Signed-off-by: Jiang Shao --- transformer_engine/common/permutation/permutation.cu | 4 ++++ transformer_engine/pytorch/permutation.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index ab1108ed5a..196cc21f26 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -113,6 +113,10 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const T *dest_row_ptr = unpermuted_output + source_token * num_cols; for (int e = 0; e < kElementsPerAccess; e++) { + if constexpr ((std::is_same_v || std::is_same_v) && + (!hasProb)) { + frag_sum[e] = frag_sum[e] / TCompute(topK); + } frag_load_store_ptr[e] = T(frag_sum[e]); } diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 422694e1f5..e1fe763caa 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -125,7 +125,7 @@ def backward( ) if fp8: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * topK ) return act_grad, None, None, None, None From 5307a0143345e4a5b5e958d4840151e3f395d209 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Tue, 20 Aug 2024 21:28:46 +0000 Subject: [PATCH 32/33] Move dtype to ctx Signed-off-by: Jiang Shao --- transformer_engine/pytorch/permutation.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index e1fe763caa..c3fba509fa 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -21,7 +21,6 @@ class _moe_permute(torch.autograd.Function): """functional Permute""" workspace = None - dtype = None max_expanded_token_num = 0 @staticmethod @@ -68,10 +67,6 @@ def forward( _moe_permute.max_expanded_token_num = input_max_expanded_token_num _moe_permute.workspace = [] - if _moe_permute.dtype != dtype: - _moe_permute.dtype = dtype - _moe_permute.workspace = [] - permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd( inp, dtype, @@ -89,6 +84,7 @@ def forward( ctx.row_id_map = row_id_map ctx.num_tokens = indices.size(0) ctx.topK = indices.size(1) + ctx.dtype = dtype ctx.fp8 = fp8 return permuted_act, row_id_map @@ -121,7 +117,7 @@ def backward( act_grad = None if ctx.needs_input_grad[0]: act_grad = tex.moe_permute_bwd( - permuted_act_grad, _moe_permute.dtype, row_id_map, torch.empty(0), num_tokens, topK + permuted_act_grad, ctx.dtype, row_id_map, torch.empty(0), num_tokens, topK ) if fp8: act_grad = Float8Tensor( From 318caf884d3cb949acf5dbf90598ac4dd3f98b33 Mon Sep 17 00:00:00 2001 From: Jiang Shao Date: Thu, 22 Aug 2024 10:14:02 +0000 Subject: [PATCH 33/33] Format fix Signed-off-by: Jiang Shao --- transformer_engine/common/permutation/permutation.cu | 6 +++--- transformer_engine/pytorch/permutation.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 196cc21f26..2b894fbfdc 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -120,7 +120,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const frag_load_store_ptr[e] = T(frag_sum[e]); } - *(float4 *)(dest_row_ptr + i) = frag_load_store; + *reinterpret_cast(dest_row_ptr + i) = frag_load_store; } } @@ -184,7 +184,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac } T *dest_row_ptr = act_grad + dest_row * num_cols; - *(float4 *)(dest_row_ptr + i) = frag_load_store; + *reinterpret_cast(dest_row_ptr + i) = frag_load_store; if (hasProb) { // Inner product calculation for prob_grad in unpermute bwd @@ -197,7 +197,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac frag_input_fwd[e] = TCompute(frag_load_store_ptr[e]); for (int e = 0; e < kElementsPerAccess; e++) { - accum[k] += float(frag_src[e] * frag_input_fwd[e]); + accum[k] += static_cast(frag_src[e] * frag_input_fwd[e]); } } } diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index c3fba509fa..0c098830a9 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -3,9 +3,9 @@ # See LICENSE for license information. """Linear API""" -import torch import warnings from typing import Tuple +import torch import transformer_engine_torch as tex from transformer_engine.pytorch.float8_tensor import Float8Tensor