Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why batch_norm is used inside AdaptiveInstanceNorm2d? #31

Open
leitro opened this issue Oct 8, 2019 · 2 comments
Open

Why batch_norm is used inside AdaptiveInstanceNorm2d? #31

leitro opened this issue Oct 8, 2019 · 2 comments

Comments

@leitro
Copy link

leitro commented Oct 8, 2019

Hi! I have a doubt that the code in blocks.py (L188-L192) as show below:

class AdaptiveInstanceNorm2d(nn.Module):
        ...
        x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
        out = F.batch_norm(
            x_reshaped, running_mean, running_var, self.weight, self.bias,
            True, self.momentum, self.eps)
        return out.view(b, c, *x.size()[2:])

It is the definition of adaptive instance normalization. It looks like you are trying to reshape a batch of images into a "bigger" single batch image, then apply "batch normalization" on it, finally recover it back to batch, channel, height, weight. But, no matter reshaping it into single batch or not, the features of each channel from all the batch have been normalized. I am wondering how it could be an instance normalization.

I believe the code is perfectly correct, but please explain the tricks that were used here, thanks in advance!

@iperov
Copy link

iperov commented Oct 10, 2019

this is why I don't like pytorch. :D

@pomelyu
Copy link

pomelyu commented Dec 20, 2019

There is a slightly difference between instance norm and adaptive instance norm.

In the instance norm, the data would be normalized on whole image for separate channels, hence the shapes of weight and bias would be both (num_channels). However, in adaptive instance norm the shapes of weight and bias should be (batch_size * num_channels), since each sample has different modulation from the corresponding latent.

That is why the code reshape x to (1, batch_size * num_channels, H, W) and then use F.batch_norm to apply the modulation on each sample and each channel instead of using F.instance_norm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants