@@ -8,7 +8,10 @@ namespace conversion {
88namespace converters {
99namespace impl {
1010namespace {
11- auto reduced_registrations = RegisterNodeConversionPatterns()
11+
12+
13+
14+ auto reduce_registrations = RegisterNodeConversionPatterns()
1215 .pattern({
1316 " aten::mean(Tensor self, *, ScalarType? dtype=None) -> (Tensor)" ,
1417 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -36,7 +39,7 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
3639 LOG_DEBUG (" Dim to reduce:" << util::toDims (dims)); // Some abuse of toDim but just for debug info
3740
3841 uint32_t axis_mask = 0 ;
39- for (int d = 0 ; d < dims.size (); d++) {
42+ for (size_t d = 0 ; d < dims.size (); d++) {
4043 axis_mask |= 1 << dims[d];
4144 }
4245 LOG_DEBUG (" Axis Mask" << std::bitset<32 >(axis_mask));
@@ -52,6 +55,131 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
5255 mean_layer->setName (util::node_info (n).c_str ());
5356 auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mean_layer->getOutput (0 ));
5457
58+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
59+ return true ;
60+ }
61+ }).pattern({
62+ " aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor" ,
63+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
64+ auto in_tensor = args[0 ].ITensor ();
65+ auto in_dims = util::toVec (in_tensor->getDimensions ());
66+ LOG_WARNING (" Sum Converter disregards dtype" );
67+
68+ uint32_t axis_mask = (uint32_t )(((uint64_t )1 << in_dims.size ()) - 1 );
69+
70+ auto sum_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kSUM , axis_mask, false );
71+
72+ TRTORCH_CHECK (sum_layer, " Unable to create sum layer from node: " << *n);
73+
74+ sum_layer->setName (util::node_info (n).c_str ());
75+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], sum_layer->getOutput (0 ));
76+
77+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
78+ return true ;
79+ }
80+ }).pattern({
81+ " aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor" ,
82+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
83+ auto in_tensor = args[0 ].ITensor ();
84+ auto dims = args[1 ].unwrapToIntList ();
85+ LOG_DEBUG (" Dim to reduce:" << util::toDims (dims)); // Some abuse of toDim but just for debug info
86+
87+ uint32_t axis_mask = 0 ;
88+ for (size_t d = 0 ; d < dims.size (); d++) {
89+ axis_mask |= 1 << dims[d];
90+ }
91+ LOG_DEBUG (" Axis Mask" << std::bitset<32 >(axis_mask));
92+
93+ auto keepdim = args[2 ].unwrapToBool ();
94+ LOG_DEBUG (" Keep dims :" << keepdim);
95+
96+ LOG_WARNING (" Sum converter disregards dtype" );
97+ auto sum_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kSUM , axis_mask, keepdim);
98+
99+ TRTORCH_CHECK (sum_layer, " Unable to create sum layer from node: " << *n);
100+
101+ sum_layer->setName (util::node_info (n).c_str ());
102+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], sum_layer->getOutput (0 ));
103+
104+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
105+ return true ;
106+ }
107+ }).pattern({
108+ " aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor" ,
109+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
110+ auto in_tensor = args[0 ].ITensor ();
111+ auto in_dims = util::toVec (in_tensor->getDimensions ());
112+ LOG_WARNING (" Prod Converter disregards dtype" );
113+
114+ uint32_t axis_mask = (uint32_t )(((uint64_t )1 << in_dims.size ()) - 1 );
115+
116+ auto prod_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kPROD , axis_mask, false );
117+
118+ TRTORCH_CHECK (prod_layer, " Unable to create sum layer from node: " << *n);
119+
120+ prod_layer->setName (util::node_info (n).c_str ());
121+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], prod_layer->getOutput (0 ));
122+
123+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
124+ return true ;
125+ }
126+ }).pattern({
127+ " aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor" ,
128+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
129+ auto in_tensor = args[0 ].ITensor ();
130+ auto dim = args[1 ].unwrapToInt ();
131+ LOG_DEBUG (" Dim to reduce:" << dim); // Some abuse of toDim but just for debug info
132+
133+ uint32_t axis_mask = 1 << dim;
134+ LOG_DEBUG (" Axis Mask" << std::bitset<32 >(axis_mask));
135+
136+ auto keepdim = args[2 ].unwrapToBool ();
137+ LOG_DEBUG (" Keep dims :" << keepdim);
138+
139+ LOG_WARNING (" Prod converter disregards dtype" );
140+ auto prod_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kPROD , axis_mask, keepdim);
141+
142+ TRTORCH_CHECK (prod_layer, " Unable to create mean layer from node: " << *n);
143+
144+ prod_layer->setName (util::node_info (n).c_str ());
145+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], prod_layer->getOutput (0 ));
146+
147+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
148+ return true ;
149+ }
150+ }).pattern({
151+ " aten::max(Tensor self) -> Tensor" ,
152+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
153+ auto in_tensor = args[0 ].ITensor ();
154+ auto in_dims = util::toVec (in_tensor->getDimensions ());
155+
156+ uint32_t axis_mask = (uint32_t )(((uint64_t )1 << in_dims.size ()) - 1 );
157+
158+ auto max_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kMAX , axis_mask, false );
159+
160+ TRTORCH_CHECK (max_layer, " Unable to create max layer from node: " << *n);
161+
162+ max_layer->setName (util::node_info (n).c_str ());
163+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], max_layer->getOutput (0 ));
164+
165+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
166+ return true ;
167+ }
168+ }).pattern({
169+ " aten::min(Tensor self) -> Tensor" ,
170+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
171+ auto in_tensor = args[0 ].ITensor ();
172+ auto in_dims = util::toVec (in_tensor->getDimensions ());
173+
174+ uint32_t axis_mask = (uint32_t )(((uint64_t )1 << in_dims.size ()) - 1 );
175+
176+ auto min_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kMIN , axis_mask, false );
177+
178+ TRTORCH_CHECK (min_layer, " Unable to create min layer from node: " << *n);
179+
180+ min_layer->setName (util::node_info (n).c_str ());
181+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], min_layer->getOutput (0 ));
182+
55183 LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
56184 return true ;
57185 }
@@ -62,63 +190,3 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
62190} // namespace conversion
63191} // namespace core
64192} // namespace trtorch
65-
66- // #include "core/util/prelude.h"
67- // #include "core/conversion/converters/converters.h"
68-
69- // namespace trtorch {
70- // namespace core {
71- // namespace conversion {
72- // namespace converters {
73- // namespace impl {
74- // namespace {
75-
76- // #define convert(unary, trt_type) \
77- // auto unary##_registrations TRTORCH_UNUSED = \
78- // RegisterNodeConversionPatterns().pattern( \
79- // {"aten::" #unary "(Tensor self) -> Tensor", \
80- // [](ConversionCtx *ctx, const torch::jit::Node *n, \
81- // args &args) -> bool { \
82- // auto in = args[0].ITensor(); \
83- // auto unary = \
84- // ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \
85- // \
86- // TRTORCH_CHECK( \
87- // unary, \
88- // "Unable to create " #unary " layer from node: " << *n); \
89- // \
90- // unary->setName(util::node_info(n).c_str()); \
91- // auto out_tensor = ctx->AssociateValueAndTensor( \
92- // n->outputs()[0], \
93- // unary->getOutput(0)); \
94- // LOG_DEBUG( \
95- // "Output tensor shape: " << out_tensor->getDimensions()); \
96- // \
97- // return true; \
98- // }});
99-
100- // convert(cos, kCOS);
101- // convert(acos, kACOS);
102- // convert(cosh, kCOSH);
103- // convert(sin, kSIN);
104- // convert(asin, kASIN);
105- // convert(sinh, kSINH);
106- // convert(tan, kTAN);
107- // convert(atan, kATAN);
108- // convert(abs, kABS);
109- // convert(floor, kFLOOR);
110- // convert(reciprocal, kRECIP);
111- // convert(log, kLOG);
112- // convert(ceil, kCEIL);
113- // convert(sqrt, kSQRT);
114- // convert(exp, kEXP);
115- // convert(neg, kNEG);
116-
117- // #undef convert
118-
119- // } // namespace
120- // } // namespace impl
121- // } // namespace converters
122- // } // namespace conversion
123- // } // namespace core
124- // } // namespace trtorch
0 commit comments