Skip to content

Commit

Permalink
Relaxing convolution infer checks.
Browse files Browse the repository at this point in the history
- Weight dtype can be different than idtype. So, using the weight tensor to set
the dtype of weight.
- For conv2d NCHWc operator, the weight can be of any dimension. For int8
computation on Intel, it can be 7D. Relaxing the weight type checking.
  • Loading branch information
anijain2305 committed Jul 8, 2019
1 parent 2a7aebe commit 03905c0
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,12 @@ bool Conv2DRel(const Array<Type>& types,
channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
DataType weight_dtype = data->dtype;
if (weight != nullptr) {
weight_dtype = weight->dtype;
}
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
Expand Down Expand Up @@ -701,7 +705,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2D", Conv2DRel)
.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);

Expand Down

0 comments on commit 03905c0

Please sign in to comment.