From 8c1b8fab8ba4a7cdac8722383ef5c2a0577fa3d1 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Thu, 14 Dec 2023 11:28:37 +0800 Subject: [PATCH 01/18] =?UTF-8?q?feat(onnx):=20=E6=B7=BB=E5=8A=A0=20onnx?= =?UTF-8?q?=20MatMulInteger=20=E5=89=8D=E7=AB=AF=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/04kernel/src/kernels/conv/cudnn_kernel.cc | 16 +++- .../src/utilities/cuda/cudnn_functions.cc | 38 +++++++++ .../src/utilities/cuda/cudnn_functions.cu | 35 -------- src/07onnx/src/operators.cpp | 2 + src/07onnx/src/operators/mat_mul_integer.cc | 82 +++++++++++++++++++ src/07onnx/src/operators/mat_mul_integer.hh | 23 ++++++ 6 files changed, 159 insertions(+), 37 deletions(-) create mode 100644 src/04kernel/src/utilities/cuda/cudnn_functions.cc delete mode 100644 src/04kernel/src/utilities/cuda/cudnn_functions.cu create mode 100644 src/07onnx/src/operators/mat_mul_integer.cc create mode 100644 src/07onnx/src/operators/mat_mul_integer.hh diff --git a/src/04kernel/src/kernels/conv/cudnn_kernel.cc b/src/04kernel/src/kernels/conv/cudnn_kernel.cc index 82848483a..88d293f57 100644 --- a/src/04kernel/src/kernels/conv/cudnn_kernel.cc +++ b/src/04kernel/src/kernels/conv/cudnn_kernel.cc @@ -19,10 +19,17 @@ namespace refactor::kernel { Tensor const &w, std::optional> b, Tensor const &y) -> KernelBox { + static const std::unordered_set + SET{DataType::FP16, DataType::BF16, DataType::F32, DataType::F64, DataType::I8}; #ifndef USE_CUDA return nullptr; #endif + auto dt = x.dataType; + if (!SET.contains(dt) || w.dataType != dt || y.dataType != dt) { + return nullptr; + } + std::optional biasExpand = std::nullopt; if (b) { ASSERT(b->get().shape[0] == y.shape[1], ""); @@ -42,7 +49,7 @@ namespace refactor::kernel { p = poolAttributes.pads(), s = poolAttributes.strides(); return std::make_unique(decltype(info){ - x.dataType, + dt, { static_cast(x.shape[0]), static_cast(x.shape[1]), @@ -134,13 +141,18 @@ namespace refactor::kernel { auto pp = info.pad; auto ss = info.stride; auto dd = info.dilation; + // clang-format off + auto computation = info.dt == DataType::F64 ? DataType::F64 + : info.dt == DataType::I8 ? DataType::I32 + : DataType::F32; + // clang-format on CUDNN_ASSERT(cudnnSetConvolution2dDescriptor( d->conv, std::min(pp[0], pp[2]), std::min(pp[1], pp[3]), ss[0], ss[1], dd[0], dd[1], CUDNN_CROSS_CORRELATION, - cudnnDataTypeConvert(d->f32 ? DataType::F32 : DataType::F64))); + cudnnDataTypeConvert(computation))); if (auto group = xs[1] / ws[1]; group > 1) { CUDNN_ASSERT(cudnnSetConvolutionGroupCount(d->conv, group)); diff --git a/src/04kernel/src/utilities/cuda/cudnn_functions.cc b/src/04kernel/src/utilities/cuda/cudnn_functions.cc new file mode 100644 index 000000000..1beeaaade --- /dev/null +++ b/src/04kernel/src/utilities/cuda/cudnn_functions.cc @@ -0,0 +1,38 @@ +#ifdef USE_CUDA + +#include "cudnn_functions.h" + +namespace refactor::kernel::cudnn { + + cudnnDataType_t cudnnDataTypeConvert(DataType dataType) { + // clang-format off + switch (dataType) { + case DataType::F32 : return CUDNN_DATA_FLOAT; break; + case DataType::F64 : return CUDNN_DATA_DOUBLE; break; + case DataType::FP16: return CUDNN_DATA_HALF; break; + case DataType::I8 : return CUDNN_DATA_INT8; break; + case DataType::I32 : return CUDNN_DATA_INT32; break; + case DataType::U8 : return CUDNN_DATA_UINT8; break; + case DataType::BF16: return CUDNN_DATA_BFLOAT16; break; + case DataType::I64 : return CUDNN_DATA_INT64; break; + case DataType::Bool: return CUDNN_DATA_BOOLEAN; break; + default: UNREACHABLE(); + } + // clang-format on + } + + void setCudnnTensor(cudnnTensorDescriptor_t t, DataType dt, slice_t d) { + auto dt_ = cudnnDataTypeConvert(dt); + if (auto n = d.size(); n == 4) { + CUDNN_ASSERT(cudnnSetTensor4dDescriptor(t, CUDNN_TENSOR_NCHW, dt_, d[0], d[1], d[2], d[3])); + } else if (n < 4) { + int d_[]{1, 1, 1, 1}; + std::copy_n(d.begin(), n, d_ + 4 - n); + CUDNN_ASSERT(cudnnSetTensor4dDescriptor(t, CUDNN_TENSOR_NCHW, dt_, d_[0], d_[1], d_[2], d_[3])); + } else { + CUDNN_ASSERT(cudnnSetTensorNdDescriptorEx(t, CUDNN_TENSOR_NCHW, dt_, d.size(), d.begin())); + } + } +}// namespace refactor::kernel::cudnn + +#endif diff --git a/src/04kernel/src/utilities/cuda/cudnn_functions.cu b/src/04kernel/src/utilities/cuda/cudnn_functions.cu deleted file mode 100644 index 62781f950..000000000 --- a/src/04kernel/src/utilities/cuda/cudnn_functions.cu +++ /dev/null @@ -1,35 +0,0 @@ -#include "cudnn_functions.h" - -namespace refactor::kernel::cudnn { - using DT = DataType; - - cudnnDataType_t cudnnDataTypeConvert(DT dataType) { - switch (dataType) { - // clang-format off - case DT::F32 : return CUDNN_DATA_FLOAT; break; - case DT::F64 : return CUDNN_DATA_DOUBLE; break; - case DT::FP16: return CUDNN_DATA_HALF; break; - case DT::I8 : return CUDNN_DATA_INT8; break; - case DT::I32 : return CUDNN_DATA_INT32; break; - case DT::U8 : return CUDNN_DATA_UINT8; break; - case DT::BF16: return CUDNN_DATA_BFLOAT16; break; - case DT::I64 : return CUDNN_DATA_INT64; break; - case DT::Bool: return CUDNN_DATA_BOOLEAN; break; - default: UNREACHABLE(); - // clang-format on - } - } - - void setCudnnTensor(cudnnTensorDescriptor_t t, DT dt, slice_t d) { - auto dt_ = cudnnDataTypeConvert(dt); - if (auto n = d.size(); n == 4) { - CUDNN_ASSERT(cudnnSetTensor4dDescriptor(t, CUDNN_TENSOR_NCHW, dt_, d[0], d[1], d[2], d[3])); - } else if (n < 4) { - int d_[]{1, 1, 1, 1}; - std::copy_n(d.begin(), n, d_ + 4 - n); - CUDNN_ASSERT(cudnnSetTensor4dDescriptor(t, CUDNN_TENSOR_NCHW, dt_, d_[0], d_[1], d_[2], d_[3])); - } else { - CUDNN_ASSERT(cudnnSetTensorNdDescriptorEx(t, CUDNN_TENSOR_NCHW, dt_, d.size(), d.begin())); - } - } -}// namespace refactor::kernel::cudnn diff --git a/src/07onnx/src/operators.cpp b/src/07onnx/src/operators.cpp index 62436887d..18a651b9a 100644 --- a/src/07onnx/src/operators.cpp +++ b/src/07onnx/src/operators.cpp @@ -15,6 +15,7 @@ #include "operators/gemm.hh" #include "operators/global_pool.hh" #include "operators/mat_mul.hh" +#include "operators/mat_mul_integer.hh" #include "operators/pool.hh" #include "operators/range.hh" #include "operators/reduce.hh" @@ -60,6 +61,7 @@ namespace refactor::onnx { REGISTER(GlobalLpPool , GlobalPool ); REGISTER(GlobalMaxPool , GlobalPool ); REGISTER(MatMul , MatMul ); + REGISTER(MatMulInteger , MatMulInteger ); REGISTER(AveragePool , Pool ); REGISTER(LpPool , Pool ); REGISTER(MaxPool , Pool ); diff --git a/src/07onnx/src/operators/mat_mul_integer.cc b/src/07onnx/src/operators/mat_mul_integer.cc new file mode 100644 index 000000000..d2466d519 --- /dev/null +++ b/src/07onnx/src/operators/mat_mul_integer.cc @@ -0,0 +1,82 @@ +#include "mat_mul_integer.hh" +#include "common.h" +#include "computation/operators/mat_mul.h" +#include + +namespace refactor::onnx { + using Op = MatMulInteger; + + Op::MatMulInteger() : Operator() {} + + auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { + ASSERT(attributes.empty(), "MatMulInteger operator should not have attributes"); + return OpBox(std::make_unique()); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "onnx::MatMulInteger"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { + static const std::unordered_set + SET{DataType::I8, DataType::U8}; + + switch (inputs.size()) { + case 2: + break; + case 3: + case 4: + return Err(InferError(ERROR_MSG("Quantization tensor not support currently"))); + default: + return Err(InferError(ERROR_MSG("Input size error"))); + } + + auto const &a = inputs[0]; + auto const &b = inputs[1]; + if (!SET.contains(a.dataType) || !SET.contains(b.dataType)) { + return Err(InferError(ERROR_MSG("Input data type not support"))); + } + auto sa = a.shape, sb = b.shape; + switch (sa.size()) { + case 1: + sa.insert(sa.begin(), DimExpr(1)); + break; + case 0: + return Err(InferError(ERROR_MSG("Input shape not support"))); + default: + break; + } + switch (sb.size()) { + case 1: + sb.emplace_back(1); + break; + case 0: + return Err(InferError(ERROR_MSG("Input shape not support"))); + default: + break; + } + auto k = sa.back(); + sa.pop_back(); + auto m = sa.back(); + sa.pop_back(); + auto n = sb.back(); + sb.pop_back(); + if (k != sb.back()) { + return Err(InferError(ERROR_MSG("Input shape not support"))); + } + sb.pop_back(); + MULTIDIR_BROADCAST((ShapeRefs{sa, sb})) + output.emplace_back(std::move(m)); + output.emplace_back(std::move(n)); + return Ok(Tensors{Tensor::share(DataType::I32, std::move(output), extractDependency(inputs))}); + } + + auto Op::lower(TensorRefs) const -> computation::OpBox { + using Op_ = computation::MatMul; + return std::make_unique(1.0, 1.0, false, false); + } + +}// namespace refactor::onnx diff --git a/src/07onnx/src/operators/mat_mul_integer.hh b/src/07onnx/src/operators/mat_mul_integer.hh new file mode 100644 index 000000000..a581ad1af --- /dev/null +++ b/src/07onnx/src/operators/mat_mul_integer.hh @@ -0,0 +1,23 @@ +#ifndef ONNX_MAT_MUL_INTEGER_HH +#define ONNX_MAT_MUL_INTEGER_HH + +#include "frontend/operator.h" + +namespace refactor::onnx { + using namespace frontend; + + struct MatMulInteger final : public Operator { + MatMulInteger(); + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::onnx + +#endif// ONNX_MAT_MUL_INTEGER_HH From 4bbf12100c9c1fb2872597754b7707f7d4ff8e3f Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Thu, 14 Dec 2023 11:57:09 +0800 Subject: [PATCH 02/18] =?UTF-8?q?style(onnx):=20=E6=95=B4=E7=90=86=20MatMu?= =?UTF-8?q?l=20cpu=20kernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../src/kernels/mat_mul/cpu_kernel.cc | 158 +++++++++--------- .../simple_unary/cudnn_activation_kernel.cc | 31 ++-- .../src/kernels/softmax/cudnn_kernel.cc | 7 +- 3 files changed, 97 insertions(+), 99 deletions(-) diff --git a/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc b/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc index e2c448d43..7bbab466f 100644 --- a/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc +++ b/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc @@ -9,11 +9,9 @@ namespace refactor::kernel { : Kernel(), info(std::move(info_)) {} auto K::build(MatMulInfo info) noexcept -> KernelBox { - if (!info.dataType.isCpuNumberic()) { - return nullptr; - } - - return std::make_unique(std::move(info)); + return info.dataType.isCpuNumberic() + ? std::make_unique(std::move(info)) + : nullptr; } auto K::typeId() noexcept -> size_t { @@ -26,97 +24,103 @@ namespace refactor::kernel { return "Performing MatMul using CPU"; } + template struct MatMulCPUMetaData { size_t M, K, N; size_t strideA0, strideA1, strideB0, strideB1; - }; + T alpha, beta; - /* - * 2D matrix multiplication: Y = a * A @ B + b * Y - * Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias. - */ - template - void matrixMultiply(T const *A, T const *B, T *Y, - T const alpha, T const beta, - const MatMulCPUMetaData md) { - // #pragma omp parallel for - for (size_t i = 0; i < md.M; i++) { - for (size_t j = 0; j < md.N; j++) { - T sum = 0; - // #pragma omp simd reduction(+ : sum) - for (size_t k = 0; k < md.K; k++) { - sum += A[i * md.strideA0 + k * md.strideA1] * B[k * md.strideB0 + j * md.strideB1]; + /* + * 2D matrix multiplication: Y = a * A @ B + b * Y + * Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias. + */ + void matrixMultiply(T const *A, T const *B, T *Y) const noexcept { + // #pragma omp parallel for + for (size_t i = 0; i < M; i++) { + for (size_t j = 0; j < N; j++) { + T sum = 0; + // #pragma omp simd reduction(+ : sum) + for (size_t k = 0; k < K; k++) { + sum += A[i * strideA0 + k * strideA1] * B[k * strideB0 + j * strideB1]; + } + Y[i * N + j] = beta * Y[i * N + j] + alpha * sum; } - Y[i * md.N + j] = beta * Y[i * md.N + j] + alpha * sum; } } - } + }; -#define CASE(T) \ - case DT::T: { \ - using T_ = primitive::type; \ - if (std::holds_alternative(info.broadcasterOrBatch)) { \ - return [alpha = static_cast(info.alpha), \ - beta = static_cast(info.biasExpand ? info.beta : 0.0f), \ - broadcaster = std::get(info.broadcasterOrBatch), \ - md, \ - stepY = info.m * info.n, \ - stepA = info.m * info.k, \ - stepB = info.k * info.n, \ - biasEx = info.biasExpand \ - ? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) \ - : std::nullopt](runtime::Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { \ - if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } \ - auto A = reinterpret_cast(inputs[0]); \ - auto B = reinterpret_cast(inputs[1]); \ - auto Y = reinterpret_cast(outputs[0]); \ - dim_t offset[2]; \ - for (size_t i = 0; i < broadcaster.outputsCount; i++) { \ - broadcaster.locate(i, offset); \ - matrixMultiply(A + stepA * offset[0], B + stepB * offset[1], Y + stepY * i, alpha, beta, md); \ - } \ - }; \ - } else { \ - return [alpha = static_cast(info.alpha), \ - beta = static_cast(info.biasExpand ? info.beta : 0.0f), \ - batch = std::get(info.broadcasterOrBatch), \ - md, \ - stepY = info.m * info.n, \ - stepA = info.m * info.k, \ - stepB = info.k * info.n, \ - biasEx = info.biasExpand \ - ? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) \ - : std::nullopt](runtime::Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { \ - if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } \ - auto A = reinterpret_cast(inputs[0]); \ - auto B = reinterpret_cast(inputs[1]); \ - auto Y = reinterpret_cast(outputs[0]); \ - for (size_t i = 0; i < batch; i++) { \ - matrixMultiply(A + stepA * i, B + stepB * i, Y + stepY * i, alpha, beta, md); \ - } \ - }; \ - } \ + template + static auto lowerTyped(MatMulInfo const &info, Resources &res) noexcept -> RoutineWorkspace { + MatMulCPUMetaData const md{ + .M = info.m, + .K = info.k, + .N = info.n, + .strideA0 = info.transA ? 1 : info.k, + .strideA1 = info.transA ? info.m : 1, + .strideB0 = info.transB ? 1 : info.n, + .strideB1 = info.transB ? info.k : 1, + .alpha = static_cast(info.alpha), + .beta = static_cast(info.biasExpand ? info.beta : 0.0f), + }; + + auto stepY = info.m * info.n, + stepA = info.m * info.k, + stepB = info.k * info.n; + auto biasEx = info.biasExpand + ? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) + : std::nullopt; + + if (std::holds_alternative(info.broadcasterOrBatch)) { + return [broadcaster = std::get(info.broadcasterOrBatch), + stepY, stepA, stepB, + md, biasEx]// + (runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { + if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } + + auto a = reinterpret_cast(inputs[0]); + auto b = reinterpret_cast(inputs[1]); + auto y = reinterpret_cast(outputs[0]); + dim_t offset[2]; + for (auto i : range0_(broadcaster.outputsCount)) { + broadcaster.locate(i, offset); + md.matrixMultiply(a + stepA * offset[0], b + stepB * offset[1], y + stepY * i); + } + }; + } else { + return [batch = std::get(info.broadcasterOrBatch), + stepY, stepA, stepB, + md, biasEx]// + (runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { + if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } + + auto a = reinterpret_cast(inputs[0]); + auto b = reinterpret_cast(inputs[1]); + auto y = reinterpret_cast(outputs[0]); + for (auto i : range0_(batch)) { + md.matrixMultiply(a + stepA * i, b + stepB * i, y + stepY * i); + } + }; + } } auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { - MatMulCPUMetaData md; - md.M = info.m, md.K = info.k, md.N = info.n; - md.strideA0 = info.transA ? 1 : info.k; - md.strideA1 = info.transA ? info.m : 1; - md.strideB0 = info.transB ? 1 : info.n; - md.strideB1 = info.transB ? info.k : 1; +#define CASE(T) \ + case DataType::T: \ + return lowerTyped::type>(info, res); switch (info.dataType) { CASE(F32); + CASE(F64); + CASE(U8); - CASE(I8); CASE(U16); + CASE(U32); + CASE(U64); + + CASE(I8); CASE(I16); CASE(I32); CASE(I64); - CASE(F64); - CASE(U32); - CASE(U64); default: UNREACHABLE(); } diff --git a/src/04kernel/src/kernels/simple_unary/cudnn_activation_kernel.cc b/src/04kernel/src/kernels/simple_unary/cudnn_activation_kernel.cc index 92662a61a..ec3ad5698 100644 --- a/src/04kernel/src/kernels/simple_unary/cudnn_activation_kernel.cc +++ b/src/04kernel/src/kernels/simple_unary/cudnn_activation_kernel.cc @@ -64,28 +64,25 @@ namespace refactor::kernel { auto d = std::make_shared(); // clang-format off - cudnnActivationMode_t - mode = type == Ty::Relu ? CUDNN_ACTIVATION_RELU - : type == Ty::Sigmoid ? CUDNN_ACTIVATION_SIGMOID - : type == Ty::Tanh ? CUDNN_ACTIVATION_TANH - : UNREACHABLEX(cudnnActivationMode_t, ""); + auto mode = type == Ty::Relu ? CUDNN_ACTIVATION_RELU + : type == Ty::Sigmoid ? CUDNN_ACTIVATION_SIGMOID + : type == Ty::Tanh ? CUDNN_ACTIVATION_TANH + : UNREACHABLEX(cudnnActivationMode_t, ""); // clang-format on + setCudnnTensor(d->tensor, dataType, slice(&size, 1)); CUDNN_ASSERT(cudnnSetActivationDescriptor(d->activation, mode, CUDNN_PROPAGATE_NAN, 0.0)); - CUDNN_ASSERT(cudnnSetTensor4dDescriptor(d->tensor, CUDNN_TENSOR_NCHW, cudnnDataTypeConvert(dataType), 1, 1, 1, size)); res.fetchOrStore(); - // nvcc at c++11 doesn't support real move capture - return [d = std::move(d)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { - // fetch cudnn handle from resources - auto handle = res.fetchOrStore()->handle; - // name inputs and outputs - auto x = inputs[0]; - auto y = outputs[0]; - // call cudnn activation - float alpha = 1, beta = 0; - CUDNN_ASSERT(cudnnActivationForward(handle, d->activation, &alpha, d->tensor, x, &beta, d->tensor, y)); - }; + return [d = std::move(d)]// + (Resources & res, void *, void const *const *inputs, void *const *outputs) { + float alpha = 1, beta = 0; + CUDNN_ASSERT(cudnnActivationForward( + res.fetchOrStore()->handle, + d->activation, + &alpha, d->tensor, inputs[0], + &beta, d->tensor, outputs[0])); + }; } #endif diff --git a/src/04kernel/src/kernels/softmax/cudnn_kernel.cc b/src/04kernel/src/kernels/softmax/cudnn_kernel.cc index 0536073a2..cff6d26bb 100644 --- a/src/04kernel/src/kernels/softmax/cudnn_kernel.cc +++ b/src/04kernel/src/kernels/softmax/cudnn_kernel.cc @@ -58,11 +58,8 @@ namespace refactor::kernel { auto d = std::make_shared( static_cast(algo), dataType != DataType::F64); - CUDNN_ASSERT(cudnnSetTensor4dDescriptor( - d->t, - CUDNN_TENSOR_NCHW, - cudnnDataTypeConvert(dataType), - pre, mid, post, 1)); + int dims[]{pre, mid, post, 1}; + setCudnnTensor(d->t, dataType, slice(dims, 4)); res.fetchOrStore(); return [d = std::move(d)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { From 7a37779e0c6e928c385d9f32f6880382705fe829 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 15 Dec 2023 14:51:02 +0800 Subject: [PATCH 03/18] =?UTF-8?q?feat(onnx):=20=E5=89=8D=E7=AB=AF=E6=94=AF?= =?UTF-8?q?=E6=8C=81=20MatMulInteger=20=E5=B8=A6=E6=9C=89=204=20=E4=B8=AA?= =?UTF-8?q?=E8=BE=93=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../include/computation/operators/cast.h | 2 +- .../include/computation/operators/identity.h | 3 +- .../include/computation/operators/where.h | 3 +- src/07onnx/src/operators/expand.cc | 2 - src/07onnx/src/operators/expand.hh | 3 +- src/07onnx/src/operators/mat_mul.cc | 2 - src/07onnx/src/operators/mat_mul.hh | 3 +- src/07onnx/src/operators/mat_mul_integer.cc | 42 ++++++++++++++----- src/07onnx/src/operators/mat_mul_integer.hh | 3 +- src/07onnx/src/operators/range.cc | 2 - src/07onnx/src/operators/range.hh | 3 +- src/07onnx/src/operators/slice.cc | 2 - src/07onnx/src/operators/slice.hh | 3 +- src/07onnx/src/operators/tile.cc | 2 - src/07onnx/src/operators/tile.hh | 2 +- src/07onnx/src/operators/where.cc | 2 - src/07onnx/src/operators/where.hh | 2 +- .../src/operators/all_reduce.cc | 2 - .../src/operators/all_reduce.hh | 2 +- 19 files changed, 49 insertions(+), 36 deletions(-) diff --git a/src/05computation/include/computation/operators/cast.h b/src/05computation/include/computation/operators/cast.h index 3af10655c..259bdc0eb 100644 --- a/src/05computation/include/computation/operators/cast.h +++ b/src/05computation/include/computation/operators/cast.h @@ -7,7 +7,7 @@ namespace refactor::computation { struct Cast final : public Operator { - constexpr explicit Cast() noexcept : Operator() {} + constexpr explicit Cast() noexcept = default; static size_t typeId() noexcept; size_t opTypeId() const noexcept final; diff --git a/src/05computation/include/computation/operators/identity.h b/src/05computation/include/computation/operators/identity.h index cf23bf92a..4f887321f 100644 --- a/src/05computation/include/computation/operators/identity.h +++ b/src/05computation/include/computation/operators/identity.h @@ -6,7 +6,8 @@ namespace refactor::computation { struct Identity final : public Operator { - constexpr Identity() noexcept : Operator() {} + + constexpr Identity() noexcept = default; static size_t typeId() noexcept; size_t opTypeId() const noexcept final; diff --git a/src/05computation/include/computation/operators/where.h b/src/05computation/include/computation/operators/where.h index a50a28cfc..5af1296cc 100644 --- a/src/05computation/include/computation/operators/where.h +++ b/src/05computation/include/computation/operators/where.h @@ -6,7 +6,8 @@ namespace refactor::computation { struct Where final : public Operator { - constexpr Where() noexcept : Operator() {} + + constexpr Where() noexcept = default; static size_t typeId() noexcept; size_t opTypeId() const noexcept final; diff --git a/src/07onnx/src/operators/expand.cc b/src/07onnx/src/operators/expand.cc index 35af90e6f..b5e729810 100644 --- a/src/07onnx/src/operators/expand.cc +++ b/src/07onnx/src/operators/expand.cc @@ -6,8 +6,6 @@ namespace refactor::onnx { using Op = Expand; - Op::Expand() : Operator() {} - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { ASSERT(attributes.empty(), "Expand operator should not have attributes"); return OpBox(std::make_unique()); diff --git a/src/07onnx/src/operators/expand.hh b/src/07onnx/src/operators/expand.hh index f621306ba..a4f3b63e1 100644 --- a/src/07onnx/src/operators/expand.hh +++ b/src/07onnx/src/operators/expand.hh @@ -7,7 +7,8 @@ namespace refactor::onnx { using namespace frontend; struct Expand final : public Operator { - Expand(); + + constexpr Expand() noexcept = default; static OpBox build(ModelContext const &, std::string_view, Attributes); static size_t typeId(); diff --git a/src/07onnx/src/operators/mat_mul.cc b/src/07onnx/src/operators/mat_mul.cc index b9acd44b1..0850ba6f0 100644 --- a/src/07onnx/src/operators/mat_mul.cc +++ b/src/07onnx/src/operators/mat_mul.cc @@ -5,8 +5,6 @@ namespace refactor::onnx { using Op = MatMul; - Op::MatMul() : Operator() {} - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { ASSERT(attributes.empty(), "MatMul operator should not have attributes"); return OpBox(std::make_unique()); diff --git a/src/07onnx/src/operators/mat_mul.hh b/src/07onnx/src/operators/mat_mul.hh index d0ce5fcdd..e845de275 100644 --- a/src/07onnx/src/operators/mat_mul.hh +++ b/src/07onnx/src/operators/mat_mul.hh @@ -7,7 +7,8 @@ namespace refactor::onnx { using namespace frontend; struct MatMul final : public Operator { - MatMul(); + + constexpr MatMul() noexcept = default; static OpBox build(ModelContext const &, std::string_view, Attributes); static size_t typeId(); diff --git a/src/07onnx/src/operators/mat_mul_integer.cc b/src/07onnx/src/operators/mat_mul_integer.cc index d2466d519..c0ad0beff 100644 --- a/src/07onnx/src/operators/mat_mul_integer.cc +++ b/src/07onnx/src/operators/mat_mul_integer.cc @@ -6,8 +6,6 @@ namespace refactor::onnx { using Op = MatMulInteger; - Op::MatMulInteger() : Operator() {} - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { ASSERT(attributes.empty(), "MatMulInteger operator should not have attributes"); return OpBox(std::make_unique()); @@ -20,18 +18,40 @@ namespace refactor::onnx { auto Op::opTypeId() const -> size_t { return typeId(); } auto Op::opTypeName() const -> std::string_view { return "onnx::MatMulInteger"; } + static bool checkZeroPoint(TensorRefs inputs, size_t scalarN) { + if (inputs.size() <= scalarN + 2) { + return true; + } + auto const &t = inputs[scalarN]; + auto const &zp = inputs[scalarN + 2]; + if (zp.dataType != t.dataType) { + return false; + } + if (zp.rank() == 0) { + return true; + } + switch (t.rank()) { + case 1: + return zp.shape == decltype(zp.shape){DimExpr(1)}; + case 2: + return zp.shape == decltype(zp.shape){t.shape[scalarN]}; + default: { + auto expect = t.shape; + expect[expect.size() - 1 - scalarN] = DimExpr(1); + return zp.shape == expect; + } + } + } + auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { static const std::unordered_set SET{DataType::I8, DataType::U8}; - switch (inputs.size()) { - case 2: - break; - case 3: - case 4: - return Err(InferError(ERROR_MSG("Quantization tensor not support currently"))); - default: - return Err(InferError(ERROR_MSG("Input size error"))); + if (inputs.size() < 2 || 4 < inputs.size()) { + return Err(InferError(ERROR_MSG("Input size not support"))); + } + if (!checkZeroPoint(inputs, 0) || !checkZeroPoint(inputs, 1)) { + return Err(InferError(ERROR_MSG("Input zero point not support"))); } auto const &a = inputs[0]; @@ -42,7 +62,7 @@ namespace refactor::onnx { auto sa = a.shape, sb = b.shape; switch (sa.size()) { case 1: - sa.insert(sa.begin(), DimExpr(1)); + sa.emplace(sa.begin(), 1); break; case 0: return Err(InferError(ERROR_MSG("Input shape not support"))); diff --git a/src/07onnx/src/operators/mat_mul_integer.hh b/src/07onnx/src/operators/mat_mul_integer.hh index a581ad1af..1d6d164a7 100644 --- a/src/07onnx/src/operators/mat_mul_integer.hh +++ b/src/07onnx/src/operators/mat_mul_integer.hh @@ -7,7 +7,8 @@ namespace refactor::onnx { using namespace frontend; struct MatMulInteger final : public Operator { - MatMulInteger(); + + constexpr MatMulInteger() noexcept = default; static OpBox build(ModelContext const &, std::string_view, Attributes); static size_t typeId(); diff --git a/src/07onnx/src/operators/range.cc b/src/07onnx/src/operators/range.cc index 8f4ce357a..d5191734b 100644 --- a/src/07onnx/src/operators/range.cc +++ b/src/07onnx/src/operators/range.cc @@ -4,8 +4,6 @@ namespace refactor::onnx { using Op = Range; - Op::Range() : Operator() {} - auto Op::build(ModelContext const &, std::string_view, Attributes) -> OpBox { return OpBox(std::make_unique()); } diff --git a/src/07onnx/src/operators/range.hh b/src/07onnx/src/operators/range.hh index baf29aa20..dcf17462a 100644 --- a/src/07onnx/src/operators/range.hh +++ b/src/07onnx/src/operators/range.hh @@ -7,7 +7,8 @@ namespace refactor::onnx { using namespace frontend; struct Range final : public Operator { - Range(); + + constexpr Range() noexcept = default; static OpBox build(ModelContext const &, std::string_view, Attributes); static size_t typeId(); diff --git a/src/07onnx/src/operators/slice.cc b/src/07onnx/src/operators/slice.cc index 49540e245..0a0853a85 100644 --- a/src/07onnx/src/operators/slice.cc +++ b/src/07onnx/src/operators/slice.cc @@ -7,8 +7,6 @@ namespace refactor::onnx { using computation::Dimensions; using Op = Slice; - Op::Slice() : Operator() {} - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { ASSERT(attributes.empty(), "Slice operator should not have attributes"); return OpBox(std::make_unique()); diff --git a/src/07onnx/src/operators/slice.hh b/src/07onnx/src/operators/slice.hh index 8afa0067b..3a082a1d5 100644 --- a/src/07onnx/src/operators/slice.hh +++ b/src/07onnx/src/operators/slice.hh @@ -7,7 +7,8 @@ namespace refactor::onnx { using namespace frontend; struct Slice final : public Operator { - Slice(); + + constexpr Slice() noexcept = default; static OpBox build(ModelContext const &, std::string_view, Attributes); static size_t typeId(); diff --git a/src/07onnx/src/operators/tile.cc b/src/07onnx/src/operators/tile.cc index a1fa4b293..14a13fb2f 100644 --- a/src/07onnx/src/operators/tile.cc +++ b/src/07onnx/src/operators/tile.cc @@ -5,8 +5,6 @@ namespace refactor::onnx { using Op = Tile; - Op::Tile() : Operator() {} - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { ASSERT(attributes.empty(), "Tile operator should not have attributes"); return OpBox(std::make_unique()); diff --git a/src/07onnx/src/operators/tile.hh b/src/07onnx/src/operators/tile.hh index 1d4e60e32..9fed86358 100644 --- a/src/07onnx/src/operators/tile.hh +++ b/src/07onnx/src/operators/tile.hh @@ -8,7 +8,7 @@ namespace refactor::onnx { struct Tile final : public Operator { - Tile(); + constexpr Tile() noexcept = default; static OpBox build(ModelContext const &, std::string_view, Attributes); static size_t typeId(); diff --git a/src/07onnx/src/operators/where.cc b/src/07onnx/src/operators/where.cc index 757ba0c78..eea30e691 100644 --- a/src/07onnx/src/operators/where.cc +++ b/src/07onnx/src/operators/where.cc @@ -6,8 +6,6 @@ namespace refactor::onnx { using Op = Where; - Op::Where() : Operator() {} - auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { ASSERT(attributes.empty(), "Where operator should not have attributes"); return OpBox(std::make_unique()); diff --git a/src/07onnx/src/operators/where.hh b/src/07onnx/src/operators/where.hh index defd0af0b..9c0d9ad68 100644 --- a/src/07onnx/src/operators/where.hh +++ b/src/07onnx/src/operators/where.hh @@ -8,7 +8,7 @@ namespace refactor::onnx { struct Where final : public Operator { - Where(); + constexpr Where() noexcept = default; static OpBox build(ModelContext const &, std::string_view, Attributes); static size_t typeId(); diff --git a/src/08communication/src/operators/all_reduce.cc b/src/08communication/src/operators/all_reduce.cc index c508b7b17..989ed6ad9 100644 --- a/src/08communication/src/operators/all_reduce.cc +++ b/src/08communication/src/operators/all_reduce.cc @@ -4,8 +4,6 @@ namespace refactor::communication { using Op = AllReduce; - Op::AllReduce() : Operator() {} - auto Op::build(ModelContext const &, std::string_view, Attributes) -> OpBox { return OpBox(std::make_unique()); } diff --git a/src/08communication/src/operators/all_reduce.hh b/src/08communication/src/operators/all_reduce.hh index a31d6f2f6..a14bc74f4 100644 --- a/src/08communication/src/operators/all_reduce.hh +++ b/src/08communication/src/operators/all_reduce.hh @@ -8,7 +8,7 @@ namespace refactor::communication { struct AllReduce final : public Operator { - AllReduce(); + constexpr AllReduce() noexcept = default; static OpBox build(ModelContext const &, std::string_view, Attributes); static size_t typeId(); From 6f1199797dc1bc1d4ca4028c23fdc894c6c3e348 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 15 Dec 2023 15:13:56 +0800 Subject: [PATCH 04/18] =?UTF-8?q?feat(computation):=20MatMulInteger=20?= =?UTF-8?q?=E4=BB=8E=20MatMul=20=E5=88=86=E7=A6=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../kernel/collectors/mat_mul_integer.h | 19 +++++++++++++++++ src/04kernel/src/collectors/mat_mul.cc | 1 - .../src/collectors/mat_mul_integer.cc | 18 ++++++++++++++++ .../src/kernels/mat_mul/cublas_kernel.cc | 3 +-- .../computation/operators/mat_mul_integer.h | 21 +++++++++++++++++++ .../src/operators/mat_mul_integer.cc | 20 ++++++++++++++++++ src/07onnx/src/operators/mat_mul_integer.cc | 6 +++--- 7 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 src/04kernel/include/kernel/collectors/mat_mul_integer.h create mode 100644 src/04kernel/src/collectors/mat_mul_integer.cc create mode 100644 src/05computation/include/computation/operators/mat_mul_integer.h create mode 100644 src/05computation/src/operators/mat_mul_integer.cc diff --git a/src/04kernel/include/kernel/collectors/mat_mul_integer.h b/src/04kernel/include/kernel/collectors/mat_mul_integer.h new file mode 100644 index 000000000..3534b0125 --- /dev/null +++ b/src/04kernel/include/kernel/collectors/mat_mul_integer.h @@ -0,0 +1,19 @@ +#ifndef KERNEL_MAT_MUL_INTEGER_H +#define KERNEL_MAT_MUL_INTEGER_H + +#include "../collector.h" + +namespace refactor::kernel { + + struct MatMulIntegerCollector final : public InfoCollector { + + constexpr MatMulIntegerCollector(decltype(_target) target) noexcept + : InfoCollector(target) {} + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_MAT_MUL_INTEGER_H diff --git a/src/04kernel/src/collectors/mat_mul.cc b/src/04kernel/src/collectors/mat_mul.cc index 5a72240f8..08d10efbf 100644 --- a/src/04kernel/src/collectors/mat_mul.cc +++ b/src/04kernel/src/collectors/mat_mul.cc @@ -1,7 +1,6 @@ #include "kernel/collectors/mat_mul.h" #include "../kernels/mat_mul/cpu_kernel.hh" #include "../kernels/mat_mul/cublas_kernel.hh" -#include "common.h" #include "kernel/attributes/matmul_info.h" namespace refactor::kernel { diff --git a/src/04kernel/src/collectors/mat_mul_integer.cc b/src/04kernel/src/collectors/mat_mul_integer.cc new file mode 100644 index 000000000..2abaded6f --- /dev/null +++ b/src/04kernel/src/collectors/mat_mul_integer.cc @@ -0,0 +1,18 @@ +#include "kernel/collectors/mat_mul_integer.h" + +namespace refactor::kernel { + std::vector + MatMulIntegerCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + break; + case decltype(_target)::Nvidia: + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/mat_mul/cublas_kernel.cc b/src/04kernel/src/kernels/mat_mul/cublas_kernel.cc index 7a69a192c..a21f1dfe0 100644 --- a/src/04kernel/src/kernels/mat_mul/cublas_kernel.cc +++ b/src/04kernel/src/kernels/mat_mul/cublas_kernel.cc @@ -8,12 +8,11 @@ namespace refactor::kernel { : Kernel(), info(std::move(info_)) {} auto K::build(MatMulInfo info) noexcept -> KernelBox { - static const std::unordered_set TYPE{DT::F32, DT::F64, DT::FP16}; #ifndef USE_CUDA return nullptr; #endif - return TYPE.contains(info.dataType) + return info.dataType.isIeee754() || info.dataType == DT::I8 ? std::make_unique(std::move(info)) : nullptr; } diff --git a/src/05computation/include/computation/operators/mat_mul_integer.h b/src/05computation/include/computation/operators/mat_mul_integer.h new file mode 100644 index 000000000..262469974 --- /dev/null +++ b/src/05computation/include/computation/operators/mat_mul_integer.h @@ -0,0 +1,21 @@ +#ifndef COMPUTATION_MAT_MUL_INTEGER_H +#define COMPUTATION_MAT_MUL_INTEGER_H + +#include "../operator.h" + +namespace refactor::computation { + + struct MatMulInteger final : public LayoutDependentOperator { + + constexpr MatMulInteger() noexcept = default; + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const noexcept final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation + +#endif// #ifndef COMPUTATION_MAT_MUL_INTEGER_H diff --git a/src/05computation/src/operators/mat_mul_integer.cc b/src/05computation/src/operators/mat_mul_integer.cc new file mode 100644 index 000000000..4d03d21b3 --- /dev/null +++ b/src/05computation/src/operators/mat_mul_integer.cc @@ -0,0 +1,20 @@ +#include "computation/operators/mat_mul_integer.h" +#include "kernel/collectors/mat_mul_integer.h" + +namespace refactor::computation { + using Op = MatMulInteger; + + auto Op::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Op::opTypeId() const noexcept -> size_t { return typeId(); } + auto Op::name() const noexcept -> std::string_view { return "MatMulInteger"; } + auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { + return std::make_unique(target); + } + auto Op::serialize() const noexcept -> std::string { + return "MatMulInteger()"; + } + +}// namespace refactor::computation diff --git a/src/07onnx/src/operators/mat_mul_integer.cc b/src/07onnx/src/operators/mat_mul_integer.cc index c0ad0beff..de5de0525 100644 --- a/src/07onnx/src/operators/mat_mul_integer.cc +++ b/src/07onnx/src/operators/mat_mul_integer.cc @@ -1,6 +1,6 @@ #include "mat_mul_integer.hh" #include "common.h" -#include "computation/operators/mat_mul.h" +#include "computation/operators/mat_mul_integer.h" #include namespace refactor::onnx { @@ -95,8 +95,8 @@ namespace refactor::onnx { } auto Op::lower(TensorRefs) const -> computation::OpBox { - using Op_ = computation::MatMul; - return std::make_unique(1.0, 1.0, false, false); + using Op_ = computation::MatMulInteger; + return std::make_unique(); } }// namespace refactor::onnx From 33ff398c4c7a7b7603e98080718ce4dc824b8fa2 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 15 Dec 2023 17:22:01 +0800 Subject: [PATCH 05/18] =?UTF-8?q?refactor(kernel):=20=E4=B8=BA=20Broadcast?= =?UTF-8?q?er=20=E8=A1=A8=E7=A4=BA=E4=B8=8D=E9=9C=80=E8=A6=81=E5=B9=BF?= =?UTF-8?q?=E6=92=AD=E6=98=8E=E7=A1=AE=E8=AF=AD=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feat(kernel): 实现 MatMulInteger 的信息抽取 Signed-off-by: YdrMaster --- .../include/kernel/attributes/broadcaster.h | 1 + .../{matmul_info.h => mat_mul_info.h} | 13 ++--- .../kernel/attributes/mat_mul_integer_info.h | 25 ++++++++ src/04kernel/src/attributes/broadcaster.cc | 4 ++ .../{matmul_info.cc => mat_mul_info.cc} | 20 ++----- .../src/attributes/mat_mul_integer_info.cc | 26 +++++++++ src/04kernel/src/collectors/mat_mul.cc | 2 +- .../src/collectors/mat_mul_integer.cc | 8 +++ .../src/kernels/mat_mul/cpu_kernel.cc | 34 ++--------- .../src/kernels/mat_mul/cpu_kernel.hh | 7 +-- .../src/kernels/mat_mul/cublas_kernel.cc | 2 +- .../src/kernels/mat_mul/cublas_kernel.cu | 57 ++++++++++--------- .../src/kernels/mat_mul/cublas_kernel.hh | 6 +- .../kernels/mat_mul_common/cpu_template.hpp | 33 +++++++++++ .../src/kernels/mat_mul_integer/cpu_kernel.cc | 29 ++++++++++ .../src/kernels/mat_mul_integer/cpu_kernel.hh | 25 ++++++++ .../src/kernels/simple_binary/cuda_kernel.cc | 2 +- 17 files changed, 206 insertions(+), 88 deletions(-) rename src/04kernel/include/kernel/attributes/{matmul_info.h => mat_mul_info.h} (62%) create mode 100644 src/04kernel/include/kernel/attributes/mat_mul_integer_info.h rename src/04kernel/src/attributes/{matmul_info.cc => mat_mul_info.cc} (66%) create mode 100644 src/04kernel/src/attributes/mat_mul_integer_info.cc create mode 100644 src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp create mode 100644 src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc create mode 100644 src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh diff --git a/src/04kernel/include/kernel/attributes/broadcaster.h b/src/04kernel/include/kernel/attributes/broadcaster.h index 006fa9adb..b95c45680 100644 --- a/src/04kernel/include/kernel/attributes/broadcaster.h +++ b/src/04kernel/include/kernel/attributes/broadcaster.h @@ -15,6 +15,7 @@ namespace refactor::kernel { explicit Broadcaster(std::vector>); explicit Broadcaster(TensorRefs const &inputs); void locate(dim_t k, dim_t ans[]) const noexcept; + bool needBroadcast() const noexcept; }; }// namespace refactor::kernel diff --git a/src/04kernel/include/kernel/attributes/matmul_info.h b/src/04kernel/include/kernel/attributes/mat_mul_info.h similarity index 62% rename from src/04kernel/include/kernel/attributes/matmul_info.h rename to src/04kernel/include/kernel/attributes/mat_mul_info.h index 831c1f9a3..95440ba2f 100644 --- a/src/04kernel/include/kernel/attributes/matmul_info.h +++ b/src/04kernel/include/kernel/attributes/mat_mul_info.h @@ -1,9 +1,8 @@ -#ifndef KERNEL_MATMUL_INFO_H -#define KERNEL_MATMUL_INFO_H +#ifndef KERNEL_MAT_MUL_INFO_H +#define KERNEL_MAT_MUL_INFO_H #include "kernel/attributes/broadcaster.h" #include "kernel/attributes/expand_info.h" -#include namespace refactor::kernel { @@ -11,11 +10,11 @@ namespace refactor::kernel { DataType dataType; float alpha, beta; bool transA, transB; - size_t m, k, n; + dim_t m, k, n; // Expand operation info for biasd std::optional biasExpand; - // A constant batch or a 2-directional broadcaster that deals with dimensions before the last 2 dimensions - std::variant broadcasterOrBatch; + // A 2-directional broadcaster that deals with dimensions before the last 2 dimensions + Broadcaster broadcaster; MatMulInfo(Tensor const &, Tensor const &, std::optional>, @@ -24,4 +23,4 @@ namespace refactor::kernel { }// namespace refactor::kernel -#endif// KERNEL_MATMUL_INFO_H +#endif// KERNEL_MAT_MUL_INFO_H diff --git a/src/04kernel/include/kernel/attributes/mat_mul_integer_info.h b/src/04kernel/include/kernel/attributes/mat_mul_integer_info.h new file mode 100644 index 000000000..83e8fcd22 --- /dev/null +++ b/src/04kernel/include/kernel/attributes/mat_mul_integer_info.h @@ -0,0 +1,25 @@ +#ifndef KERNEL_MAT_MUL_INTEGER_INFO_H +#define KERNEL_MAT_MUL_INTEGER_INFO_H + +#include "kernel/attributes/broadcaster.h" + +namespace refactor::kernel { + + struct MatMulIntegerInfo { + struct Input { + bool signed_; + bool withZeroPoint; + + Input(TensorRefs const &, size_t i) noexcept; + }; + + Input a, b; + dim_t m, k, n; + Broadcaster broadcaster; + + explicit MatMulIntegerInfo(TensorRefs const &inputs) noexcept; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_MAT_MUL_INTEGER_INFO_H diff --git a/src/04kernel/src/attributes/broadcaster.cc b/src/04kernel/src/attributes/broadcaster.cc index c35fb3ae1..d0c4bea73 100644 --- a/src/04kernel/src/attributes/broadcaster.cc +++ b/src/04kernel/src/attributes/broadcaster.cc @@ -96,4 +96,8 @@ namespace refactor::kernel { } } + bool Broadcaster::needBroadcast() const noexcept { + return !strides.empty(); + } + }// namespace refactor::kernel diff --git a/src/04kernel/src/attributes/matmul_info.cc b/src/04kernel/src/attributes/mat_mul_info.cc similarity index 66% rename from src/04kernel/src/attributes/matmul_info.cc rename to src/04kernel/src/attributes/mat_mul_info.cc index 5f1162072..3619e9b59 100644 --- a/src/04kernel/src/attributes/matmul_info.cc +++ b/src/04kernel/src/attributes/mat_mul_info.cc @@ -1,10 +1,8 @@ -#include "kernel/attributes/matmul_info.h" -#include -#include +#include "kernel/attributes/mat_mul_info.h" namespace refactor::kernel { - ExpandInfo buildBias(size_t m, size_t n, + ExpandInfo buildBias(dim_t m, dim_t n, Tensor const &a, Tensor const &b, Tensor const &c) { @@ -12,8 +10,8 @@ namespace refactor::kernel { auto it = output.rbegin(); *it++ = n; *it++ = m; - for (auto da = static_cast(a.rank() - 2), - db = static_cast(b.rank() - 2); + for (auto da = static_cast(a.rank() - 2), + db = static_cast(b.rank() - 2); auto i : range0_(output.size() - 2)) { auto a_ = i < da ? a.shape[da - i - 1] : 1; auto b_ = i < db ? b.shape[db - i - 1] : 1; @@ -26,13 +24,6 @@ namespace refactor::kernel { slice(output.data(), output.size())); } - std::variant buildBroadcasterOrBatch(slice_t dimA, slice_t dimB) { - if (std::equal(dimA.begin(), dimA.end(), dimB.begin(), dimB.end())) { - return std::accumulate(dimA.begin(), dimA.end(), (size_t) 1, std::multiplies()); - } - return Broadcaster({dimA, dimB}); - } - MatMulInfo::MatMulInfo( Tensor const &a, Tensor const &b, std::optional> c, @@ -44,7 +35,8 @@ namespace refactor::kernel { k(transA ? a.shape.rbegin()[1] : a.shape.rbegin()[0]), n(transB ? b.shape.rbegin()[1] : b.shape.rbegin()[0]), biasExpand(c ? std::make_optional(buildBias(m, n, a, b, *c)) : std::nullopt), - broadcasterOrBatch(buildBroadcasterOrBatch(slice(a.shape.data(), a.shape.size() - 2), slice(b.shape.data(), b.shape.size() - 2))) { + broadcaster({slice(a.shape.data(), a.shape.size() - 2), + slice(b.shape.data(), b.shape.size() - 2)}) { auto kB = transB ? b.shape.rbegin()[0] : b.shape.rbegin()[1]; ASSERT(k == kB, "MatMul: input shape not matched."); } diff --git a/src/04kernel/src/attributes/mat_mul_integer_info.cc b/src/04kernel/src/attributes/mat_mul_integer_info.cc new file mode 100644 index 000000000..862565c9a --- /dev/null +++ b/src/04kernel/src/attributes/mat_mul_integer_info.cc @@ -0,0 +1,26 @@ +#include "kernel/attributes/mat_mul_integer_info.h" + +namespace refactor::kernel { + +#define A (inputs[0].get().shape) +#define B (inputs[1].get().shape) + + MatMulIntegerInfo::Input::Input(TensorRefs const &inputs, size_t i) noexcept + : signed_(inputs[i].get().dataType == DataType::I8), + withZeroPoint(false) { + if (inputs.size() > i + 2) { + auto const &t = inputs[i + 2].get(); + withZeroPoint = t.rank() != 0 || !t.data || t.data->get() != 0; + } + } + + MatMulIntegerInfo::MatMulIntegerInfo(TensorRefs const &inputs) noexcept + : a(inputs, 0), + b(inputs, 1), + m(A.rbegin()[1]), + k(A.rbegin()[0]), + n(B.rbegin()[0]), + broadcaster({slice(A.data(), A.size() - 2), + slice(B.data(), B.size() - 2)}) {} + +}// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/mat_mul.cc b/src/04kernel/src/collectors/mat_mul.cc index 08d10efbf..7581200cd 100644 --- a/src/04kernel/src/collectors/mat_mul.cc +++ b/src/04kernel/src/collectors/mat_mul.cc @@ -1,7 +1,7 @@ #include "kernel/collectors/mat_mul.h" #include "../kernels/mat_mul/cpu_kernel.hh" #include "../kernels/mat_mul/cublas_kernel.hh" -#include "kernel/attributes/matmul_info.h" +#include "kernel/attributes/mat_mul_info.h" namespace refactor::kernel { #define REGISTER(T) \ diff --git a/src/04kernel/src/collectors/mat_mul_integer.cc b/src/04kernel/src/collectors/mat_mul_integer.cc index 2abaded6f..3575c6b10 100644 --- a/src/04kernel/src/collectors/mat_mul_integer.cc +++ b/src/04kernel/src/collectors/mat_mul_integer.cc @@ -1,11 +1,19 @@ #include "kernel/collectors/mat_mul_integer.h" +#include "../../src/kernels/mat_mul_integer/cpu_kernel.hh" +#include "kernel/attributes/mat_mul_integer_info.h" namespace refactor::kernel { + std::vector MatMulIntegerCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + MatMulIntegerInfo info(inputs); + std::vector ans; switch (_target) { case decltype(_target)::Cpu: + if (auto ptr = MatMulIntegerCPU::build(info); ptr) { + ans.emplace_back(std::move(ptr)); + } break; case decltype(_target)::Nvidia: break; diff --git a/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc b/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc index 7bbab466f..0d1ca1c4b 100644 --- a/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc +++ b/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc @@ -1,5 +1,6 @@ #include "cpu_kernel.hh" #include "../expand/cpu_kernel.hh" +#include "../mat_mul_common/cpu_template.hpp" namespace refactor::kernel { using K = MatMulCPU; @@ -8,7 +9,7 @@ namespace refactor::kernel { K::MatMulCPU(decltype(info) info_) noexcept : Kernel(), info(std::move(info_)) {} - auto K::build(MatMulInfo info) noexcept -> KernelBox { + auto K::build(decltype(info) info) noexcept -> KernelBox { return info.dataType.isCpuNumberic() ? std::make_unique(std::move(info)) : nullptr; @@ -24,31 +25,6 @@ namespace refactor::kernel { return "Performing MatMul using CPU"; } - template - struct MatMulCPUMetaData { - size_t M, K, N; - size_t strideA0, strideA1, strideB0, strideB1; - T alpha, beta; - - /* - * 2D matrix multiplication: Y = a * A @ B + b * Y - * Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias. - */ - void matrixMultiply(T const *A, T const *B, T *Y) const noexcept { - // #pragma omp parallel for - for (size_t i = 0; i < M; i++) { - for (size_t j = 0; j < N; j++) { - T sum = 0; - // #pragma omp simd reduction(+ : sum) - for (size_t k = 0; k < K; k++) { - sum += A[i * strideA0 + k * strideA1] * B[k * strideB0 + j * strideB1]; - } - Y[i * N + j] = beta * Y[i * N + j] + alpha * sum; - } - } - } - }; - template static auto lowerTyped(MatMulInfo const &info, Resources &res) noexcept -> RoutineWorkspace { MatMulCPUMetaData const md{ @@ -70,8 +46,8 @@ namespace refactor::kernel { ? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) : std::nullopt; - if (std::holds_alternative(info.broadcasterOrBatch)) { - return [broadcaster = std::get(info.broadcasterOrBatch), + if (info.broadcaster.needBroadcast()) { + return [broadcaster = info.broadcaster, stepY, stepA, stepB, md, biasEx]// (runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { @@ -87,7 +63,7 @@ namespace refactor::kernel { } }; } else { - return [batch = std::get(info.broadcasterOrBatch), + return [batch = info.broadcaster.outputsCount, stepY, stepA, stepB, md, biasEx]// (runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { diff --git a/src/04kernel/src/kernels/mat_mul/cpu_kernel.hh b/src/04kernel/src/kernels/mat_mul/cpu_kernel.hh index ca2ce140b..950cb7e24 100644 --- a/src/04kernel/src/kernels/mat_mul/cpu_kernel.hh +++ b/src/04kernel/src/kernels/mat_mul/cpu_kernel.hh @@ -1,18 +1,17 @@ #ifndef KERNEL_MATMUL_CPU_KERNEL_HH #define KERNEL_MATMUL_CPU_KERNEL_HH -#include "kernel/attributes/matmul_info.h" +#include "kernel/attributes/mat_mul_info.h" #include "kernel/kernel.h" -#include "kernel/tensor.h" namespace refactor::kernel { struct MatMulCPU final : public Kernel { MatMulInfo info; - explicit MatMulCPU(MatMulInfo) noexcept; + explicit MatMulCPU(decltype(info)) noexcept; - static KernelBox build(MatMulInfo) noexcept; + static KernelBox build(decltype(info)) noexcept; static size_t typeId() noexcept; size_t kernelTypeId() const noexcept final; diff --git a/src/04kernel/src/kernels/mat_mul/cublas_kernel.cc b/src/04kernel/src/kernels/mat_mul/cublas_kernel.cc index a21f1dfe0..5f97d56ba 100644 --- a/src/04kernel/src/kernels/mat_mul/cublas_kernel.cc +++ b/src/04kernel/src/kernels/mat_mul/cublas_kernel.cc @@ -7,7 +7,7 @@ namespace refactor::kernel { K::MatMulCublas(decltype(info) info_) noexcept : Kernel(), info(std::move(info_)) {} - auto K::build(MatMulInfo info) noexcept -> KernelBox { + auto K::build(decltype(info) info) noexcept -> KernelBox { #ifndef USE_CUDA return nullptr; #endif diff --git a/src/04kernel/src/kernels/mat_mul/cublas_kernel.cu b/src/04kernel/src/kernels/mat_mul/cublas_kernel.cu index 8e32bad24..7476a6a1b 100644 --- a/src/04kernel/src/kernels/mat_mul/cublas_kernel.cu +++ b/src/04kernel/src/kernels/mat_mul/cublas_kernel.cu @@ -28,34 +28,8 @@ namespace refactor::kernel { ? std::make_optional(ExpandCuda(*info.biasExpand).lower(res).routine) : std::nullopt; // clang-format on - if (std::holds_alternative(info.broadcasterOrBatch)) { - return [batch = std::get(info.broadcasterOrBatch), - cudaDataType, - alpha, beta, tA, tB, - m, n, k, - strideA, strideB, - lda, ldb, - biasEx]// - (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { - // Call expand kernel to broadcast bias if bias is used - if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } - - auto a = reinterpret_cast(inputs[0]); - auto b = reinterpret_cast(inputs[1]); - auto y = reinterpret_cast(outputs[0]); - cublasGemmStridedBatchedEx( - res.fetchOrStore()->handle, - tB, tA, - n, m, k, - &alpha, - b, cudaDataType, ldb, strideB, - a, cudaDataType, lda, strideA, - &beta, y, cudaDataType, - n, m * n, batch, cudaDataType, - CUBLAS_GEMM_DEFAULT); - }; - } else {// if use boradcaster - return [broadcaster = std::get(info.broadcasterOrBatch), + if (info.broadcaster.needBroadcast()) { + return [broadcaster = info.broadcaster, cudaDataType, alpha, beta, tA, tB, m, n, k, @@ -83,6 +57,33 @@ namespace refactor::kernel { CUBLAS_GEMM_DEFAULT); } }; + + } else { + return [batch = info.broadcaster.outputsCount, + cudaDataType, + alpha, beta, tA, tB, + m, n, k, + strideA, strideB, + lda, ldb, + biasEx]// + (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { + // Call expand kernel to broadcast bias if bias is used + if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } + + auto a = reinterpret_cast(inputs[0]); + auto b = reinterpret_cast(inputs[1]); + auto y = reinterpret_cast(outputs[0]); + cublasGemmStridedBatchedEx( + res.fetchOrStore()->handle, + tB, tA, + n, m, k, + &alpha, + b, cudaDataType, ldb, strideB, + a, cudaDataType, lda, strideA, + &beta, y, cudaDataType, + n, m * n, batch, cudaDataType, + CUBLAS_GEMM_DEFAULT); + }; } } diff --git a/src/04kernel/src/kernels/mat_mul/cublas_kernel.hh b/src/04kernel/src/kernels/mat_mul/cublas_kernel.hh index 01a511527..9d49aa805 100644 --- a/src/04kernel/src/kernels/mat_mul/cublas_kernel.hh +++ b/src/04kernel/src/kernels/mat_mul/cublas_kernel.hh @@ -1,7 +1,7 @@ #ifndef KERNEL_MATMUL_CUBLAS_KERNEL_HH #define KERNEL_MATMUL_CUBLAS_KERNEL_HH -#include "kernel/attributes/matmul_info.h" +#include "kernel/attributes/mat_mul_info.h" #include "kernel/kernel.h" namespace refactor::kernel { @@ -9,9 +9,9 @@ namespace refactor::kernel { struct MatMulCublas final : public Kernel { MatMulInfo info; - explicit MatMulCublas(MatMulInfo) noexcept; + explicit MatMulCublas(decltype(info)) noexcept; - static KernelBox build(MatMulInfo) noexcept; + static KernelBox build(decltype(info)) noexcept; static size_t typeId() noexcept; size_t kernelTypeId() const noexcept final; diff --git a/src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp b/src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp new file mode 100644 index 000000000..eee00ee1f --- /dev/null +++ b/src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp @@ -0,0 +1,33 @@ +#ifndef KERNEL_MATMUL_COMMON_CPU_TEMPLATE_HPP +#define KERNEL_MATMUL_COMMON_CPU_TEMPLATE_HPP + +namespace refactor::kernel { + + template + struct MatMulCPUMetaData { + size_t M, K, N; + size_t strideA0, strideA1, strideB0, strideB1; + T alpha, beta; + + /* + * 2D matrix multiplication: Y = a * A @ B + b * Y + * Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias. + */ + void matrixMultiply(T const *A, T const *B, T *Y) const noexcept { + // #pragma omp parallel for + for (size_t i = 0; i < M; i++) { + for (size_t j = 0; j < N; j++) { + T sum = 0; + // #pragma omp simd reduction(+ : sum) + for (size_t k = 0; k < K; k++) { + sum += A[i * strideA0 + k * strideA1] * B[k * strideB0 + j * strideB1]; + } + Y[i * N + j] = beta * Y[i * N + j] + alpha * sum; + } + } + } + }; + +}// namespace refactor::kernel + +#endif// KERNEL_MATMUL_COMMON_CPU_TEMPLATE_HPP diff --git a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc new file mode 100644 index 000000000..3472ed749 --- /dev/null +++ b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc @@ -0,0 +1,29 @@ +#include "cpu_kernel.hh" +#include "../mat_mul_common/cpu_template.hpp" + +namespace refactor::kernel { + using K = MatMulIntegerCPU; + using DT = DataType; + + K::MatMulIntegerCPU(decltype(info) info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(decltype(info) info) noexcept -> KernelBox { + return std::make_unique(std::move(info)); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing MatMulInteger using CPU"; + } + + auto K::lower(Resources &res) const -> RoutineWorkspace { + TODO(""); + }; + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh new file mode 100644 index 000000000..5610d732d --- /dev/null +++ b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh @@ -0,0 +1,25 @@ +#ifndef KERNEL_MATMUL_INTEGER_CPU_KERNEL_HH +#define KERNEL_MATMUL_INTEGER_CPU_KERNEL_HH + +#include "kernel/attributes/mat_mul_integer_info.h" +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct MatMulIntegerCPU final : public Kernel { + MatMulIntegerInfo info; + + explicit MatMulIntegerCPU(decltype(info)) noexcept; + + static KernelBox build(decltype(info)) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; + + RoutineWorkspace lower(Resources &) const final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_MATMUL_INTEGER_CPU_KERNEL_HH diff --git a/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc b/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc index 3c2351c29..58d5f677e 100644 --- a/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc +++ b/src/04kernel/src/kernels/simple_binary/cuda_kernel.cc @@ -154,7 +154,7 @@ extern "C" __global__ void kernel( auto op_ = op(opType, dataType); auto params = cuda::ThreadsDistributer()(broadcaster.outputsCount); - if (broadcaster.strides.empty()) { + if (!broadcaster.needBroadcast()) { auto name = fmt::format("binary{}", postfix); auto code = fmt::format(NO_BROADCAST, dt_, op_); return [params, h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel")]// From 6462a270c5343ad7fa65b5ab08ed580e31054768 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 15 Dec 2023 18:32:57 +0800 Subject: [PATCH 06/18] =?UTF-8?q?feat(kernel):=20=E5=AE=9E=E7=8E=B0=20=20M?= =?UTF-8?q?atMulInteger=20cpu=20kernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../kernel/attributes/mat_mul_integer_info.h | 3 +- .../src/attributes/mat_mul_integer_info.cc | 12 ++- .../src/kernels/mat_mul/cpu_kernel.cc | 74 +++++++---------- .../kernels/mat_mul_common/cpu_template.hpp | 18 +++-- .../src/kernels/mat_mul_integer/cpu_kernel.cc | 81 ++++++++++++++++++- .../src/kernels/mat_mul_integer/cpu_kernel.hh | 2 +- 6 files changed, 132 insertions(+), 58 deletions(-) diff --git a/src/04kernel/include/kernel/attributes/mat_mul_integer_info.h b/src/04kernel/include/kernel/attributes/mat_mul_integer_info.h index 83e8fcd22..ad091b0e9 100644 --- a/src/04kernel/include/kernel/attributes/mat_mul_integer_info.h +++ b/src/04kernel/include/kernel/attributes/mat_mul_integer_info.h @@ -7,8 +7,9 @@ namespace refactor::kernel { struct MatMulIntegerInfo { struct Input { - bool signed_; bool withZeroPoint; + bool signed_; + dim_t groupCount, groupSize; Input(TensorRefs const &, size_t i) noexcept; }; diff --git a/src/04kernel/src/attributes/mat_mul_integer_info.cc b/src/04kernel/src/attributes/mat_mul_integer_info.cc index 862565c9a..36ec18ff0 100644 --- a/src/04kernel/src/attributes/mat_mul_integer_info.cc +++ b/src/04kernel/src/attributes/mat_mul_integer_info.cc @@ -6,11 +6,17 @@ namespace refactor::kernel { #define B (inputs[1].get().shape) MatMulIntegerInfo::Input::Input(TensorRefs const &inputs, size_t i) noexcept - : signed_(inputs[i].get().dataType == DataType::I8), - withZeroPoint(false) { + : withZeroPoint(false), + signed_(true), + groupCount(1), + groupSize(1) { if (inputs.size() > i + 2) { auto const &t = inputs[i + 2].get(); - withZeroPoint = t.rank() != 0 || !t.data || t.data->get() != 0; + if (withZeroPoint = t.rank() != 0 || !t.data || t.data->get() != 0) { + signed_ = t.dataType == DataType::I8; + groupCount = t.elementsSize(); + groupSize = inputs[i].get().elementsSize() / groupCount; + } } } diff --git a/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc b/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc index 0d1ca1c4b..bcd184c3f 100644 --- a/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc +++ b/src/04kernel/src/kernels/mat_mul/cpu_kernel.cc @@ -27,56 +27,44 @@ namespace refactor::kernel { template static auto lowerTyped(MatMulInfo const &info, Resources &res) noexcept -> RoutineWorkspace { - MatMulCPUMetaData const md{ - .M = info.m, - .K = info.k, - .N = info.n, - .strideA0 = info.transA ? 1 : info.k, - .strideA1 = info.transA ? info.m : 1, - .strideB0 = info.transB ? 1 : info.n, - .strideB1 = info.transB ? info.k : 1, - .alpha = static_cast(info.alpha), - .beta = static_cast(info.biasExpand ? info.beta : 0.0f), - }; - auto stepY = info.m * info.n, - stepA = info.m * info.k, - stepB = info.k * info.n; auto biasEx = info.biasExpand ? std::make_optional(ExpandCpu(*info.biasExpand).lower(res).routine) : std::nullopt; - if (info.broadcaster.needBroadcast()) { - return [broadcaster = info.broadcaster, - stepY, stepA, stepB, - md, biasEx]// - (runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { - if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } + return [info = info, biasEx](runtime::Resources &res, void *, void const *const *inputs, void *const *outputs) { + if (biasEx) { (*biasEx)(res, nullptr, inputs + 2, outputs); } - auto a = reinterpret_cast(inputs[0]); - auto b = reinterpret_cast(inputs[1]); - auto y = reinterpret_cast(outputs[0]); - dim_t offset[2]; - for (auto i : range0_(broadcaster.outputsCount)) { - broadcaster.locate(i, offset); - md.matrixMultiply(a + stepA * offset[0], b + stepB * offset[1], y + stepY * i); - } - }; - } else { - return [batch = info.broadcaster.outputsCount, - stepY, stepA, stepB, - md, biasEx]// - (runtime::Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { - if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } + MatMulCPUMetaData const md{ + .M = info.m, + .K = info.k, + .N = info.n, + .strideA0 = info.transA ? 1 : info.k, + .strideA1 = info.transA ? info.m : 1, + .strideB0 = info.transB ? 1 : info.n, + .strideB1 = info.transB ? info.k : 1, + .alpha = static_cast(info.alpha), + .beta = static_cast(info.biasExpand ? info.beta : 0.0f), + }; + auto const stepY = info.m * info.n, + stepA = info.m * info.k, + stepB = info.k * info.n; - auto a = reinterpret_cast(inputs[0]); - auto b = reinterpret_cast(inputs[1]); - auto y = reinterpret_cast(outputs[0]); - for (auto i : range0_(batch)) { - md.matrixMultiply(a + stepA * i, b + stepB * i, y + stepY * i); - } - }; - } + auto a = reinterpret_cast(inputs[0]); + auto b = reinterpret_cast(inputs[1]); + auto y = reinterpret_cast(outputs[0]); + if (info.broadcaster.needBroadcast()) { + dim_t offset[2]; + for (auto i : range0_(info.broadcaster.outputsCount)) { + info.broadcaster.locate(i, offset); + md.matrixMultiply(a + stepA * offset[0], b + stepB * offset[1], y + stepY * i); + } + } else { + for (auto i : range0_(info.broadcaster.outputsCount)) { + md.matrixMultiply(a + stepA * i, b + stepB * i, y + stepY * i); + } + } + }; } auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { diff --git a/src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp b/src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp index eee00ee1f..766279b3f 100644 --- a/src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp +++ b/src/04kernel/src/kernels/mat_mul_common/cpu_template.hpp @@ -3,26 +3,28 @@ namespace refactor::kernel { - template + template struct MatMulCPUMetaData { - size_t M, K, N; - size_t strideA0, strideA1, strideB0, strideB1; - T alpha, beta; + size_t M, K, N, + strideA0, strideA1, + strideB0, strideB1; + TI alpha; + TO beta; /* * 2D matrix multiplication: Y = a * A @ B + b * Y * Assume bias C has been broadcast to Y already. Beta should be 0 in the absence of bias. */ - void matrixMultiply(T const *A, T const *B, T *Y) const noexcept { + void matrixMultiply(TI const *a, TI const *b, TO *y) const noexcept { // #pragma omp parallel for for (size_t i = 0; i < M; i++) { for (size_t j = 0; j < N; j++) { - T sum = 0; + TO sum = 0; // #pragma omp simd reduction(+ : sum) for (size_t k = 0; k < K; k++) { - sum += A[i * strideA0 + k * strideA1] * B[k * strideB0 + j * strideB1]; + sum += static_cast(a[i * strideA0 + k * strideA1] * b[k * strideB0 + j * strideB1]); } - Y[i * N + j] = beta * Y[i * N + j] + alpha * sum; + y[i * N + j] = beta * y[i * N + j] + alpha * sum; } } } diff --git a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc index 3472ed749..0c014ff42 100644 --- a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc +++ b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc @@ -22,8 +22,85 @@ namespace refactor::kernel { return "Performing MatMulInteger using CPU"; } - auto K::lower(Resources &res) const -> RoutineWorkspace { - TODO(""); + template static int8_t sub(T, T); + template<> int8_t sub(int8_t a, int8_t b) { return a - b; } + template<> int8_t sub(uint8_t a, uint8_t b) { return static_cast(static_cast(a) - static_cast(b)); } + + template + static void applyZeroPoint(MatMulIntegerInfo::Input meta, int8_t *dst, void const *src_, void const *zp_) { + auto src = reinterpret_cast(src_), + zp = reinterpret_cast(zp_); + for (auto i : range0_(meta.groupCount)) { + for (auto j : range0_(meta.groupSize)) { + dst[meta.groupSize * i + j] = sub(src[meta.groupSize * i + j], zp[i]); + } + } + } + + auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace runtime; + + size_t workspace = 0; + if (info.a.withZeroPoint) { + workspace += info.a.groupCount * info.a.groupSize; + } + if (info.b.withZeroPoint) { + workspace += info.b.groupCount * info.b.groupSize; + } + + auto routine = [info = info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto workspacePtr = reinterpret_cast(workspace); + auto a = reinterpret_cast(inputs[0]), + b = reinterpret_cast(inputs[1]); + auto y = reinterpret_cast(outputs[0]); + + if (auto meta = info.a; meta.withZeroPoint) { + if (meta.signed_) { + applyZeroPoint(meta, workspacePtr, a, inputs[2]); + } else { + applyZeroPoint(meta, workspacePtr, a, inputs[2]); + } + a = workspacePtr; + workspacePtr += meta.groupCount * meta.groupSize; + } + if (auto meta = info.b; meta.withZeroPoint) { + if (meta.signed_) { + applyZeroPoint(meta, workspacePtr, b, inputs[3]); + } else { + applyZeroPoint(meta, workspacePtr, b, inputs[3]); + } + b = workspacePtr; + } + + MatMulCPUMetaData const md{ + .M = info.m, + .K = info.k, + .N = info.n, + .strideA0 = info.k, + .strideA1 = 1, + .strideB0 = info.n, + .strideB1 = 1, + .alpha = 1, + .beta = 0, + }; + auto const stepY = info.m * info.n, + stepA = info.m * info.k, + stepB = info.k * info.n; + + if (info.broadcaster.needBroadcast()) { + dim_t offset[2]; + for (auto i : range0_(info.broadcaster.outputsCount)) { + info.broadcaster.locate(i, offset); + md.matrixMultiply(a + stepA * offset[0], b + stepB * offset[1], y + stepY * i); + } + } else { + for (auto i : range0_(info.broadcaster.outputsCount)) { + md.matrixMultiply(a + stepA * i, b + stepB * i, y + stepY * i); + } + } + }; + + return {std::move(routine), workspace}; }; }// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh index 5610d732d..2cb8cb312 100644 --- a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh +++ b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh @@ -17,7 +17,7 @@ namespace refactor::kernel { size_t kernelTypeId() const noexcept final; std::string_view description() const noexcept final; - RoutineWorkspace lower(Resources &) const final; + RoutineWorkspace lower(Resources &) const noexcept final; }; }// namespace refactor::kernel From 916fd3d3c66f8f2273d60fc6c7b650d43d9aac4b Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 18 Dec 2023 09:19:54 +0800 Subject: [PATCH 07/18] =?UTF-8?q?test(kernel):=20=E6=B5=8B=E8=AF=95=20MatM?= =?UTF-8?q?ulInteger=20cpu=20kernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../src/collectors/mat_mul_integer.cc | 2 +- .../src/kernels/mat_mul_integer/cpu_kernel.cc | 4 +- .../src/kernels/mat_mul_integer/cpu_kernel.hh | 4 +- .../test/kernels/mat_mul/test_cpu.cpp | 74 +++++++++++++++ .../test_cublas.cpp} | 0 .../mat_mul_integer/test_cpu_kernel.cpp | 31 ++++++ .../test/kernels/matmul/test_matmul_cpu.cpp | 95 ------------------- 7 files changed, 110 insertions(+), 100 deletions(-) create mode 100644 src/04kernel/test/kernels/mat_mul/test_cpu.cpp rename src/04kernel/test/kernels/{matmul/test_matmul_cublas.cpp => mat_mul/test_cublas.cpp} (100%) create mode 100644 src/04kernel/test/kernels/mat_mul_integer/test_cpu_kernel.cpp delete mode 100644 src/04kernel/test/kernels/matmul/test_matmul_cpu.cpp diff --git a/src/04kernel/src/collectors/mat_mul_integer.cc b/src/04kernel/src/collectors/mat_mul_integer.cc index 3575c6b10..123f0beee 100644 --- a/src/04kernel/src/collectors/mat_mul_integer.cc +++ b/src/04kernel/src/collectors/mat_mul_integer.cc @@ -11,7 +11,7 @@ namespace refactor::kernel { std::vector ans; switch (_target) { case decltype(_target)::Cpu: - if (auto ptr = MatMulIntegerCPU::build(info); ptr) { + if (auto ptr = MatMulIntegerCpu::build(info); ptr) { ans.emplace_back(std::move(ptr)); } break; diff --git a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc index 0c014ff42..0cf31a956 100644 --- a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc +++ b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc @@ -2,10 +2,10 @@ #include "../mat_mul_common/cpu_template.hpp" namespace refactor::kernel { - using K = MatMulIntegerCPU; + using K = MatMulIntegerCpu; using DT = DataType; - K::MatMulIntegerCPU(decltype(info) info_) noexcept + K::MatMulIntegerCpu(decltype(info) info_) noexcept : Kernel(), info(std::move(info_)) {} auto K::build(decltype(info) info) noexcept -> KernelBox { diff --git a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh index 2cb8cb312..53e6f501c 100644 --- a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh +++ b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.hh @@ -6,10 +6,10 @@ namespace refactor::kernel { - struct MatMulIntegerCPU final : public Kernel { + struct MatMulIntegerCpu final : public Kernel { MatMulIntegerInfo info; - explicit MatMulIntegerCPU(decltype(info)) noexcept; + explicit MatMulIntegerCpu(decltype(info)) noexcept; static KernelBox build(decltype(info)) noexcept; static size_t typeId() noexcept; diff --git a/src/04kernel/test/kernels/mat_mul/test_cpu.cpp b/src/04kernel/test/kernels/mat_mul/test_cpu.cpp new file mode 100644 index 000000000..f6e8f3147 --- /dev/null +++ b/src/04kernel/test/kernels/mat_mul/test_cpu.cpp @@ -0,0 +1,74 @@ +#include "../src/kernels/mat_mul/cpu_kernel.hh" +#include + +using namespace refactor; +using namespace kernel; + +template +static void check( + Resources &&res, + Routine &&routine, + std::vector ans, + std::vector const &a, + std::vector const &b, + std::vector const &c) { + std::vector result(ans.size()); + // inference + void const *inputs[]{a.data(), b.data(), c.data()}; + void *outputs[]{result.data()}; + routine(res, nullptr, inputs, outputs); + // check + EXPECT_EQ(result, ans); +} + +TEST(kernel, MatMulCPU_WithBias) { + // build routine + auto A = Tensor::share(DataType::F32, Shape{1, 2, 2}); + auto B = Tensor::share(DataType::F32, Shape{2, 2}); + auto C = Tensor::share(DataType::F32, Shape{}); + auto Y = Tensor::share(DataType::F32, Shape{2, 2}); + auto kernel = MatMulCPU::build(MatMulInfo(*A, *B, *C, false, false, 1, 1)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + check(std::move(res), kernel->lower(res).routine, + {2, 4, 1, 1.25}, + {1.0, 2.0, 0.0, 0.5}, + {1.0, 2.0, 0.0, 0.5}, + {1.0}); +} + +TEST(kernel, MatMulCPU_UINT16NoBias) { + // build routine + auto A = Tensor::share(DataType::U16, Shape{2, 2}); + auto B = Tensor::share(DataType::U16, Shape{2, 2}); + auto Y = Tensor::share(DataType::U16, Shape{2, 2}); + auto kernel = MatMulCPU::build(MatMulInfo(*A, *B, std::nullopt, false, false, 1, 1)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + check(std::move(res), kernel->lower(res).routine, + {7, 6, 2, 3}, + {3, 2, 0, 1}, + {1, 0, 2, 3}, + {}); +} + +TEST(kernel, MatMulCPU_Broadcast) { + // build routine + auto A = Tensor::share(DataType::F32, Shape{2, 1, 2, 2}); + auto B = Tensor::share(DataType::F32, Shape{1, 2, 2, 2}); + auto C = Tensor::share(DataType::F32, Shape{2, 1}); + auto Y = Tensor::share(DataType::F32, Shape{2, 2, 2, 2}); + auto kernel = MatMulCPU::build(MatMulInfo(*A, *B, *C, false, false, 1, 1)); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + check(std::move(res), kernel->lower(res).routine, + {2.0, 4.0, 0.0, 0.25, + 2.0, 3.0, 0.0, 0.5, + 2.0, 3.0, 0.0, 0.5, + 2.0, 1.0, 0.0, 1.0}, + {1.0, 2.0, 0.0, 0.5, + 1.0, 0.0, 0.0, 1.0}, + {1.0, 2.0, 0.0, 0.5, + 1.0, 0.0, 0.0, 1.0}, + {1.0, 0.0}); +} diff --git a/src/04kernel/test/kernels/matmul/test_matmul_cublas.cpp b/src/04kernel/test/kernels/mat_mul/test_cublas.cpp similarity index 100% rename from src/04kernel/test/kernels/matmul/test_matmul_cublas.cpp rename to src/04kernel/test/kernels/mat_mul/test_cublas.cpp diff --git a/src/04kernel/test/kernels/mat_mul_integer/test_cpu_kernel.cpp b/src/04kernel/test/kernels/mat_mul_integer/test_cpu_kernel.cpp new file mode 100644 index 000000000..cbf4d3514 --- /dev/null +++ b/src/04kernel/test/kernels/mat_mul_integer/test_cpu_kernel.cpp @@ -0,0 +1,31 @@ +#include "../src/kernels/mat_mul_integer/cpu_kernel.hh" +#include + +using namespace refactor; +using namespace kernel; + +TEST(kernel, MatMulIntegerCpu) { + // build routine + auto A = Tensor::share(DataType::U8, Shape{2, 3}); + auto B = Tensor::share(DataType::U8, Shape{3, 1}); + auto Y = Tensor::share(DataType::I32, Shape{2, 1}); + auto kernel = MatMulIntegerCpu::build(MatMulIntegerInfo(TensorRefs{*A, *B})); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector + dataA{1, 2, 3, 4, 5, 6}, + dataB{1, 2, 3}; + std::vector + result(Y->elementsSize()), + ans{14, 32}; + // inference + { + void const *inputs[]{dataA.data(), dataB.data()}; + void *outputs[]{result.data()}; + routine(res, nullptr, inputs, outputs); + } + // check + EXPECT_EQ(result, ans); +} diff --git a/src/04kernel/test/kernels/matmul/test_matmul_cpu.cpp b/src/04kernel/test/kernels/matmul/test_matmul_cpu.cpp deleted file mode 100644 index 0228d1fb7..000000000 --- a/src/04kernel/test/kernels/matmul/test_matmul_cpu.cpp +++ /dev/null @@ -1,95 +0,0 @@ -#include "../src/kernels/mat_mul/cpu_kernel.hh" -#include - -using namespace refactor; -using namespace kernel; - -TEST(kernel, MatMulCPU_WithBias) { - // build routine - auto A = Tensor::share(DataType::F32, Shape{1, 2, 2}); - auto B = Tensor::share(DataType::F32, Shape{2, 2}); - auto C = Tensor::share(DataType::F32, Shape{}); - auto Y = Tensor::share(DataType::F32, Shape{2, 2}); - auto kernel = MatMulCPU::build(MatMulInfo(*A, *B, *C, false, false, 1, 1)); - ASSERT_TRUE(kernel); - auto res = runtime::Resources(); - auto routine = kernel->lower(res).routine; - // put input data - std::vector - dataA{1.0, 2.0, 0.0, 0.5}, - dataB{1.0, 2.0, 0.0, 0.5}, - dataC{1.0}, - result(Y->elementsSize()), - ans{2, 4, 1, 1.25}; - // inference - { - void const *inputs[]{dataA.data(), dataB.data(), dataC.data()}; - void *outputs[]{result.data()}; - routine(res, nullptr, inputs, outputs); - } - // check - for (auto i : range0_(result.size())) { - EXPECT_FLOAT_EQ(result[i], ans[i]); - } -} - -TEST(kernel, MatMulCPU_UINT16NoBias) { - // build routine - auto A = Tensor::share(DataType::U16, Shape{2, 2}); - auto B = Tensor::share(DataType::U16, Shape{2, 2}); - auto Y = Tensor::share(DataType::U16, Shape{2, 2}); - auto kernel = MatMulCPU::build(MatMulInfo(*A, *B, std::nullopt, false, false, 1, 1)); - ASSERT_TRUE(kernel); - auto res = runtime::Resources(); - auto routine = kernel->lower(res).routine; - // put input data - std::vector - dataA{3, 2, 0, 1}, - dataB{1, 0, 2, 3}, - result(Y->elementsSize()), - ans{7, 6, 2, 3}; - // inference - { - void const *inputs[]{dataA.data(), dataB.data()}; - void *outputs[]{result.data()}; - routine(res, nullptr, inputs, outputs); - } - // check - for (auto i : range0_(result.size())) { - EXPECT_EQ(result[i], ans[i]); - } -} - -TEST(kernel, MatMulCPU_Broadcast) { - // build routine - auto A = Tensor::share(DataType::F32, Shape{2, 1, 2, 2}); - auto B = Tensor::share(DataType::F32, Shape{1, 2, 2, 2}); - auto C = Tensor::share(DataType::F32, Shape{2, 1}); - auto Y = Tensor::share(DataType::F32, Shape{2, 2, 2, 2}); - auto kernel = MatMulCPU::build(MatMulInfo(*A, *B, *C, false, false, 1, 1)); - ASSERT_TRUE(kernel); - auto res = runtime::Resources(); - auto routine = kernel->lower(res).routine; - // put input data - std::vector - dataA{1.0, 2.0, 0.0, 0.5, - 1.0, 0.0, 0.0, 1.0}, - dataB{1.0, 2.0, 0.0, 0.5, - 1.0, 0.0, 0.0, 1.0}, - dataC{1.0, 0.0}, - result(Y->elementsSize()), - ans{2.0, 4.0, 0.0, 0.25, - 2.0, 3.0, 0.0, 0.5, - 2.0, 3.0, 0.0, 0.5, - 2.0, 1.0, 0.0, 1.0}; - // inference - { - void const *inputs[]{dataA.data(), dataB.data(), dataC.data()}; - void *outputs[]{result.data()}; - routine(res, nullptr, inputs, outputs); - } - // check - for (auto i : range0_(result.size())) { - EXPECT_FLOAT_EQ(result[i], ans[i]); - } -} From c74b5885c2ed262dc535e26cdae366e3c5ebe01f Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 18 Dec 2023 09:51:23 +0800 Subject: [PATCH 08/18] =?UTF-8?q?feat(kernel):=20=E5=AE=9E=E7=8E=B0=20MatM?= =?UTF-8?q?ulInteger=20cublas=20kernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../src/collectors/mat_mul_integer.cc | 4 + .../src/kernels/mat_mul_integer/cpu_kernel.cc | 2 +- .../kernels/mat_mul_integer/cublas_kernel.cc | 28 +++++ .../kernels/mat_mul_integer/cublas_kernel.cu | 115 ++++++++++++++++++ .../kernels/mat_mul_integer/cublas_kernel.hh | 26 ++++ 5 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cc create mode 100644 src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu create mode 100644 src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.hh diff --git a/src/04kernel/src/collectors/mat_mul_integer.cc b/src/04kernel/src/collectors/mat_mul_integer.cc index 123f0beee..c0124de91 100644 --- a/src/04kernel/src/collectors/mat_mul_integer.cc +++ b/src/04kernel/src/collectors/mat_mul_integer.cc @@ -1,5 +1,6 @@ #include "kernel/collectors/mat_mul_integer.h" #include "../../src/kernels/mat_mul_integer/cpu_kernel.hh" +#include "../../src/kernels/mat_mul_integer/cublas_kernel.hh" #include "kernel/attributes/mat_mul_integer_info.h" namespace refactor::kernel { @@ -16,6 +17,9 @@ namespace refactor::kernel { } break; case decltype(_target)::Nvidia: + if (auto ptr = MatMulIntegerCublas::build(info); ptr) { + ans.emplace_back(std::move(ptr)); + } break; default: UNREACHABLEX(void, "Unknown target"); diff --git a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc index 0cf31a956..751fb7c0d 100644 --- a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc +++ b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc @@ -37,7 +37,7 @@ namespace refactor::kernel { } } - auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { using namespace runtime; size_t workspace = 0; diff --git a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cc b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cc new file mode 100644 index 000000000..d1eeb607e --- /dev/null +++ b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cc @@ -0,0 +1,28 @@ +#include "cublas_kernel.hh" + +namespace refactor::kernel { + using K = MatMulIntegerCublas; + using DT = DataType; + + K::MatMulIntegerCublas(decltype(info) info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(decltype(info) info) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + + return std::make_unique(std::move(info)); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing MatMulInteger using CUBLAS"; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu new file mode 100644 index 000000000..94d3aeb0d --- /dev/null +++ b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu @@ -0,0 +1,115 @@ +#include "../../utilities/cuda/cublas_context.hh" +#include "cublas_kernel.hh" +#include +#include +#include + +namespace refactor::kernel { + using namespace runtime; + using namespace cublas; + + template __device__ __forceinline__ static int8_t sub(T, T); + template<> __device__ __forceinline__ int8_t sub(int8_t a, int8_t b) { return a - b; } + template<> __device__ __forceinline__ int8_t sub(uint8_t a, uint8_t b) { return static_cast(static_cast(a) - static_cast(b)); } + + template + struct MatMulIntegerZPFunctor { + dim_t groupSize; + T const *src, *zp; + + __device__ int8_t operator()(size_t i) const noexcept { + return sub(src[i], zp[i / groupSize]); + } + }; + + template + static void applyZeroPoint(MatMulIntegerInfo::Input meta, int8_t *dst, void const *src, void const *zp) { + thrust::tabulate( + thrust::device, + dst, dst + meta.groupCount * meta.groupSize, + MatMulIntegerZPFunctor{ + .groupSize = meta.groupSize, + .src = reinterpret_cast(src), + .zp = reinterpret_cast(zp), + }); + } + + auto MatMulIntegerCublas::lower(Resources &res) const noexcept -> RoutineWorkspace { + + size_t workspace = 0; + if (info.a.withZeroPoint) { + workspace += info.a.groupCount * info.a.groupSize; + } + if (info.b.withZeroPoint) { + workspace += info.b.groupCount * info.b.groupSize; + } + + auto routine = [info = info](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + auto workspacePtr = reinterpret_cast(workspace); + auto a = reinterpret_cast(inputs[0]), + b = reinterpret_cast(inputs[1]); + auto y = reinterpret_cast(outputs[0]); + + if (auto meta = info.a; meta.withZeroPoint) { + if (meta.signed_) { + applyZeroPoint(meta, workspacePtr, a, inputs[2]); + } else { + applyZeroPoint(meta, workspacePtr, a, inputs[2]); + } + a = workspacePtr; + workspacePtr += meta.groupCount * meta.groupSize; + } + if (auto meta = info.b; meta.withZeroPoint) { + if (meta.signed_) { + applyZeroPoint(meta, workspacePtr, b, inputs[3]); + } else { + applyZeroPoint(meta, workspacePtr, b, inputs[3]); + } + b = workspacePtr; + } + + int32_t alpha = 1, beta = 0; + auto m = info.m, + n = info.n, + k = info.k; + auto strideY = m * n, + strideA = m * k, + strideB = k * n; + auto lda = info.k, + ldb = info.n; + if (info.broadcaster.needBroadcast()) { + + uint32_t offset[2]; + for (auto i : range0_(info.broadcaster.outputsCount)) { + info.broadcaster.locate(i, offset); + cublasGemmEx( + res.fetchOrStore()->handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + b + strideB * offset[1], CUDA_R_8I, ldb, + a + strideA * offset[0], CUDA_R_8I, lda, + &beta, y + strideY * i, CUDA_R_32I, + n, CUDA_R_32I, + CUBLAS_GEMM_DEFAULT); + } + } else { + + cublasGemmStridedBatchedEx( + res.fetchOrStore()->handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + b, CUDA_R_8I, ldb, strideB, + a, CUDA_R_8I, lda, strideA, + &beta, y, CUDA_R_32I, + n, m * n, info.broadcaster.outputsCount, CUDA_R_32I, + CUBLAS_GEMM_DEFAULT); + } + }; + + res.fetchOrStore(); + return {std::move(routine), workspace}; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.hh b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.hh new file mode 100644 index 000000000..d0d0400a7 --- /dev/null +++ b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.hh @@ -0,0 +1,26 @@ +#ifndef KERNEL_MATMUL_CUBLAS_KERNEL_HH +#define KERNEL_MATMUL_CUBLAS_KERNEL_HH + +#include "kernel/attributes/mat_mul_integer_info.h" +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct MatMulIntegerCublas final : public Kernel { + MatMulIntegerInfo info; + + explicit MatMulIntegerCublas(decltype(info)) noexcept; + + static KernelBox build(decltype(info)) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_MATMUL_CUBLAS_KERNEL_HH From 55cd88693da873f5886cd610e93548dbd6f86d41 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 18 Dec 2023 10:37:40 +0800 Subject: [PATCH 09/18] =?UTF-8?q?test(kernel):=20=E6=B5=8B=E8=AF=95=20MatM?= =?UTF-8?q?ulInteger=20cublas=20kernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../src/kernels/mat_mul/cublas_kernel.cu | 40 ++++++------ .../kernels/mat_mul_integer/cublas_kernel.cu | 28 ++++----- .../src/utilities/cuda/cublas_context.hh | 7 +++ ...cudnn_functions.cc => cudnn_functions.cpp} | 0 .../test/kernels/mat_mul/test_cublas.cpp | 12 +--- .../mat_mul_integer/test_cpu_kernel.cpp | 7 ++- .../mat_mul_integer/test_cublas_kernel.cpp | 61 +++++++++++++++++++ 7 files changed, 108 insertions(+), 47 deletions(-) rename src/04kernel/src/utilities/cuda/{cudnn_functions.cc => cudnn_functions.cpp} (100%) create mode 100644 src/04kernel/test/kernels/mat_mul_integer/test_cublas_kernel.cpp diff --git a/src/04kernel/src/kernels/mat_mul/cublas_kernel.cu b/src/04kernel/src/kernels/mat_mul/cublas_kernel.cu index 7476a6a1b..b0c2a7a0b 100644 --- a/src/04kernel/src/kernels/mat_mul/cublas_kernel.cu +++ b/src/04kernel/src/kernels/mat_mul/cublas_kernel.cu @@ -1,7 +1,6 @@ #include "../../utilities/cuda/cublas_context.hh" #include "../expand/cuda_kernel.hh" #include "cublas_kernel.hh" -#include namespace refactor::kernel { using namespace runtime; @@ -19,11 +18,11 @@ namespace refactor::kernel { auto m = info.m, n = info.n, k = info.k; - auto strideY = info.m * info.n, - strideA = info.m * info.k, - strideB = info.k * info.n; - auto lda = info.transA ? info.m : info.k, - ldb = info.transB ? info.k : info.n; + auto strideY = m * n, + strideA = m * k, + strideB = k * n; + auto lda = info.transA ? m : k, + ldb = info.transB ? k : n; auto biasEx = info.biasExpand ? std::make_optional(ExpandCuda(*info.biasExpand).lower(res).routine) : std::nullopt; @@ -39,22 +38,21 @@ namespace refactor::kernel { (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } + auto handle = res.fetchOrStore()->handle; auto a = reinterpret_cast(inputs[0]); auto b = reinterpret_cast(inputs[1]); auto y = reinterpret_cast(outputs[0]); uint32_t offset[2]; for (auto i : range0_(broadcaster.outputsCount)) { broadcaster.locate(i, offset); - cublasGemmEx( - res.fetchOrStore()->handle, - tB, tA, - n, m, k, + CUBLAS_ASSERT(cublasGemmEx( + handle, + tB, tA, n, m, k, &alpha, b + strideB * offset[1], cudaDataType, ldb, a + strideA * offset[0], cudaDataType, lda, - &beta, y + strideY * i, cudaDataType, - n, cudaDataType, - CUBLAS_GEMM_DEFAULT); + &beta, y + strideY * i, cudaDataType, n, + cudaDataType, CUBLAS_GEMM_DEFAULT)); } }; @@ -63,26 +61,26 @@ namespace refactor::kernel { cudaDataType, alpha, beta, tA, tB, m, n, k, - strideA, strideB, + strideY, strideA, strideB, lda, ldb, biasEx]// (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { // Call expand kernel to broadcast bias if bias is used if (biasEx) { (*biasEx)(res, workspace, inputs + 2, outputs); } + auto handle = res.fetchOrStore()->handle; auto a = reinterpret_cast(inputs[0]); auto b = reinterpret_cast(inputs[1]); auto y = reinterpret_cast(outputs[0]); - cublasGemmStridedBatchedEx( - res.fetchOrStore()->handle, - tB, tA, - n, m, k, + CUBLAS_ASSERT(cublasGemmStridedBatchedEx( + handle, + tB, tA, n, m, k, &alpha, b, cudaDataType, ldb, strideB, a, cudaDataType, lda, strideA, - &beta, y, cudaDataType, - n, m * n, batch, cudaDataType, - CUBLAS_GEMM_DEFAULT); + &beta, y, cudaDataType, n, + strideY, batch, + cudaDataType, CUBLAS_GEMM_DEFAULT)); }; } } diff --git a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu index 94d3aeb0d..d125a52c5 100644 --- a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu +++ b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu @@ -1,6 +1,5 @@ #include "../../utilities/cuda/cublas_context.hh" #include "cublas_kernel.hh" -#include #include #include @@ -68,43 +67,44 @@ namespace refactor::kernel { b = workspacePtr; } - int32_t alpha = 1, beta = 0; + auto handle = res.fetchOrStore()->handle; + int32_t alpha = 1, + beta = 0; auto m = info.m, n = info.n, k = info.k; auto strideY = m * n, strideA = m * k, strideB = k * n; - auto lda = info.k, - ldb = info.n; + auto lda = k, + ldb = n; if (info.broadcaster.needBroadcast()) { uint32_t offset[2]; for (auto i : range0_(info.broadcaster.outputsCount)) { info.broadcaster.locate(i, offset); - cublasGemmEx( - res.fetchOrStore()->handle, + CUBLAS_ASSERT(cublasGemmEx( + handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b + strideB * offset[1], CUDA_R_8I, ldb, a + strideA * offset[0], CUDA_R_8I, lda, - &beta, y + strideY * i, CUDA_R_32I, - n, CUDA_R_32I, - CUBLAS_GEMM_DEFAULT); + &beta, y + strideY * i, CUDA_R_32I, n, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT)); } } else { - cublasGemmStridedBatchedEx( - res.fetchOrStore()->handle, + CUBLAS_ASSERT(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, CUDA_R_8I, ldb, strideB, a, CUDA_R_8I, lda, strideA, - &beta, y, CUDA_R_32I, - n, m * n, info.broadcaster.outputsCount, CUDA_R_32I, - CUBLAS_GEMM_DEFAULT); + &beta, y, CUDA_R_32I, n, + strideY, info.broadcaster.outputsCount, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT)); } }; diff --git a/src/04kernel/src/utilities/cuda/cublas_context.hh b/src/04kernel/src/utilities/cuda/cublas_context.hh index 1013fe972..69f56cd4a 100644 --- a/src/04kernel/src/utilities/cuda/cublas_context.hh +++ b/src/04kernel/src/utilities/cuda/cublas_context.hh @@ -4,6 +4,13 @@ #include "runtime/resource.h" #include +#define CUBLAS_ASSERT(STATUS) \ + if (auto status = (STATUS); status != CUBLAS_STATUS_SUCCESS) { \ + fmt::println("cublas failed on \"" #STATUS "\" with {}", \ + (int) status); \ + abort(); \ + } + namespace refactor::kernel::cublas { struct CublasContext final : public runtime::Resource { diff --git a/src/04kernel/src/utilities/cuda/cudnn_functions.cc b/src/04kernel/src/utilities/cuda/cudnn_functions.cpp similarity index 100% rename from src/04kernel/src/utilities/cuda/cudnn_functions.cc rename to src/04kernel/src/utilities/cuda/cudnn_functions.cpp diff --git a/src/04kernel/test/kernels/mat_mul/test_cublas.cpp b/src/04kernel/test/kernels/mat_mul/test_cublas.cpp index 7f96d5f27..ef1328f20 100644 --- a/src/04kernel/test/kernels/mat_mul/test_cublas.cpp +++ b/src/04kernel/test/kernels/mat_mul/test_cublas.cpp @@ -92,9 +92,7 @@ TEST(kernel, MatMulCublas_Broadcast) { std::vector result(Y->elementsSize()); my->copyToHost(result.data(), Y->bytesSize()); // check - for (auto i : range0_(result.size())) { - EXPECT_FLOAT_EQ(result[i], cpuOut[i]); - } + EXPECT_EQ(result, cpuOut); } TEST(kernel, MatMulCublas_TransABNoBias) { @@ -137,9 +135,7 @@ TEST(kernel, MatMulCublas_TransABNoBias) { std::vector result(Y->elementsSize()); my->copyToHost(result.data(), Y->bytesSize()); // check - for (auto i : range0_(result.size())) { - EXPECT_FLOAT_EQ(result[i], cpuOut[i]); - } + EXPECT_EQ(result, cpuOut); } TEST(kernel, MatMulCublas_Large) { @@ -192,9 +188,7 @@ TEST(kernel, MatMulCublas_Large) { std::vector result(Y->elementsSize()); my->copyToHost(result.data(), Y->bytesSize()); // check - for (auto i : range0_(result.size())) { - EXPECT_FLOAT_EQ(result[i], cpuOut[i]); - } + EXPECT_EQ(result, cpuOut); } #endif diff --git a/src/04kernel/test/kernels/mat_mul_integer/test_cpu_kernel.cpp b/src/04kernel/test/kernels/mat_mul_integer/test_cpu_kernel.cpp index cbf4d3514..82cf8b4f1 100644 --- a/src/04kernel/test/kernels/mat_mul_integer/test_cpu_kernel.cpp +++ b/src/04kernel/test/kernels/mat_mul_integer/test_cpu_kernel.cpp @@ -12,11 +12,12 @@ TEST(kernel, MatMulIntegerCpu) { auto kernel = MatMulIntegerCpu::build(MatMulIntegerInfo(TensorRefs{*A, *B})); ASSERT_TRUE(kernel); auto res = runtime::Resources(); - auto routine = kernel->lower(res).routine; + auto [routine, workspaceSize] = kernel->lower(res); // put input data std::vector dataA{1, 2, 3, 4, 5, 6}, - dataB{1, 2, 3}; + dataB{1, 2, 3}, + workspace(workspaceSize); std::vector result(Y->elementsSize()), ans{14, 32}; @@ -24,7 +25,7 @@ TEST(kernel, MatMulIntegerCpu) { { void const *inputs[]{dataA.data(), dataB.data()}; void *outputs[]{result.data()}; - routine(res, nullptr, inputs, outputs); + routine(res, workspace.data(), inputs, outputs); } // check EXPECT_EQ(result, ans); diff --git a/src/04kernel/test/kernels/mat_mul_integer/test_cublas_kernel.cpp b/src/04kernel/test/kernels/mat_mul_integer/test_cublas_kernel.cpp new file mode 100644 index 000000000..55ff188e4 --- /dev/null +++ b/src/04kernel/test/kernels/mat_mul_integer/test_cublas_kernel.cpp @@ -0,0 +1,61 @@ +#ifdef USE_CUDA + +#include "../src/kernels/mat_mul_integer/cpu_kernel.hh" +#include "../src/kernels/mat_mul_integer/cublas_kernel.hh" +#include "hardware/device_manager.h" +#include +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, MatMulIntegerCublas) { + // build routine + auto A = Tensor::share(DataType::U8, Shape{1, 4}); + auto B = Tensor::share(DataType::U8, Shape{4, 12}); + auto Y = Tensor::share(DataType::I32, Shape{1, 12}); + MatMulIntegerInfo info(TensorRefs{*A, *B}); + auto cpuKernel = MatMulIntegerCpu::build(info); + auto gpuKernel = MatMulIntegerCublas::build(info); + ASSERT_TRUE(cpuKernel && gpuKernel); + auto res = runtime::Resources(); + auto [cpuRoutine, workspace] = cpuKernel->lower(res); + auto [gpuRoutine, workspace_] = gpuKernel->lower(res); + ASSERT_EQ(workspace, workspace_); + // put input data + std::vector + dataA(A->elementsSize()), + dataB(B->elementsSize()); + std::vector + dataY(Y->elementsSize()), + result(Y->elementsSize()); + std::iota(dataA.begin(), dataA.end(), 1); + std::iota(dataB.data() + 0, dataB.data() + 12, 1); + std::iota(dataB.data() + 12, dataB.data() + 24, 1); + std::iota(dataB.data() + 24, dataB.data() + 36, 1); + std::iota(dataB.data() + 36, dataB.data() + 48, 1); + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto ma = dev.malloc(A->bytesSize()), + mb = dev.malloc(B->bytesSize()), + my = dev.malloc(Y->bytesSize()); + ma->copyFromHost(dataA.data(), A->bytesSize()); + mb->copyFromHost(dataB.data(), B->bytesSize()); + // inference + { + void const *inputs[]{*ma, *mb}; + void *outputs[]{*my}; + gpuRoutine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{dataA.data(), dataB.data()}; + void *outputs[]{dataY.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + } + // take output data + my->copyToHost(result.data(), Y->bytesSize()); + // check + EXPECT_EQ(result, dataY); +} + +#endif From 3d14bf03843df362c72af3f0619bf56d0781b117 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 18 Dec 2023 11:26:10 +0800 Subject: [PATCH 10/18] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E9=80=90?= =?UTF-8?q?=E5=BC=A0=E9=87=8F=E7=9A=84=E9=87=8F=E5=8C=96=E5=92=8C=E5=8F=8D?= =?UTF-8?q?=E9=87=8F=E5=8C=96=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../kernel/collectors/dequantize_linear.h | 18 ++ .../collectors/dynamic_quantize_linear.h | 18 ++ .../src/collectors/dequantize_linear.cc | 23 +++ .../src/collectors/dynamic_quantize_linear.cc | 23 +++ .../include/computation/operators/cast.h | 2 +- .../computation/operators/dequantize_linear.h | 21 +++ .../operators/dynamic_quantize_linear.h | 21 +++ .../src/operators/dequantize_linear.cc | 21 +++ .../src/operators/dynamic_quantize_linear.cc | 21 +++ src/07onnx/src/operators.cpp | 162 +++++++++--------- src/07onnx/src/operators/dequantize_linear.cc | 52 ++++++ src/07onnx/src/operators/dequantize_linear.hh | 25 +++ .../src/operators/dynamic_quantize_linear.cc | 39 +++++ .../src/operators/dynamic_quantize_linear.hh | 24 +++ 14 files changed, 390 insertions(+), 80 deletions(-) create mode 100644 src/04kernel/include/kernel/collectors/dequantize_linear.h create mode 100644 src/04kernel/include/kernel/collectors/dynamic_quantize_linear.h create mode 100644 src/04kernel/src/collectors/dequantize_linear.cc create mode 100644 src/04kernel/src/collectors/dynamic_quantize_linear.cc create mode 100644 src/05computation/include/computation/operators/dequantize_linear.h create mode 100644 src/05computation/include/computation/operators/dynamic_quantize_linear.h create mode 100644 src/05computation/src/operators/dequantize_linear.cc create mode 100644 src/05computation/src/operators/dynamic_quantize_linear.cc create mode 100644 src/07onnx/src/operators/dequantize_linear.cc create mode 100644 src/07onnx/src/operators/dequantize_linear.hh create mode 100644 src/07onnx/src/operators/dynamic_quantize_linear.cc create mode 100644 src/07onnx/src/operators/dynamic_quantize_linear.hh diff --git a/src/04kernel/include/kernel/collectors/dequantize_linear.h b/src/04kernel/include/kernel/collectors/dequantize_linear.h new file mode 100644 index 000000000..640a4329a --- /dev/null +++ b/src/04kernel/include/kernel/collectors/dequantize_linear.h @@ -0,0 +1,18 @@ +#ifndef KERNEL_DEQUANTIZE_LINEAR_H +#define KERNEL_DEQUANTIZE_LINEAR_H + +#include "../collector.h" + +namespace refactor::kernel { + + struct DequantizeLinearCollector final : public InfoCollector { + + explicit DequantizeLinearCollector(decltype(_target)) noexcept; + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_DEQUANTIZE_LINEAR_H diff --git a/src/04kernel/include/kernel/collectors/dynamic_quantize_linear.h b/src/04kernel/include/kernel/collectors/dynamic_quantize_linear.h new file mode 100644 index 000000000..52a2a0445 --- /dev/null +++ b/src/04kernel/include/kernel/collectors/dynamic_quantize_linear.h @@ -0,0 +1,18 @@ +#ifndef KERNEL_DYNAMIC_QUANTIZE_LINEAR_H +#define KERNEL_DYNAMIC_QUANTIZE_LINEAR_H + +#include "../collector.h" + +namespace refactor::kernel { + + struct DynamicQuantizeLinearCollector final : public InfoCollector { + + explicit DynamicQuantizeLinearCollector(decltype(_target)) noexcept; + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_DYNAMIC_QUANTIZE_LINEAR_H diff --git a/src/04kernel/src/collectors/dequantize_linear.cc b/src/04kernel/src/collectors/dequantize_linear.cc new file mode 100644 index 000000000..d2eb5a696 --- /dev/null +++ b/src/04kernel/src/collectors/dequantize_linear.cc @@ -0,0 +1,23 @@ +#include "kernel/collectors/dequantize_linear.h" + +namespace refactor::kernel { + + DequantizeLinearCollector:: + DequantizeLinearCollector(decltype(_target) target) noexcept + : InfoCollector(target) {} + + std::vector + DequantizeLinearCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + break; + case decltype(_target)::Nvidia: + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/dynamic_quantize_linear.cc b/src/04kernel/src/collectors/dynamic_quantize_linear.cc new file mode 100644 index 000000000..c105caab5 --- /dev/null +++ b/src/04kernel/src/collectors/dynamic_quantize_linear.cc @@ -0,0 +1,23 @@ +#include "kernel/collectors/dynamic_quantize_linear.h" + +namespace refactor::kernel { + + DynamicQuantizeLinearCollector:: + DynamicQuantizeLinearCollector(decltype(_target) target) noexcept + : InfoCollector(target) {} + + std::vector + DynamicQuantizeLinearCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + break; + case decltype(_target)::Nvidia: + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + +}// namespace refactor::kernel diff --git a/src/05computation/include/computation/operators/cast.h b/src/05computation/include/computation/operators/cast.h index 259bdc0eb..1c9bc10bd 100644 --- a/src/05computation/include/computation/operators/cast.h +++ b/src/05computation/include/computation/operators/cast.h @@ -7,7 +7,7 @@ namespace refactor::computation { struct Cast final : public Operator { - constexpr explicit Cast() noexcept = default; + constexpr Cast() noexcept = default; static size_t typeId() noexcept; size_t opTypeId() const noexcept final; diff --git a/src/05computation/include/computation/operators/dequantize_linear.h b/src/05computation/include/computation/operators/dequantize_linear.h new file mode 100644 index 000000000..321b6449b --- /dev/null +++ b/src/05computation/include/computation/operators/dequantize_linear.h @@ -0,0 +1,21 @@ +#ifndef COMPUTATION_DEQUANTIZE_LINEAR_H +#define COMPUTATION_DEQUANTIZE_LINEAR_H + +#include "../operator.h" + +namespace refactor::computation { + + struct DequantizeLinear final : public Operator { + + constexpr DequantizeLinear() noexcept = default; + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const noexcept final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_DEQUANTIZE_LINEAR_H diff --git a/src/05computation/include/computation/operators/dynamic_quantize_linear.h b/src/05computation/include/computation/operators/dynamic_quantize_linear.h new file mode 100644 index 000000000..6f57fdd71 --- /dev/null +++ b/src/05computation/include/computation/operators/dynamic_quantize_linear.h @@ -0,0 +1,21 @@ +#ifndef COMPUTATION_DYNAMIC_QUANTIZE_LINEAR_H +#define COMPUTATION_DYNAMIC_QUANTIZE_LINEAR_H + +#include "../operator.h" + +namespace refactor::computation { + + struct DynamicQuantizeLinear final : public Operator { + + constexpr DynamicQuantizeLinear() noexcept = default; + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const noexcept final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_DYNAMIC_QUANTIZE_LINEAR_H diff --git a/src/05computation/src/operators/dequantize_linear.cc b/src/05computation/src/operators/dequantize_linear.cc new file mode 100644 index 000000000..e4c1cf726 --- /dev/null +++ b/src/05computation/src/operators/dequantize_linear.cc @@ -0,0 +1,21 @@ +#include "computation/operators/dequantize_linear.h" +#include "kernel/collectors/dequantize_linear.h" + +namespace refactor::computation { + using Op = DequantizeLinear; + + size_t Op::typeId() noexcept { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + size_t Op::opTypeId() const noexcept { return typeId(); } + std::string_view Op::name() const noexcept { return "DequantizeLinear"; } + auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { + using Collector = kernel::DequantizeLinearCollector; + return std::make_unique(target); + } + auto Op::serialize() const noexcept -> std::string { + return "DequantizeLinear()"; + } + +}// namespace refactor::computation diff --git a/src/05computation/src/operators/dynamic_quantize_linear.cc b/src/05computation/src/operators/dynamic_quantize_linear.cc new file mode 100644 index 000000000..88cff76a3 --- /dev/null +++ b/src/05computation/src/operators/dynamic_quantize_linear.cc @@ -0,0 +1,21 @@ +#include "computation/operators/dynamic_quantize_linear.h" +#include "kernel/collectors/dynamic_quantize_linear.h" + +namespace refactor::computation { + using Op = DynamicQuantizeLinear; + + size_t Op::typeId() noexcept { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + size_t Op::opTypeId() const noexcept { return typeId(); } + std::string_view Op::name() const noexcept { return "DynamicQuantizeLinear"; } + auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { + using Collector = kernel::DynamicQuantizeLinearCollector; + return std::make_unique(target); + } + auto Op::serialize() const noexcept -> std::string { + return "DynamicQuantizeLinear()"; + } + +}// namespace refactor::computation diff --git a/src/07onnx/src/operators.cpp b/src/07onnx/src/operators.cpp index 18a651b9a..2fbfe8b03 100644 --- a/src/07onnx/src/operators.cpp +++ b/src/07onnx/src/operators.cpp @@ -8,6 +8,8 @@ #include "operators/constant_of_shape.hh" #include "operators/conv.hh" #include "operators/cum_sum.hh" +#include "operators/dequantize_linear.hh" +#include "operators/dynamic_quantize_linear.hh" #include "operators/einsum.hh" #include "operators/expand.hh" #include "operators/gather.hh" @@ -39,85 +41,87 @@ namespace refactor::onnx { void register_() { // clang-format off #define REGISTER(NAME, CLASS) Operator::register_("onnx::" #NAME) - REGISTER(BatchNormalization, BatchNormalization); - REGISTER(Cast , Cast ); - REGISTER(Clip , Clip ); - REGISTER(Equal , Compair ); - REGISTER(Greater , Compair ); - REGISTER(GreaterOrEqual , Compair ); - REGISTER(Less , Compair ); - REGISTER(LessOrEqual , Compair ); - REGISTER(Concat , Concat ); - REGISTER(Constant , Constant ); - REGISTER(ConstantOfShape , ConstantOfShape ); - REGISTER(Conv , Conv ); - REGISTER(CumSum , CumSum ); - REGISTER(Einsum , Einsum ); - REGISTER(Expand , Expand ); - REGISTER(Gather , Gather ); - REGISTER(GatherElements , GatherElements ); - REGISTER(Gemm , Gemm ); - REGISTER(GlobalAveragePool , GlobalPool ); - REGISTER(GlobalLpPool , GlobalPool ); - REGISTER(GlobalMaxPool , GlobalPool ); - REGISTER(MatMul , MatMul ); - REGISTER(MatMulInteger , MatMulInteger ); - REGISTER(AveragePool , Pool ); - REGISTER(LpPool , Pool ); - REGISTER(MaxPool , Pool ); - REGISTER(Range , Range ); - REGISTER(ReduceMean , Reduce ); - REGISTER(ReduceL1 , Reduce ); - REGISTER(ReduceL2 , Reduce ); - REGISTER(ReduceLogSum , Reduce ); - REGISTER(ReduceLogSumExp , Reduce ); - REGISTER(ReduceMax , Reduce ); - REGISTER(ReduceMin , Reduce ); - REGISTER(ReduceProd , Reduce ); - REGISTER(ReduceSum , Reduce ); - REGISTER(ReduceSumSquare , Reduce ); - REGISTER(Reshape , Reshape ); - REGISTER(ScatterND , ScatterND ); - REGISTER(Max , Select ); - REGISTER(Min , Select ); - REGISTER(Shape , Shape ); - REGISTER(Add , SimpleBinary ); - REGISTER(Sub , SimpleBinary ); - REGISTER(Mul , SimpleBinary ); - REGISTER(Div , SimpleBinary ); - REGISTER(Pow , SimpleBinary ); - REGISTER(And , SimpleBinary ); - REGISTER(Or , SimpleBinary ); - REGISTER(Xor , SimpleBinary ); - REGISTER(Abs , SimpleUnary ); - REGISTER(Acos , SimpleUnary ); - REGISTER(Acosh , SimpleUnary ); - REGISTER(Asin , SimpleUnary ); - REGISTER(Asinh , SimpleUnary ); - REGISTER(Atan , SimpleUnary ); - REGISTER(Atanh , SimpleUnary ); - REGISTER(Cos , SimpleUnary ); - REGISTER(Cosh , SimpleUnary ); - REGISTER(Sin , SimpleUnary ); - REGISTER(Sinh , SimpleUnary ); - REGISTER(Tan , SimpleUnary ); - REGISTER(Tanh , SimpleUnary ); - REGISTER(Relu , SimpleUnary ); - REGISTER(Sqrt , SimpleUnary ); - REGISTER(Sigmoid , SimpleUnary ); - REGISTER(Erf , SimpleUnary ); - REGISTER(Log , SimpleUnary ); - REGISTER(Not , SimpleUnary ); - REGISTER(Neg , SimpleUnary ); - REGISTER(Identity , SimpleUnary ); - REGISTER(Slice , Slice ); - REGISTER(Softmax , Softmax ); - REGISTER(Split , Split ); - REGISTER(Squeeze , Squeeze ); - REGISTER(Tile , Tile ); - REGISTER(Transpose , Transpose ); - REGISTER(Unsqueeze , Unsqueeze ); - REGISTER(Where , Where ); + REGISTER(BatchNormalization , BatchNormalization ); + REGISTER(Cast , Cast ); + REGISTER(Clip , Clip ); + REGISTER(Equal , Compair ); + REGISTER(Greater , Compair ); + REGISTER(GreaterOrEqual , Compair ); + REGISTER(Less , Compair ); + REGISTER(LessOrEqual , Compair ); + REGISTER(Concat , Concat ); + REGISTER(Constant , Constant ); + REGISTER(ConstantOfShape , ConstantOfShape ); + REGISTER(Conv , Conv ); + REGISTER(DequantizeLinear , DequantizeLinear ); + REGISTER(DynamicQuantizeLinear, DynamicQuantizeLinear); + REGISTER(CumSum , CumSum ); + REGISTER(Einsum , Einsum ); + REGISTER(Expand , Expand ); + REGISTER(Gather , Gather ); + REGISTER(GatherElements , GatherElements ); + REGISTER(Gemm , Gemm ); + REGISTER(GlobalAveragePool , GlobalPool ); + REGISTER(GlobalLpPool , GlobalPool ); + REGISTER(GlobalMaxPool , GlobalPool ); + REGISTER(MatMul , MatMul ); + REGISTER(MatMulInteger , MatMulInteger ); + REGISTER(AveragePool , Pool ); + REGISTER(LpPool , Pool ); + REGISTER(MaxPool , Pool ); + REGISTER(Range , Range ); + REGISTER(ReduceMean , Reduce ); + REGISTER(ReduceL1 , Reduce ); + REGISTER(ReduceL2 , Reduce ); + REGISTER(ReduceLogSum , Reduce ); + REGISTER(ReduceLogSumExp , Reduce ); + REGISTER(ReduceMax , Reduce ); + REGISTER(ReduceMin , Reduce ); + REGISTER(ReduceProd , Reduce ); + REGISTER(ReduceSum , Reduce ); + REGISTER(ReduceSumSquare , Reduce ); + REGISTER(Reshape , Reshape ); + REGISTER(ScatterND , ScatterND ); + REGISTER(Max , Select ); + REGISTER(Min , Select ); + REGISTER(Shape , Shape ); + REGISTER(Add , SimpleBinary ); + REGISTER(Sub , SimpleBinary ); + REGISTER(Mul , SimpleBinary ); + REGISTER(Div , SimpleBinary ); + REGISTER(Pow , SimpleBinary ); + REGISTER(And , SimpleBinary ); + REGISTER(Or , SimpleBinary ); + REGISTER(Xor , SimpleBinary ); + REGISTER(Abs , SimpleUnary ); + REGISTER(Acos , SimpleUnary ); + REGISTER(Acosh , SimpleUnary ); + REGISTER(Asin , SimpleUnary ); + REGISTER(Asinh , SimpleUnary ); + REGISTER(Atan , SimpleUnary ); + REGISTER(Atanh , SimpleUnary ); + REGISTER(Cos , SimpleUnary ); + REGISTER(Cosh , SimpleUnary ); + REGISTER(Sin , SimpleUnary ); + REGISTER(Sinh , SimpleUnary ); + REGISTER(Tan , SimpleUnary ); + REGISTER(Tanh , SimpleUnary ); + REGISTER(Relu , SimpleUnary ); + REGISTER(Sqrt , SimpleUnary ); + REGISTER(Sigmoid , SimpleUnary ); + REGISTER(Erf , SimpleUnary ); + REGISTER(Log , SimpleUnary ); + REGISTER(Not , SimpleUnary ); + REGISTER(Neg , SimpleUnary ); + REGISTER(Identity , SimpleUnary ); + REGISTER(Slice , Slice ); + REGISTER(Softmax , Softmax ); + REGISTER(Split , Split ); + REGISTER(Squeeze , Squeeze ); + REGISTER(Tile , Tile ); + REGISTER(Transpose , Transpose ); + REGISTER(Unsqueeze , Unsqueeze ); + REGISTER(Where , Where ); #undef REGISTER // clang-format on } diff --git a/src/07onnx/src/operators/dequantize_linear.cc b/src/07onnx/src/operators/dequantize_linear.cc new file mode 100644 index 000000000..5dc691513 --- /dev/null +++ b/src/07onnx/src/operators/dequantize_linear.cc @@ -0,0 +1,52 @@ +#include "dequantize_linear.hh" +#include "common.h" +#include "computation/operators/dequantize_linear.h" + +namespace refactor::onnx { + using Op = DequantizeLinear; + + Op::DequantizeLinear(Int axis_) noexcept + : Operator(), axis(axis_) {} + + auto Op::build(ModelContext const &, std::string_view, Attributes attrs) -> OpBox { + auto axis = defaultOr(attrs, "axis", {1}).int_(); + return OpBox(std::make_unique(axis)); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "onnx::DynamicQuantizeLinear"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { + switch (inputs.size()) { + case 2: + case 3: + break; + default: + return Err(InferError(ERROR_MSG("Input size error"))); + } + + auto const &x = inputs[0]; + auto const &xScale = inputs[1]; + if (xScale.rank() != 0) { + return Err(InferError(ERROR_MSG("Only support per-tensor quantization currently"))); + } + if (inputs.size() > 2) { + auto const &xZeroPoint = inputs[2]; + if (xZeroPoint.dataType != x.dataType || xZeroPoint.shape != xScale.shape) { + return Err(InferError(ERROR_MSG("x_zero_point info mismatch"))); + } + } + + return Ok(Tensors{Tensor::share(xScale.dataType, x.shape, extractDependency(inputs))}); + } + + auto Op::lower(TensorRefs inputs) const -> computation::OpBox { + using Op_ = computation::DequantizeLinear; + return std::make_unique(); + } + +}// namespace refactor::onnx diff --git a/src/07onnx/src/operators/dequantize_linear.hh b/src/07onnx/src/operators/dequantize_linear.hh new file mode 100644 index 000000000..a24c7bbac --- /dev/null +++ b/src/07onnx/src/operators/dequantize_linear.hh @@ -0,0 +1,25 @@ +#ifndef ONNX_DEQUANTIZE_LINEAR_HH +#define ONNX_DEQUANTIZE_LINEAR_HH + +#include "frontend/operator.h" + +namespace refactor::onnx { + using namespace frontend; + + struct DequantizeLinear final : public Operator { + Int axis; + + explicit DequantizeLinear(Int) noexcept; + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::onnx + +#endif// ONNX_DEQUANTIZE_LINEAR_HH diff --git a/src/07onnx/src/operators/dynamic_quantize_linear.cc b/src/07onnx/src/operators/dynamic_quantize_linear.cc new file mode 100644 index 000000000..17031607b --- /dev/null +++ b/src/07onnx/src/operators/dynamic_quantize_linear.cc @@ -0,0 +1,39 @@ +#include "dynamic_quantize_linear.hh" +#include "common.h" +#include "computation/operators/dynamic_quantize_linear.h" + +namespace refactor::onnx { + using Op = DynamicQuantizeLinear; + + auto Op::build(ModelContext const &, std::string_view, Attributes) -> OpBox { + return OpBox(std::make_unique()); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "onnx::DynamicQuantizeLinear"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { + EXPECT_SIZE(1) + + auto const &x = inputs[0]; + if (x.dataType != DataType::F32) { + return Err(InferError(ERROR_MSG("Input data type not support"))); + } + auto deps = extractDependency(inputs); + return Ok(Tensors{ + Tensor::share(DataType::U8, x.shape, deps), + Tensor::share(DataType::F32, {}, deps), + Tensor::share(DataType::U8, {}, deps), + }); + } + + auto Op::lower(TensorRefs inputs) const -> computation::OpBox { + using Op_ = computation::DynamicQuantizeLinear; + return std::make_unique(); + } + +}// namespace refactor::onnx diff --git a/src/07onnx/src/operators/dynamic_quantize_linear.hh b/src/07onnx/src/operators/dynamic_quantize_linear.hh new file mode 100644 index 000000000..899a97695 --- /dev/null +++ b/src/07onnx/src/operators/dynamic_quantize_linear.hh @@ -0,0 +1,24 @@ +#ifndef ONNX_DYNAMIC_QUANTIZE_LINEAR_HH +#define ONNX_DYNAMIC_QUANTIZE_LINEAR_HH + +#include "frontend/operator.h" + +namespace refactor::onnx { + using namespace frontend; + + struct DynamicQuantizeLinear final : public Operator { + + DynamicQuantizeLinear() = default; + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::onnx + +#endif// ONNX_DYNAMIC_QUANTIZE_LINEAR_HH From 103254b09b1f02219df9eaaa693df3de5982ec42 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 18 Dec 2023 14:22:51 +0800 Subject: [PATCH 11/18] =?UTF-8?q?feat(kernel):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E9=80=90=E5=BC=A0=E9=87=8F=E9=87=8F=E5=8C=96=E7=9A=84=20cpu=20?= =?UTF-8?q?kernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix(kernel): 改正 mat mul integer Signed-off-by: YdrMaster --- .../kernel/attributes/mat_mul_integer_info.h | 8 +- .../src/attributes/mat_mul_integer_info.cc | 30 +++-- .../src/collectors/dynamic_quantize_linear.cc | 6 + .../dynamic_quantize_linear/cpu_kernel.cc | 64 +++++++++ .../dynamic_quantize_linear/cpu_kernel.hh | 23 ++++ .../dynamic_quantize_linear/cuda_kernel.cc | 23 ++++ .../dynamic_quantize_linear/cuda_kernel.cu | 16 +++ .../dynamic_quantize_linear/cuda_kernel.hh | 25 ++++ .../src/kernels/mat_mul_integer/cpu_kernel.cc | 79 ++++++++--- .../kernels/mat_mul_integer/cublas_kernel.cu | 124 ++++++++++++++---- 10 files changed, 344 insertions(+), 54 deletions(-) create mode 100644 src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.cc create mode 100644 src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.hh create mode 100644 src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cc create mode 100644 src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu create mode 100644 src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.hh diff --git a/src/04kernel/include/kernel/attributes/mat_mul_integer_info.h b/src/04kernel/include/kernel/attributes/mat_mul_integer_info.h index ad091b0e9..b7d8d4759 100644 --- a/src/04kernel/include/kernel/attributes/mat_mul_integer_info.h +++ b/src/04kernel/include/kernel/attributes/mat_mul_integer_info.h @@ -7,9 +7,10 @@ namespace refactor::kernel { struct MatMulIntegerInfo { struct Input { - bool withZeroPoint; - bool signed_; - dim_t groupCount, groupSize; + bool + withZeroPoint, + signed_, + scalar; Input(TensorRefs const &, size_t i) noexcept; }; @@ -19,6 +20,7 @@ namespace refactor::kernel { Broadcaster broadcaster; explicit MatMulIntegerInfo(TensorRefs const &inputs) noexcept; + dim_t batch() const noexcept; }; }// namespace refactor::kernel diff --git a/src/04kernel/src/attributes/mat_mul_integer_info.cc b/src/04kernel/src/attributes/mat_mul_integer_info.cc index 36ec18ff0..a16689fe2 100644 --- a/src/04kernel/src/attributes/mat_mul_integer_info.cc +++ b/src/04kernel/src/attributes/mat_mul_integer_info.cc @@ -2,31 +2,41 @@ namespace refactor::kernel { -#define A (inputs[0].get().shape) -#define B (inputs[1].get().shape) - MatMulIntegerInfo::Input::Input(TensorRefs const &inputs, size_t i) noexcept : withZeroPoint(false), signed_(true), - groupCount(1), - groupSize(1) { + scalar(true) { if (inputs.size() > i + 2) { auto const &t = inputs[i + 2].get(); - if (withZeroPoint = t.rank() != 0 || !t.data || t.data->get() != 0) { - signed_ = t.dataType == DataType::I8; - groupCount = t.elementsSize(); - groupSize = inputs[i].get().elementsSize() / groupCount; + auto size = t.elementsSize(); + if (t.data) { + auto data = slice(t.data->get(), size); + if (std::all_of(data.begin(), data.end(), [](auto x) { return x == 0; })) { + return; + } } + withZeroPoint = true; + signed_ = t.dataType == DataType::I8; + scalar = size == 1; } } MatMulIntegerInfo::MatMulIntegerInfo(TensorRefs const &inputs) noexcept : a(inputs, 0), b(inputs, 1), +#define A (inputs[0].get().shape) +#define B (inputs[1].get().shape) m(A.rbegin()[1]), k(A.rbegin()[0]), n(B.rbegin()[0]), broadcaster({slice(A.data(), A.size() - 2), - slice(B.data(), B.size() - 2)}) {} + slice(B.data(), B.size() - 2)}) { + } +#undef A +#undef B + + dim_t MatMulIntegerInfo::batch() const noexcept { + return broadcaster.outputsCount; + } }// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/dynamic_quantize_linear.cc b/src/04kernel/src/collectors/dynamic_quantize_linear.cc index c105caab5..d4ba7c95f 100644 --- a/src/04kernel/src/collectors/dynamic_quantize_linear.cc +++ b/src/04kernel/src/collectors/dynamic_quantize_linear.cc @@ -1,4 +1,5 @@ #include "kernel/collectors/dynamic_quantize_linear.h" +#include "../kernels/dynamic_quantize_linear/cpu_kernel.hh" namespace refactor::kernel { @@ -8,9 +9,14 @@ namespace refactor::kernel { std::vector DynamicQuantizeLinearCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + auto size = inputs[0].get().elementsSize(); + std::vector ans; switch (_target) { case decltype(_target)::Cpu: + if (auto ptr = DynamicQuantizeLinearCpu::build(size); ptr) { + ans.emplace_back(std::move(ptr)); + } break; case decltype(_target)::Nvidia: break; diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.cc b/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.cc new file mode 100644 index 000000000..77a7715f9 --- /dev/null +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.cc @@ -0,0 +1,64 @@ +#include "cpu_kernel.hh" +#include +#include + +namespace refactor::kernel { + using K = DynamicQuantizeLinearCpu; + + K::DynamicQuantizeLinearCpu(decltype(size) size_) noexcept + : Kernel(), size(size_) {} + + auto K::build(decltype(size) size) noexcept -> KernelBox { + return std::make_unique(size); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing dynamic quantize linear using CPU"; + } + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { + using namespace runtime; + return [size = size](Resources &, void *, void const *const *inputs, void *const *outputs) { + using TI = float; + using TO = uint8_t; + + constexpr static auto + ZERO = static_cast(0), + _MIN = std::numeric_limits::min(), + _MAX = std::numeric_limits::max(), + QMIN = static_cast(std::numeric_limits::min()), + QMAX = static_cast(std::numeric_limits::max()), + QLEN = QMAX - QMIN; + + auto x = reinterpret_cast(inputs[0]); + auto [min, max] = std::accumulate( + x, x + size, + std::pair{_MAX, _MIN}, + [](auto acc, auto it) { + auto [min, max] = acc; + return std::pair{ + std::min(min, it), + std::max(max, it), + }; + }); + auto len = std::max(ZERO, max) - std::min(ZERO, min); + auto scale = len / QLEN; + auto zp = static_cast(std::round(QMIN - min * QLEN / len)); + + std::transform( + std::execution::par_unseq, + x, x + size, + reinterpret_cast(outputs[0]), + [=](auto it) { return static_cast(std::round(it / scale) + zp); }); + *reinterpret_cast(outputs[1]) = scale; + *reinterpret_cast(outputs[2]) = zp; + }; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.hh b/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.hh new file mode 100644 index 000000000..97020832f --- /dev/null +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.hh @@ -0,0 +1,23 @@ +#ifndef KERNEL_DYNAMIC_QUANTIZE_LINEAR_CPU_KERNEL_HH +#define KERNEL_DYNAMIC_QUANTIZE_LINEAR_CPU_KERNEL_HH + +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct DynamicQuantizeLinearCpu final : public Kernel { + size_t size; + + explicit DynamicQuantizeLinearCpu(decltype(size)) noexcept; + + static KernelBox build(decltype(size)) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; + RoutineWorkspace lower(Resources &) const noexcept final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SOFTMAX_CPU_KERNEL_HH diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cc b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cc new file mode 100644 index 000000000..3f26397fe --- /dev/null +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cc @@ -0,0 +1,23 @@ +#include "cuda_kernel.hh" + +namespace refactor::kernel { + using K = DynamicQuantizeLinearCuda; + + K::DynamicQuantizeLinearCuda(decltype(size) size_) noexcept + : Kernel(), size(size_) {} + + auto K::build(decltype(size) size) noexcept -> KernelBox { + return std::make_unique(size); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing dynamic quantize linear using Nvidia GPU"; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu new file mode 100644 index 000000000..d82c0d2a0 --- /dev/null +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu @@ -0,0 +1,16 @@ +#include "cuda_kernel.hh" +#include + +namespace refactor::kernel { + using K = DynamicQuantizeLinearCuda; + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { + using namespace runtime; + using TI = float; + using TO = uint8_t; + + return [size = size](Resources &, void *, void const *const *inputs, void *const *outputs) { + }; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.hh b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.hh new file mode 100644 index 000000000..b0151b875 --- /dev/null +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.hh @@ -0,0 +1,25 @@ +#ifndef KERNEL_DYNAMIC_QUANTIZE_LINEAR_CUDA_KERNEL_HH +#define KERNEL_DYNAMIC_QUANTIZE_LINEAR_CUDA_KERNEL_HH + +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct DynamicQuantizeLinearCuda final : public Kernel { + size_t size; + + explicit DynamicQuantizeLinearCuda(decltype(size)) noexcept; + + static KernelBox build(decltype(size)) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_DYNAMIC_QUANTIZE_LINEAR_CUDA_KERNEL_HH diff --git a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc index 751fb7c0d..14040f3fc 100644 --- a/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc +++ b/src/04kernel/src/kernels/mat_mul_integer/cpu_kernel.cc @@ -1,5 +1,6 @@ #include "cpu_kernel.hh" #include "../mat_mul_common/cpu_template.hpp" +#include namespace refactor::kernel { using K = MatMulIntegerCpu; @@ -27,14 +28,38 @@ namespace refactor::kernel { template<> int8_t sub(uint8_t a, uint8_t b) { return static_cast(static_cast(a) - static_cast(b)); } template - static void applyZeroPoint(MatMulIntegerInfo::Input meta, int8_t *dst, void const *src_, void const *zp_) { + static void applyZeroPointScalar( + size_t size, int8_t *dst, void const *src_, void const *zp_) { + + auto src = reinterpret_cast(src_); + auto zp = *reinterpret_cast(zp_); + std::transform(std::execution::par_unseq, + src, src + size, + dst, [zp](auto x) { return sub(x, zp); }); + } + template + static void applyZeroPointA( + dim_t b, dim_t m, dim_t n, + int8_t *dst, void const *src_, void const *zp_) { + auto src = reinterpret_cast(src_), zp = reinterpret_cast(zp_); - for (auto i : range0_(meta.groupCount)) { - for (auto j : range0_(meta.groupSize)) { - dst[meta.groupSize * i + j] = sub(src[meta.groupSize * i + j], zp[i]); - } - } + for (auto i : range0_(b)) + for (auto j : range0_(m)) + for (auto k : range0_(n)) + dst[i * m * n + j * n + k] = sub(src[i * m * n + j * n + k], zp[i * m + j]); + } + template + static void applyZeroPointB( + dim_t b, dim_t m, dim_t n, + int8_t *dst, void const *src_, void const *zp_) { + + auto src = reinterpret_cast(src_), + zp = reinterpret_cast(zp_); + for (auto i : range0_(b)) + for (auto j : range0_(m)) + for (auto k : range0_(n)) + dst[i * m * n + j * n + k] = sub(src[i * m * n + j * n + k], zp[i * n + k]); } auto K::lower(Resources &) const noexcept -> RoutineWorkspace { @@ -42,10 +67,10 @@ namespace refactor::kernel { size_t workspace = 0; if (info.a.withZeroPoint) { - workspace += info.a.groupCount * info.a.groupSize; + workspace += info.batch() * info.m * info.k; } if (info.b.withZeroPoint) { - workspace += info.b.groupCount * info.b.groupSize; + workspace += info.batch() * info.k * info.n; } auto routine = [info = info](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { @@ -55,19 +80,39 @@ namespace refactor::kernel { auto y = reinterpret_cast(outputs[0]); if (auto meta = info.a; meta.withZeroPoint) { - if (meta.signed_) { - applyZeroPoint(meta, workspacePtr, a, inputs[2]); + auto size = info.batch() * info.m * info.k; + auto zp = inputs[2]; + if (meta.scalar) { + if (meta.signed_) { + applyZeroPointScalar(size, workspacePtr, a, zp); + } else { + applyZeroPointScalar(size, workspacePtr, a, zp); + } } else { - applyZeroPoint(meta, workspacePtr, a, inputs[2]); + if (meta.signed_) { + applyZeroPointA(info.batch(), info.m, info.k, workspacePtr, a, zp); + } else { + applyZeroPointA(info.batch(), info.m, info.k, workspacePtr, a, zp); + } } a = workspacePtr; - workspacePtr += meta.groupCount * meta.groupSize; + workspacePtr += size; } if (auto meta = info.b; meta.withZeroPoint) { - if (meta.signed_) { - applyZeroPoint(meta, workspacePtr, b, inputs[3]); + auto size = info.batch() * info.k * info.n; + auto zp = inputs[3]; + if (meta.scalar) { + if (meta.signed_) { + applyZeroPointScalar(size, workspacePtr, b, zp); + } else { + applyZeroPointScalar(size, workspacePtr, b, zp); + } } else { - applyZeroPoint(meta, workspacePtr, b, inputs[3]); + if (meta.signed_) { + applyZeroPointA(info.batch(), info.k, info.n, workspacePtr, b, zp); + } else { + applyZeroPointA(info.batch(), info.k, info.n, workspacePtr, b, zp); + } } b = workspacePtr; } @@ -89,12 +134,12 @@ namespace refactor::kernel { if (info.broadcaster.needBroadcast()) { dim_t offset[2]; - for (auto i : range0_(info.broadcaster.outputsCount)) { + for (auto i : range0_(info.batch())) { info.broadcaster.locate(i, offset); md.matrixMultiply(a + stepA * offset[0], b + stepB * offset[1], y + stepY * i); } } else { - for (auto i : range0_(info.broadcaster.outputsCount)) { + for (auto i : range0_(info.batch())) { md.matrixMultiply(a + stepA * i, b + stepB * i, y + stepY * i); } } diff --git a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu index d125a52c5..36383ea9c 100644 --- a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu +++ b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu @@ -2,6 +2,7 @@ #include "cublas_kernel.hh" #include #include +#include namespace refactor::kernel { using namespace runtime; @@ -12,35 +13,90 @@ namespace refactor::kernel { template<> __device__ __forceinline__ int8_t sub(uint8_t a, uint8_t b) { return static_cast(static_cast(a) - static_cast(b)); } template - struct MatMulIntegerZPFunctor { - dim_t groupSize; + struct MatMulIntegerZPFunctorScalar { + T const *zp; + + __device__ int8_t operator()(T x) const noexcept { + return sub(x, *zp); + } + }; + + template + static void applyZeroPointScalar( + size_t size, int8_t *dst, void const *src_, void const *zp_) { + + auto src = reinterpret_cast(src_), + zp = reinterpret_cast(zp_); + thrust::transform(thrust::device, + src, src + size, + dst, MatMulIntegerZPFunctorScalar{zp}); + } + + template + struct MatMulIntegerZPFunctorA { + dim_t m, n; T const *src, *zp; - __device__ int8_t operator()(size_t i) const noexcept { - return sub(src[i], zp[i / groupSize]); + __device__ int8_t operator()(size_t idx) const noexcept { + auto + // k = idx % n, + j = idx / n % m, + i = idx / n / m; + return sub(src[idx], zp[i * m + j]); } }; template - static void applyZeroPoint(MatMulIntegerInfo::Input meta, int8_t *dst, void const *src, void const *zp) { - thrust::tabulate( - thrust::device, - dst, dst + meta.groupCount * meta.groupSize, - MatMulIntegerZPFunctor{ - .groupSize = meta.groupSize, - .src = reinterpret_cast(src), - .zp = reinterpret_cast(zp), - }); + static void applyZeroPointA( + dim_t b, dim_t m, dim_t n, + int8_t *dst, void const *src_, void const *zp_) { + thrust::tabulate(thrust::device, + dst, dst + b * m * n, + MatMulIntegerZPFunctorA{ + m, + n, + reinterpret_cast(src_), + reinterpret_cast(zp_), + }); + } + + template + struct MatMulIntegerZPFunctorB { + dim_t m, n; + T const *src, *zp; + + __device__ int8_t operator()(size_t idx) const noexcept { + auto + k = idx % n, + // j = idx / n % m, + i = idx / n / m; + return sub(src[idx], zp[i * n + k]); + } + }; + + template + static void applyZeroPointB( + dim_t b, dim_t m, dim_t n, + int8_t *dst, void const *src_, void const *zp_) { + + thrust::tabulate(thrust::device, + dst, dst + b * m * n, + MatMulIntegerZPFunctorB{ + m, + n, + reinterpret_cast(src_), + reinterpret_cast(zp_), + }); } auto MatMulIntegerCublas::lower(Resources &res) const noexcept -> RoutineWorkspace { size_t workspace = 0; if (info.a.withZeroPoint) { - workspace += info.a.groupCount * info.a.groupSize; + workspace += info.batch() * info.m * info.k; } if (info.b.withZeroPoint) { - workspace += info.b.groupCount * info.b.groupSize; + workspace += info.batch() * info.k * info.n; } auto routine = [info = info](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { @@ -50,19 +106,39 @@ namespace refactor::kernel { auto y = reinterpret_cast(outputs[0]); if (auto meta = info.a; meta.withZeroPoint) { - if (meta.signed_) { - applyZeroPoint(meta, workspacePtr, a, inputs[2]); + auto size = info.batch() * info.m * info.k; + auto zp = inputs[2]; + if (meta.scalar) { + if (meta.signed_) { + applyZeroPointScalar(size, workspacePtr, a, zp); + } else { + applyZeroPointScalar(size, workspacePtr, a, zp); + } } else { - applyZeroPoint(meta, workspacePtr, a, inputs[2]); + if (meta.signed_) { + applyZeroPointA(info.batch(), info.m, info.k, workspacePtr, a, zp); + } else { + applyZeroPointA(info.batch(), info.m, info.k, workspacePtr, a, zp); + } } a = workspacePtr; - workspacePtr += meta.groupCount * meta.groupSize; + workspacePtr += size; } if (auto meta = info.b; meta.withZeroPoint) { - if (meta.signed_) { - applyZeroPoint(meta, workspacePtr, b, inputs[3]); + auto size = info.batch() * info.k * info.n; + auto zp = inputs[3]; + if (meta.scalar) { + if (meta.signed_) { + applyZeroPointScalar(size, workspacePtr, b, zp); + } else { + applyZeroPointScalar(size, workspacePtr, b, zp); + } } else { - applyZeroPoint(meta, workspacePtr, b, inputs[3]); + if (meta.signed_) { + applyZeroPointA(info.batch(), info.k, info.n, workspacePtr, b, zp); + } else { + applyZeroPointA(info.batch(), info.k, info.n, workspacePtr, b, zp); + } } b = workspacePtr; } @@ -81,7 +157,7 @@ namespace refactor::kernel { if (info.broadcaster.needBroadcast()) { uint32_t offset[2]; - for (auto i : range0_(info.broadcaster.outputsCount)) { + for (auto i : range0_(info.batch())) { info.broadcaster.locate(i, offset); CUBLAS_ASSERT(cublasGemmEx( handle, @@ -103,7 +179,7 @@ namespace refactor::kernel { b, CUDA_R_8I, ldb, strideB, a, CUDA_R_8I, lda, strideA, &beta, y, CUDA_R_32I, n, - strideY, info.broadcaster.outputsCount, + strideY, info.batch(), CUDA_R_32I, CUBLAS_GEMM_DEFAULT)); } }; From d6b4952858280cccf76e4be2c48b783fc0cb760e Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 18 Dec 2023 15:53:30 +0800 Subject: [PATCH 12/18] =?UTF-8?q?feat(kernel):=20=E5=AE=9E=E7=8E=B0=20Dyna?= =?UTF-8?q?micQuantizeLinear=20cuda=20kernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../dynamic_quantize_linear/cuda_kernel.cu | 124 +++++++++++++++++- 1 file changed, 122 insertions(+), 2 deletions(-) diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu index d82c0d2a0..3070036a6 100644 --- a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu @@ -1,16 +1,136 @@ #include "cuda_kernel.hh" +#include "hardware/functions.h" +#include "kernel/cuda/threads_distributer.cuh" #include +#include +#include namespace refactor::kernel { using K = DynamicQuantizeLinearCuda; + template + struct QuantizeMinMax { + T min, max; + }; + + template + struct QuantizeMapMinMaxFunctor { + __device__ __forceinline__ QuantizeMinMax + operator()(T x) const { + return {x, x}; + } + }; + + template + struct QuantizeReduceMinMaxFunctor { + __device__ __forceinline__ QuantizeMinMax + operator()(QuantizeMinMax a, QuantizeMinMax b) const { + return {a.min < b.min ? a.min : b.min, + a.max > b.max ? a.max : b.max}; + } + }; + + template + constexpr static auto + ZERO = static_cast(0); + + template + constexpr static auto + QMIN = static_cast(std::numeric_limits::min()); + + template + constexpr static auto + QMAX = static_cast(std::numeric_limits::max()); + + template + constexpr static auto + QLEN = QMAX - QMIN; + + template + __global__ static void kernel( + size_t n, + QuantizeMinMax const *__restrict__ minmax, + TI const *__restrict__ x, + TO *__restrict__ y, + TI *__restrict__ scale_, + TO *__restrict__ zp_) { + + auto const [min, max] = *minmax; + auto temp = QuantizeReduceMinMaxFunctor{}({min, max}, {ZERO, ZERO}); + auto scale = (temp.max - temp.min) / QLEN; + auto zp = static_cast(round(QMIN - min / scale)); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + for (auto step = blockDim.x * gridDim.x; + tid < n; + tid += step) { + y[tid] = static_cast(std::round(x[tid] / scale) + zp); + } + switch (tid) { + case 0: + *scale_ = scale; + break; + case 1: + *zp_ = zp; + break; + } + } + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { using namespace runtime; using TI = float; using TO = uint8_t; - return [size = size](Resources &, void *, void const *const *inputs, void *const *outputs) { - }; + constexpr static auto + _MIN = std::numeric_limits::min(), + _MAX = std::numeric_limits::max(); + + auto workspaceSize = hardware::alignBytes(size * sizeof(QuantizeMinMax), 256); + + QuantizeMinMax *nullTyped = nullptr; + size_t tempStorageBytes = 0; + cub::DeviceReduce::Reduce( + nullptr, tempStorageBytes, + nullTyped, nullTyped, 0, + QuantizeReduceMinMaxFunctor{}, + QuantizeMinMax{}); + + auto offset0 = workspaceSize; + workspaceSize += tempStorageBytes; + workspaceSize = hardware::alignBytes(workspaceSize, 256); + + auto offset1 = workspaceSize; + workspaceSize += sizeof(QuantizeMinMax); + + auto routine = [offset0, tempStorageBytes, offset1, + params = cuda::ThreadsDistributer()(size)]// + (Resources &, void *workspacePtr, void const *const *inputs, void *const *outputs) { + auto x = reinterpret_cast(inputs[0]); + auto y = reinterpret_cast(outputs[0]); + auto scale = reinterpret_cast(outputs[1]); + auto zp = reinterpret_cast(outputs[2]); + auto workspace = reinterpret_cast(workspacePtr); + auto doubled = reinterpret_cast *>(workspace); + auto tempStorage = workspace + offset0; + auto minmax = reinterpret_cast *>(workspace + offset1); + + thrust::transform( + thrust::device, + x, x + params.n, doubled, + QuantizeMapMinMaxFunctor{}); + + auto tempStorageSize_ = tempStorageBytes; + cub::DeviceReduce::Reduce( + tempStorage, tempStorageSize_, + doubled, minmax, params.n, + QuantizeReduceMinMaxFunctor{}, + QuantizeMinMax{_MIN, _MAX}); + + kernel<<>>( + params.n, minmax, x, y, scale, zp); + }; + + return {std::move(routine), workspaceSize}; } }// namespace refactor::kernel From ee14c96fe24fc2c4aefe30441fe9b45d8060ef11 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 18 Dec 2023 17:05:46 +0800 Subject: [PATCH 13/18] =?UTF-8?q?test(kernel):=20=E6=B5=8B=E8=AF=95=20Dyna?= =?UTF-8?q?micQuantizeLinear?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../dynamic_quantize_linear/cpu_kernel.cc | 12 +++- .../dynamic_quantize_linear/cuda_kernel.cu | 2 +- .../dynamic_quantize_linear/test_cpu.cpp | 31 ++++++++++ .../dynamic_quantize_linear/test_cuda.cpp | 56 +++++++++++++++++++ 4 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 src/04kernel/test/kernels/dynamic_quantize_linear/test_cpu.cpp create mode 100644 src/04kernel/test/kernels/dynamic_quantize_linear/test_cuda.cpp diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.cc b/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.cc index 77a7715f9..cc52f6bd4 100644 --- a/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.cc +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.cc @@ -22,6 +22,14 @@ namespace refactor::kernel { return "Performing dynamic quantize linear using CPU"; } + template + static TO saturate(TI x) { + constexpr static auto + QMIN = static_cast(std::numeric_limits::min()), + QMAX = static_cast(std::numeric_limits::max()); + return static_cast(std::round(std::clamp(x, QMIN, QMAX))); + } + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { using namespace runtime; return [size = size](Resources &, void *, void const *const *inputs, void *const *outputs) { @@ -49,13 +57,13 @@ namespace refactor::kernel { }); auto len = std::max(ZERO, max) - std::min(ZERO, min); auto scale = len / QLEN; - auto zp = static_cast(std::round(QMIN - min * QLEN / len)); + auto zp = saturate(QMIN - min * QLEN / len); std::transform( std::execution::par_unseq, x, x + size, reinterpret_cast(outputs[0]), - [=](auto it) { return static_cast(std::round(it / scale) + zp); }); + [=](auto it) { return saturate(std::round(it / scale) + zp); }); *reinterpret_cast(outputs[1]) = scale; *reinterpret_cast(outputs[2]) = zp; }; diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu index 3070036a6..4f4905912 100644 --- a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu @@ -124,7 +124,7 @@ namespace refactor::kernel { tempStorage, tempStorageSize_, doubled, minmax, params.n, QuantizeReduceMinMaxFunctor{}, - QuantizeMinMax{_MIN, _MAX}); + QuantizeMinMax{_MAX, _MIN}); kernel<<>>( params.n, minmax, x, y, scale, zp); diff --git a/src/04kernel/test/kernels/dynamic_quantize_linear/test_cpu.cpp b/src/04kernel/test/kernels/dynamic_quantize_linear/test_cpu.cpp new file mode 100644 index 000000000..4101e1895 --- /dev/null +++ b/src/04kernel/test/kernels/dynamic_quantize_linear/test_cpu.cpp @@ -0,0 +1,31 @@ +#include "../../../src/kernels/dynamic_quantize_linear/cpu_kernel.hh" +#include +#include + +using namespace refactor; +using namespace kernel; + +TEST(kernel, DynamicQuantizeLinearCpu) { + // build routine + auto kernel = DynamicQuantizeLinearCpu::build(6); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector x{0, 2, -3, -2.5, 1.34, 0.5}; + std::vector y(x.size()); + float scale; + uint8_t zeroPoint; + // inference + { + void const *inputs[]{x.data()}; + void *outputs[]{y.data(), &scale, &zeroPoint}; + routine(res, nullptr, inputs, outputs); + } + // check + EXPECT_FLOAT_EQ(scale, (2 + 3) / 255.f); + EXPECT_EQ(zeroPoint, 153); + for (auto i : range0_(y.size())) { + EXPECT_EQ(y[i], static_cast(std::round(x[i] / scale) + zeroPoint)); + } +} diff --git a/src/04kernel/test/kernels/dynamic_quantize_linear/test_cuda.cpp b/src/04kernel/test/kernels/dynamic_quantize_linear/test_cuda.cpp new file mode 100644 index 000000000..40f769f29 --- /dev/null +++ b/src/04kernel/test/kernels/dynamic_quantize_linear/test_cuda.cpp @@ -0,0 +1,56 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/dynamic_quantize_linear/cpu_kernel.hh" +#include "../../../src/kernels/dynamic_quantize_linear/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, DynamicQuantizeLinearCuda) { + auto size = 20; + // build routine + auto kernel = DynamicQuantizeLinearCuda::build(size), + kCpu = DynamicQuantizeLinearCpu::build(size); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + auto rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto xGpu = dev.malloc(size * sizeof(float)), + yGpu = dev.malloc(size * sizeof(uint8_t)), + scaleGpu = dev.malloc(sizeof(float)), + zpGpu = dev.malloc(sizeof(uint8_t)), + workspace = dev.malloc(workspaceSize); + // put input data + std::vector x(size); + std::vector y(size); + float scale; + uint8_t zeroPoint; + for (auto i : range0_(size)) { + x[i] = i * 3 + 15; + } + xGpu->copyFromHost(x.data()); + // inference + { + void const *inputs[]{*xGpu}; + void *outputs[]{*yGpu, *scaleGpu, *zpGpu}; + routine(res, *workspace, inputs, outputs); + } + { + void const *inputs[]{x.data()}; + void *outputs[]{y.data(), &scale, &zeroPoint}; + rCpu(res, nullptr, inputs, outputs); + } + // check + { + std::vector result(size); + yGpu->copyToHost(result.data()); + EXPECT_EQ(result, y); + } +} + +#endif From 2e75d3893c8ec8ad42af18f1c895a6a85db4179c Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 19 Dec 2023 09:06:53 +0800 Subject: [PATCH 14/18] =?UTF-8?q?feat(kernel):=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E5=8F=8D=E9=87=8F=E5=8C=96=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../src/collectors/dequantize_linear.cc | 9 ++ .../src/collectors/dynamic_quantize_linear.cc | 4 + .../kernels/dequantize_linear/cpu_kernel.cc | 76 +++++++++++++++ .../kernels/dequantize_linear/cpu_kernel.hh | 29 ++++++ .../kernels/dequantize_linear/cuda_kernel.cc | 92 +++++++++++++++++++ .../kernels/dequantize_linear/cuda_kernel.hh | 32 +++++++ .../dynamic_quantize_linear/cpu_kernel.hh | 2 +- .../kernels/dequantize_linear/test_cpu.cpp | 31 +++++++ .../kernels/dequantize_linear/test_cuda.cpp | 57 ++++++++++++ 9 files changed, 331 insertions(+), 1 deletion(-) create mode 100644 src/04kernel/src/kernels/dequantize_linear/cpu_kernel.cc create mode 100644 src/04kernel/src/kernels/dequantize_linear/cpu_kernel.hh create mode 100644 src/04kernel/src/kernels/dequantize_linear/cuda_kernel.cc create mode 100644 src/04kernel/src/kernels/dequantize_linear/cuda_kernel.hh create mode 100644 src/04kernel/test/kernels/dequantize_linear/test_cpu.cpp create mode 100644 src/04kernel/test/kernels/dequantize_linear/test_cuda.cpp diff --git a/src/04kernel/src/collectors/dequantize_linear.cc b/src/04kernel/src/collectors/dequantize_linear.cc index d2eb5a696..78fc939a4 100644 --- a/src/04kernel/src/collectors/dequantize_linear.cc +++ b/src/04kernel/src/collectors/dequantize_linear.cc @@ -1,4 +1,6 @@ #include "kernel/collectors/dequantize_linear.h" +#include "../kernels/dequantize_linear/cpu_kernel.hh" +#include "../kernels/dequantize_linear/cuda_kernel.hh" namespace refactor::kernel { @@ -8,11 +10,18 @@ namespace refactor::kernel { std::vector DequantizeLinearCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + auto const &output = outputs[0]; std::vector ans; switch (_target) { case decltype(_target)::Cpu: + if (auto ptr = DequantizeLinearCpu::build(inputs, output); ptr) { + ans.emplace_back(std::move(ptr)); + } break; case decltype(_target)::Nvidia: + if (auto ptr = DequantizeLinearCuda::build(inputs, output); ptr) { + ans.emplace_back(std::move(ptr)); + } break; default: UNREACHABLEX(void, "Unknown target"); diff --git a/src/04kernel/src/collectors/dynamic_quantize_linear.cc b/src/04kernel/src/collectors/dynamic_quantize_linear.cc index d4ba7c95f..8d54eee03 100644 --- a/src/04kernel/src/collectors/dynamic_quantize_linear.cc +++ b/src/04kernel/src/collectors/dynamic_quantize_linear.cc @@ -1,5 +1,6 @@ #include "kernel/collectors/dynamic_quantize_linear.h" #include "../kernels/dynamic_quantize_linear/cpu_kernel.hh" +#include "../kernels/dynamic_quantize_linear/cuda_kernel.hh" namespace refactor::kernel { @@ -19,6 +20,9 @@ namespace refactor::kernel { } break; case decltype(_target)::Nvidia: + if (auto ptr = DynamicQuantizeLinearCuda::build(size); ptr) { + ans.emplace_back(std::move(ptr)); + } break; default: UNREACHABLEX(void, "Unknown target"); diff --git a/src/04kernel/src/kernels/dequantize_linear/cpu_kernel.cc b/src/04kernel/src/kernels/dequantize_linear/cpu_kernel.cc new file mode 100644 index 000000000..7a53f533b --- /dev/null +++ b/src/04kernel/src/kernels/dequantize_linear/cpu_kernel.cc @@ -0,0 +1,76 @@ +#include "cpu_kernel.hh" +#include +#include + +namespace refactor::kernel { + using K = DequantizeLinearCpu; + + K::DequantizeLinearCpu( + decltype(from) from_, + decltype(size) size_, + decltype(withZeroPoint) withZeroPoint_) noexcept + : Kernel(), + from(from_), + size(size_), + withZeroPoint(withZeroPoint_) {} + + auto K::build(TensorRefs const &inputs, Tensor const &output) noexcept -> KernelBox { + if (inputs[1].get().elementsSize() != 1) { + return nullptr; + } + if (output.dataType != DataType::F32) { + return nullptr; + } + return std::make_unique( + inputs[0].get().dataType, + inputs[0].get().elementsSize(), + inputs.size() > 2); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing dequantize linear using CPU"; + } + + template + auto lowerTyped(size_t size, bool withZeroPoint) noexcept -> RoutineWorkspace { + + return [size, withZeroPoint]// + (Resources &, void *, void const *const *inputs, void *const *outputs) { + auto x = reinterpret_cast(inputs[0]); + auto scale = *reinterpret_cast(inputs[1]); + auto zp = withZeroPoint ? *reinterpret_cast(inputs[2]) : 0; + auto y = reinterpret_cast(outputs[0]); + std::transform( + std::execution::par_unseq, + x, x + size, + y, + [scale, zp](TI x) { + return static_cast(x - zp) * scale; + }); + }; + } + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { + switch (from) { + case DataType::U8: + return lowerTyped(size, withZeroPoint); + case DataType::U16: + return lowerTyped(size, withZeroPoint); + case DataType::I8: + return lowerTyped(size, withZeroPoint); + case DataType::I16: + return lowerTyped(size, withZeroPoint); + case DataType::I32: + return lowerTyped(size, withZeroPoint); + default: + UNREACHABLE(); + } + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/dequantize_linear/cpu_kernel.hh b/src/04kernel/src/kernels/dequantize_linear/cpu_kernel.hh new file mode 100644 index 000000000..977692fd2 --- /dev/null +++ b/src/04kernel/src/kernels/dequantize_linear/cpu_kernel.hh @@ -0,0 +1,29 @@ +#ifndef KERNEL_DEQUANTIZE_LINEAR_CPU_KERNEL_HH +#define KERNEL_DEQUANTIZE_LINEAR_CPU_KERNEL_HH + +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct DequantizeLinearCpu final : public Kernel { + DataType from; + size_t size; + bool withZeroPoint; + + DequantizeLinearCpu( + decltype(from), + decltype(size), + decltype(withZeroPoint)) noexcept; + + static KernelBox build(TensorRefs const &, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; + RoutineWorkspace lower(Resources &) const noexcept final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_DEQUANTIZE_LINEAR_CPU_KERNEL_HH diff --git a/src/04kernel/src/kernels/dequantize_linear/cuda_kernel.cc b/src/04kernel/src/kernels/dequantize_linear/cuda_kernel.cc new file mode 100644 index 000000000..1cdc1c45b --- /dev/null +++ b/src/04kernel/src/kernels/dequantize_linear/cuda_kernel.cc @@ -0,0 +1,92 @@ +#include "cuda_kernel.hh" + +#ifdef USE_CUDA +#include "../../generator/nvrtc_repo.h" +#include "kernel/cuda/threads_distributer.cuh" +#endif + +namespace refactor::kernel { + using K = DequantizeLinearCuda; + + K::DequantizeLinearCuda( + decltype(from) from_, + decltype(to) to_, + decltype(size) size_, + decltype(withZeroPoint) withZeroPoint_) noexcept + : Kernel(), + from(from_), + to(to_), + size(size_), + withZeroPoint(withZeroPoint_) {} + + auto K::build(TensorRefs const &inputs, Tensor const &output) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + + auto const &x = inputs[0].get(); + if (inputs[1].get().elementsSize() != 1) { + return nullptr; + } + return std::make_unique( + x.dataType, + output.dataType, + x.elementsSize(), + inputs.size() > 2); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing dequantize linear using Nvidia GPU"; + } + +#ifdef USE_CUDA + + constexpr static const char *TEMPLATE = R"~( +extern "C" __global__ void kernel( + {0:} *__restrict__ y, + {1:} const *__restrict__ x, + {0:} const *__restrict__ scale_, + {1:} const *__restrict__ zp_, + size_t n +) {{ + auto zp = zp_ ? *zp_ : static_cast<{1:}>(0); + auto scale = *scale_; + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) {{ + y[tid] = static_cast<{0:}>(x[tid] - zp) * scale; + }} +}} +)~"; + + auto K::lower(Resources &res) const -> RoutineWorkspace { + using namespace runtime; + + auto name = fmt::format("DequantizeLinear{}->{}", from.name(), to.name()); + auto code = fmt::format(TEMPLATE, nvrtc::dataType(to), nvrtc::dataType(from)); + return [withZeroPoint = withZeroPoint, + params = cuda::ThreadsDistributer()(size), + h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel")]// + (Resources &, void *, void const *const *inputs, void *const *outputs) { + auto y = outputs[0]; + auto x = inputs[0], + scale = inputs[1], + zp = withZeroPoint ? inputs[2] : nullptr; + auto n = params.n; + void *args[]{&y, &x, &scale, &zp, &n}; + h->launch(params.gridSize, 1, 1, + params.blockSize, 1, 1, + 0, args); + }; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/dequantize_linear/cuda_kernel.hh b/src/04kernel/src/kernels/dequantize_linear/cuda_kernel.hh new file mode 100644 index 000000000..143552fd0 --- /dev/null +++ b/src/04kernel/src/kernels/dequantize_linear/cuda_kernel.hh @@ -0,0 +1,32 @@ +#ifndef KERNEL_DEQUANTIZE_LINEAR_CUDA_KERNEL_HH +#define KERNEL_DEQUANTIZE_LINEAR_CUDA_KERNEL_HH + +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct DequantizeLinearCuda final : public Kernel { + DataType from, to; + size_t size; + bool withZeroPoint; + + DequantizeLinearCuda( + decltype(from), + decltype(to), + decltype(size), + decltype(withZeroPoint)) noexcept; + + static KernelBox build(TensorRefs const &, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_DEQUANTIZE_LINEAR_CUDA_KERNEL_HH diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.hh b/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.hh index 97020832f..1883179b0 100644 --- a/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.hh +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cpu_kernel.hh @@ -20,4 +20,4 @@ namespace refactor::kernel { }// namespace refactor::kernel -#endif// KERNEL_SOFTMAX_CPU_KERNEL_HH +#endif// KERNEL_DYNAMIC_QUANTIZE_LINEAR_CPU_KERNEL_HH diff --git a/src/04kernel/test/kernels/dequantize_linear/test_cpu.cpp b/src/04kernel/test/kernels/dequantize_linear/test_cpu.cpp new file mode 100644 index 000000000..831dda424 --- /dev/null +++ b/src/04kernel/test/kernels/dequantize_linear/test_cpu.cpp @@ -0,0 +1,31 @@ +#include "../../../src/kernels/dequantize_linear/cpu_kernel.hh" +#include +#include + +using namespace refactor; +using namespace kernel; + +TEST(kernel, DequantizeLinearCpu) { + // build routine + auto x = Tensor::share(DataType::U8, {4}); + auto scale = Tensor::share(DataType::F32, {}); + auto zeroPoint = Tensor::share(DataType::U8, {}); + auto y = Tensor::share(DataType::F32, {4}); + auto kernel = DequantizeLinearCpu::build({*x, *scale, *zeroPoint}, *y); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // put input data + std::vector xData{0, 3, 128, 255}; + float scale_ = 2; + uint8_t zp_ = 128; + std::vector yData(xData.size()); + // inference + { + void const *inputs[]{xData.data(), &scale_, &zp_}; + void *outputs[]{yData.data()}; + routine(res, nullptr, inputs, outputs); + } + // check + ASSERT_EQ(yData, (decltype(yData){-256, -250, 0, 254})); +} diff --git a/src/04kernel/test/kernels/dequantize_linear/test_cuda.cpp b/src/04kernel/test/kernels/dequantize_linear/test_cuda.cpp new file mode 100644 index 000000000..87ec6fd95 --- /dev/null +++ b/src/04kernel/test/kernels/dequantize_linear/test_cuda.cpp @@ -0,0 +1,57 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/dequantize_linear/cpu_kernel.hh" +#include "../../../src/kernels/dequantize_linear/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, DequantizeLinearCuda) { + // build routine + auto x = Tensor::share(DataType::U8, {4}); + auto scale = Tensor::share(DataType::F32, {}); + auto zeroPoint = Tensor::share(DataType::U8, {}); + auto y = Tensor::share(DataType::F32, {4}); + auto kernel = DequantizeLinearCuda::build({*x, *scale, *zeroPoint}, *y), + kCpu = DequantizeLinearCpu::build({*x, *scale, *zeroPoint}, *y); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + auto rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto xGpu = dev.malloc(x->bytesSize()), + scaleGpu = dev.malloc(sizeof(float)), + zpGpu = dev.malloc(sizeof(uint8_t)), + yGpu = dev.malloc(y->bytesSize()); + // put input data + std::vector xData{0, 3, 128, 255}; + float scale_ = 2; + uint8_t zp_ = 128; + std::vector yData(xData.size()); + xGpu->copyFromHost(xData.data()); + scaleGpu->copyFromHost(&scale_); + zpGpu->copyFromHost(&zp_); + // inference + { + void const *inputs[]{*xGpu, *scaleGpu, *zpGpu}; + void *outputs[]{*yGpu}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{xData.data(), &scale_, &zp_}; + void *outputs[]{yData.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + { + std::vector result(yData.size()); + yGpu->copyToHost(result.data()); + EXPECT_EQ(result, yData); + } +} + +#endif From 5d7885574429ac3cf0cff6e5d58b7c9282bbdb16 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 19 Dec 2023 12:55:23 +0800 Subject: [PATCH 15/18] =?UTF-8?q?fix(kernel):=20=E4=BF=AE=E6=AD=A3=20dynam?= =?UTF-8?q?ic=20quantize=20linear=20=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../dynamic_quantize_linear/cuda_kernel.cu | 31 ++++++++++--------- .../dynamic_quantize_linear/test_cuda.cpp | 6 ++++ 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu index 4f4905912..0936ac7ad 100644 --- a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu @@ -61,18 +61,15 @@ namespace refactor::kernel { auto zp = static_cast(round(QMIN - min / scale)); auto tid = blockIdx.x * blockDim.x + threadIdx.x; - for (auto step = blockDim.x * gridDim.x; - tid < n; - tid += step) { - y[tid] = static_cast(std::round(x[tid] / scale) + zp); + for (auto step = blockDim.x * gridDim.x, + i = tid; + i < n; + i += step) { + y[i] = static_cast(std::round(x[i] / scale) + zp); } - switch (tid) { - case 0: - *scale_ = scale; - break; - case 1: - *zp_ = zp; - break; + if (tid == 0) { + *scale_ = scale; + *zp_ = zp; } } @@ -89,11 +86,14 @@ namespace refactor::kernel { QuantizeMinMax *nullTyped = nullptr; size_t tempStorageBytes = 0; - cub::DeviceReduce::Reduce( + auto e = cub::DeviceReduce::Reduce( nullptr, tempStorageBytes, - nullTyped, nullTyped, 0, + nullTyped, nullTyped, size, QuantizeReduceMinMaxFunctor{}, QuantizeMinMax{}); + if (e != cudaSuccess) { + RUNTIME_ERROR(fmt::format("error: {} {}", (int) e, cudaGetErrorString(e))); + } auto offset0 = workspaceSize; workspaceSize += tempStorageBytes; @@ -120,11 +120,14 @@ namespace refactor::kernel { QuantizeMapMinMaxFunctor{}); auto tempStorageSize_ = tempStorageBytes; - cub::DeviceReduce::Reduce( + auto e = cub::DeviceReduce::Reduce( tempStorage, tempStorageSize_, doubled, minmax, params.n, QuantizeReduceMinMaxFunctor{}, QuantizeMinMax{_MAX, _MIN}); + if (e != cudaSuccess) { + RUNTIME_ERROR(fmt::format("error: {} {}", (int) e, cudaGetErrorString(e))); + } kernel<<>>( params.n, minmax, x, y, scale, zp); diff --git a/src/04kernel/test/kernels/dynamic_quantize_linear/test_cuda.cpp b/src/04kernel/test/kernels/dynamic_quantize_linear/test_cuda.cpp index 40f769f29..2d99f6476 100644 --- a/src/04kernel/test/kernels/dynamic_quantize_linear/test_cuda.cpp +++ b/src/04kernel/test/kernels/dynamic_quantize_linear/test_cuda.cpp @@ -50,6 +50,12 @@ TEST(kernel, DynamicQuantizeLinearCuda) { std::vector result(size); yGpu->copyToHost(result.data()); EXPECT_EQ(result, y); + float scale_; + scaleGpu->copyToHost(&scale_); + EXPECT_EQ(scale_, scale); + uint8_t zp_; + zpGpu->copyToHost(&zp_); + EXPECT_EQ(zp_, zeroPoint); } } From bed3627234a56f2206e308652f26fd47c175c0ba Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 19 Dec 2023 18:19:31 +0800 Subject: [PATCH 16/18] =?UTF-8?q?style(kernel):=20=E5=80=9F=E5=8A=A9=20cub?= =?UTF-8?q?=20=E5=9F=BA=E7=A1=80=E8=AE=BE=E6=96=BD=E7=AE=80=E5=8C=96?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../dynamic_quantize_linear/cuda_kernel.cu | 13 ++++++------- .../dynamic_quantize_linear/cuda_kernel.hh | 2 +- .../src/kernels/softmax/cuda_kernel.cu | 19 +++---------------- 3 files changed, 10 insertions(+), 24 deletions(-) diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu index 0936ac7ad..71b2ce31e 100644 --- a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.cu @@ -16,7 +16,7 @@ namespace refactor::kernel { template struct QuantizeMapMinMaxFunctor { __device__ __forceinline__ QuantizeMinMax - operator()(T x) const { + operator()(T x) const noexcept { return {x, x}; } }; @@ -24,9 +24,8 @@ namespace refactor::kernel { template struct QuantizeReduceMinMaxFunctor { __device__ __forceinline__ QuantizeMinMax - operator()(QuantizeMinMax a, QuantizeMinMax b) const { - return {a.min < b.min ? a.min : b.min, - a.max > b.max ? a.max : b.max}; + operator()(QuantizeMinMax a, QuantizeMinMax b) const noexcept { + return {CUB_MIN(a.min, b.min), CUB_MAX(a.max, b.max)}; } }; @@ -56,8 +55,8 @@ namespace refactor::kernel { TO *__restrict__ zp_) { auto const [min, max] = *minmax; - auto temp = QuantizeReduceMinMaxFunctor{}({min, max}, {ZERO, ZERO}); - auto scale = (temp.max - temp.min) / QLEN; + auto cover0 = QuantizeReduceMinMaxFunctor{}({min, max}, {ZERO, ZERO}); + auto scale = (cover0.max - cover0.min) / QLEN; auto zp = static_cast(round(QMIN - min / scale)); auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -73,7 +72,7 @@ namespace refactor::kernel { } } - auto K::lower(Resources &) const noexcept -> RoutineWorkspace { + auto K::lower(Resources &) const -> RoutineWorkspace { using namespace runtime; using TI = float; using TO = uint8_t; diff --git a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.hh b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.hh index b0151b875..d027d751e 100644 --- a/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.hh +++ b/src/04kernel/src/kernels/dynamic_quantize_linear/cuda_kernel.hh @@ -16,7 +16,7 @@ namespace refactor::kernel { size_t kernelTypeId() const noexcept final; std::string_view description() const noexcept final; #ifdef USE_CUDA - RoutineWorkspace lower(Resources &) const noexcept final; + RoutineWorkspace lower(Resources &) const final; #endif }; diff --git a/src/04kernel/src/kernels/softmax/cuda_kernel.cu b/src/04kernel/src/kernels/softmax/cuda_kernel.cu index 7c65ea029..114cb453a 100644 --- a/src/04kernel/src/kernels/softmax/cuda_kernel.cu +++ b/src/04kernel/src/kernels/softmax/cuda_kernel.cu @@ -4,9 +4,6 @@ namespace refactor::kernel { using namespace runtime; - template - __device__ __forceinline__ T max_(T a, T b) { return a > b ? a : b; } - template __device__ __forceinline__ T exp_(T x); template<> __device__ __forceinline__ float exp_(float x) { return expf(x); } @@ -58,16 +55,6 @@ namespace refactor::kernel { } } - template struct SumOp { - __device__ __forceinline__ T operator()(T const &a, T const &b) const { - return a + b; - } - }; - template struct MaxOp { - __device__ __forceinline__ T operator()(T const &a, T const &b) const { - return max_(a, b); - } - }; template __device__ __forceinline__ T WarpAllReduce(T val, ReductionOp op) { for (int mask = blockDim.x >> 1; mask > 0; mask >>= 1) { @@ -92,9 +79,9 @@ namespace refactor::kernel { T maxData = -__FLT_MAX__; for (int i = threadIdx.x; i < dimsize; i += blockDim.x) { - maxData = max_(maxData, input[tid + i * stride]); + maxData = CUB_MAX(maxData, input[tid + i * stride]); } - maxData = WarpAllReduce(maxData, MaxOp{}); + maxData = WarpAllReduce(maxData, cub::Max()); if (threadIdx.x == 0) { maxTotal[threadIdx.y] = maxData; } @@ -104,7 +91,7 @@ namespace refactor::kernel { for (int i = threadIdx.x; i < dimsize; i += blockDim.x) { sumData += exp_(input[tid + i * stride] - maxTotal[threadIdx.y]); } - sumData = WarpAllReduce(sumData, SumOp{}); + sumData = WarpAllReduce(sumData, cub::Sum()); if (threadIdx.x == 0) { sumTotal[threadIdx.y] = sumData; } From 416cd2ef64370047b1192a2cd38c6f495c078a1c Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 25 Dec 2023 10:24:57 +0800 Subject: [PATCH 17/18] =?UTF-8?q?fix(kernel):=20=E7=A8=8D=E5=BE=AE?= =?UTF-8?q?=E8=B0=83=E6=95=B4=20MatMulInteger=20=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- scripts/compare/compare.py | 18 +++++---- .../kernels/mat_mul_integer/cublas_kernel.cu | 37 ++++++++----------- 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/scripts/compare/compare.py b/scripts/compare/compare.py index c235fd1bf..3f8b0f766 100644 --- a/scripts/compare/compare.py +++ b/scripts/compare/compare.py @@ -23,6 +23,7 @@ def parse_args(): args.actual, ) + def getDiff(base, test): absolute_diff = np.subtract(base, test) max_absolute_diff = np.max(np.abs(absolute_diff)) @@ -35,16 +36,19 @@ def getDiff(base, test): return max_absolute_diff, max_relative_diff -def compare_npy(actual_path, expect_path, edge, node): + +def compare_npy(node, actual_path, expect_path): actual = np.load(actual_path) expect = np.load(expect_path) if np.isnan(actual).any(): - print(f"NAN value in node:{node} edge:{edge}") + print(f"NAN value in node:{node}\t{actual_path}\t{expect_path}") return - + max_absolute_diff, max_relative_diff = getDiff(expect, actual) - if max_absolute_diff != 0.0: ## No need to print tensor with no diff - print(f'{max_absolute_diff}\t{max_relative_diff}\t{node}\t{edge}') + if max_absolute_diff != 0.0: ## No need to print tensor with no diff + print( + f"{max_absolute_diff}\t{max_relative_diff}\t{node}\t{actual_path}\t{expect_path}" + ) def main(): @@ -70,9 +74,7 @@ def main(): expect_file = expect_file + ".npy" expect_file_path = os.path.join(expect_dir, expect_file) if os.path.exists(expect_file_path): - compare_npy( - actual_file_path, expect_file_path, edge_name, node_name - ) + compare_npy(meta_file, actual_file_path, expect_file_path) if __name__ == "__main__": diff --git a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu index 36383ea9c..ccff7001c 100644 --- a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu +++ b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu @@ -10,7 +10,10 @@ namespace refactor::kernel { template __device__ __forceinline__ static int8_t sub(T, T); template<> __device__ __forceinline__ int8_t sub(int8_t a, int8_t b) { return a - b; } - template<> __device__ __forceinline__ int8_t sub(uint8_t a, uint8_t b) { return static_cast(static_cast(a) - static_cast(b)); } + template<> __device__ __forceinline__ int8_t sub(uint8_t a, uint8_t b) { + constexpr static int16_t MAX = 127; + return static_cast(CUB_MIN(MAX, static_cast(a) - static_cast(b))); + } template struct MatMulIntegerZPFunctorScalar { @@ -33,16 +36,16 @@ namespace refactor::kernel { } template - struct MatMulIntegerZPFunctorA { - dim_t m, n; + struct MatMulIntegerZPFunctor { + dim_t m, n, a, b, c; T const *src, *zp; __device__ int8_t operator()(size_t idx) const noexcept { auto - // k = idx % n, + k = idx % n, j = idx / n % m, i = idx / n / m; - return sub(src[idx], zp[i * m + j]); + return sub(src[idx], zp[i * a + j * b + k * c]); } }; @@ -52,28 +55,17 @@ namespace refactor::kernel { int8_t *dst, void const *src_, void const *zp_) { thrust::tabulate(thrust::device, dst, dst + b * m * n, - MatMulIntegerZPFunctorA{ + MatMulIntegerZPFunctor{ m, n, + m, + 1, + 0, reinterpret_cast(src_), reinterpret_cast(zp_), }); } - template - struct MatMulIntegerZPFunctorB { - dim_t m, n; - T const *src, *zp; - - __device__ int8_t operator()(size_t idx) const noexcept { - auto - k = idx % n, - // j = idx / n % m, - i = idx / n / m; - return sub(src[idx], zp[i * n + k]); - } - }; - template static void applyZeroPointB( dim_t b, dim_t m, dim_t n, @@ -81,9 +73,12 @@ namespace refactor::kernel { thrust::tabulate(thrust::device, dst, dst + b * m * n, - MatMulIntegerZPFunctorB{ + MatMulIntegerZPFunctor{ m, n, + n, + 0, + 1, reinterpret_cast(src_), reinterpret_cast(zp_), }); From 25d0c443ca7ec75804adb2c488421f76b207ce98 Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Wed, 27 Dec 2023 14:47:56 +0800 Subject: [PATCH 18/18] fix: add flatten front support --- src/07onnx/src/operators.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/07onnx/src/operators.cpp b/src/07onnx/src/operators.cpp index 2fbfe8b03..6e657f272 100644 --- a/src/07onnx/src/operators.cpp +++ b/src/07onnx/src/operators.cpp @@ -12,6 +12,7 @@ #include "operators/dynamic_quantize_linear.hh" #include "operators/einsum.hh" #include "operators/expand.hh" +#include "operators/flatten.hh" #include "operators/gather.hh" #include "operators/gather_elements.hh" #include "operators/gemm.hh" @@ -81,6 +82,7 @@ namespace refactor::onnx { REGISTER(ReduceSum , Reduce ); REGISTER(ReduceSumSquare , Reduce ); REGISTER(Reshape , Reshape ); + REGISTER(Flatten , Flatten ); REGISTER(ScatterND , ScatterND ); REGISTER(Max , Select ); REGISTER(Min , Select );