Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions include/TensorCompiler/Conversion/MLIRBuilder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#pragma once
#include "TensorCompiler/Dialect/NNDialect.hpp"
#include "TensorCompiler/Graph/Graph.hpp"
#include "TensorCompiler/Graph/GraphVisitor.hpp"

#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>

#include <unordered_map>
#include <unordered_set>

namespace tc::conversion {

class HighLevelMLIRBuilder : public tc::graph::GraphVisitor {
public:
explicit HighLevelMLIRBuilder(mlir::MLIRContext &ctx);

const mlir::ModuleOp &GetModule() const { return module_; }

mlir::ModuleOp Build(const tc::graph::Graph &graph) {
graph.Accept(*this);
return module_;
}

private:
void Visit(const tc::graph::Graph &graph) override;
void Visit(const tc::graph::TensorEntity &tensor) override;
void Visit(const tc::graph::OperationEntity &op) override;
void Visit(const tc::graph::ConstantEntity &constant) override;
void Finalize(const tc::graph::Graph &graph) override;

mlir::Value GetValue(tc::graph::EntityId id) const;
mlir::Value GetValue(const std::string &name) const;

void BuildRelu(const tc::graph::OperationEntity &op, mlir::Location loc);
void BuildAdd(const tc::graph::OperationEntity &op, mlir::Location loc);
void BuildMul(const tc::graph::OperationEntity &op, mlir::Location loc);
void BuildMatMul(const tc::graph::OperationEntity &op, mlir::Location loc);
void BuildGemm(const tc::graph::OperationEntity &op, mlir::Location loc);
void BuildConv(const tc::graph::OperationEntity &op, mlir::Location loc);
void BuildTranspose(const tc::graph::OperationEntity &op, mlir::Location loc);

mlir::Type ConvertElemType(int32_t onnx_dtype);
mlir::RankedTensorType ConvertTensorType(int32_t dtype,
llvm::ArrayRef<int64_t> shape);
mlir::DenseElementsAttr ConvertTensorData(const tc::graph::TensorData &data);

private:
mlir::MLIRContext &ctx_;
mlir::OpBuilder builder_;
mlir::ModuleOp module_;

std::unordered_map<tc::graph::EntityId, mlir::Value> valueMap_;
std::unordered_map<std::string, tc::graph::EntityId> tensorNameToId_;
std::unordered_set<tc::graph::EntityId> initializerIds_;
};
} // namespace tc::conversion
53 changes: 0 additions & 53 deletions include/TensorCompiler/Converter/MLIRBuilder.hpp

This file was deleted.

2 changes: 1 addition & 1 deletion include/TensorCompiler/Dialect/NNDialect.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>

#include "NNDialect.h.inc"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#pragma once
#include "TensorCompiler/Frontend/ONNXVisitor.hpp"
#include "TensorCompiler/Graph/IR.hpp"
#include "TensorCompiler/Graph/Graph.hpp"

#include <set>
#include <string>
#include <vector>

namespace tc::converter::onnx_to_graph {
namespace tc::frontend {
using tc::graph::AttrValue;
using tc::graph::EntityId;
using tc::graph::Graph;
Expand All @@ -15,12 +15,14 @@ using tc::graph::TensorData;
class GraphBuilder final : public tc::frontend::ONNXVisitor {
Graph graph_;
std::set<std::string> initializerNames_;
std::set<std::string> inputNames_;
std::set<std::string> outputNames_;

TensorData ParseTensor(const onnx::TensorProto &tensor);
AttrValue ParseAttribute(const onnx::AttributeProto &attr);

EntityId EnsureTensor(const std::string &name,
const std::vector<int64_t> &shape = {});
const std::vector<int64_t> &shape, int32_t dtype);

public:
const Graph &GetGraph() const { return graph_; }
Expand All @@ -34,4 +36,4 @@ class GraphBuilder final : public tc::frontend::ONNXVisitor {

void Finalize(const onnx::GraphProto &) override;
};
} // namespace tc::converter::onnx_to_graph
} // namespace tc::frontend
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <vector>

namespace tc::graph {
class GraphVisitor;
using IntList = std::vector<int64_t>;
using DoubleList = std::vector<double>;
using StringList = std::vector<std::string>;
Expand All @@ -15,6 +16,7 @@ struct TensorData {
std::string name;
std::vector<int64_t> dims;
std::vector<uint8_t> raw_data;
int32_t dtype;
};

using AttrScalar = std::variant<int64_t, double, std::string, bool>;
Expand All @@ -30,6 +32,7 @@ struct TensorEntity {
std::string name;
std::vector<int64_t> shape;
bool is_initializer = false;
int32_t dtype;
std::optional<TensorData> data;
};

Expand All @@ -51,20 +54,17 @@ struct Entity {
EntityKind kind;
std::variant<TensorEntity, OperationEntity, ConstantEntity> entity;

Entity(EntityKind k, OperationEntity op)
: kind(k), entity(std::move(op)) {}
Entity(EntityKind k, OperationEntity op) : kind(k), entity(std::move(op)) {}

Entity(EntityKind k, TensorEntity t)
: kind(k), entity(std::move(t)) {}
Entity(EntityKind k, TensorEntity t) : kind(k), entity(std::move(t)) {}

Entity(EntityKind k, ConstantEntity c)
: kind(k), entity(std::move(c)) {}
Entity(EntityKind k, ConstantEntity c) : kind(k), entity(std::move(c)) {}
};

class Graph {
public:
EntityId AddTensor(const std::string &name, std::vector<int64_t> shape = {},
bool is_init = false);
EntityId AddTensor(const std::string &name, std::vector<int64_t> shape,
int32_t dtype, bool is_init = false);

EntityId AddConstant(const TensorData &data);

Expand All @@ -83,6 +83,8 @@ class Graph {
void AddInput(EntityId id) { inputs_.push_back(id); }
void AddOutput(EntityId id) { outputs_.push_back(id); }

void Accept(GraphVisitor &visitor) const;

private:
std::vector<EntityId> inputs_;
std::vector<EntityId> outputs_;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#pragma once
#include "TensorCompiler/Graph/IR.hpp"
#include "TensorCompiler/Graph/Graph.hpp"
#include <string>

namespace tc::graph {

std::string DumpGraph(const Graph &g);

std::string ToDot(const Graph &g);
bool SaveDot(const Graph &g, const std::string &file);

Expand Down
23 changes: 23 additions & 0 deletions include/TensorCompiler/Graph/GraphVisitor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

namespace tc::graph {
class Graph;
struct TensorEntity;
struct OperationEntity;
struct ConstantEntity;
} // namespace tc::graph

namespace tc::graph {

class GraphVisitor {
public:
virtual ~GraphVisitor() = default;

virtual void Visit(const tc::graph::Graph &graph) = 0;
virtual void Visit(const tc::graph::TensorEntity &tensor) = 0;
virtual void Visit(const tc::graph::OperationEntity &op) = 0;
virtual void Visit(const tc::graph::ConstantEntity &constant) = 0;
virtual void Finalize(const tc::graph::Graph &graph) = 0;
};

} // namespace tc::graph
4 changes: 2 additions & 2 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_subdirectory(Dialect)
add_subdirectory(Frontend)
add_subdirectory(Graph)
add_subdirectory(Converter)
add_subdirectory(Conversion)

add_library(tc-core INTERFACE)

Expand All @@ -10,7 +10,7 @@ target_link_libraries(tc-core
tc-dialect
tc-frontend
tc-graph
tc-converter
tc-conversion
MLIRParser
MLIRFuncDialect
MLIRArithDialect
Expand Down
11 changes: 11 additions & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
add_library(tc-conversion MLIRBuilder.cpp)

target_link_libraries(tc-conversion
tc-graph
tc-dialect
onnx)

target_include_directories(tc-conversion PUBLIC
${PROJECT_SOURCE_DIR}/include
${CMAKE_CURRENT_BINARY_DIR}/../Dialect
)
Loading
Loading