diff --git a/include/TensorCompiler/Conversion/MLIRBuilder.hpp b/include/TensorCompiler/Conversion/MLIRBuilder.hpp new file mode 100644 index 0000000..c88cfba --- /dev/null +++ b/include/TensorCompiler/Conversion/MLIRBuilder.hpp @@ -0,0 +1,57 @@ +#pragma once +#include "TensorCompiler/Dialect/NNDialect.hpp" +#include "TensorCompiler/Graph/Graph.hpp" +#include "TensorCompiler/Graph/GraphVisitor.hpp" + +#include +#include + +#include +#include + +namespace tc::conversion { + +class HighLevelMLIRBuilder : public tc::graph::GraphVisitor { +public: + explicit HighLevelMLIRBuilder(mlir::MLIRContext &ctx); + + const mlir::ModuleOp &GetModule() const { return module_; } + + mlir::ModuleOp Build(const tc::graph::Graph &graph) { + graph.Accept(*this); + return module_; + } + +private: + void Visit(const tc::graph::Graph &graph) override; + void Visit(const tc::graph::TensorEntity &tensor) override; + void Visit(const tc::graph::OperationEntity &op) override; + void Visit(const tc::graph::ConstantEntity &constant) override; + void Finalize(const tc::graph::Graph &graph) override; + + mlir::Value GetValue(tc::graph::EntityId id) const; + mlir::Value GetValue(const std::string &name) const; + + void BuildRelu(const tc::graph::OperationEntity &op, mlir::Location loc); + void BuildAdd(const tc::graph::OperationEntity &op, mlir::Location loc); + void BuildMul(const tc::graph::OperationEntity &op, mlir::Location loc); + void BuildMatMul(const tc::graph::OperationEntity &op, mlir::Location loc); + void BuildGemm(const tc::graph::OperationEntity &op, mlir::Location loc); + void BuildConv(const tc::graph::OperationEntity &op, mlir::Location loc); + void BuildTranspose(const tc::graph::OperationEntity &op, mlir::Location loc); + + mlir::Type ConvertElemType(int32_t onnx_dtype); + mlir::RankedTensorType ConvertTensorType(int32_t dtype, + llvm::ArrayRef shape); + mlir::DenseElementsAttr ConvertTensorData(const tc::graph::TensorData &data); + +private: + mlir::MLIRContext &ctx_; + mlir::OpBuilder builder_; + mlir::ModuleOp module_; + + std::unordered_map valueMap_; + std::unordered_map tensorNameToId_; + std::unordered_set initializerIds_; +}; +} // namespace tc::conversion \ No newline at end of file diff --git a/include/TensorCompiler/Converter/MLIRBuilder.hpp b/include/TensorCompiler/Converter/MLIRBuilder.hpp deleted file mode 100644 index 00efcb4..0000000 --- a/include/TensorCompiler/Converter/MLIRBuilder.hpp +++ /dev/null @@ -1,53 +0,0 @@ -#pragma once -#include "TensorCompiler/Frontend/ONNXVisitor.hpp" - -#include -#include - -#include -#include -#include - -namespace tc::converter::onnx_to_high_mlir { -class MLIRBuilder final : public tc::frontend::ONNXVisitor { -public: - explicit MLIRBuilder(mlir::MLIRContext &ctx); - - const mlir::ModuleOp &GetModule() const; - - void Visit(const onnx::ModelProto &) override; - void Visit(const onnx::GraphProto &graph) override; - void Visit(const onnx::ValueInfoProto &) override; - void Visit(const onnx::TensorProto &tensor) override; - void Visit(const onnx::NodeProto &node) override; - void Visit(const onnx::AttributeProto &) override; - - void Finalize(const onnx::GraphProto &graph) override; - -private: - mlir::Value FindValue(const std::string &name) const; - - mlir::Type ConvertElemType(int onnx_dtype); - mlir::RankedTensorType ConvertTensorType(int onnx_dtype, - llvm::ArrayRef shape); - - mlir::DenseElementsAttr ConvertTensor(const onnx::TensorProto &tensor); - mlir::RankedTensorType ConvertValueInfo(const onnx::ValueInfoProto &info); - - void BuildRelu(const onnx::NodeProto &node, mlir::Location loc); - void BuildAdd(const onnx::NodeProto &node, mlir::Location loc); - void BuildMul(const onnx::NodeProto &node, mlir::Location loc); - void BuildMatMul(const onnx::NodeProto &node, mlir::Location loc); - void BuildGemm(const onnx::NodeProto &node, mlir::Location loc); - void BuildConv(const onnx::NodeProto &node, mlir::Location loc); - void BuildTranspose(const onnx::NodeProto &node, mlir::Location loc); - -private: - mlir::MLIRContext &ctx_; - mlir::OpBuilder builder_; - mlir::ModuleOp module_; - - std::map valueMap_; - std::set initializerNames_; -}; -} // namespace tc::converter::onnx_to_high_mlir \ No newline at end of file diff --git a/include/TensorCompiler/Dialect/NNDialect.hpp b/include/TensorCompiler/Dialect/NNDialect.hpp index 4194ade..dee05e6 100644 --- a/include/TensorCompiler/Dialect/NNDialect.hpp +++ b/include/TensorCompiler/Dialect/NNDialect.hpp @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #include "NNDialect.h.inc" diff --git a/include/TensorCompiler/Converter/GraphBuilder.hpp b/include/TensorCompiler/Frontend/GraphBuilder.hpp similarity index 79% rename from include/TensorCompiler/Converter/GraphBuilder.hpp rename to include/TensorCompiler/Frontend/GraphBuilder.hpp index a0b648f..2061b92 100644 --- a/include/TensorCompiler/Converter/GraphBuilder.hpp +++ b/include/TensorCompiler/Frontend/GraphBuilder.hpp @@ -1,12 +1,12 @@ #pragma once #include "TensorCompiler/Frontend/ONNXVisitor.hpp" -#include "TensorCompiler/Graph/IR.hpp" +#include "TensorCompiler/Graph/Graph.hpp" #include #include #include -namespace tc::converter::onnx_to_graph { +namespace tc::frontend { using tc::graph::AttrValue; using tc::graph::EntityId; using tc::graph::Graph; @@ -15,12 +15,14 @@ using tc::graph::TensorData; class GraphBuilder final : public tc::frontend::ONNXVisitor { Graph graph_; std::set initializerNames_; + std::set inputNames_; + std::set outputNames_; TensorData ParseTensor(const onnx::TensorProto &tensor); AttrValue ParseAttribute(const onnx::AttributeProto &attr); EntityId EnsureTensor(const std::string &name, - const std::vector &shape = {}); + const std::vector &shape, int32_t dtype); public: const Graph &GetGraph() const { return graph_; } @@ -34,4 +36,4 @@ class GraphBuilder final : public tc::frontend::ONNXVisitor { void Finalize(const onnx::GraphProto &) override; }; -} // namespace tc::converter::onnx_to_graph \ No newline at end of file +} // namespace tc::frontend \ No newline at end of file diff --git a/include/TensorCompiler/Graph/IR.hpp b/include/TensorCompiler/Graph/Graph.hpp similarity index 84% rename from include/TensorCompiler/Graph/IR.hpp rename to include/TensorCompiler/Graph/Graph.hpp index a70ef81..c246e90 100644 --- a/include/TensorCompiler/Graph/IR.hpp +++ b/include/TensorCompiler/Graph/Graph.hpp @@ -7,6 +7,7 @@ #include namespace tc::graph { +class GraphVisitor; using IntList = std::vector; using DoubleList = std::vector; using StringList = std::vector; @@ -15,6 +16,7 @@ struct TensorData { std::string name; std::vector dims; std::vector raw_data; + int32_t dtype; }; using AttrScalar = std::variant; @@ -30,6 +32,7 @@ struct TensorEntity { std::string name; std::vector shape; bool is_initializer = false; + int32_t dtype; std::optional data; }; @@ -51,20 +54,17 @@ struct Entity { EntityKind kind; std::variant entity; - Entity(EntityKind k, OperationEntity op) - : kind(k), entity(std::move(op)) {} + Entity(EntityKind k, OperationEntity op) : kind(k), entity(std::move(op)) {} - Entity(EntityKind k, TensorEntity t) - : kind(k), entity(std::move(t)) {} + Entity(EntityKind k, TensorEntity t) : kind(k), entity(std::move(t)) {} - Entity(EntityKind k, ConstantEntity c) - : kind(k), entity(std::move(c)) {} + Entity(EntityKind k, ConstantEntity c) : kind(k), entity(std::move(c)) {} }; class Graph { public: - EntityId AddTensor(const std::string &name, std::vector shape = {}, - bool is_init = false); + EntityId AddTensor(const std::string &name, std::vector shape, + int32_t dtype, bool is_init = false); EntityId AddConstant(const TensorData &data); @@ -83,6 +83,8 @@ class Graph { void AddInput(EntityId id) { inputs_.push_back(id); } void AddOutput(EntityId id) { outputs_.push_back(id); } + void Accept(GraphVisitor &visitor) const; + private: std::vector inputs_; std::vector outputs_; diff --git a/include/TensorCompiler/Graph/Exporter.hpp b/include/TensorCompiler/Graph/GraphDumper.hpp similarity index 83% rename from include/TensorCompiler/Graph/Exporter.hpp rename to include/TensorCompiler/Graph/GraphDumper.hpp index e3c031c..7b6d64e 100644 --- a/include/TensorCompiler/Graph/Exporter.hpp +++ b/include/TensorCompiler/Graph/GraphDumper.hpp @@ -1,11 +1,10 @@ #pragma once -#include "TensorCompiler/Graph/IR.hpp" +#include "TensorCompiler/Graph/Graph.hpp" #include namespace tc::graph { std::string DumpGraph(const Graph &g); - std::string ToDot(const Graph &g); bool SaveDot(const Graph &g, const std::string &file); diff --git a/include/TensorCompiler/Graph/GraphVisitor.hpp b/include/TensorCompiler/Graph/GraphVisitor.hpp new file mode 100644 index 0000000..0bce4dd --- /dev/null +++ b/include/TensorCompiler/Graph/GraphVisitor.hpp @@ -0,0 +1,23 @@ +#pragma once + +namespace tc::graph { +class Graph; +struct TensorEntity; +struct OperationEntity; +struct ConstantEntity; +} // namespace tc::graph + +namespace tc::graph { + +class GraphVisitor { +public: + virtual ~GraphVisitor() = default; + + virtual void Visit(const tc::graph::Graph &graph) = 0; + virtual void Visit(const tc::graph::TensorEntity &tensor) = 0; + virtual void Visit(const tc::graph::OperationEntity &op) = 0; + virtual void Visit(const tc::graph::ConstantEntity &constant) = 0; + virtual void Finalize(const tc::graph::Graph &graph) = 0; +}; + +} // namespace tc::graph \ No newline at end of file diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 3ed6e66..73af67a 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,7 +1,7 @@ add_subdirectory(Dialect) add_subdirectory(Frontend) add_subdirectory(Graph) -add_subdirectory(Converter) +add_subdirectory(Conversion) add_library(tc-core INTERFACE) @@ -10,7 +10,7 @@ target_link_libraries(tc-core tc-dialect tc-frontend tc-graph - tc-converter + tc-conversion MLIRParser MLIRFuncDialect MLIRArithDialect diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt new file mode 100644 index 0000000..20b009c --- /dev/null +++ b/lib/Conversion/CMakeLists.txt @@ -0,0 +1,11 @@ +add_library(tc-conversion MLIRBuilder.cpp) + +target_link_libraries(tc-conversion + tc-graph + tc-dialect + onnx) + +target_include_directories(tc-conversion PUBLIC + ${PROJECT_SOURCE_DIR}/include + ${CMAKE_CURRENT_BINARY_DIR}/../Dialect + ) \ No newline at end of file diff --git a/lib/Conversion/MLIRBuilder.cpp b/lib/Conversion/MLIRBuilder.cpp new file mode 100644 index 0000000..712478c --- /dev/null +++ b/lib/Conversion/MLIRBuilder.cpp @@ -0,0 +1,385 @@ +#include "TensorCompiler/Conversion/MLIRBuilder.hpp" + +#include +#include +#include +#include + +#include + +#include + +namespace tc::conversion { +using tc::graph::AttrValue; +using tc::graph::DoubleList; +using tc::graph::EntityId; +using tc::graph::IntList; +using tc::graph::StringList; +using tc::graph::TensorData; + +static const AttrValue *FindAttr(const tc::graph::OperationEntity &op, + std::string_view name) { + auto it = op.attrs.find(std::string(name)); + return (it != op.attrs.end()) ? &it->second : nullptr; +} + +static int64_t GetAttrInt(const tc::graph::OperationEntity &op, + std::string_view name, int64_t defaultVal) { + const auto *attr = FindAttr(op, name); + if (attr && std::holds_alternative(*attr)) { + const auto &scalar = std::get(*attr); + if (std::holds_alternative(scalar)) + return std::get(scalar); + } + return defaultVal; +} + +static float GetAttrFloat(const tc::graph::OperationEntity &op, + std::string_view name, float defaultVal) { + const auto *attr = FindAttr(op, name); + if (attr && std::holds_alternative(*attr)) { + const auto &scalar = std::get(*attr); + if (std::holds_alternative(scalar)) + return static_cast(std::get(scalar)); + if (std::holds_alternative(scalar)) + return static_cast(std::get(scalar)); + } + return defaultVal; +} + +static std::string GetAttrString(const tc::graph::OperationEntity &op, + std::string_view name, + std::string_view defaultVal) { + const auto *attr = FindAttr(op, name); + if (attr && std::holds_alternative(*attr)) { + const auto &scalar = std::get(*attr); + if (std::holds_alternative(scalar)) + return std::get(scalar); + } + return std::string(defaultVal); +} + +static llvm::SmallVector +GetAttrInts(const tc::graph::OperationEntity &op, std::string_view name, + size_t fillLen = 0, int64_t fillVal = 0) { + const auto *attr = FindAttr(op, name); + if (attr && std::holds_alternative(*attr)) { + const auto &list = std::get(*attr); + if (std::holds_alternative(list)) { + const auto &vec = std::get(list); + return {vec.begin(), vec.end()}; + } + } + return llvm::SmallVector(fillLen, fillVal); +} + +HighLevelMLIRBuilder::HighLevelMLIRBuilder(mlir::MLIRContext &ctx) + : ctx_(ctx), builder_(&ctx_), + module_(mlir::ModuleOp::create(builder_.getUnknownLoc())) { + builder_.setInsertionPointToEnd(module_.getBody()); +} + +mlir::Value HighLevelMLIRBuilder::GetValue(tc::graph::EntityId id) const { + auto it = valueMap_.find(id); + return (it != valueMap_.end()) ? it->second : mlir::Value{}; +} + +mlir::Value HighLevelMLIRBuilder::GetValue(const std::string &name) const { + auto it = tensorNameToId_.find(name); + if (it != tensorNameToId_.end()) + return GetValue(it->second); + return mlir::Value{}; +} + +mlir::Type HighLevelMLIRBuilder::ConvertElemType(int32_t onnx_dtype) { + switch (onnx_dtype) { + case onnx::TensorProto::FLOAT: + return builder_.getF32Type(); + case onnx::TensorProto::DOUBLE: + return builder_.getF64Type(); + case onnx::TensorProto::INT32: + return builder_.getI32Type(); + case onnx::TensorProto::INT64: + return builder_.getI64Type(); + case onnx::TensorProto::INT8: + return builder_.getI8Type(); + case onnx::TensorProto::UINT8: + return builder_.getIntegerType(8, false); + case onnx::TensorProto::BOOL: + return builder_.getI1Type(); + default: + llvm::errs() << "Unsupported dtype: " << onnx_dtype << "\n"; + llvm_unreachable("unsupported tensor dtype"); + } +} + +mlir::RankedTensorType +HighLevelMLIRBuilder::ConvertTensorType(int32_t dtype, + llvm::ArrayRef shape) { + return mlir::RankedTensorType::get(shape, ConvertElemType(dtype)); +} + +mlir::DenseElementsAttr +HighLevelMLIRBuilder::ConvertTensorData(const tc::graph::TensorData &data) { + auto type = ConvertTensorType(data.dtype, data.dims); + llvm::ArrayRef raw(data.raw_data); + + switch (data.dtype) { + case onnx::TensorProto::FLOAT: { + auto *p = reinterpret_cast(raw.data()); + return mlir::DenseElementsAttr::get( + type, llvm::ArrayRef(p, raw.size() / sizeof(float))); + } + case onnx::TensorProto::INT64: { + auto *p = reinterpret_cast(raw.data()); + return mlir::DenseElementsAttr::get( + type, llvm::ArrayRef(p, raw.size() / sizeof(int64_t))); + } + case onnx::TensorProto::INT32: { + auto *p = reinterpret_cast(raw.data()); + return mlir::DenseElementsAttr::get( + type, llvm::ArrayRef(p, raw.size() / sizeof(int32_t))); + } + default: + llvm_unreachable("unsupported raw data type"); + } +} + +void HighLevelMLIRBuilder::BuildRelu(const tc::graph::OperationEntity &op, + mlir::Location loc) { + mlir::Value in = GetValue(op.inputs[0]); + auto resultType = in.getType(); + auto mlirOp = mlir::nn::ReluOp::create(builder_, loc, resultType, in); + valueMap_[op.outputs[0]] = mlirOp.getResult(); +} + +void HighLevelMLIRBuilder::BuildAdd(const tc::graph::OperationEntity &op, + mlir::Location loc) { + mlir::Value lhs = GetValue(op.inputs[0]); + mlir::Value rhs = GetValue(op.inputs[1]); + auto lhsTy = llvm::cast(lhs.getType()); + auto rhsTy = llvm::cast(rhs.getType()); + mlir::Type resultType = + (lhsTy.getRank() >= rhsTy.getRank()) ? lhs.getType() : rhs.getType(); + auto mlirOp = mlir::nn::AddOp::create(builder_, loc, resultType, lhs, rhs); + valueMap_[op.outputs[0]] = mlirOp.getResult(); +} + +void HighLevelMLIRBuilder::BuildMul(const tc::graph::OperationEntity &op, + mlir::Location loc) { + mlir::Value lhs = GetValue(op.inputs[0]); + mlir::Value rhs = GetValue(op.inputs[1]); + auto lhsTy = llvm::cast(lhs.getType()); + auto rhsTy = llvm::cast(rhs.getType()); + mlir::Type resultType = + (lhsTy.getRank() >= rhsTy.getRank()) ? lhs.getType() : rhs.getType(); + auto mlirOp = mlir::nn::MulOp::create(builder_, loc, resultType, lhs, rhs); + valueMap_[op.outputs[0]] = mlirOp.getResult(); +} + +void HighLevelMLIRBuilder::BuildMatMul(const tc::graph::OperationEntity &op, + mlir::Location loc) { + mlir::Value A = GetValue(op.inputs[0]); + mlir::Value B = GetValue(op.inputs[1]); + auto aTy = llvm::cast(A.getType()); + auto bTy = llvm::cast(B.getType()); + llvm::SmallVector shape(aTy.getShape()); + shape.back() = bTy.getShape().back(); + auto resultType = mlir::RankedTensorType::get(shape, aTy.getElementType()); + auto mlirOp = mlir::nn::MatMulOp::create(builder_, loc, resultType, A, B); + valueMap_[op.outputs[0]] = mlirOp.getResult(); +} + +void HighLevelMLIRBuilder::BuildGemm(const tc::graph::OperationEntity &op, + mlir::Location loc) { + mlir::Value A = GetValue(op.inputs[0]); + mlir::Value B = GetValue(op.inputs[1]); + mlir::Value C = + (op.inputs.size() > 2) ? GetValue(op.inputs[2]) : mlir::Value{}; + + int64_t transA = GetAttrInt(op, "transA", 0); + int64_t transB = GetAttrInt(op, "transB", 0); + float alpha = GetAttrFloat(op, "alpha", 1.0f); + float beta = GetAttrFloat(op, "beta", 1.0f); + + auto aTy = llvm::cast(A.getType()); + auto bTy = llvm::cast(B.getType()); + int64_t M = transA ? aTy.getDimSize(1) : aTy.getDimSize(0); + int64_t N = transB ? bTy.getDimSize(0) : bTy.getDimSize(1); + auto resultType = mlir::RankedTensorType::get({M, N}, aTy.getElementType()); + + auto mlirOp = mlir::nn::GemmOp::create( + builder_, loc, resultType, A, B, C, builder_.getI64IntegerAttr(transA), + builder_.getI64IntegerAttr(transB), builder_.getF32FloatAttr(alpha), + builder_.getF32FloatAttr(beta)); + valueMap_[op.outputs[0]] = mlirOp.getResult(); +} + +void HighLevelMLIRBuilder::BuildConv(const tc::graph::OperationEntity &op, + mlir::Location loc) { + mlir::Value input = GetValue(op.inputs[0]); + mlir::Value weight = GetValue(op.inputs[1]); + mlir::Value bias = + (op.inputs.size() > 2) ? GetValue(op.inputs[2]) : mlir::Value{}; + + auto inputTy = llvm::cast(input.getType()); + auto weightTy = llvm::cast(weight.getType()); + + int64_t spatialDims = inputTy.getRank() - 2; + + auto strides = GetAttrInts(op, "strides", spatialDims, 1); + auto dilations = GetAttrInts(op, "dilations", spatialDims, 1); + auto pads = GetAttrInts(op, "pads", spatialDims * 2, 0); + auto group = GetAttrInt(op, "group", 1); + auto autoPad = GetAttrString(op, "auto_pad", "NOTSET"); + + int64_t N = inputTy.getDimSize(0); + int64_t Cout = weightTy.getDimSize(0); + + llvm::SmallVector outputShape = {N, Cout}; + for (int i = 0; i < spatialDims; ++i) { + int64_t in = inputTy.getDimSize(2 + i); + int64_t k = weightTy.getDimSize(2 + i); + int64_t pb = pads[i]; + int64_t pe = pads[i + spatialDims]; + int64_t s = strides[i]; + int64_t d = dilations[i]; + + if (in == mlir::ShapedType::kDynamic || k == mlir::ShapedType::kDynamic) { + outputShape.push_back(mlir::ShapedType::kDynamic); + continue; + } + int64_t out = (in + pb + pe - d * (k - 1) - 1) / s + 1; + outputShape.push_back(out); + } + + auto resultType = + mlir::RankedTensorType::get(outputShape, inputTy.getElementType()); + + if (bias) { + auto biasTy = llvm::cast(bias.getType()); + if (biasTy.getRank() != 1 || biasTy.getDimSize(0) != Cout) { + llvm::errs() << "Conv bias must have shape [" << Cout << "]\n"; + llvm::report_fatal_error("Invalid Conv bias shape"); + } + } + + auto mlirOp = mlir::nn::ConvOp::create( + builder_, loc, resultType, input, weight, bias, + mlir::DenseI64ArrayAttr::get(builder_.getContext(), strides), + mlir::DenseI64ArrayAttr::get(builder_.getContext(), dilations), + mlir::DenseI64ArrayAttr::get(builder_.getContext(), pads), + builder_.getI64IntegerAttr(group), builder_.getStringAttr(autoPad)); + valueMap_[op.outputs[0]] = mlirOp.getResult(); +} + +void HighLevelMLIRBuilder::BuildTranspose(const tc::graph::OperationEntity &op, + mlir::Location loc) { + mlir::Value input = GetValue(op.inputs[0]); + auto inputTy = llvm::cast(input.getType()); + int64_t rank = inputTy.getRank(); + + auto perm = GetAttrInts(op, "perm", 0, 0); + if (perm.empty()) { + perm.resize(rank); + for (int64_t i = 0; i < rank; ++i) + perm[i] = rank - 1 - i; + } + + llvm::SmallVector outShape(rank); + for (int64_t i = 0; i < rank; ++i) + outShape[i] = inputTy.getDimSize(perm[i]); + auto resultType = + mlir::RankedTensorType::get(outShape, inputTy.getElementType()); + + auto mlirOp = mlir::nn::TransposeOp::create( + builder_, loc, resultType, input, + mlir::DenseI64ArrayAttr::get(builder_.getContext(), perm)); + valueMap_[op.outputs[0]] = mlirOp.getResult(); +} + +void HighLevelMLIRBuilder::Visit(const tc::graph::Graph &graph) { + llvm::SmallVector inputTypes, outputTypes; + for (tc::graph::EntityId id : graph.Inputs()) { + const auto *tensor = graph.GetTensor(id); + if (tensor && !tensor->is_initializer) { + inputTypes.push_back(ConvertTensorType(tensor->dtype, tensor->shape)); + } + } + for (tc::graph::EntityId id : graph.Outputs()) { + const auto *tensor = graph.GetTensor(id); + if (tensor) { + outputTypes.push_back(ConvertTensorType(tensor->dtype, tensor->shape)); + } + } + + std::string funcName = "main"; + auto func = mlir::func::FuncOp::create( + builder_.getUnknownLoc(), funcName, + builder_.getFunctionType(inputTypes, outputTypes)); + builder_.insert(func); + + auto *block = func.addEntryBlock(); + builder_.setInsertionPointToStart(block); + + size_t argIdx = 0; + for (tc::graph::EntityId id : graph.Inputs()) { + const auto *tensor = graph.GetTensor(id); + if (tensor && !tensor->is_initializer) { + valueMap_[id] = block->getArgument(argIdx++); + tensorNameToId_[tensor->name] = id; + } + } +} + +void HighLevelMLIRBuilder::Visit(const tc::graph::TensorEntity &tensor) { + if (!tensor.is_initializer || !tensor.data.has_value()) + return; + auto it = valueMap_.find(tensor.id); + if (it != valueMap_.end()) + return; + auto attr = ConvertTensorData(*tensor.data); + auto op = + mlir::arith::ConstantOp::create(builder_, builder_.getUnknownLoc(), attr); + valueMap_[tensor.id] = op.getResult(); + tensorNameToId_[tensor.name] = tensor.id; +} + +void HighLevelMLIRBuilder::Visit(const tc::graph::OperationEntity &op) { + auto loc = builder_.getUnknownLoc(); + const auto &opType = op.op_type; + + if (opType == "Relu") + BuildRelu(op, loc); + else if (opType == "Add") + BuildAdd(op, loc); + else if (opType == "Mul") + BuildMul(op, loc); + else if (opType == "MatMul") + BuildMatMul(op, loc); + else if (opType == "Gemm") + BuildGemm(op, loc); + else if (opType == "Conv") + BuildConv(op, loc); + else if (opType == "Transpose") + BuildTranspose(op, loc); + else + llvm::errs() << "[warn] Unsupported op: " << opType << "\n"; +} + +void HighLevelMLIRBuilder::Visit(const tc::graph::ConstantEntity &constant) { + /* Do nothing */ +} + +void HighLevelMLIRBuilder::Finalize(const tc::graph::Graph &graph) { + llvm::SmallVector outputs; + for (tc::graph::EntityId id : graph.Outputs()) { + auto it = valueMap_.find(id); + if (it != valueMap_.end()) + outputs.push_back(it->second); + else + llvm::errs() << "Warning: output " << id << " not found in valueMap\n"; + } + mlir::func::ReturnOp::create(builder_, builder_.getUnknownLoc(), outputs); +} +} // namespace tc::conversion \ No newline at end of file diff --git a/lib/Converter/CMakeLists.txt b/lib/Converter/CMakeLists.txt deleted file mode 100644 index 85ee3d3..0000000 --- a/lib/Converter/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -add_library(tc-converter MLIRBuilder.cpp GraphBuilder.cpp) - -target_link_libraries(tc-converter - onnx) - -target_include_directories(tc-converter PUBLIC - ${PROJECT_SOURCE_DIR}/include - ${CMAKE_CURRENT_BINARY_DIR}/../Dialect - ) \ No newline at end of file diff --git a/lib/Converter/MLIRBuilder.cpp b/lib/Converter/MLIRBuilder.cpp deleted file mode 100644 index 752b7f5..0000000 --- a/lib/Converter/MLIRBuilder.cpp +++ /dev/null @@ -1,366 +0,0 @@ -#include "TensorCompiler/Converter/MLIRBuilder.hpp" -#include "TensorCompiler/Dialect/NNDialect.hpp" - -#include -#include - -namespace tc::converter::onnx_to_high_mlir { -static const onnx::AttributeProto *FindAttr(const onnx::NodeProto &node, - std::string_view name) { - for (const auto &attr : node.attribute()) - if (attr.name() == name) - return &attr; - return nullptr; -} - -static int64_t GetAttrInt(const onnx::NodeProto &node, std::string_view name, - int64_t defaultVal) { - const auto *a = FindAttr(node, name); - return (a && a->type() == onnx::AttributeProto::INT) ? a->i() : defaultVal; -} - -static float GetAttrFloat(const onnx::NodeProto &node, std::string_view name, - float defaultVal) { - const auto *a = FindAttr(node, name); - return (a && a->type() == onnx::AttributeProto::FLOAT) ? a->f() : defaultVal; -} - -static std::string getAttrString(const onnx::NodeProto &node, - std::string_view name, - std::string_view defaultVal) { - const auto *a = FindAttr(node, name); - return (a && a->type() == onnx::AttributeProto::STRING) - ? a->s() - : std::string(defaultVal); -} - -static llvm::SmallVector GetAttrInts(const onnx::NodeProto &node, - std::string_view name, - size_t fillLen = 0, - int64_t fillVal = 0) { - const auto *a = FindAttr(node, name); - if (a && a->type() == onnx::AttributeProto::INTS) - return {a->ints().begin(), a->ints().end()}; - return llvm::SmallVector(fillLen, fillVal); -} - -mlir::Value MLIRBuilder::FindValue(const std::string &name) const { - if (name.empty()) - return mlir::Value{}; - auto it = valueMap_.find(name); - return (it != valueMap_.end()) ? it->second : mlir::Value{}; -} - -mlir::Type MLIRBuilder::ConvertElemType(int onnx_dtype) { - switch (onnx_dtype) { - case onnx::TensorProto::FLOAT: - return builder_.getF32Type(); - case onnx::TensorProto::DOUBLE: - return builder_.getF64Type(); - case onnx::TensorProto::INT32: - return builder_.getI32Type(); - case onnx::TensorProto::INT64: - return builder_.getI64Type(); - case onnx::TensorProto::INT8: - return builder_.getI8Type(); - case onnx::TensorProto::UINT8: - return builder_.getIntegerType(8, false); - case onnx::TensorProto::BOOL: - return builder_.getI1Type(); - default: - llvm::errs() << "Unsupported dtype: " << onnx_dtype << "\n"; - llvm_unreachable("unsupported tensor dtype"); - } -} - -mlir::RankedTensorType -MLIRBuilder::ConvertTensorType(int onnx_dtype, llvm::ArrayRef shape) { - return mlir::RankedTensorType::get(shape, ConvertElemType(onnx_dtype)); -} - -mlir::DenseElementsAttr -MLIRBuilder::ConvertTensor(const onnx::TensorProto &tensor) { - llvm::SmallVector shape; - for (size_t i = 0; i < tensor.dims_size(); ++i) - shape.push_back(tensor.dims(i)); - auto type = ConvertTensorType(tensor.data_type(), shape); - - if (tensor.float_data_size() > 0) { - llvm::SmallVector data(tensor.float_data().begin(), - tensor.float_data().end()); - return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(data)); - } - if (tensor.has_raw_data()) { - const auto &raw = tensor.raw_data(); - switch (tensor.data_type()) { - case onnx::TensorProto::FLOAT: { - auto *p = reinterpret_cast(raw.data()); - return mlir::DenseElementsAttr::get( - type, llvm::ArrayRef(p, raw.size() / sizeof(float))); - } - case onnx::TensorProto::INT64: { - auto *p = reinterpret_cast(raw.data()); - return mlir::DenseElementsAttr::get( - type, llvm::ArrayRef(p, raw.size() / sizeof(int64_t))); - } - case onnx::TensorProto::INT32: { - auto *p = reinterpret_cast(raw.data()); - return mlir::DenseElementsAttr::get( - type, llvm::ArrayRef(p, raw.size() / sizeof(int32_t))); - } - default: - llvm_unreachable("unsupported raw data type"); - } - } - if (tensor.int64_data_size() > 0) { - llvm::SmallVector data(tensor.int64_data().begin(), - tensor.int64_data().end()); - return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(data)); - } - if (tensor.int32_data_size() > 0) { - llvm::SmallVector data(tensor.int32_data().begin(), - tensor.int32_data().end()); - return mlir::DenseElementsAttr::get(type, llvm::ArrayRef(data)); - } - llvm_unreachable("unsupported tensor data format"); -} - -mlir::RankedTensorType -MLIRBuilder::ConvertValueInfo(const onnx::ValueInfoProto &info) { - assert(info.has_type() && info.type().has_tensor_type()); - const auto &tt = info.type().tensor_type(); - llvm::SmallVector shape; - if (tt.has_shape()) - for (const auto &dim : tt.shape().dim()) - shape.push_back(dim.has_dim_value() ? dim.dim_value() : -1); - return ConvertTensorType(tt.elem_type(), shape); -} - -void MLIRBuilder::BuildRelu(const onnx::NodeProto &node, mlir::Location loc) { - mlir::Value in = valueMap_.at(node.input(0)); - auto op = mlir::nn::ReluOp::create(builder_, loc, in.getType(), in); - valueMap_[node.output(0)] = op.getResult(); -} - -void MLIRBuilder::BuildAdd(const onnx::NodeProto &node, mlir::Location loc) { - mlir::Value lhs = valueMap_.at(node.input(0)); - mlir::Value rhs = valueMap_.at(node.input(1)); - auto lhsTy = llvm::cast(lhs.getType()); - auto rhsTy = llvm::cast(rhs.getType()); - mlir::Type resultType = - lhsTy.getRank() >= rhsTy.getRank() ? lhs.getType() : rhs.getType(); - auto op = mlir::nn::AddOp::create(builder_, loc, resultType, lhs, rhs); - valueMap_[node.output(0)] = op.getResult(); -} - -void MLIRBuilder::BuildMul(const onnx::NodeProto &node, mlir::Location loc) { - mlir::Value lhs = valueMap_.at(node.input(0)); - mlir::Value rhs = valueMap_.at(node.input(1)); - auto lhsTy = llvm::cast(lhs.getType()); - auto rhsTy = llvm::cast(rhs.getType()); - mlir::Type resultType = - lhsTy.getRank() >= rhsTy.getRank() ? lhs.getType() : rhs.getType(); - auto op = mlir::nn::MulOp::create(builder_, loc, resultType, lhs, rhs); - valueMap_[node.output(0)] = op.getResult(); -} - -void MLIRBuilder::BuildMatMul(const onnx::NodeProto &node, mlir::Location loc) { - mlir::Value A = valueMap_.at(node.input(0)); - mlir::Value B = valueMap_.at(node.input(1)); - auto aTy = llvm::cast(A.getType()); - auto bTy = llvm::cast(B.getType()); - llvm::SmallVector shape(aTy.getShape()); - shape.back() = bTy.getShape().back(); - auto resultType = mlir::RankedTensorType::get(shape, aTy.getElementType()); - auto op = mlir::nn::MatMulOp::create(builder_, loc, resultType, A, B); - valueMap_[node.output(0)] = op.getResult(); -} - -void MLIRBuilder::BuildGemm(const onnx::NodeProto &node, mlir::Location loc) { - mlir::Value A = valueMap_.at(node.input(0)); - mlir::Value B = valueMap_.at(node.input(1)); - mlir::Value C = - (node.input_size() > 2) ? FindValue(node.input(2)) : mlir::Value{}; - - int64_t transA = GetAttrInt(node, "transA", 0); - int64_t transB = GetAttrInt(node, "transB", 0); - float alpha = GetAttrFloat(node, "alpha", 1.0f); - float beta = GetAttrFloat(node, "beta", 1.0f); - - auto aTy = llvm::cast(A.getType()); - auto bTy = llvm::cast(B.getType()); - int64_t M = transA ? aTy.getDimSize(1) : aTy.getDimSize(0); - int64_t N = transB ? bTy.getDimSize(0) : bTy.getDimSize(1); - auto resultType = mlir::RankedTensorType::get({M, N}, aTy.getElementType()); - - auto op = mlir::nn::GemmOp::create( - builder_, loc, resultType, A, B, C, builder_.getI64IntegerAttr(transA), - builder_.getI64IntegerAttr(transB), builder_.getF32FloatAttr(alpha), - builder_.getF32FloatAttr(beta)); - valueMap_[node.output(0)] = op.getResult(); -} - -void MLIRBuilder::BuildConv(const onnx::NodeProto &node, mlir::Location loc) { - mlir::Value input = valueMap_.at(node.input(0)); - mlir::Value weight = valueMap_.at(node.input(1)); - mlir::Value bias = - (node.input_size() > 2) ? FindValue(node.input(2)) : mlir::Value{}; - - auto inputTy = llvm::cast(input.getType()); - auto weightTy = llvm::cast(weight.getType()); - - int64_t spatialDims = inputTy.getRank() - 2; - - auto strides = GetAttrInts(node, "strides", spatialDims, 1); - auto dilations = GetAttrInts(node, "dilations", spatialDims, 1); - auto pads = GetAttrInts(node, "pads", spatialDims * 2, 0); - auto group = GetAttrInt(node, "group", 1); - auto autoPad = getAttrString(node, "auto_pad", "NOTSET"); - - int64_t N = inputTy.getDimSize(0); - int64_t Cout = weightTy.getDimSize(0); - - auto inputShape = inputTy.getShape(); - auto kernelShape = weightTy.getShape(); - - llvm::SmallVector outputShape = {N, Cout}; - - for (int i = 0; i < spatialDims; ++i) { - int64_t in = inputShape[2 + i]; - int64_t k = kernelShape[2 + i]; - int64_t pb = pads[i]; - int64_t pe = pads[i + spatialDims]; - int64_t s = strides[i]; - int64_t d = dilations[i]; - - if (in == mlir::ShapedType::kDynamic || k == mlir::ShapedType::kDynamic) { - outputShape.push_back(mlir::ShapedType::kDynamic); - continue; - } - - int64_t out = (in + pb + pe - d * (k - 1) - 1) / s + 1; - outputShape.push_back(out); - } - - auto resultType = - mlir::RankedTensorType::get(outputShape, inputTy.getElementType()); - - if (bias) { - auto biasTy = llvm::cast(bias.getType()); - if (biasTy.getRank() != 1 || biasTy.getDimSize(0) != Cout) { - llvm::errs() << "Conv bias must have shape [" << Cout << "]\n"; - llvm::report_fatal_error("Invalid Conv bias shape"); - } - } - - auto op = mlir::nn::ConvOp::create( - builder_, loc, resultType, input, weight, bias, - mlir::DenseI64ArrayAttr::get(builder_.getContext(), strides), - mlir::DenseI64ArrayAttr::get(builder_.getContext(), dilations), - mlir::DenseI64ArrayAttr::get(builder_.getContext(), pads), - builder_.getI64IntegerAttr(group), builder_.getStringAttr(autoPad)); - - valueMap_[node.output(0)] = op.getResult(); -} - -void MLIRBuilder::BuildTranspose(const onnx::NodeProto &node, - mlir::Location loc) { - mlir::Value input = valueMap_.at(node.input(0)); - auto inputTy = llvm::cast(input.getType()); - int64_t rank = inputTy.getRank(); - - auto perm = GetAttrInts(node, "perm", 0, 0); - if (perm.empty()) { - perm.resize(rank); - for (int64_t i = 0; i < rank; ++i) - perm[i] = rank - 1 - i; - } - - llvm::SmallVector outShape(rank); - for (int64_t i = 0; i < rank; ++i) - outShape[i] = inputTy.getDimSize(perm[i]); - auto resultType = - mlir::RankedTensorType::get(outShape, inputTy.getElementType()); - - auto op = mlir::nn::TransposeOp::create( - builder_, loc, resultType, input, - mlir::DenseI64ArrayAttr::get(builder_.getContext(), perm)); - valueMap_[node.output(0)] = op.getResult(); -} - -MLIRBuilder::MLIRBuilder(mlir::MLIRContext &ctx) - : ctx_(ctx), builder_(&ctx_), - module_(mlir::ModuleOp::create(builder_.getUnknownLoc())) { - builder_.setInsertionPointToEnd(module_.getBody()); -} - -const mlir::ModuleOp &MLIRBuilder::GetModule() const { return module_; } - -void MLIRBuilder::Visit(const onnx::ModelProto &) { /* Do nothing */ } - -void MLIRBuilder::Visit(const onnx::GraphProto &graph) { - for (const auto &init : graph.initializer()) - initializerNames_.insert(init.name()); - - llvm::SmallVector inputTypes, outputTypes; - for (const auto &in : graph.input()) - if (!initializerNames_.count(in.name())) - inputTypes.push_back(ConvertValueInfo(in)); - for (const auto &out : graph.output()) - outputTypes.push_back(ConvertValueInfo(out)); - - auto funcName = graph.name().empty() ? "main" : graph.name(); - auto func = mlir::func::FuncOp::create( - builder_.getUnknownLoc(), funcName, - builder_.getFunctionType(inputTypes, outputTypes)); - builder_.insert(func); - - auto *block = func.addEntryBlock(); - builder_.setInsertionPointToStart(block); - - unsigned argIdx = 0; - for (const auto &in : graph.input()) - if (!initializerNames_.count(in.name())) - valueMap_[in.name()] = block->getArgument(argIdx++); -} - -void MLIRBuilder::Visit(const onnx::ValueInfoProto &) { /* Do nothing */ } - -void MLIRBuilder::Visit(const onnx::TensorProto &tensor) { - auto attr = ConvertTensor(tensor); - auto op = - mlir::arith::ConstantOp::create(builder_, builder_.getUnknownLoc(), attr); - valueMap_[tensor.name()] = op.getResult(); -} - -void MLIRBuilder::Visit(const onnx::NodeProto &node) { - auto loc = builder_.getUnknownLoc(); - const auto &opType = node.op_type(); - - if (opType == "Relu") - BuildRelu(node, loc); - else if (opType == "Add") - BuildAdd(node, loc); - else if (opType == "Mul") - BuildMul(node, loc); - else if (opType == "MatMul") - BuildMatMul(node, loc); - else if (opType == "Gemm") - BuildGemm(node, loc); - else if (opType == "Conv") - BuildConv(node, loc); - else if (opType == "Transpose") - BuildTranspose(node, loc); - else - llvm::errs() << "[warn] Unsupported op: " << opType << "\n"; -} - -void MLIRBuilder::Visit(const onnx::AttributeProto &) { /* Do nothing */ } - -void MLIRBuilder::Finalize(const onnx::GraphProto &graph) { - llvm::SmallVector outputs; - for (size_t i = 0; i < graph.output_size(); ++i) - outputs.push_back(valueMap_.at(graph.output(i).name())); - mlir::func::ReturnOp::create(builder_, builder_.getUnknownLoc(), outputs); -} -} // namespace tc::converter::onnx_to_high_mlir \ No newline at end of file diff --git a/lib/Frontend/CMakeLists.txt b/lib/Frontend/CMakeLists.txt index c49069d..6385647 100644 --- a/lib/Frontend/CMakeLists.txt +++ b/lib/Frontend/CMakeLists.txt @@ -1,4 +1,4 @@ -add_library(tc-frontend ONNXDumper.cpp) +add_library(tc-frontend GraphBuilder.cpp ONNXDumper.cpp) target_link_libraries(tc-frontend onnx diff --git a/lib/Converter/GraphBuilder.cpp b/lib/Frontend/GraphBuilder.cpp similarity index 73% rename from lib/Converter/GraphBuilder.cpp rename to lib/Frontend/GraphBuilder.cpp index b6229ec..a4e31ae 100644 --- a/lib/Converter/GraphBuilder.cpp +++ b/lib/Frontend/GraphBuilder.cpp @@ -1,8 +1,8 @@ -#include "TensorCompiler/Converter/GraphBuilder.hpp" +#include "TensorCompiler/Frontend/GraphBuilder.hpp" #include -namespace tc::converter::onnx_to_graph { +namespace tc::frontend { using tc::graph::AttrValue; using tc::graph::DoubleList; using tc::graph::EntityId; @@ -13,6 +13,7 @@ using tc::graph::TensorData; TensorData GraphBuilder::ParseTensor(const onnx::TensorProto &t) { TensorData td; td.name = t.name(); + td.dtype = t.data_type(); td.dims.reserve(t.dims_size()); for (auto d : t.dims()) @@ -74,52 +75,59 @@ AttrValue GraphBuilder::ParseAttribute(const onnx::AttributeProto &a) { } } -EntityId GraphBuilder::EnsureTensor(const std::string &name, - const std::vector &shape) { +EntityId +GraphBuilder::EnsureTensor(const std::string &name, + const std::vector &shape = {}, + int32_t dtype = /* onnx::TensorProto::FLOAT */ 1) { if (name.empty()) return -1; - return graph_.AddTensor(name, shape, initializerNames_.count(name) != 0); + return graph_.AddTensor(name, shape, dtype, + initializerNames_.count(name) != 0); } void GraphBuilder::Visit(const onnx::ModelProto &) { /* Do nothing */ } void GraphBuilder::Visit(const onnx::GraphProto &g) { initializerNames_.clear(); + inputNames_.clear(); + outputNames_.clear(); for (auto &init : g.initializer()) initializerNames_.insert(init.name()); for (auto &in : g.input()) - if (!initializerNames_.count(in.name())) - graph_.AddInput(EnsureTensor(in.name())); + inputNames_.insert(in.name()); for (auto &out : g.output()) - graph_.AddOutput(EnsureTensor(out.name())); + outputNames_.insert(out.name()); } void GraphBuilder::Visit(const onnx::ValueInfoProto &value) { - if (!value.name().empty() && value.has_type() && - value.type().has_tensor_type()) { + if (value.name().empty() || !value.has_type() || + !value.type().has_tensor_type()) + return; - std::vector shape; + const auto &tt = value.type().tensor_type(); + int32_t dtype = tt.elem_type(); - const auto &tt = value.type().tensor_type(); - if (tt.has_shape()) { - for (auto &dim : tt.shape().dim()) - shape.push_back(dim.has_dim_value() ? dim.dim_value() : -1); - } - - EnsureTensor(value.name(), shape); + std::vector shape; + if (tt.has_shape()) { + for (auto &dim : tt.shape().dim()) + shape.push_back(dim.has_dim_value() ? dim.dim_value() : -1); } + + EntityId tid = EnsureTensor(value.name(), shape, dtype); + + if (inputNames_.count(value.name()) && !initializerNames_.count(value.name())) + graph_.AddInput(tid); + if (outputNames_.count(value.name())) + graph_.AddOutput(tid); } void GraphBuilder::Visit(const onnx::TensorProto &tensor) { TensorData td = ParseTensor(tensor); - - graph_.AddConstant(td); - - EntityId tid = EnsureTensor(td.name, td.dims); + EntityId tid = EnsureTensor(td.name, td.dims, td.dtype); if (auto *t = graph_.GetTensor(tid)) { t->is_initializer = true; @@ -153,4 +161,4 @@ void GraphBuilder::Visit(const onnx::AttributeProto &) { /* Do nothing */ } void GraphBuilder::Finalize(const onnx::GraphProto &) { /* Do nothing */ } -} // namespace tc::converter::onnx_to_graph \ No newline at end of file +} // namespace tc::frontend \ No newline at end of file diff --git a/lib/Graph/CMakeLists.txt b/lib/Graph/CMakeLists.txt index e384c2a..e4b01d8 100644 --- a/lib/Graph/CMakeLists.txt +++ b/lib/Graph/CMakeLists.txt @@ -1,4 +1,4 @@ -add_library(tc-graph IR.cpp Exporter.cpp) +add_library(tc-graph Graph.cpp GraphDumper.cpp) target_include_directories(tc-graph PUBLIC ${PROJECT_SOURCE_DIR}/include) \ No newline at end of file diff --git a/lib/Graph/Exporter.cpp b/lib/Graph/Exporter.cpp deleted file mode 100644 index 13337ab..0000000 --- a/lib/Graph/Exporter.cpp +++ /dev/null @@ -1,142 +0,0 @@ -#include "TensorCompiler/Graph/Exporter.hpp" -#include -#include - -namespace tc::graph { - -static constexpr size_t PAD = 2; -static constexpr size_t OP_PAD = 4; -static constexpr size_t IO_PAD = 6; - -static std::string indent(size_t n) { - return std::string(n, ' '); -} - -static std::string ShapeToStr(const std::vector& shape) { - std::ostringstream ss; - ss << "["; - for (size_t i = 0; i < shape.size(); ++i) { - if (i) ss << ", "; - ss << shape[i]; - } - ss << "]"; - return ss.str(); -} - -std::string DumpGraph(const Graph& g) { - std::ostringstream os; - - os << "=== Graph ===\n"; - - os << "\nInputs:\n"; - for (auto id : g.Inputs()) { - if (auto* t = g.GetTensor(id)) - os << indent(PAD) << t->name << " " << ShapeToStr(t->shape) << "\n"; - } - - os << "\nOperations:\n"; - - for (const auto& e : g.Entities()) { - if (e.kind != EntityKind::Operation) - continue; - - const auto& op = std::get(e.entity); - - os << "\n" << indent(PAD) - << "[" << op.op_type << "] " << op.name << "\n"; - - os << indent(OP_PAD) << "inputs:\n"; - for (auto in : op.inputs) - if (auto* t = g.GetTensor(in)) - os << indent(IO_PAD) << t->name << "\n"; - - os << indent(OP_PAD) << "outputs:\n"; - for (auto out : op.outputs) - if (auto* t = g.GetTensor(out)) - os << indent(IO_PAD) << t->name << "\n"; - } - - os << "\nOutputs:\n"; - for (auto id : g.Outputs()) { - if (auto* t = g.GetTensor(id)) - os << indent(PAD) << t->name << "\n"; - } - - os << "\nConstants:\n"; - for (const auto &e : g.Entities()) { - if (e.kind != EntityKind::Tensor) - continue; - - const auto &t = std::get(e.entity); - - if (!t.is_initializer) - continue; - - os << indent(PAD) - << t.name << " " - << ShapeToStr(t.shape) - << "\n"; - } - - os << "\nTensors:\n"; - for (const auto &e : g.Entities()) { - if (e.kind != EntityKind::Tensor) - continue; - - const auto &t = std::get(e.entity); - - os << indent(PAD) - << t.name - << " " << ShapeToStr(t.shape); - - if (t.is_initializer) - os << " (init)"; - - os << "\n"; - } - - return os.str(); -} - -std::string ToDot(const Graph& g) { - std::ostringstream os; - os << "digraph G {\n"; - os << " rankdir=LR;\n"; - - for (size_t i = 0; i < g.Entities().size(); ++i) { - const auto& e = g.Entities()[i]; - - if (e.kind == EntityKind::Tensor) { - const auto& t = std::get(e.entity); - os << " t" << i - << " [shape=oval,label=\"" << t.name << "\"];\n"; - } - - if (e.kind == EntityKind::Operation) { - const auto& op = std::get(e.entity); - - os << " op" << i - << " [shape=box,label=\"" << op.op_type << "\"];\n"; - - for (auto in : op.inputs) - os << " t" << in << " -> op" << i << ";\n"; - - for (auto out : op.outputs) - os << " op" << i << " -> t" << out << ";\n"; - } - } - - os << "}\n"; - return os.str(); -} - -bool SaveDot(const Graph& g, const std::string& file) { - std::ofstream f(file); - if (!f) - return false; - f << ToDot(g); - return true; -} - - -} // namespace tc::graph \ No newline at end of file diff --git a/lib/Graph/IR.cpp b/lib/Graph/Graph.cpp similarity index 58% rename from lib/Graph/IR.cpp rename to lib/Graph/Graph.cpp index 3603117..cb6956c 100644 --- a/lib/Graph/IR.cpp +++ b/lib/Graph/Graph.cpp @@ -1,13 +1,15 @@ -#include "TensorCompiler/Graph/IR.hpp" +#include "TensorCompiler/Graph/Graph.hpp" +#include "TensorCompiler/Graph/GraphVisitor.hpp" namespace tc::graph { + EntityId Graph::AddTensor(const std::string &name, std::vector shape, - bool is_init) { + int32_t dtype, bool is_init) { if (auto it = tensorByName_.find(name); it != tensorByName_.end()) return it->second; EntityId id = entities_.size(); - TensorEntity t{id, name, shape, is_init, std::nullopt}; + TensorEntity t{id, name, shape, is_init, dtype, std::nullopt}; entities_.emplace_back(EntityKind::Tensor, t); tensorByName_[name] = id; return id; @@ -32,14 +34,32 @@ Graph::AddOperation(const std::string &op, const std::string &name, } TensorEntity *Graph::GetTensor(EntityId id) { - if (entities_[id].kind != EntityKind::Tensor) + if (id >= entities_.size() || entities_[id].kind != EntityKind::Tensor) return nullptr; return &std::get(entities_[id].entity); } const TensorEntity *Graph::GetTensor(EntityId id) const { - if (entities_[id].kind != EntityKind::Tensor) + if (id >= entities_.size() || entities_[id].kind != EntityKind::Tensor) return nullptr; return &std::get(entities_[id].entity); } + +void Graph::Accept(GraphVisitor &visitor) const { + visitor.Visit(*this); + for (const auto &entity : Entities()) { + switch (entity.kind) { + case EntityKind::Tensor: + visitor.Visit(std::get(entity.entity)); + break; + case EntityKind::Operation: + visitor.Visit(std::get(entity.entity)); + break; + case EntityKind::Constant: + visitor.Visit(std::get(entity.entity)); + break; + } + } + visitor.Finalize(*this); +} } // namespace tc::graph \ No newline at end of file diff --git a/lib/Graph/GraphDumper.cpp b/lib/Graph/GraphDumper.cpp new file mode 100644 index 0000000..0e83269 --- /dev/null +++ b/lib/Graph/GraphDumper.cpp @@ -0,0 +1,200 @@ +#include "TensorCompiler/Graph/GraphDumper.hpp" +#include "TensorCompiler/Graph/GraphVisitor.hpp" +#include +#include +#include +#include + +namespace tc::graph { +namespace { + +static std::string indent(size_t n) { return std::string(n, ' '); } + +static std::string ShapeToStr(const std::vector &shape) { + std::ostringstream ss; + ss << "["; + for (size_t i = 0; i < shape.size(); ++i) { + if (i) + ss << ", "; + ss << shape[i]; + } + ss << "]"; + return ss.str(); +} + +class DumpGraphVisitor final : public tc::graph::GraphVisitor { +public: + void Visit(const Graph &graph) override { graph_ = &graph; } + + void Visit(const TensorEntity &tensor) override { + tensors_.push_back(&tensor); + } + + void Visit(const OperationEntity &op) override { ops_.push_back(&op); } + + void Visit(const ConstantEntity &) override { /* Do nothing */ } + + void Finalize(const Graph & /*graph*/) override { + std::ostringstream oss; + oss << "=== Graph ===\n"; + + oss << "\nInputs:\n"; + for (auto id : graph_->Inputs()) { + if (auto *t = graph_->GetTensor(id)) + oss << indent(2) << t->name << " " << ShapeToStr(t->shape) << "\n"; + } + + oss << "\nOperations:\n"; + for (const auto *op : ops_) { + oss << "\n" + << indent(2) << "[" << op->op_type << "] " << op->name << "\n"; + oss << indent(4) << "inputs:\n"; + for (auto in : op->inputs) { + if (auto *t = graph_->GetTensor(in)) + oss << indent(6) << t->name << "\n"; + } + oss << indent(4) << "outputs:\n"; + for (auto out : op->outputs) { + if (auto *t = graph_->GetTensor(out)) + oss << indent(6) << t->name << "\n"; + } + } + + oss << "\nOutputs:\n"; + for (auto id : graph_->Outputs()) { + if (auto *t = graph_->GetTensor(id)) + oss << indent(2) << t->name << "\n"; + } + + oss << "\nConstants:\n"; + for (const auto *t : tensors_) { + if (t->is_initializer) + oss << indent(2) << t->name << " " << ShapeToStr(t->shape) << "\n"; + } + + oss << "\nTensors:\n"; + for (const auto *t : tensors_) { + oss << indent(2) << t->name << " " << ShapeToStr(t->shape); + if (t->is_initializer) + oss << " (init)"; + oss << "\n"; + } + + result_ = oss.str(); + } + + std::string GetResult() const { return result_; } + +private: + const Graph *graph_ = nullptr; + std::vector tensors_; + std::vector ops_; + std::string result_; +}; + +class DotGraphVisitor final : public tc::graph::GraphVisitor { +public: + void Visit(const Graph &graph) override { + graph_ = &graph; + oss_ << "digraph G {\n"; + oss_ << " rankdir=LR;\n"; + } + + void Visit(const TensorEntity &tensor) override { + int id = nextId(); + tensorIds_[&tensor] = id; + oss_ << " t" << id << " [shape=oval,label=\"" << tensor.name << "\"];\n"; + } + + void Visit(const OperationEntity &op) override { + int id = nextId(); + opIds_[&op] = id; + ops_.push_back(&op); + oss_ << " op" << id << " [shape=box,label=\"" << op.op_type << "\"];\n"; + } + + void Visit(const ConstantEntity &) override { /* Do nothing */ } + + void Finalize(const Graph & /*graph*/) override { + for (const auto *op : ops_) { + auto opIt = opIds_.find(op); + if (opIt == opIds_.end()) + continue; + int opId = opIt->second; + + for (auto inId : op->inputs) { + auto *t = graph_->GetTensor(inId); + if (!t) + continue; + auto tIt = tensorIds_.find(t); + if (tIt != tensorIds_.end()) + oss_ << " t" << tIt->second << " -> op" << opId << ";\n"; + } + for (auto outId : op->outputs) { + auto *t = graph_->GetTensor(outId); + if (!t) + continue; + auto tIt = tensorIds_.find(t); + if (tIt != tensorIds_.end()) + oss_ << " op" << opId << " -> t" << tIt->second << ";\n"; + } + } + + oss_ << "}\n"; + result_ = oss_.str(); + } + + std::string GetResult() const { return result_; } + +private: + int nextId() { return idCounter_++; } + + const Graph *graph_ = nullptr; + std::ostringstream oss_; + int idCounter_ = 0; + std::vector ops_; + std::unordered_map tensorIds_; + std::unordered_map opIds_; + std::string result_; +}; + +} // namespace + +std::string DumpGraph(const Graph &g) { + DumpGraphVisitor visitor; + visitor.Visit(g); + for (const auto &e : g.Entities()) { + if (e.kind == EntityKind::Tensor) + visitor.Visit(std::get(e.entity)); + else if (e.kind == EntityKind::Operation) + visitor.Visit(std::get(e.entity)); + else if (e.kind == EntityKind::Constant) + visitor.Visit(std::get(e.entity)); + } + visitor.Finalize(g); + return visitor.GetResult(); +} + +std::string ToDot(const Graph &g) { + DotGraphVisitor visitor; + visitor.Visit(g); + for (const auto &e : g.Entities()) { + if (e.kind == EntityKind::Tensor) + visitor.Visit(std::get(e.entity)); + else if (e.kind == EntityKind::Operation) + visitor.Visit(std::get(e.entity)); + else if (e.kind == EntityKind::Constant) + visitor.Visit(std::get(e.entity)); + } + visitor.Finalize(g); + return visitor.GetResult(); +} + +bool SaveDot(const Graph &g, const std::string &file) { + std::ofstream f(file); + if (!f) + return false; + f << ToDot(g); + return true; +} +} // namespace tc::graph \ No newline at end of file diff --git a/tensor-compiler.cpp b/tensor-compiler.cpp index c45f5cf..d2d122b 100644 --- a/tensor-compiler.cpp +++ b/tensor-compiler.cpp @@ -1,8 +1,8 @@ -#include "TensorCompiler/Converter/GraphBuilder.hpp" -#include "TensorCompiler/Converter/MLIRBuilder.hpp" +#include "TensorCompiler/Conversion/MLIRBuilder.hpp" #include "TensorCompiler/Dialect/NNDialect.hpp" +#include "TensorCompiler/Frontend/GraphBuilder.hpp" #include "TensorCompiler/Frontend/ONNXModel.hpp" -#include "TensorCompiler/Graph/Exporter.hpp" +#include "TensorCompiler/Graph/GraphDumper.hpp" #include #include @@ -20,6 +20,9 @@ static cl::OptionCategory TCOptions("Tensor Compiler Options"); static cl::opt InputModel(cl::Positional, cl::desc(""), cl::Required, cl::cat(TCOptions)); +static cl::opt DumpGraph("graph-dump", cl::desc("Dump graph"), + cl::init(false), cl::cat(TCOptions)); + static cl::opt DumpGraphDot("graph-dot-dump", cl::desc("Dump graph to graph.dot"), cl::init(false), cl::cat(TCOptions)); @@ -34,11 +37,13 @@ int main(int argc, char **argv) { try { tc::frontend::ONNXModel model{InputModel}; - tc::converter::onnx_to_graph::GraphBuilder graphBuilder; + tc::frontend::GraphBuilder graphBuilder; model.Parse(graphBuilder); const auto &graph = graphBuilder.GetGraph(); - std::cout << tc::graph::DumpGraph(graph); + if (DumpGraph) { + llvm::outs() << tc::graph::DumpGraph(graph); + } if (DumpGraphDot) { if (!tc::graph::SaveDot(graph, "graph.dot")) { @@ -48,21 +53,17 @@ int main(int argc, char **argv) { llvm::outs() << "DOT graph saved to graph.dot\n"; } - if (DumpHIR) { - mlir::MLIRContext ctx; - ctx.loadDialect(); - - tc::converter::onnx_to_high_mlir::MLIRBuilder builder(ctx); - model.Parse(builder); + mlir::MLIRContext ctx; + ctx.loadDialect(); - const mlir::ModuleOp &module = builder.GetModule(); + tc::conversion::HighLevelMLIRBuilder builder(ctx); + const mlir::ModuleOp &module = builder.Build(graph); - llvm::outs() << "\n=== HIR Dialect Dump ===\n"; + if (DumpHIR) { + llvm::outs() << "HIR Dialect Dump:\n"; module->print(llvm::outs()); - llvm::outs() << "\n"; } - } catch (const std::exception &e) { llvm::errs() << "Compilation failed: " << e.what() << "\n"; return 1; diff --git a/tests/unit/graph.cpp b/tests/unit/graph.cpp index 7accc63..ef77f01 100644 --- a/tests/unit/graph.cpp +++ b/tests/unit/graph.cpp @@ -1,33 +1,37 @@ -#include "TensorCompiler/Graph/IR.hpp" -#include "TensorCompiler/Graph/Exporter.hpp" +#include "TensorCompiler/Graph/Graph.hpp" +#include "TensorCompiler/Graph/GraphDumper.hpp" #include #include using namespace tc::graph; -static TensorData MakeTensorData(std::string name, std::vector dims) { +static constexpr int32_t FLOAT_DTYPE = 1; // onnx::TensorProto::FLOAT + +static TensorData MakeTensorData(std::string name, std::vector dims, + int32_t dtype = FLOAT_DTYPE) { TensorData td; td.name = std::move(name); td.dims = std::move(dims); + td.dtype = dtype; return td; } -template -const T& GetScalar(const AttrValue& v) { +template const T &GetScalar(const AttrValue &v) { return std::get(std::get(v)); } TEST(GraphEntityTest, AddTensorAndGet) { Graph g; - EntityId id = g.AddTensor("t0", {1, 2, 3}, false); + EntityId id = g.AddTensor("t0", {1, 2, 3}, FLOAT_DTYPE, false); EXPECT_GE(id, 0); TensorEntity *t = g.GetTensor(id); ASSERT_NE(t, nullptr); EXPECT_EQ(t->name, "t0"); - EXPECT_EQ(t->shape, std::vector({1,2,3})); + EXPECT_EQ(t->shape, std::vector({1, 2, 3})); EXPECT_FALSE(t->is_initializer); + EXPECT_EQ(t->dtype, FLOAT_DTYPE); } TEST(GraphEntityTest, AddConstantAndInitializerFlag) { @@ -36,7 +40,7 @@ TEST(GraphEntityTest, AddConstantAndInitializerFlag) { TensorData data = MakeTensorData("w0", {4, 4}); g.AddConstant(data); - EntityId tId = g.AddTensor("w0", {4,4}, true); + EntityId tId = g.AddTensor("w0", {4, 4}, data.dtype, true); TensorEntity *t = g.GetTensor(tId); ASSERT_NE(t, nullptr); EXPECT_TRUE(t->is_initializer); @@ -48,7 +52,8 @@ TEST(GraphEntityTest, AddConstantAndInitializerFlag) { const auto &c = std::get(e.entity); if (c.data.name == "w0") { foundConstant = true; - EXPECT_EQ(c.data.dims, std::vector({4,4})); + EXPECT_EQ(c.data.dims, std::vector({4, 4})); + EXPECT_EQ(c.data.dtype, FLOAT_DTYPE); } } @@ -57,9 +62,9 @@ TEST(GraphEntityTest, AddConstantAndInitializerFlag) { TEST(GraphOperationTest, AddOperationCreatesCorrectEntity) { Graph g; - EntityId a = g.AddTensor("a"); - EntityId b = g.AddTensor("b"); - EntityId out = g.AddTensor("out"); + EntityId a = g.AddTensor("a", {}, FLOAT_DTYPE); + EntityId b = g.AddTensor("b", {}, FLOAT_DTYPE); + EntityId out = g.AddTensor("out", {}, FLOAT_DTYPE); std::unordered_map attrs; attrs["alpha"] = int64_t(2); @@ -73,7 +78,8 @@ TEST(GraphOperationTest, AddOperationCreatesCorrectEntity) { bool found = false; for (const auto &e : ents) { - if (e.kind != EntityKind::Operation) continue; + if (e.kind != EntityKind::Operation) + continue; const auto &op = std::get(e.entity); @@ -101,26 +107,29 @@ TEST(GraphOperationTest, AddOperationCreatesCorrectEntity) { TEST(GraphExporterTest, DumpGraphIncludesConstantsAndIO) { Graph g; - EntityId image = g.AddTensor("image", {}); - EntityId vec = g.AddTensor("vector", {128}); + EntityId image = g.AddTensor("image", {}, FLOAT_DTYPE); + EntityId vec = g.AddTensor("vector", {128}, FLOAT_DTYPE); - TensorData conv_w = MakeTensorData("conv_w", {64,3,7,7}); + TensorData conv_w = MakeTensorData("conv_w", {64, 3, 7, 7}); g.AddConstant(conv_w); - EntityId conv_w_t = g.AddTensor("conv_w", {64,3,7,7}, true); + EntityId conv_w_t = g.AddTensor("conv_w", {64, 3, 7, 7}, conv_w.dtype, true); TensorData conv_b = MakeTensorData("conv_b", {64}); g.AddConstant(conv_b); - EntityId conv_b_t = g.AddTensor("conv_b", {64}, true); + EntityId conv_b_t = g.AddTensor("conv_b", {64}, conv_b.dtype, true); - EntityId c = g.AddTensor("c"); + EntityId c = g.AddTensor("c", {}, FLOAT_DTYPE); g.AddOperation("Conv", "conv1", {image, conv_w_t, conv_b_t}, {c}, {}); - EntityId r = g.AddTensor("r"); + EntityId r = g.AddTensor("r", {}, FLOAT_DTYPE); g.AddOperation("Relu", "relu1", {c}, {r}, {}); - EntityId out = g.AddTensor("output"); + EntityId out = g.AddTensor("output", {}, FLOAT_DTYPE); g.AddOperation("Identity", "out_op", {r}, {out}, {}); + g.AddInput(image); + g.AddOutput(out); + std::string dump = DumpGraph(g); EXPECT_NE(dump.find("Inputs:"), std::string::npos); @@ -139,9 +148,9 @@ TEST(GraphExporterTest, DumpGraphIncludesConstantsAndIO) { TEST(GraphExporterTest, ToDotProducesEdges) { Graph g; - EntityId t0 = g.AddTensor("t0"); - EntityId t1 = g.AddTensor("t1"); - EntityId out = g.AddTensor("out"); + EntityId t0 = g.AddTensor("t0", {}, FLOAT_DTYPE); + EntityId t1 = g.AddTensor("t1", {}, FLOAT_DTYPE); + EntityId out = g.AddTensor("out", {}, FLOAT_DTYPE); g.AddOperation("OpA", "opA", {t0, t1}, {out}, {}); diff --git a/tests/unit/onnx2graph.cpp b/tests/unit/onnx2graph.cpp index 4309472..5c12b1c 100644 --- a/tests/unit/onnx2graph.cpp +++ b/tests/unit/onnx2graph.cpp @@ -1,54 +1,64 @@ +#include "TensorCompiler/Frontend/GraphBuilder.hpp" #include "TensorCompiler/Frontend/ONNXModel.hpp" -#include "TensorCompiler/Converter/GraphBuilder.hpp" -#include "TensorCompiler/Graph/IR.hpp" +#include "TensorCompiler/Graph/Graph.hpp" #include using namespace tc; -using namespace tc::converter::onnx_to_graph; +using namespace tc::frontend; using namespace tc::graph; static onnx::ModelProto BuildSimpleModel() { onnx::ModelProto model; - auto* graph = model.mutable_graph(); + auto *graph = model.mutable_graph(); graph->set_name("test_graph"); - auto* input = graph->add_input(); + auto *input = graph->add_input(); input->set_name("x"); - - auto* weight = graph->add_initializer(); + auto *input_type = input->mutable_type(); + auto *tensor_type = input_type->mutable_tensor_type(); + tensor_type->set_elem_type(onnx::TensorProto::FLOAT); + auto *shape = tensor_type->mutable_shape(); + shape->add_dim()->set_dim_value(2); + shape->add_dim()->set_dim_value(2); + + auto *weight = graph->add_initializer(); weight->set_name("w"); weight->add_dims(2); weight->add_dims(2); weight->set_data_type(onnx::TensorProto::FLOAT); - weight->add_float_data(1); - weight->add_float_data(1); - weight->add_float_data(1); - weight->add_float_data(1); + for (int i = 0; i < 4; ++i) + weight->add_float_data(1.0f); - auto* node = graph->add_node(); + auto *node = graph->add_node(); node->set_op_type("Relu"); node->add_input("x"); node->add_output("y"); - auto* output = graph->add_output(); + auto *output = graph->add_output(); output->set_name("y"); + auto *output_type = output->mutable_type(); + auto *output_tensor_type = output_type->mutable_tensor_type(); + output_tensor_type->set_elem_type(onnx::TensorProto::FLOAT); + auto *output_shape = output_tensor_type->mutable_shape(); + output_shape->add_dim()->set_dim_value(2); + output_shape->add_dim()->set_dim_value(2); return model; } -static Graph BuildGraphFromModel(const onnx::ModelProto& model) { +static Graph BuildGraphFromModel(const onnx::ModelProto &model) { GraphBuilder builder; builder.Visit(model); builder.Visit(model.graph()); - for (const auto& in : model.graph().input()) + for (const auto &in : model.graph().input()) builder.Visit(in); - - for (const auto& init : model.graph().initializer()) + for (const auto &out : model.graph().output()) + builder.Visit(out); + for (const auto &init : model.graph().initializer()) builder.Visit(init); - - for (const auto& node : model.graph().node()) + for (const auto &node : model.graph().node()) builder.Visit(node); builder.Finalize(model.graph()); @@ -57,7 +67,7 @@ static Graph BuildGraphFromModel(const onnx::ModelProto& model) { static Graph ONNXtoGraph(std::string_view fileName) { tc::frontend::ONNXModel model{fileName}; - tc::converter::onnx_to_graph::GraphBuilder builder; + tc::frontend::GraphBuilder builder; model.Parse(builder); return builder.GetGraph(); } @@ -66,7 +76,6 @@ static bool hasOp(const Graph &g, const std::string &opType) { for (const auto &e : g.Entities()) { if (e.kind != EntityKind::Operation) continue; - const auto &op = std::get(e.entity); if (op.op_type == opType) return true; @@ -76,9 +85,13 @@ static bool hasOp(const Graph &g, const std::string &opType) { static size_t countConstants(const Graph &g) { size_t n = 0; - for (auto &e : g.Entities()) - if (e.kind == EntityKind::Constant) - ++n; + for (const auto &e : g.Entities()) { + if (e.kind == EntityKind::Tensor) { + const auto &t = std::get(e.entity); + if (t.is_initializer) + ++n; + } + } return n; } @@ -92,33 +105,28 @@ TEST(ONNXGraphBuilderTest, BuildsGraphStructure) { bool foundTensorX = false; bool foundTensorY = false; bool foundInitializer = false; - bool foundConstant = false; bool foundRelu = false; - for (const auto& e : g.Entities()) { - + for (const auto &e : g.Entities()) { if (e.kind == EntityKind::Tensor) { - const auto& t = std::get(e.entity); - if (t.name == "x") foundTensorX = true; - if (t.name == "y") foundTensorY = true; - if (t.name == "w" && t.is_initializer) foundInitializer = true; - } - - if (e.kind == EntityKind::Constant) { - const auto& c = std::get(e.entity); - if (c.data.name == "w") foundConstant = true; + const auto &t = std::get(e.entity); + if (t.name == "x") + foundTensorX = true; + if (t.name == "y") + foundTensorY = true; + if (t.name == "w" && t.is_initializer) + foundInitializer = true; } - if (e.kind == EntityKind::Operation) { - const auto& op = std::get(e.entity); - if (op.op_type == "Relu") foundRelu = true; + const auto &op = std::get(e.entity); + if (op.op_type == "Relu") + foundRelu = true; } } EXPECT_TRUE(foundTensorX); EXPECT_TRUE(foundTensorY); EXPECT_TRUE(foundInitializer); - EXPECT_TRUE(foundConstant); EXPECT_TRUE(foundRelu); } @@ -128,11 +136,10 @@ TEST(ONNXGraphBuilderTest, DetectsInitializerTensor) { bool foundInitializer = false; - for (const auto& e : g.Entities()) { + for (const auto &e : g.Entities()) { if (e.kind != EntityKind::Tensor) continue; - - const auto& t = std::get(e.entity); + const auto &t = std::get(e.entity); if (t.name == "w") { foundInitializer = true; EXPECT_TRUE(t.is_initializer); @@ -150,12 +157,10 @@ TEST(ONNXGraphBuilderTest, OperationCreatedCorrectly) { bool foundRelu = false; - for (const auto& e : g.Entities()) { + for (const auto &e : g.Entities()) { if (e.kind != EntityKind::Operation) continue; - - const auto& op = std::get(e.entity); - + const auto &op = std::get(e.entity); if (op.op_type == "Relu") { foundRelu = true; EXPECT_EQ(op.inputs.size(), 1); @@ -168,29 +173,40 @@ TEST(ONNXGraphBuilderTest, OperationCreatedCorrectly) { TEST(ONNXGraphBuilderTest, HandlesBinaryOp) { onnx::ModelProto model; - auto* graph = model.mutable_graph(); + auto *graph = model.mutable_graph(); + + auto *input_a = graph->add_input(); + input_a->set_name("a"); + auto *type_a = input_a->mutable_type()->mutable_tensor_type(); + type_a->set_elem_type(onnx::TensorProto::FLOAT); + type_a->mutable_shape()->add_dim()->set_dim_value(2); - graph->add_input()->set_name("a"); - graph->add_input()->set_name("b"); + auto *input_b = graph->add_input(); + input_b->set_name("b"); + auto *type_b = input_b->mutable_type()->mutable_tensor_type(); + type_b->set_elem_type(onnx::TensorProto::FLOAT); + type_b->mutable_shape()->add_dim()->set_dim_value(2); - auto* node = graph->add_node(); + auto *node = graph->add_node(); node->set_op_type("Add"); node->add_input("a"); node->add_input("b"); node->add_output("c"); - graph->add_output()->set_name("c"); + auto *output = graph->add_output(); + output->set_name("c"); + auto *type_c = output->mutable_type()->mutable_tensor_type(); + type_c->set_elem_type(onnx::TensorProto::FLOAT); + type_c->mutable_shape()->add_dim()->set_dim_value(2); Graph g = BuildGraphFromModel(model); bool foundAdd = false; - for (auto& e : g.Entities()) { + for (auto &e : g.Entities()) { if (e.kind != EntityKind::Operation) continue; - - auto& op = std::get(e.entity); - + auto &op = std::get(e.entity); if (op.op_type == "Add") { foundAdd = true; EXPECT_EQ(op.inputs.size(), 2); @@ -203,47 +219,40 @@ TEST(ONNXGraphBuilderTest, HandlesBinaryOp) { TEST(ONNXtoGraph, AddMul) { auto g = ONNXtoGraph("tests/data/add_mul.onnx"); - EXPECT_TRUE(hasOp(g, "Add")); EXPECT_TRUE(hasOp(g, "Mul")); } TEST(ONNXtoGraph, ConvRelu) { auto g = ONNXtoGraph("tests/data/conv_relu.onnx"); - EXPECT_TRUE(hasOp(g, "Conv")); EXPECT_TRUE(hasOp(g, "Relu")); } TEST(ONNXtoGraph, MatMulRelu) { auto g = ONNXtoGraph("tests/data/matmul_relu.onnx"); - EXPECT_TRUE(hasOp(g, "MatMul")); EXPECT_TRUE(hasOp(g, "Relu")); } TEST(ONNXtoGraph, SingleRelu) { auto g = ONNXtoGraph("tests/data/relu.onnx"); - EXPECT_TRUE(hasOp(g, "Relu")); } TEST(ONNXtoGraph, TransposeMatMul) { auto g = ONNXtoGraph("tests/data/transpose_matmul.onnx"); - EXPECT_TRUE(hasOp(g, "Transpose")); EXPECT_TRUE(hasOp(g, "MatMul")); } TEST(ONNXtoGraph, Gemm) { auto g = ONNXtoGraph("tests/data/gemm.onnx"); - EXPECT_TRUE(hasOp(g, "Gemm")); } TEST(ONNXtoGraph, Pipeline) { auto g = ONNXtoGraph("tests/data/test_0.onnx"); - EXPECT_TRUE(hasOp(g, "Transpose")); EXPECT_TRUE(hasOp(g, "MatMul")); EXPECT_TRUE(hasOp(g, "Relu")); @@ -251,7 +260,6 @@ TEST(ONNXtoGraph, Pipeline) { EXPECT_TRUE(hasOp(g, "Add")); EXPECT_TRUE(hasOp(g, "Mul")); EXPECT_TRUE(hasOp(g, "Gemm")); - EXPECT_FALSE(g.Inputs().empty()); EXPECT_FALSE(g.Outputs().empty()); } @@ -263,17 +271,18 @@ TEST(ONNXtoGraph, WeightsAndBias) { bool foundBias = false; for (const auto &e : g.Entities()) { - if (e.kind != EntityKind::Constant) + if (e.kind != EntityKind::Tensor) + continue; + const auto &t = std::get(e.entity); + if (!t.is_initializer) continue; - const auto &c = std::get(e.entity); - - if (c.data.dims == std::vector{3,4}) { + if (t.shape == std::vector{3, 4}) { foundWeights = true; - EXPECT_EQ(c.data.raw_data.size(), 3 * 4 * sizeof(float)); + EXPECT_TRUE(t.data.has_value()); + EXPECT_EQ(t.data->raw_data.size(), 3 * 4 * sizeof(float)); } - - if (c.data.dims == std::vector{4}) { + if (t.shape == std::vector{4}) { foundBias = true; } } @@ -284,7 +293,6 @@ TEST(ONNXtoGraph, WeightsAndBias) { TEST(ONNXtoGraph, InputsOutputsAreSet) { auto g = ONNXtoGraph("tests/data/add_mul.onnx"); - EXPECT_GT(g.Inputs().size(), 0u); EXPECT_GT(g.Outputs().size(), 0u); diff --git a/tests/unit/onnx2mlir.cpp b/tests/unit/onnx2mlir.cpp index bb22d29..bb845b7 100644 --- a/tests/unit/onnx2mlir.cpp +++ b/tests/unit/onnx2mlir.cpp @@ -1,7 +1,7 @@ -#include "TensorCompiler/Frontend/ONNXModel.hpp" -#include "TensorCompiler/Frontend/ONNXDumper.hpp" -#include "TensorCompiler/Converter/MLIRBuilder.hpp" +#include "TensorCompiler/Conversion/MLIRBuilder.hpp" #include "TensorCompiler/Dialect/NNDialect.hpp" +#include "TensorCompiler/Frontend/GraphBuilder.hpp" +#include "TensorCompiler/Frontend/ONNXModel.hpp" #include #include @@ -13,25 +13,26 @@ #include #include +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "llvm/Support/Casting.h" #include -static mlir::ModuleOp -ONNXtoMLIR(mlir::MLIRContext &ctx, std::string_view fileName) { - ctx.loadDialect< - mlir::arith::ArithDialect, - mlir::func::FuncDialect, - mlir::nn::NNDialect>(); +static mlir::ModuleOp ONNXtoMLIR(mlir::MLIRContext &ctx, + std::string_view fileName) { + ctx.loadDialect(); tc::frontend::ONNXModel model{fileName}; - tc::converter::onnx_to_high_mlir::MLIRBuilder builder(ctx); - model.Parse(builder); + tc::frontend::GraphBuilder graphBuilder; + model.Parse(graphBuilder); + const auto &graph = graphBuilder.GetGraph(); + tc::conversion::HighLevelMLIRBuilder builder(ctx); + graph.Accept(builder); return builder.GetModule(); } @@ -58,7 +59,7 @@ TEST(ONNXtoMLIR, AddMul) { auto mdl = ONNXtoMLIR(ctx, "tests/data/add_mul.onnx"); EXPECT_TRUE(mlir::succeeded(mdl.verify())); - EXPECT_TRUE(hasFunc(mdl, "AddMulGraph")); + EXPECT_TRUE(hasFunc(mdl, "main")); EXPECT_TRUE(hasOp(mdl, "nn.add")); EXPECT_TRUE(hasOp(mdl, "nn.mul")); } @@ -68,7 +69,7 @@ TEST(ONNXtoMLIR, ConvRelu) { auto mdl = ONNXtoMLIR(ctx, "tests/data/conv_relu.onnx"); EXPECT_TRUE(mlir::succeeded(mdl.verify())); - EXPECT_TRUE(hasFunc(mdl, "ConvReluGraph")); + EXPECT_TRUE(hasFunc(mdl, "main")); EXPECT_TRUE(hasOp(mdl, "nn.conv")); EXPECT_TRUE(hasOp(mdl, "nn.relu")); } @@ -78,7 +79,7 @@ TEST(ONNXtoMLIR, MatMulRelu) { auto mdl = ONNXtoMLIR(ctx, "tests/data/matmul_relu.onnx"); EXPECT_TRUE(mlir::succeeded(mdl.verify())); - EXPECT_TRUE(hasFunc(mdl, "MatMulReluGraph")); + EXPECT_TRUE(hasFunc(mdl, "main")); EXPECT_TRUE(hasOp(mdl, "nn.matmul")); EXPECT_TRUE(hasOp(mdl, "nn.relu")); } @@ -88,7 +89,7 @@ TEST(ONNXtoMLIR, SingleRelu) { auto mdl = ONNXtoMLIR(ctx, "tests/data/relu.onnx"); EXPECT_TRUE(mlir::succeeded(mdl.verify())); - EXPECT_TRUE(hasFunc(mdl, "SingleRelu")); + EXPECT_TRUE(hasFunc(mdl, "main")); EXPECT_TRUE(hasOp(mdl, "nn.relu")); } @@ -97,7 +98,7 @@ TEST(ONNXtoMLIR, TransposeMatMul) { auto mdl = ONNXtoMLIR(ctx, "tests/data/transpose_matmul.onnx"); EXPECT_TRUE(mlir::succeeded(mdl.verify())); - EXPECT_TRUE(hasFunc(mdl, "TransposeMatMulGraph")); + EXPECT_TRUE(hasFunc(mdl, "main")); EXPECT_TRUE(hasOp(mdl, "nn.transpose")); EXPECT_TRUE(hasOp(mdl, "nn.matmul")); } @@ -107,7 +108,7 @@ TEST(ONNXtoMLIR, Gemm) { auto mdl = ONNXtoMLIR(ctx, "tests/data/gemm.onnx"); EXPECT_TRUE(mlir::succeeded(mdl.verify())); - EXPECT_TRUE(hasFunc(mdl, "GemmGraph")); + EXPECT_TRUE(hasFunc(mdl, "main")); EXPECT_TRUE(hasOp(mdl, "nn.gemm")); } @@ -116,7 +117,7 @@ TEST(ONNXtoMLIR, Pipeline) { auto mdl = ONNXtoMLIR(ctx, "tests/data/test_0.onnx"); EXPECT_TRUE(mlir::succeeded(mdl.verify())); - EXPECT_TRUE(hasFunc(mdl, "PipelineGraph")); + EXPECT_TRUE(hasFunc(mdl, "main")); EXPECT_TRUE(hasOp(mdl, "nn.transpose")); EXPECT_TRUE(hasOp(mdl, "nn.matmul")); EXPECT_TRUE(hasOp(mdl, "nn.relu")); @@ -135,18 +136,16 @@ TEST(ONNXtoMLIR, WeightsAndBias) { bool foundBias = false; mdl.walk([&](mlir::arith::ConstantOp cst) { - auto type = - llvm::cast(cst.getType()); - auto dense = - llvm::dyn_cast(cst.getValue()); - - ASSERT_TRUE(dense); - - if (type.getShape() == llvm::ArrayRef{3,4}) { + auto type = llvm::cast(cst.getType()); + auto dense = llvm::dyn_cast(cst.getValue()); + if (!dense) + return; + + if (type.getShape() == llvm::ArrayRef{3, 4}) { std::vector values(dense.getValues().begin(), dense.getValues().end()); for (size_t i = 0; i < 12; ++i) { - EXPECT_EQ(values[i], i + 1); + EXPECT_EQ(values[i], static_cast(i + 1)); } foundWeights = true; }