forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
onnx_exporter.h
138 lines (109 loc) · 4.57 KB
/
onnx_exporter.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#pragma once
#include "caffe2/core/common.h"
#include "caffe2/core/tensor.h"
#include "caffe2/onnx/helper.h"
#include "caffe2/proto/caffe2_pb.h"
#include "onnx/onnx_pb.h"
#include <string>
#include <unordered_map>
#include <vector>
namespace caffe2 {
namespace onnx {
namespace {
using ::ONNX_NAMESPACE::AttributeProto;
using ::ONNX_NAMESPACE::GraphProto;
using ::ONNX_NAMESPACE::ModelProto;
using ::ONNX_NAMESPACE::NodeProto;
using ::ONNX_NAMESPACE::TensorProto;
} // namespace
using ConvertedResult =
std::pair<std::vector<NodeProto>, std::vector<TensorProto>>;
// Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external
// output names for predict net.
CAFFE2_API std::unordered_map<std::string, std::string> SsaRewrite(
caffe2::NetDef* init_net,
caffe2::NetDef* pred_net);
::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
caffe2::TensorProto::DataType t);
class CAFFE2_API OnnxExporter {
using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
const caffe2::OperatorDef&,
const std::unordered_map<std::string, caffe2::TensorShape>&);
public:
OnnxExporter(DummyName* dummy = nullptr) {
if (dummy) {
dummy_ = std::shared_ptr<DummyName>(dummy, [](DummyName*) {});
} else {
dummy_ = std::make_shared<DummyName>();
}
}
ConvertedResult Caffe2OpToOnnxNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
void InitOpToTensorProto(const caffe2::OperatorDef& def, TensorProto* tensor);
private:
ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);
ConvertedResult CreateArgMaxMinOpNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateBinaryElementwiseOpNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateCastNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateElementwiseLinearNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateConvPoolNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateGemmNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateReshapeNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateSliceNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateChannelShuffleNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateReduceMeanNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateConcatNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateMergeDimNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateLrnNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateUpsampleNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
// \brief Check black listed arguments where we won't pass down when
// converting to ONNX node
bool IsBlackListed(const caffe2::Argument& arg);
// \brief Convert Caffe2 argument to Onnx attribute
void CopyCaffe2ArgToOnnxAttr(
AttributeProto* attr,
const std::string& op_type,
const caffe2::Argument& arg);
// LUT getters
const std::unordered_map<std::string, std::string>& get_renamed_operators()
const;
const std::unordered_map<std::string, std::string>& get_renamed_attrs() const;
const std::
unordered_map<std::string, std::unordered_map<std::string, std::string>>&
get_per_op_renamed_attrs() const;
const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
get_special_operators() const;
// Dummy name generator
std::shared_ptr<DummyName> dummy_;
};
} // namespace onnx
} // namespace caffe2