Skip to content
Merged

Dev #59

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
8c1b8fa
feat(onnx): 添加 onnx MatMulInteger 前端算子
YdrMaster Dec 14, 2023
4bbf121
style(onnx): 整理 MatMul cpu kernel
YdrMaster Dec 14, 2023
7a37779
feat(onnx): 前端支持 MatMulInteger 带有 4 个输入
YdrMaster Dec 15, 2023
6f11997
feat(computation): MatMulInteger 从 MatMul 分离
YdrMaster Dec 15, 2023
33ff398
refactor(kernel): 为 Broadcaster 表示不需要广播明确语义
YdrMaster Dec 15, 2023
6462a27
feat(kernel): 实现 MatMulInteger cpu kernel
YdrMaster Dec 15, 2023
916fd3d
test(kernel): 测试 MatMulInteger cpu kernel
YdrMaster Dec 18, 2023
c74b588
feat(kernel): 实现 MatMulInteger cublas kernel
YdrMaster Dec 18, 2023
55cd886
test(kernel): 测试 MatMulInteger cublas kernel
YdrMaster Dec 18, 2023
3d14bf0
feat: 添加逐张量的量化和反量化算子
YdrMaster Dec 18, 2023
103254b
feat(kernel): 添加逐张量量化的 cpu kernel
YdrMaster Dec 18, 2023
d6b4952
feat(kernel): 实现 DynamicQuantizeLinear cuda kernel
YdrMaster Dec 18, 2023
ee14c96
test(kernel): 测试 DynamicQuantizeLinear
YdrMaster Dec 18, 2023
2e75d38
feat(kernel): 实现反量化算子
YdrMaster Dec 19, 2023
5d78855
fix(kernel): 修正 dynamic quantize linear 错误
YdrMaster Dec 19, 2023
bed3627
style(kernel): 借助 cub 基础设施简化代码
YdrMaster Dec 19, 2023
416cd2e
fix(kernel): 稍微调整 MatMulInteger 逻辑
YdrMaster Dec 25, 2023
25d0c44
fix: add flatten front support
bitzyz Dec 27, 2023
d0c4692
Merge pull request #60 from InfiniTensor/dev_flatten
YdrMaster Dec 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions scripts/compare/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def parse_args():
args.actual,
)


def getDiff(base, test):
absolute_diff = np.subtract(base, test)
max_absolute_diff = np.max(np.abs(absolute_diff))
Expand All @@ -35,16 +36,19 @@ def getDiff(base, test):

return max_absolute_diff, max_relative_diff

def compare_npy(actual_path, expect_path, edge, node):

def compare_npy(node, actual_path, expect_path):
actual = np.load(actual_path)
expect = np.load(expect_path)
if np.isnan(actual).any():
print(f"NAN value in node:{node} edge:{edge}")
print(f"NAN value in node:{node}\t{actual_path}\t{expect_path}")
return

max_absolute_diff, max_relative_diff = getDiff(expect, actual)
if max_absolute_diff != 0.0: ## No need to print tensor with no diff
print(f'{max_absolute_diff}\t{max_relative_diff}\t{node}\t{edge}')
if max_absolute_diff != 0.0: ## No need to print tensor with no diff
print(
f"{max_absolute_diff}\t{max_relative_diff}\t{node}\t{actual_path}\t{expect_path}"
)


def main():
Expand All @@ -70,9 +74,7 @@ def main():
expect_file = expect_file + ".npy"
expect_file_path = os.path.join(expect_dir, expect_file)
if os.path.exists(expect_file_path):
compare_npy(
actual_file_path, expect_file_path, edge_name, node_name
)
compare_npy(meta_file, actual_file_path, expect_file_path)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions src/04kernel/include/kernel/attributes/broadcaster.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace refactor::kernel {
explicit Broadcaster(std::vector<slice_t<dim_t>>);
explicit Broadcaster(TensorRefs const &inputs);
void locate(dim_t k, dim_t ans[]) const noexcept;
bool needBroadcast() const noexcept;
};

}// namespace refactor::kernel
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
#ifndef KERNEL_MATMUL_INFO_H
#define KERNEL_MATMUL_INFO_H
#ifndef KERNEL_MAT_MUL_INFO_H
#define KERNEL_MAT_MUL_INFO_H

#include "kernel/attributes/broadcaster.h"
#include "kernel/attributes/expand_info.h"
#include <variant>

namespace refactor::kernel {

struct MatMulInfo {
DataType dataType;
float alpha, beta;
bool transA, transB;
size_t m, k, n;
dim_t m, k, n;
// Expand operation info for biasd
std::optional<ExpandInfo> biasExpand;
// A constant batch or a 2-directional broadcaster that deals with dimensions before the last 2 dimensions
std::variant<Broadcaster, size_t> broadcasterOrBatch;
// A 2-directional broadcaster that deals with dimensions before the last 2 dimensions
Broadcaster broadcaster;

MatMulInfo(Tensor const &, Tensor const &,
std::optional<std::reference_wrapper<Tensor const>>,
Expand All @@ -24,4 +23,4 @@ namespace refactor::kernel {

}// namespace refactor::kernel

#endif// KERNEL_MATMUL_INFO_H
#endif// KERNEL_MAT_MUL_INFO_H
28 changes: 28 additions & 0 deletions src/04kernel/include/kernel/attributes/mat_mul_integer_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef KERNEL_MAT_MUL_INTEGER_INFO_H
#define KERNEL_MAT_MUL_INTEGER_INFO_H

#include "kernel/attributes/broadcaster.h"

namespace refactor::kernel {

struct MatMulIntegerInfo {
struct Input {
bool
withZeroPoint,
signed_,
scalar;

Input(TensorRefs const &, size_t i) noexcept;
};

Input a, b;
dim_t m, k, n;
Broadcaster broadcaster;

explicit MatMulIntegerInfo(TensorRefs const &inputs) noexcept;
dim_t batch() const noexcept;
};

}// namespace refactor::kernel

#endif// KERNEL_MAT_MUL_INTEGER_INFO_H
18 changes: 18 additions & 0 deletions src/04kernel/include/kernel/collectors/dequantize_linear.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef KERNEL_DEQUANTIZE_LINEAR_H
#define KERNEL_DEQUANTIZE_LINEAR_H

#include "../collector.h"

namespace refactor::kernel {

struct DequantizeLinearCollector final : public InfoCollector {

explicit DequantizeLinearCollector(decltype(_target)) noexcept;

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};

}// namespace refactor::kernel

#endif// KERNEL_DEQUANTIZE_LINEAR_H
18 changes: 18 additions & 0 deletions src/04kernel/include/kernel/collectors/dynamic_quantize_linear.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef KERNEL_DYNAMIC_QUANTIZE_LINEAR_H
#define KERNEL_DYNAMIC_QUANTIZE_LINEAR_H

#include "../collector.h"

namespace refactor::kernel {

struct DynamicQuantizeLinearCollector final : public InfoCollector {

explicit DynamicQuantizeLinearCollector(decltype(_target)) noexcept;

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};

}// namespace refactor::kernel

#endif// KERNEL_DYNAMIC_QUANTIZE_LINEAR_H
19 changes: 19 additions & 0 deletions src/04kernel/include/kernel/collectors/mat_mul_integer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef KERNEL_MAT_MUL_INTEGER_H
#define KERNEL_MAT_MUL_INTEGER_H

#include "../collector.h"

namespace refactor::kernel {

struct MatMulIntegerCollector final : public InfoCollector {

constexpr MatMulIntegerCollector(decltype(_target) target) noexcept
: InfoCollector(target) {}

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};

}// namespace refactor::kernel

#endif// KERNEL_MAT_MUL_INTEGER_H
4 changes: 4 additions & 0 deletions src/04kernel/src/attributes/broadcaster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,8 @@ namespace refactor::kernel {
}
}

bool Broadcaster::needBroadcast() const noexcept {
return !strides.empty();
}

}// namespace refactor::kernel
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
#include "kernel/attributes/matmul_info.h"
#include <cstddef>
#include <numeric>
#include "kernel/attributes/mat_mul_info.h"

namespace refactor::kernel {

ExpandInfo buildBias(size_t m, size_t n,
ExpandInfo buildBias(dim_t m, dim_t n,
Tensor const &a,
Tensor const &b,
Tensor const &c) {
std::vector<dim_t> output(std::max(a.rank(), b.rank()));
auto it = output.rbegin();
*it++ = n;
*it++ = m;
for (auto da = static_cast<size_t>(a.rank() - 2),
db = static_cast<size_t>(b.rank() - 2);
for (auto da = static_cast<dim_t>(a.rank() - 2),
db = static_cast<dim_t>(b.rank() - 2);
auto i : range0_(output.size() - 2)) {
auto a_ = i < da ? a.shape[da - i - 1] : 1;
auto b_ = i < db ? b.shape[db - i - 1] : 1;
Expand All @@ -26,13 +24,6 @@ namespace refactor::kernel {
slice(output.data(), output.size()));
}

std::variant<Broadcaster, size_t> buildBroadcasterOrBatch(slice_t<dim_t> dimA, slice_t<dim_t> dimB) {
if (std::equal(dimA.begin(), dimA.end(), dimB.begin(), dimB.end())) {
return std::accumulate(dimA.begin(), dimA.end(), (size_t) 1, std::multiplies<size_t>());
}
return Broadcaster({dimA, dimB});
}

MatMulInfo::MatMulInfo(
Tensor const &a, Tensor const &b,
std::optional<std::reference_wrapper<Tensor const>> c,
Expand All @@ -44,7 +35,8 @@ namespace refactor::kernel {
k(transA ? a.shape.rbegin()[1] : a.shape.rbegin()[0]),
n(transB ? b.shape.rbegin()[1] : b.shape.rbegin()[0]),
biasExpand(c ? std::make_optional(buildBias(m, n, a, b, *c)) : std::nullopt),
broadcasterOrBatch(buildBroadcasterOrBatch(slice(a.shape.data(), a.shape.size() - 2), slice(b.shape.data(), b.shape.size() - 2))) {
broadcaster({slice(a.shape.data(), a.shape.size() - 2),
slice(b.shape.data(), b.shape.size() - 2)}) {
auto kB = transB ? b.shape.rbegin()[0] : b.shape.rbegin()[1];
ASSERT(k == kB, "MatMul: input shape not matched.");
}
Expand Down
42 changes: 42 additions & 0 deletions src/04kernel/src/attributes/mat_mul_integer_info.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "kernel/attributes/mat_mul_integer_info.h"

namespace refactor::kernel {

MatMulIntegerInfo::Input::Input(TensorRefs const &inputs, size_t i) noexcept
: withZeroPoint(false),
signed_(true),
scalar(true) {
if (inputs.size() > i + 2) {
auto const &t = inputs[i + 2].get();
auto size = t.elementsSize();
if (t.data) {
auto data = slice(t.data->get<uint8_t>(), size);
if (std::all_of(data.begin(), data.end(), [](auto x) { return x == 0; })) {
return;
}
}
withZeroPoint = true;
signed_ = t.dataType == DataType::I8;
scalar = size == 1;
}
}

MatMulIntegerInfo::MatMulIntegerInfo(TensorRefs const &inputs) noexcept
: a(inputs, 0),
b(inputs, 1),
#define A (inputs[0].get().shape)
#define B (inputs[1].get().shape)
m(A.rbegin()[1]),
k(A.rbegin()[0]),
n(B.rbegin()[0]),
broadcaster({slice(A.data(), A.size() - 2),
slice(B.data(), B.size() - 2)}) {
}
#undef A
#undef B

dim_t MatMulIntegerInfo::batch() const noexcept {
return broadcaster.outputsCount;
}

}// namespace refactor::kernel
32 changes: 32 additions & 0 deletions src/04kernel/src/collectors/dequantize_linear.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "kernel/collectors/dequantize_linear.h"
#include "../kernels/dequantize_linear/cpu_kernel.hh"
#include "../kernels/dequantize_linear/cuda_kernel.hh"

namespace refactor::kernel {

DequantizeLinearCollector::
DequantizeLinearCollector(decltype(_target) target) noexcept
: InfoCollector(target) {}

std::vector<KernelBox>
DequantizeLinearCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto const &output = outputs[0];
std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = DequantizeLinearCpu::build(inputs, output); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Nvidia:
if (auto ptr = DequantizeLinearCuda::build(inputs, output); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
33 changes: 33 additions & 0 deletions src/04kernel/src/collectors/dynamic_quantize_linear.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "kernel/collectors/dynamic_quantize_linear.h"
#include "../kernels/dynamic_quantize_linear/cpu_kernel.hh"
#include "../kernels/dynamic_quantize_linear/cuda_kernel.hh"

namespace refactor::kernel {

DynamicQuantizeLinearCollector::
DynamicQuantizeLinearCollector(decltype(_target) target) noexcept
: InfoCollector(target) {}

std::vector<KernelBox>
DynamicQuantizeLinearCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto size = inputs[0].get().elementsSize();

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = DynamicQuantizeLinearCpu::build(size); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Nvidia:
if (auto ptr = DynamicQuantizeLinearCuda::build(size); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
3 changes: 1 addition & 2 deletions src/04kernel/src/collectors/mat_mul.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#include "kernel/collectors/mat_mul.h"
#include "../kernels/mat_mul/cpu_kernel.hh"
#include "../kernels/mat_mul/cublas_kernel.hh"
#include "common.h"
#include "kernel/attributes/matmul_info.h"
#include "kernel/attributes/mat_mul_info.h"

namespace refactor::kernel {
#define REGISTER(T) \
Expand Down
30 changes: 30 additions & 0 deletions src/04kernel/src/collectors/mat_mul_integer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "kernel/collectors/mat_mul_integer.h"
#include "../../src/kernels/mat_mul_integer/cpu_kernel.hh"
#include "../../src/kernels/mat_mul_integer/cublas_kernel.hh"
#include "kernel/attributes/mat_mul_integer_info.h"

namespace refactor::kernel {

std::vector<KernelBox>
MatMulIntegerCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
MatMulIntegerInfo info(inputs);

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
if (auto ptr = MatMulIntegerCpu::build(info); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
case decltype(_target)::Nvidia:
if (auto ptr = MatMulIntegerCublas::build(info); ptr) {
ans.emplace_back(std::move(ptr));
}
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
Loading