Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2989e5c
Add submodule cutlass v3.5.0
StudyingShao Jun 16, 2024
0dc63b2
Add permutation functions
StudyingShao Jun 17, 2024
4b4930a
Building pass
StudyingShao Jun 17, 2024
5f881d9
Add permutation ops
StudyingShao Jun 17, 2024
1ce7fc1
Everything works fine
StudyingShao Jun 17, 2024
fb483f8
Replace get_ptr with getDataPtr
StudyingShao Jun 24, 2024
c8f7991
Rename some functions and kernels
StudyingShao Jun 24, 2024
6b6eb39
Refactor to fit the TE style. Part I
StudyingShao Jun 24, 2024
7135612
Remove the dependency on cutlass
StudyingShao Jun 25, 2024
1f92e83
Refactor to fit the TE style. Part II
StudyingShao Jun 25, 2024
64986fb
Move permutation.py out of module dir
StudyingShao Jun 25, 2024
115931c
pre-commit reformat
StudyingShao Jun 25, 2024
816c8f6
Rewrite the unit test
StudyingShao Jul 9, 2024
67c4764
Enable skipping if FP8 is unavailable
StudyingShao Jul 9, 2024
76c5635
Rename exposed C++ api and reorder its parameters
StudyingShao Jul 9, 2024
25276a8
Minor changes
StudyingShao Jul 9, 2024
57ce3d0
Minor changes
StudyingShao Jul 10, 2024
385741f
Remove the dependency on pytorch fp8 data type
StudyingShao Jul 24, 2024
b9ad0ae
Move dtype dispatch from pytorch dir to common dir
StudyingShao Jul 24, 2024
97a583a
Minor changes
StudyingShao Jul 24, 2024
31cfa79
Clear up the code path
StudyingShao Jul 25, 2024
b37533e
Minor Changes
StudyingShao Jul 25, 2024
c8a9977
Add some comments
StudyingShao Jul 25, 2024
83ddfad
Add some comments
StudyingShao Jul 25, 2024
f049343
Revise funcion description
StudyingShao Jul 25, 2024
2b331d4
Take NVTETensor as inputs
StudyingShao Aug 1, 2024
6d23c98
Split unit tests
StudyingShao Aug 1, 2024
0e01f8e
Change names
StudyingShao Aug 19, 2024
906dbb8
Use Float8Tensor for FP8 input
StudyingShao Aug 19, 2024
6740781
Reformat
StudyingShao Aug 19, 2024
f5ad7fc
Rescale fp8 for permute backward
StudyingShao Aug 20, 2024
5307a01
Move dtype to ctx
StudyingShao Aug 20, 2024
47be00f
Merge branch 'main' into jiangs/permutation
phu0ngng Aug 21, 2024
318caf8
Format fix
StudyingShao Aug 22, 2024
5f04315
Merge branch 'main' into jiangs/permutation
StudyingShao Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
515 changes: 515 additions & 0 deletions tests/pytorch/test_permutation.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ list(APPEND transformer_engine_SOURCES
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
layer_norm/ln_fwd_cuda_kernel.cu
permutation/permutation.cu
rmsnorm/rmsnorm_api.cpp
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_fwd_cuda_kernel.cu
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"Unable to find suitable cuBLAS GEMM algorithm");
NVTE_CHECK_CUBLAS(status);

if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms");

// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
Expand Down
21 changes: 21 additions & 0 deletions transformer_engine/common/include/transformer_engine/permutation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_PERMUTATION_H_
#define TRANSFORMER_ENGINE_PERMUTATION_H_

#include "transformer_engine.h"

void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id,
NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad,
const NVTETensor input_fwd, const int num_rows, const int topK,
const int num_cols, const int num_out_tokens, cudaStream_t stream = nullptr);

void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map,
const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
cudaStream_t stream = nullptr);

#endif // TRANSFORMER_ENGINE_PERMUTATION_H_
Loading