Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can you add the inverse/backward for batch norm layer #2

Closed
sndnyang opened this issue Jan 25, 2019 · 3 comments
Closed

can you add the inverse/backward for batch norm layer #2

sndnyang opened this issue Jan 25, 2019 · 3 comments

Comments

@sndnyang
Copy link

HI, Thank you for your excellent work!
I have an issue. You have implemented the inverse layers for Conv/Linear/Dropout/Pool layers, but I found you forgot the batch norm layer which is used widely in NN too...
So can you add an NN example with batch norm layers?

@ZiangYan
Copy link
Owner

Hi, thanks for your interest in our work.

Actually, we use a simple strategy for batch norm layers: just freeze all variables in BN layers during both training and test.

The ResNet-18 result reported in our paper is produced with this method.

The core code about this part should be something like

for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.eval()

@sndnyang
Copy link
Author

sndnyang commented Jan 25, 2019

But, in many cases, we need to use the batch norm when other methods use it. So I'm interested in the implementation like below:

For example, I use an MLP like:
class MLP(nn.Module):
def init(self):
super(MLP, self).init()
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(784, 1200)
self.bn_fc1 = nn.BatchNorm1d(1200)
self.fc2 = nn.Linear(1200, 600)
self.bn_fc2 = nn.BatchNorm1d(600)
self.fc3 = nn.Linear(600, 10)

how to write the InverseMLP and its forward (based on your MLP model).

class InverseMLP(nn.Module):
def init(self):
super(InverseMLP, self).init()
self.transposefc3 = LinearTranspose(10, 600, bias=False)
self.transposefc2 = LinearTranspose(600, 1200, bias=False)
self.transposefc1 = LinearTranspose(1200, 784, bias=False)

def forward(self, x, relu1_mask, relu2_mask):
    self.relu2_out = self.transposefc3(x)
    self.fc2_out = self.relu2_out * relu2_mask
    self.relu1_out = self.transposefc2(self.fc2_out)
    self.fc1_out = self.relu1_out * relu1_mask
    self.flat_out = self.transposefc1(self.fc1_out)
    self.input_out = self.flat_out.view(-1, 1, 28, 28)
    return self.input_out

Thank you!

@ZiangYan
Copy link
Owner

Hi, sorry for the late reply.

I've added an example with batch norm.

class MLPBN(nn.Module):
def __init__(self):
super(MLPBN, self).__init__()
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(784, 500)
self.bn1 = nn.BatchNorm1d(500)
self.fc2 = nn.Linear(500, 150)
self.bn2 = nn.BatchNorm1d(150)
self.fc3 = nn.Linear(150, 10)
def forward(self, x):
self.x = x
self.flat_out = self.x.view(-1, 784)
self.fc1_out = self.fc1(self.flat_out)
self.bn1_out = self.bn1(self.fc1_out)
self.relu1_out = F.relu(self.bn1_out)
self.fc2_out = self.fc2(self.relu1_out)
self.bn2_out = self.bn2(self.fc2_out)
self.relu2_out = F.relu(self.bn2_out)
self.fc3_out = self.fc3(self.relu2_out)
return self.fc3_out
def load_weights(self, source=None):
if source is None:
source = 'data/mnist-mlpbn-25b43980.pth'
self.load_state_dict(torch.load(source))
class InverseMLPBN(nn.Module):
def __init__(self):
super(InverseMLPBN, self).__init__()
self.transposefc3 = LinearTranspose(10, 150, bias=False)
self.transposebn2 = BNTranspose(150)
self.transposefc2 = LinearTranspose(150, 500, bias=False)
self.transposebn1 = BNTranspose(500)
self.transposefc1 = LinearTranspose(500, 784, bias=False)
def forward(self, x, relu1_mask, relu2_mask):
self.relu2_out = self.transposefc3(x)
self.bn2_out = self.relu2_out * relu2_mask
self.fc2_out = self.transposebn2(self.bn2_out)
self.relu1_out = self.transposefc2(self.fc2_out)
self.bn1_out = self.relu1_out * relu1_mask
self.fc1_out = self.transposebn1(self.bn1_out)
self.flat_out = self.transposefc1(self.fc1_out)
self.input_out = self.flat_out.view(-1, 1, 28, 28)
return self.input_out
def copy_from(self, net):
for k in ['fc1', 'fc2', 'fc3']:
t = net.__getattr__(k)
tt = self.__getattr__('transpose%s' % k)
assert t.weight.data.size() == tt.weight.data.size()
tt.weight = t.weight
for k in ['bn1', 'bn2']:
t = net.__getattr__(k)
tt = self.__getattr__('transpose%s' % k)
tt.weight.data[:] = (t.weight / torch.sqrt(t.running_var + t.eps)).data[:]
# tt.bias.data[:] = (-t.running_mean * tt.weight + t.bias).data[:]
def forward_from_net(self, net, input_image, idx):
num_target_label = idx.size()[1]
batch_size = input_image.size()[0]
image_shape = input_image.size()[1:]
output_var = net(input_image.cuda())
dzdy = np.zeros((idx.numel(), output_var.size()[1]), dtype=np.float32)
dzdy[np.arange(idx.numel()), idx.view(idx.numel()).cpu().numpy()] = 1.
inverse_input_var = torch.from_numpy(dzdy).cuda()
inverse_input_var.requires_grad = True
inverse_output_var = self.forward(
inverse_input_var,
(net.bn1_out > 0).float().repeat(1, num_target_label).view(idx.numel(), 500),
(net.bn2_out > 0).float().repeat(1, num_target_label).view(idx.numel(), 150),
)
dzdx = inverse_output_var.view(input_image.size()[0], idx.size()[1], -1).transpose(1, 2)
return dzdx

You can find the download url for reference model of this example in README.

Thanks.

@ZiangYan ZiangYan closed this as completed Feb 1, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants