1+ #include < bitset>
12#include " core/util/prelude.h"
23#include " core/conversion/converters/converters.h"
34
@@ -22,25 +23,36 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
2223 TRTORCH_CHECK (mean_layer, " Unable to create mean layer from node: " << *n);
2324
2425 mean_layer->setName (util::node_info (n).c_str ());
25- ctx->AssociateValueAndTensor (n->outputs ()[0 ], mean_layer->getOutput (0 ));
26+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mean_layer->getOutput (0 ));
27+
28+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
2629 return true ;
2730 }
2831 }).pattern({
29- " aten::mean.dim(Tensor self, int[1 ] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)" ,
32+ " aten::mean.dim(Tensor self, int[] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)" ,
3033 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
3134 auto in_tensor = args[0 ].ITensor ();
32- auto dim = args[1 ].unwrapToIntList ();
33- auto keepdim = args[ 2 ]. unwrapToBool ();
35+ auto dims = args[1 ].unwrapToIntList ();
36+ LOG_DEBUG ( " Dim to reduce: " << util::toDims (dims)); // Some abuse of toDim but just for debug info
3437
35- uint32_t axis_mask = 1 << dim[0 ];
38+ uint32_t axis_mask = 0 ;
39+ for (int d = 0 ; d < dims.size (); d++) {
40+ axis_mask |= 1 << dims[d];
41+ }
42+ LOG_DEBUG (" Axis Mask" << std::bitset<32 >(axis_mask));
43+
44+ auto keepdim = args[2 ].unwrapToBool ();
45+ LOG_DEBUG (" Keep dims :" << keepdim);
3646
3747 LOG_WARNING (" Mean converter disregards dtype" );
3848 auto mean_layer = ctx->net ->addReduce (*in_tensor, nvinfer1::ReduceOperation::kAVG , axis_mask, keepdim);
3949
4050 TRTORCH_CHECK (mean_layer, " Unable to create mean layer from node: " << *n);
4151
4252 mean_layer->setName (util::node_info (n).c_str ());
43- ctx->AssociateValueAndTensor (n->outputs ()[0 ], mean_layer->getOutput (0 ));
53+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mean_layer->getOutput (0 ));
54+
55+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
4456 return true ;
4557 }
4658 });
0 commit comments