Skip to content

Commit

Permalink
Add gemv specialization for h100 fp16 grouped gemm (apache#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 20, 2024
1 parent 04c2787 commit ddf836e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
2 changes: 1 addition & 1 deletion cmake/modules/contrib/CUTLASS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ if(USE_CUDA AND USE_CUTLASS)
set(TVM_CUTLASS_RUNTIME_SRCS "")

# TODO: Should get rid of the postfix 'a' and test sm >= 90
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90|90a")
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu)
endif()

Expand Down
23 changes: 22 additions & 1 deletion src/runtime/contrib/cutlass/fp16_group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ struct KernelTraits<cutlass::half_t> {
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
};

namespace fastertransformer {

template <typename T, typename WeightType>
void moe_gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases,
T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n,
int64_t gemm_k, int num_experts, std::optional<std::string> activation,
cudaStream_t stream);
}

namespace tvm {
namespace runtime {

Expand All @@ -44,6 +53,8 @@ void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDAr
// Recommened size is 4MB.
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());

CHECK_EQ(x->ndim, 2);
CHECK_EQ(weight->ndim, 3);
CHECK_EQ(indptr->ndim, 1);
Expand All @@ -52,9 +63,19 @@ void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDAr
int num_groups = weight->shape[0];
int n = weight->shape[1];
int k = weight->shape[2];

if (x->shape[0] <= 4 && k % 8 == 0)
{
fastertransformer::moe_gemm_bias_act<half, half>(
reinterpret_cast<half*>(x->data), reinterpret_cast<half*>(weight->data), nullptr, nullptr,
reinterpret_cast<half*>(out->data),
reinterpret_cast<int64_t*>(indptr->data), x->shape[0], n, k, num_groups,
std::nullopt, stream);
return;
}

float alpha = 1.0f;
float beta = 0.0f;
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, alpha, beta,
Expand Down

0 comments on commit ddf836e

Please sign in to comment.