diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py new file mode 100644 index 0000000000..99bd706b45 --- /dev/null +++ b/tests/pytorch/test_permutation.py @@ -0,0 +1,515 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import pytest +from typing import Dict, List + +from transformer_engine.pytorch import moe_permute as te_permute, moe_unpermute as te_unpermute +from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.float8_tensor import Float8Tensor +import transformer_engine_torch as tex + + +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +def pytorch_permute(tokens, indices, num_out_tokens: int = None): + """ + Permute the tokens based on the indices. Token with the same index will be grouped together. + The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately. + + Args: + tokens: torch.Tensor + The input token tensor. + indices: torch.Tensor + The token to expert indices tensor, should have a shape of [num_tokens] or [num_tokens, topk]. + num_out_tokens: int, optional + The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. + By default, set to None, meaning no tokens are dropped. + + Returns: + torch.Tensor: + The permuted tensor. + torch.Tensor: + The sorted_indices corresponding permuted tensor. + """ + if indices.dim() == 1: + topk = 1 + else: + topk = indices.size(1) + flatten_indices = indices.view(-1) + sorted_indices = torch.argsort(flatten_indices, stable=True) + num_out_tokens = num_out_tokens if num_out_tokens is not None else flatten_indices.size(0) + + permuted_tokens = tokens.index_select(0, sorted_indices[:num_out_tokens] // topk) + return permuted_tokens, sorted_indices + + +def pytorch_unpermute( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + probs: torch.Tensor = None, +): + """ + Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their + corresponding probabilities. + + Args: + permuted_tokens: torch.Tensor + The tensor of permuted tokens to be unpermuted. + sorted_indices: torch.Tensor + The tensor of sorted indices used to unpermute the tokens. + probs: torch.Tensor, optional + The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will + be merged with their respective probabilities. + + Returns: + torch.Tensor: + The unpermuted tokens, optionally merged with probabilities. + """ + + if probs is not None: + # Unpermute and merge the tokens with their probabilities + num_unpermuted_tokens = probs.numel() + topk = probs.size(1) + else: + # Unpermute the tokens without merge + num_unpermuted_tokens = sorted_indices.size(0) + topk = 1 + unpermuted_tokens = torch.zeros( + [num_unpermuted_tokens, permuted_tokens.shape[-1]], + dtype=permuted_tokens.dtype, + device=permuted_tokens.device, + ) + + unpermuted_tokens.index_copy_(0, sorted_indices[: permuted_tokens.size(0)], permuted_tokens) + unpermuted_tokens = unpermuted_tokens.reshape(-1, topk, permuted_tokens.size(-1)) + if probs is not None: + unpermuted_tokens = unpermuted_tokens * probs.unsqueeze(-1) + unpermuted_tokens = unpermuted_tokens.sum(dim=1) + return unpermuted_tokens + + +def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: + """Estimated tolerances for a datatype + + Based on tolerances for torch.testing.assert_close. + + """ + if te_dtype == tex.DType.kFloat32: + return dict(rtol=1.0e-6, atol=1.0e-6) + if te_dtype == tex.DType.kFloat16: + return dict(rtol=3.0e-3, atol=1.0e-5) + if te_dtype == tex.DType.kBFloat16: + return dict(rtol=2.0e-2, atol=1.0e-5) + if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3: + return dict(rtol=2.0e-1, atol=1.0e-1) + raise ValueError(f"Unsuppored dtype ({te_dtype})") + + +def _test_permutation( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + with_probs, + BENCHMARK=False, +): + if not with_probs and topK > 1: + pytest.skip("Only permutations with topK=1 and without probabilities are supported.") + + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" + ) + + fp8 = False + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): + dtype = torch.uint8 + fp8 = True + else: + pytest.skip("Invalid dtype.") + + if fp8: + permute_fwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + permute_bwd_input = torch.rand( + size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + unpermute_bwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + + permute_fwd_input = Float8Tensor.to_float8( + permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + permute_bwd_input = Float8Tensor.to_float8( + permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + unpermute_bwd_input = Float8Tensor.to_float8( + unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + ) + + pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16) + else: + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + + pytorch_permute_fwd_input.requires_grad_(True) + + if num_tokens > 0: + indices = torch.stack([torch.randperm(num_expert)[:topK] for _ in range(num_tokens)]) + else: + indices = torch.empty((num_tokens, topK)) + indices = indices.to(torch.int32).cuda() + + probs = None + if with_probs: + probs = torch.rand(num_tokens, topK).cuda() + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + probs.requires_grad_(True) + + ################################################################################################################################### + # + # PyTorch Permutation + # + ################################################################################################################################### + pytorch_permute_output, sorted_indices = pytorch_permute( + pytorch_permute_fwd_input, indices, num_out_tokens + ) + pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True) + + pytorch_unpermute_fwd_input = pytorch_permute_output.detach() + pytorch_unpermute_fwd_input.requires_grad_(True) + + pytorch_unpermute_output = pytorch_unpermute( + pytorch_unpermute_fwd_input, sorted_indices, probs=probs + ) + pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # TE Permutation + # + ################################################################################################################################### + te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input.requires_grad_(True) + te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() + + te_permute_output, row_id_map = te_permute( + te_permute_fwd_input, te_dtype, indices, num_out_tokens + ) + te_permute_output.backward(te_permute_bwd_input, retain_graph=True) + + te_probs = None + if with_probs: + te_probs = probs.detach() + te_probs.requires_grad_(True) + te_unpermute_fwd_input = te_permute_output.detach() + te_unpermute_fwd_input.requires_grad_(True) + te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + + te_unpermute_output = te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) + te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + if fp8: + te_permute_output_ = te_permute_output.from_float8(torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32) + te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32) + else: + te_permute_output_ = te_permute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() + + torch.testing.assert_close( + pytorch_permute_output.float(), + te_permute_output_, + msg=f"Mismatch in te_permute fwd", + ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_fwd_input.grad.float(), + te_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + if with_probs: + torch.testing.assert_close( + probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols + ) + + if not pytorch_permute_fwd_input.numel(): + print("Empty pytorch_permute_fwd_input activation test passed.") + return + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + def backward_wrapper( + act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False + ): + # Set forward_input.grad to None to avoid grad accumulation. + if accumulate_grad == False: + for i in forward_input: + i.grad = None + return act.backward(backward_input, retain_graph=retain_graph) + + if BENCHMARK: + t1 = perf_test_cuda_kernel( + lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens) + ) + t2 = perf_test_cuda_kernel( + lambda: te_permute(te_permute_fwd_input, te_dtype, indices, num_out_tokens) + ) + print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_permute_output, + pytorch_permute_bwd_input, + forward_input=[pytorch_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_permute_output, + te_permute_bwd_input, + forward_input=[te_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs) + ) + t2 = perf_test_cuda_kernel( + lambda: te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) + ) + print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_unpermute_output, + pytorch_unpermute_bwd_input, + forward_input=( + [pytorch_unpermute_fwd_input, probs] + if with_probs + else [pytorch_unpermute_fwd_input] + ), + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_unpermute_output, + te_unpermute_bwd_input, + forward_input=( + [te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input] + ), + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + +def perf_test_cuda_kernel(cuda_kernel_fn): + if torch.cuda.is_available(): + # create CUDA event + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # warmup + for _ in range(50): + cuda_kernel_fn() + + start_event.record() + for _ in range(100): + cuda_kernel_fn() + end_event.record() + torch.cuda.synchronize() + + elapsed_time_ms = start_event.elapsed_time(end_event) + return elapsed_time_ms / 100 + else: + pytest.skip("CUDA is not available.") + + +# TE tensor dtypes +_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] +if is_bf16_compatible(): + _te_dtypes.append(tex.DType.kBFloat16) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +# Only run FP8 tests on H100. +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +def test_permutation_topk1_no_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, +): + topK = 1 + num_out_tokens = None + with_probs = False + BENCHMARK = False + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +def test_permutation_single_case(): + print("GPU:", torch.cuda.get_device_name(0)) + + # te_dtype = tex.DType.kFloat32 + # te_dtype = tex.DType.kFloat16 + # te_dtype = tex.DType.kBFloat16 + te_dtype = tex.DType.kFloat8E5M2 + # te_dtype = tex.DType.kFloat8E4M3 + + num_tokens = 10 + num_expert = 4 + hidden_size = 16 + topK = 2 + num_out_tokens = num_tokens * topK - 1 + with_probs = True + Benchmark = True + + _test_permutation( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=Benchmark, + ) + + +if __name__ == "__main__": + test_permutation_single_case() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a4497751f4..06bfec49b4 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -62,6 +62,7 @@ list(APPEND transformer_engine_SOURCES layer_norm/ln_api.cpp layer_norm/ln_bwd_semi_cuda_kernel.cu layer_norm/ln_fwd_cuda_kernel.cu + permutation/permutation.cu rmsnorm/rmsnorm_api.cpp rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu rmsnorm/rmsnorm_fwd_cuda_kernel.cu diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 8667b64e65..593ec086d7 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -255,7 +255,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "Unable to find suitable cuBLAS GEMM algorithm"); NVTE_CHECK_CUBLAS(status); - if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); + if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h new file mode 100644 index 0000000000..c6263bf87e --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -0,0 +1,21 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PERMUTATION_H_ +#define TRANSFORMER_ENGINE_PERMUTATION_H_ + +#include "transformer_engine.h" + +void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, + NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, + const NVTETensor input_fwd, const int num_rows, const int topK, + const int num_cols, const int num_out_tokens, cudaStream_t stream = nullptr); + +void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor prob, const int num_rows, const int topK, const int num_cols, + cudaStream_t stream = nullptr); + +#endif // TRANSFORMER_ENGINE_PERMUTATION_H_ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu new file mode 100644 index 0000000000..2b894fbfdc --- /dev/null +++ b/transformer_engine/common/permutation/permutation.cu @@ -0,0 +1,369 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../common.h" + +static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id_map, + const int num_rows, const int topK, + const int num_out_tokens) { + // Each block corresponds to one source token + // row_id_map[topK][num_rows] + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int idx = bid * blockDim.x + tid; + + if (idx >= num_rows * topK) return; + + int source_row = sorted_row_id[idx]; + int source_token_id = source_row / topK; + int source_topK_id = source_row % topK; + + if (idx >= num_out_tokens) { + // Set the indices of dropped tokens to -1 + row_id_map[source_topK_id * num_rows + source_token_id] = -1; + } else { + // Create a row id map for subsequent unpermute operation + row_id_map[source_topK_id * num_rows + source_token_id] = idx; + } +} + +template +__global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const int *row_id_map, + const float *prob, const int num_rows, const int topK, + const int num_cols) { + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + // Each block corresponds to one dest token + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) { + for (int i = tid; i < topK; i += blockDim.x * blockDim.y) { + // Load all the topK probs related to the source row into smem + s_prob[i] = TCompute(prob[source_token * topK + i]); + } + __syncthreads(); + } + + // Register buffers for vector type (float4) memory access + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + + // Number of elemments in frag_load_store + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + // Traverse along the hidden dimention + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { + TCompute frag_elem[kElementsPerAccess]; + TCompute frag_sum[kElementsPerAccess]; + + int source_row = row_id_map[source_token]; + + // source_row == -1 represents a dropped token + if (source_row != -1) { + const T *source_row_ptr = input + source_row * num_cols; + + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(frag_load_store_ptr[e]); + } + + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] * s_prob[0]; + } + } + } else { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = TCompute(0.0f); + } + } + + for (int k = 1; k < topK; k++) { + source_row = row_id_map[k * num_rows + source_token]; + + if (source_row == -1) continue; + + const T *source_row_ptr = input + source_row * num_cols; + + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = TCompute(frag_load_store_ptr[e]); + } + + if (hasProb) { + for (int e = 0; e < kElementsPerAccess; e++) { + frag_elem[e] = frag_elem[e] * s_prob[k]; + } + } + + for (int e = 0; e < kElementsPerAccess; e++) { + frag_sum[e] = frag_sum[e] + frag_elem[e]; + } + } + + T *dest_row_ptr = unpermuted_output + source_token * num_cols; + + for (int e = 0; e < kElementsPerAccess; e++) { + if constexpr ((std::is_same_v || std::is_same_v) && + (!hasProb)) { + frag_sum[e] = frag_sum[e] / TCompute(topK); + } + frag_load_store_ptr[e] = T(frag_sum[e]); + } + + *reinterpret_cast(dest_row_ptr + i) = frag_load_store; + } +} + +template +__global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *act_grad, + const float *prob, float *prob_grad, const int *row_id_map, + const int num_rows, const int topK, const int num_cols) { + extern __shared__ int8_t s_mem[]; + TCompute *s_prob = reinterpret_cast(s_mem); + + // Each block corresponds to one source token + const int source_token = blockIdx.x; + const int tid = threadIdx.x; + + if (hasProb) { + for (int i = tid; i < topK; i += blockDim.x) { + // Load all the topK probs related to the source row into smem + s_prob[i] = TCompute(prob[source_token * topK + i]); + } + __syncthreads(); + } + + // Accumulators for the calculation of prob_grad + float accum[topKTile] = {0.0f}; + + // Register buffers for vector type (float4) memory access + float4 frag_load_store; + T *frag_load_store_ptr = reinterpret_cast(&frag_load_store); + + // Number of elemments in frag_load_store + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + // The starting address of each source row + const T *source_row_ptr = input_bwd + source_token * num_cols; + + // Traverse along the hidden dimention + for (int i = tid * kElementsPerAccess; i < num_cols; i += blockDim.x * kElementsPerAccess) { + TCompute frag_src[kElementsPerAccess]; + + frag_load_store = __ldlu(reinterpret_cast(source_row_ptr + i)); + + for (int e = 0; e < kElementsPerAccess; e++) frag_src[e] = TCompute(frag_load_store_ptr[e]); + + int index = source_token; + + // Process each row in the corresponding topK rows + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + + int dest_row = row_id_map[index]; + index += num_rows; + + if (dest_row != -1) { + if (hasProb) { + // Calculate act_grad in unpermute bwd + for (int e = 0; e < kElementsPerAccess; e++) + frag_load_store_ptr[e] = T(frag_src[e] * s_prob[k]); + } else { + // permute fwd + for (int e = 0; e < kElementsPerAccess; e++) frag_load_store_ptr[e] = T(frag_src[e]); + } + + T *dest_row_ptr = act_grad + dest_row * num_cols; + *reinterpret_cast(dest_row_ptr + i) = frag_load_store; + + if (hasProb) { + // Inner product calculation for prob_grad in unpermute bwd + const T *input_fwd_ptr = input_fwd + dest_row * num_cols; + + frag_load_store = __ldlu(reinterpret_cast(input_fwd_ptr + i)); + + TCompute frag_input_fwd[kElementsPerAccess]; + for (int e = 0; e < kElementsPerAccess; e++) + frag_input_fwd[e] = TCompute(frag_load_store_ptr[e]); + + for (int e = 0; e < kElementsPerAccess; e++) { + accum[k] += static_cast(frag_src[e] * frag_input_fwd[e]); + } + } + } + } + } + + if (hasProb) { + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + // Warp-level reduction + for (int mask = 16; mask > 0; mask /= 2) { + accum[k] = accum[k] + __shfl_xor_sync(0xffffffff, accum[k], mask, 32); + } + } + + if (tid == 0) { + for (int k = 0; k < topKTile; k++) { + if (k == topK) break; + prob_grad[source_token * topK + k] = accum[k]; + } + } + } +} + +template +void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, int *row_id_map, + const float *prob, float *prob_grad, const T *input_fwd, + const int num_rows, const int topK, const int num_cols, + const int num_out_tokens, cudaStream_t stream) { + using TCompute = typename std::conditional<(std::is_same::value || + std::is_same::value), + half, T>::type; + + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + if (input_fwd == nullptr) { + // moe_permute_fwd + + int threads = 64; + int blocks = (num_rows * topK + threads - 1) / threads; + + moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, + num_out_tokens); + + blocks = num_rows; + threads = std::min(num_cols / kElementsPerAccess, 1024); + moe_permute_kernel<<>>( + input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + } else { + // moe_unpermute_bwd + + int threads = 32; + int blocks = num_rows; + + if (prob == nullptr) { + // moe_unpermute_bwd without probs + + moe_permute_kernel<<>>( + input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + } else { + // moe_unpermute_bwd with probs + + size_t smem_bytes = topK * sizeof(TCompute); + + if (topK <= 8) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 16) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 32) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 64) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else if (topK <= 128) { + moe_permute_kernel<<>>( + input, input_fwd, output, prob, prob_grad, row_id_map, num_rows, topK, num_cols); + } else { + NVTE_ERROR("topK cannot exceed 128."); + } + } + } +} + +template +void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const float *prob, + const int num_rows, const int topK, const int num_cols, + cudaStream_t stream) { + using TCompute = typename std::conditional<(std::is_same::value || + std::is_same::value), + half, T>::type; + + static constexpr int kElementsPerAccess = 16 / sizeof(T); + + int blocks = num_rows; + int threads = std::min(num_cols / kElementsPerAccess, 1024); + size_t smem_bytes = topK * sizeof(TCompute); + + if (prob == nullptr) { + // moe_permute_bwd + // moe_unpermute_fwd without probs + + moe_unpermute_kernel<<>>( + input, output, row_id_map, nullptr, num_rows, topK, num_cols); + } else { + // moe_unpermute_fwd with probs + + moe_unpermute_kernel<<>>( + input, output, row_id_map, prob, num_rows, topK, num_cols); + } +} + +void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, + NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, + const NVTETensor input_fwd, const int num_rows, const int topK, + const int num_cols, const int num_out_tokens, cudaStream_t stream) { + NVTE_API_CALL(nvte_permute); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *sorted_row_id_cu = + reinterpret_cast(sorted_row_id); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *prob_cu = + reinterpret_cast(prob); + const transformer_engine::Tensor *prob_grad_cu = + reinterpret_cast(prob_grad); + const transformer_engine::Tensor *input_fwd_cu = + reinterpret_cast(input_fwd); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_cu->data.dtype, T, + nvte_permute_launcher(reinterpret_cast(input_cu->data.dptr), + reinterpret_cast(output_cu->data.dptr), + reinterpret_cast(sorted_row_id_cu->data.dptr), + reinterpret_cast(row_id_map_cu->data.dptr), + reinterpret_cast(prob_cu->data.dptr), + reinterpret_cast(prob_grad_cu->data.dptr), + reinterpret_cast(input_fwd_cu->data.dptr), num_rows, topK, + num_cols, num_out_tokens, stream);); +} + +void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, + const NVTETensor prob, const int num_rows, const int topK, const int num_cols, + cudaStream_t stream) { + NVTE_API_CALL(nvte_unpermute); + + const transformer_engine::Tensor *input_cu = + reinterpret_cast(input); + const transformer_engine::Tensor *output_cu = + reinterpret_cast(output); + const transformer_engine::Tensor *row_id_map_cu = + reinterpret_cast(row_id_map); + const transformer_engine::Tensor *prob_cu = + reinterpret_cast(prob); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_cu->data.dtype, T, + nvte_unpermute_launcher(reinterpret_cast(input_cu->data.dptr), + reinterpret_cast(output_cu->data.dptr), + reinterpret_cast(row_id_map_cu->data.dptr), + reinterpret_cast(prob_cu->data.dptr), num_rows, topK, + num_cols, stream);); +} diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 20b6f79da6..1c755491b0 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -44,6 +44,7 @@ def _load_library(): from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.transformer import TransformerLayer +from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.graph import make_graphed_callables diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index aac693a430..7fb9953f94 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 05e4e97112..1a6f5f157e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -10,6 +10,26 @@ #include "common.h" #include "common/common.h" +/*************************************************************************************************** + * Permutation + **************************************************************************************************/ + +std::tuple> moe_permute_fwd( + at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, + int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num); + +at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK); + +at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK); + +std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, + const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob); + /*************************************************************************************************** * Attention **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu new file mode 100644 index 0000000000..0c9bed45e0 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -0,0 +1,170 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "extensions.h" + +std::tuple> moe_permute_fwd( + at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, + int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num) { + const int num_tokens = input.size(0); + int num_cols = input.size(1); + const int topK = indices.size(1); + + // Initialize the workspace on the first run + if (workspace.empty()) { + auto options = + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false); + + at::Tensor sorted_indices = torch::empty(max_expanded_token_num, options); + at::Tensor row_id = torch::range(0, max_expanded_token_num - 1, 1, options); + at::Tensor sorted_row_id = + torch::empty(max_expanded_token_num, + torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + + size_t temp_storage_bytes = 0; + int *temp_ptr = nullptr; + cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_ptr, temp_ptr, temp_ptr, + temp_ptr, max_expanded_token_num); + at::Tensor temp_storage = torch::empty( + temp_storage_bytes, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + + workspace.push_back(sorted_indices); + workspace.push_back(row_id); + workspace.push_back(sorted_row_id); + workspace.push_back(temp_storage); + } + + int *indices_ptr = reinterpret_cast(getDataPtr(indices, 0)); + int *sorted_indices_ptr = reinterpret_cast(getDataPtr(workspace[0], 0)); + int *row_id_ptr = reinterpret_cast(getDataPtr(workspace[1], 0)); + int *sorted_row_id_ptr = reinterpret_cast(getDataPtr(workspace[2], 0)); + + void *d_temp_storage = getDataPtr(workspace[3], 0); + size_t temp_storage_bytes = std::numeric_limits::max(); + + cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, indices_ptr, + sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, + num_tokens * topK); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input.scalar_type(); + + // Output buffer alloc + num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; + at::Tensor permuted_output = torch::empty( + {num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor row_id_map = torch::empty( + {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); + auto permuted_output_cu = makeTransformerEngineTensor( + permuted_output.data_ptr(), + {static_cast(permuted_output.size(0)), static_cast(num_cols)}, dtype); + auto sorted_row_id_cu = + makeTransformerEngineTensor(sorted_row_id_ptr, {static_cast(num_tokens * topK)}, + transformer_engine::DType::kInt32); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + + nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), + row_id_map_cu.data(), transformer_engine::TensorWrapper().data(), + transformer_engine::TensorWrapper().data(), + transformer_engine::TensorWrapper().data(), num_tokens, topK, num_cols, + num_out_tokens, stream); + + return std::make_tuple(permuted_output, row_id_map, workspace); +} + +at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK) { + return moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK); +} + +at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, + int64_t topK) { + int num_cols = input.size(1); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input.scalar_type(); + + // Output buffer alloc + at::Tensor unpermuted_output = torch::empty( + {num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto input_cu = makeTransformerEngineTensor( + input.data_ptr(), {static_cast(input.size(0)), static_cast(num_cols)}, dtype); + auto unpermuted_output_cu = makeTransformerEngineTensor( + unpermuted_output.data_ptr(), + {static_cast(unpermuted_output.size(0)), static_cast(num_cols)}, dtype); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + auto prob_cu = makeTransformerEngineTensor(prob); + + nvte_unpermute(input_cu.data(), unpermuted_output_cu.data(), row_id_map_cu.data(), prob_cu.data(), + num_tokens, topK, num_cols, stream); + + return unpermuted_output; +} + +std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, + const transformer_engine::DType dtype, + at::Tensor row_id_map, at::Tensor prob) { + const int topK = (prob.numel() > 0) ? prob.size(1) : 1; + const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); + int num_cols = input_bwd.size(1); + + // Activations type + at::ScalarType _st; + if (dtype == transformer_engine::DType::kFloat8E4M3 || + dtype == transformer_engine::DType::kFloat8E5M2) + _st = at::ScalarType::Byte; + else + _st = input_bwd.scalar_type(); + + // Output buffer alloc + at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, + torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); + at::Tensor prob_grad = torch::empty( + {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto input_bwd_cu = makeTransformerEngineTensor( + input_bwd.data_ptr(), {static_cast(input_bwd.size(0)), static_cast(num_cols)}, + dtype); + auto act_grad_cu = makeTransformerEngineTensor( + act_grad.data_ptr(), {static_cast(act_grad.size(0)), static_cast(num_cols)}, + dtype); + auto input_fwd_cu = makeTransformerEngineTensor( + input_fwd.data_ptr(), {static_cast(input_fwd.size(0)), static_cast(num_cols)}, + dtype); + auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); + auto prob_cu = makeTransformerEngineTensor(prob); + auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); + + nvte_permute(input_bwd_cu.data(), act_grad_cu.data(), transformer_engine::TensorWrapper().data(), + row_id_map_cu.data(), prob_cu.data(), prob_grad_cu.data(), input_fwd_cu.data(), + num_tokens, topK, num_cols, 0, stream); + + return std::make_tuple(act_grad, prob_grad); +} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 11b47ccdec..f903a1c35b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -10,6 +10,12 @@ #include "../extensions.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // Permutation functions + m.def("moe_permute_fwd", moe_permute_fwd); + m.def("moe_permute_bwd", moe_permute_bwd); + m.def("moe_unpermute_fwd", moe_unpermute_fwd); + m.def("moe_unpermute_bwd", moe_unpermute_bwd); + // Softmax functions m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD", py::call_guard()); diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py new file mode 100644 index 0000000000..0c098830a9 --- /dev/null +++ b/transformer_engine/pytorch/permutation.py @@ -0,0 +1,289 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear API""" +import warnings +from typing import Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine.pytorch.float8_tensor import Float8Tensor + + +__all__ = [ + "moe_permute", + "moe_unpermute", +] + + +class _moe_permute(torch.autograd.Function): + """functional Permute""" + + workspace = None + max_expanded_token_num = 0 + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + dtype: tex.DType, + indices: torch.Tensor, + num_out_tokens: int, + max_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Empty input check + if not inp.numel(): + return inp, None + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert indices.is_cuda, "TransformerEngine needs CUDA." + # Shape check + assert inp.size(0) == indices.size(0), "Permute not possible" + + # Data type check + fp8 = False + if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: + fp8 = True + if fp8: + assert isinstance( + inp, Float8Tensor + ), "Input must be in Float8Tensor type for FP8 moe_permute." + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + inp = inp._data + if indices.dtype != torch.int32: + warnings.warn( + f"The data type of the input `indices` of Permute is {indices.dtype}! " + "The recommended type is torch.int32." + ) + indices = indices.to(torch.int32) + + topK = indices.size(1) + + input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK + if _moe_permute.max_expanded_token_num < input_max_expanded_token_num: + _moe_permute.max_expanded_token_num = input_max_expanded_token_num + _moe_permute.workspace = [] + + permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd( + inp, + dtype, + indices, + num_out_tokens, + _moe_permute.workspace, + _moe_permute.max_expanded_token_num, + ) + + if fp8: + permuted_act = Float8Tensor( + data=permuted_act, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) + + ctx.row_id_map = row_id_map + ctx.num_tokens = indices.size(0) + ctx.topK = indices.size(1) + ctx.dtype = dtype + ctx.fp8 = fp8 + return permuted_act, row_id_map + + @staticmethod + def backward( + ctx, + permuted_act_grad: torch.Tensor, + _, + ) -> Tuple[torch.Tensor, ...]: + # Empty input check + if not permuted_act_grad.numel(): + return permuted_act_grad, None, None, None + + if not permuted_act_grad.is_contiguous(): + permuted_act_grad = permuted_act_grad.contiguous() + + fp8 = ctx.fp8 + if fp8: + assert isinstance( + permuted_act_grad, Float8Tensor + ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." + fp8_dtype = permuted_act_grad._fp8_dtype + fp8_scale_inv = permuted_act_grad._scale_inv + permuted_act_grad = permuted_act_grad._data + + row_id_map = ctx.row_id_map + num_tokens = ctx.num_tokens + topK = ctx.topK + + act_grad = None + if ctx.needs_input_grad[0]: + act_grad = tex.moe_permute_bwd( + permuted_act_grad, ctx.dtype, row_id_map, torch.empty(0), num_tokens, topK + ) + if fp8: + act_grad = Float8Tensor( + data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * topK + ) + + return act_grad, None, None, None, None + + +class _moe_unpermute(torch.autograd.Function): + """functional Unpermute""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + dtype: tex.DType, + row_id_map: torch.Tensor, + probs: torch.Tensor, + ) -> torch.Tensor: + # Empty input check + if not inp.numel(): + ctx.probs = probs + return inp + + # None probs check + if probs is not None: + assert probs.is_cuda, "TransformerEngine needs CUDA." + + if probs.dtype != torch.float32: + warnings.warn( + f"The data type of the input `probs` of Unpermute is {probs.dtype}! " + "The recommended type is torch.float32." + ) + probs = probs.to(torch.float32) + + num_tokens = probs.size(0) + topK = probs.size(1) + else: + num_tokens = row_id_map.size(0) + topK = 1 + probs = torch.empty(0) + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + + # Data type check + fp8 = False + if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: + fp8 = True + if fp8: + assert isinstance( + inp, Float8Tensor + ), "Input must be in Float8Tensor type for FP8 moe_unpermute." + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + inp = inp._data + if row_id_map.dtype != torch.int32: + warnings.warn( + f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " + "The recommended type is torch.int32." + ) + row_id_map = row_id_map.to(torch.int32) + + unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) + + if fp8: + unpermuted_output = Float8Tensor( + data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) + + ctx.dtype = dtype + ctx.save_for_backward(inp, row_id_map, probs) + ctx.fp8 = fp8 + return unpermuted_output + + @staticmethod + def backward( + ctx, + unpermuted_act_grad: torch.Tensor, + ) -> Tuple[torch.Tensor, None, torch.Tensor]: + # Empty input check + if not unpermuted_act_grad.numel(): + return unpermuted_act_grad, None, ctx.probs + + if not unpermuted_act_grad.is_contiguous(): + unpermuted_act_grad = unpermuted_act_grad.contiguous() + + fp8 = ctx.fp8 + if fp8: + assert isinstance( + unpermuted_act_grad, Float8Tensor + ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." + fp8_dtype = unpermuted_act_grad._fp8_dtype + fp8_scale_inv = unpermuted_act_grad._scale_inv + unpermuted_act_grad = unpermuted_act_grad._data + + inp, row_id_map, probs = ctx.saved_tensors + + act_grad = None + if ctx.needs_input_grad[0]: + act_grad, prob_grad = tex.moe_unpermute_bwd( + unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs + ) + if fp8: + act_grad = Float8Tensor( + data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv + ) + if not ctx.needs_input_grad[3]: + prob_grad = None + + return act_grad, None, None, prob_grad + + +def moe_permute( + inp: torch.Tensor, + dtype: tex.DType, + indices: torch.Tensor, + num_out_tokens: int = -1, + max_token_num: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Permute the tokens based on the indices. Token with the same index will be grouped together. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + dtype: tex.DType + Data type of the input tensor. + indices: torch.Tensor + The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'. + num_out_tokens: int, default = -1 + The effective output token count, representing the number of tokens not dropped. + By default, set to '-1', meaning no tokens are dropped. + max_token_num: int, default = -1 + The maximum number of tokens, used for workspace allocation. + By default, set to '-1', meaning the calculation of the size of workspace is + automatically taken over by the operator. + """ + return _moe_permute.apply(inp, dtype, indices, num_out_tokens, max_token_num) + + +def moe_unpermute( + inp: torch.Tensor, + dtype: tex.DType, + row_id_map: torch.Tensor, + probs: torch.Tensor = None, +) -> torch.Tensor: + """ + Unpermute a tensor with permuted tokens, and optionally merge the tokens with their + corresponding probabilities. + + Parameters + ---------- + inp: torch.Tensor + Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted. + dtype: tex.DType + Data type of the input tensor. + row_id_map: torch.Tensor + The tensor of a mapping table for sorted indices used to unpermute the tokens, + which is the second output tensor of `Permute`. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens. If provided, + the unpermuted tokens will be merged with their respective probabilities. + By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. + """ + return _moe_unpermute.apply(inp, dtype, row_id_map, probs)