diff --git a/src/ops/rms_norm/bang/rms_norm_cnnl.cc b/src/ops/rms_norm/bang/rms_norm_cnnl.cc deleted file mode 100644 index 01e9aacd..00000000 --- a/src/ops/rms_norm/bang/rms_norm_cnnl.cc +++ /dev/null @@ -1,56 +0,0 @@ -#include "rms_norm_cnnl.h" -#include "../../../devices/bang/common_bang.h" -#include "../../../devices/bang/handle_pool.h" -#include "../../utils.h" -#include "cnrt.h" - -RMSNormCnnlDescriptor::RMSNormCnnlDescriptor(Device device) { - this->device = device; - get_cnnl_pool(); -} - -void rms_norm_cnnl_f16(Tensor y, Tensor x, Tensor w, float epsilon, void *stream) { - ASSERT_EQ(y.layout->ndim, 2); - ASSERT_EQ(x.layout->ndim, 2); - ASSERT_EQ(w.layout->ndim, 1); - - auto n = y.layout->shape[0], - d = y.layout->shape[1]; - - ASSERT_EQ(x.layout->shape[0], n); - ASSERT_EQ(x.layout->shape[1], d); - ASSERT_EQ(w.layout->shape[0], d); - - cnnlTensorDescriptor_t yDesc, xDesc, wDesc; - cnnlCreateTensorDescriptor(&yDesc); - cnnlCreateTensorDescriptor(&xDesc); - cnnlCreateTensorDescriptor(&wDesc); - setCnnlTensor(yDesc, y.layout); - setCnnlTensor(xDesc, x.layout); - setCnnlTensor(wDesc, w.layout); - - cnnlFuseNormDescriptor_t opDesc; - cnnlCreateFuseNormDescriptor(&opDesc); - cnnlSetFuseNormDescriptor(opDesc, epsilon, 1.0, true, - false, false, false, false, - CNNL_DTYPE_HALF, CNNL_TRANSFORMER_RMSNORM); - - void *workspace; - - use_cnnl((cnrtQueue_t) stream, - [&](cnnlHandle_t handle) { - size_t wsSize; - cnnlGetFuseNormWorkspaceSize(handle, opDesc, xDesc, &wsSize); - cnrtMalloc(&workspace, wsSize); - cnnlFuseNorm(handle, opDesc, xDesc, x.data, - wDesc, w.data, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, - workspace, wsSize, yDesc, y.data, nullptr, nullptr); - }); - - cnrtFree(workspace); - cnnlDestroyFuseNormDescriptor(opDesc); - cnnlDestroyTensorDescriptor(xDesc); - cnnlDestroyTensorDescriptor(yDesc); - cnnlDestroyTensorDescriptor(wDesc); -} diff --git a/src/ops/rms_norm/bang/rms_norm_cnnl.h b/src/ops/rms_norm/bang/rms_norm_cnnl.h deleted file mode 100644 index c76bf2d0..00000000 --- a/src/ops/rms_norm/bang/rms_norm_cnnl.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef __CNNL_RMS_NORM_H__ -#define __CNNL_RMS_NORM_H__ - -#include "cnnl.h" -#include "cnnl_extra.h" -#include "operators.h" - -struct RMSNormCnnlDescriptor { - Device device; - RMSNormCnnlDescriptor(Device device); -}; - -void rms_norm_cnnl_f16(Tensor y, Tensor x, Tensor w, float epsilon, void *stream); - -#endif// __CNNL_RMS_NORM_H__ diff --git a/src/ops/rms_norm/operator.cc b/src/ops/rms_norm/operator.cc index e466d436..9aa4b206 100644 --- a/src/ops/rms_norm/operator.cc +++ b/src/ops/rms_norm/operator.cc @@ -13,7 +13,6 @@ #ifdef ENABLE_CAMBRICON_MLU #include "../../devices/bang/bang_handle.h" #include "bang/rms_norm_bang.h" -#include "bang/rms_norm_cnnl.h" #endif #ifdef ENABLE_ASCEND_NPU #include "ascend/rms_norm_aclnn.h"