diff --git a/official/quantization/models/resnet.py b/official/quantization/models/resnet.py index 3a482abf..73003226 100644 --- a/official/quantization/models/resnet.py +++ b/official/quantization/models/resnet.py @@ -79,7 +79,7 @@ def __init__( if in_channels == channels and stride == 1 else M.ConvBn2d(in_channels, channels, 1, stride, bias=False) ) - self.add = M.Elemwise("ADD") + self.add = M.Elemwise("FUSE_ADD_RELU") def forward(self, x): identity = x @@ -87,7 +87,6 @@ def forward(self, x): x = self.conv_bn2(x) identity = self.downsample(identity) x = self.add(x, identity) - x = F.relu(x) return x @@ -125,7 +124,7 @@ def __init__( if in_channels == channels * self.expansion and stride == 1 else M.ConvBn2d(in_channels, channels * self.expansion, 1, stride, bias=False) ) - self.add = M.Elemwise("ADD") + self.add = M.Elemwise("FUSE_ADD_RELU") def forward(self, x): identity = x @@ -134,7 +133,6 @@ def forward(self, x): x = self.conv_bn3(x) identity = self.downsample(identity) x = self.add(x, identity) - x = F.relu(x) return x