Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the implementation #7

Closed
boluoweifenda opened this issue Jun 14, 2019 · 9 comments
Closed

About the implementation #7

boluoweifenda opened this issue Jun 14, 2019 · 9 comments

Comments

@boluoweifenda
Copy link

Thanks for your great work! I have two questions regarding the implementation details:
(1) In the situation of strided-convolution, why the BlurPool layer is placed after the ReLU rather than right next to the convolution?
It would be much more flexible if the conv and blurpool can be coupled.
I was considering the implementation in the pre-activation resnet.
(2) This question might be silly, but why not apply bilinear interpolation layers to downsample the feature map? I haven't seen any work use it.

@richzhang
Copy link
Contributor

richzhang commented Jun 14, 2019

Thanks for the questions.

(1) If we do Conv(s1)-BlurPool(s2)-Relu, that is equivalent to Blur(s1)-Conv(s2)-Relu, since Blur and Conv are commutable. This is equivalent to blurring the input before processing it, destroying information unnecessarily. Performance will go down.

So doing Conv(s1)-Relu-BlurPool(s2) is the right ordering.

(2) Yes, using [1 2 1] is essentially equivalent to bilinear interpolation

@boluoweifenda
Copy link
Author

Thanks for your quick reply, I will rethink deeply into these.

@wandering007
Copy link

wandering007 commented Aug 8, 2019

If we do Conv(s1)-BlurPool(s2)-Relu, that is equivalent to Blur(s1)-Conv(s2)-Relu, since Blur and Conv are commutable.

I don't think these two are strictly equivalent. Here is the code snippet to verify it:

import torch
import torch.nn as nn

conv = nn.Conv2d(3, 3, 3, stride=1, padding=1)

def test_conv_blur():
    inputs = torch.randn(1, 3, 32, 32)
    blur1 = nn.AvgPool2d(3, stride=2, padding=1)
    blur2 = nn.AvgPool2d(3, stride=1, padding=1)
    with torch.no_grad():
        outputs1 = blur1(conv(inputs))
        conv.stride = (2, 2)
        outputs2 = conv(blur2(inputs))
    assert torch.equal(outputs1, outputs2), "diff mean: {}".format(torch.mean(torch.abs(outputs1 - outputs2)).item())

if __name__ == "__main__":
    test_conv_blur()

The assert cannot be passed.

Furthermore, the increased FLOPs might contribute to the increased accuracy.

@richzhang
Copy link
Contributor

Thanks for the code snippet. Your implementation differs from my statement in a critical way -- note the striding.

The readme has a new plot showing accuracy vs run-time.

@wandering007
Copy link

wandering007 commented Aug 8, 2019

Thanks for your quick reply and pointing out my mistakes. I update the code snippet but the assert still cannot be passed. Anyway, the difference becomes much smaller.

@richzhang
Copy link
Contributor

richzhang commented Aug 8, 2019

Maybe you can print the norm of the error vs the norm of the output signal. I suspect the discrepancy is due to numerical issues.

In any case, the equivalence should be provable. The avgpool is a convolution with a [1 1; 1 1] filter, and convolutions should be commutable. In fact, you could combine the two operations by applying [1 1; 1 1] to the 3x3 conv kernels, making a single 4x4 conv layer.

@wandering007
Copy link

The average difference value is about 0.01, which cannot be caused by numerical issues.

@richzhang
Copy link
Contributor

richzhang commented Aug 8, 2019

Maybe you can take some time to go through the proof that convolutions are associative and communative. https://en.wikipedia.org/wiki/Convolution

import torch
import torch.nn as nn

inp = torch.randn(1, 3, 32, 32)
conv = nn.Conv2d(3, 3, 3, stride=1, padding=0)
blur = nn.AvgPool2d(3, stride=1, padding=0)

out1 = conv(blur(inp))
out2 = blur(conv(inp))

print(torch.mean((out1-out2)**2).item(), torch.mean(out1**2).item())

(1.284619802805647e-15, 0.05549472197890282)

@wandering007
Copy link

wandering007 commented Aug 8, 2019

Thanks a lot. I just figure it out. padding matters a lot. I set padding=1 and extra zeros around the image border caused the inequality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants