Skip to content

Commit

Permalink
Merge branch 'PytorchBNFix' into 'master'
Browse files Browse the repository at this point in the history
[Pytorch2N2D2] Fix issue with BN when affine=False

See merge request n2d2/n2d2!114
  • Loading branch information
cmoineau committed Aug 30, 2023
2 parents 99cdd30 + 9a7ff63 commit b0bdfa5
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions python/pytorch_to_n2d2/pytorch_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _to_n2d2(torch_tensor:torch.Tensor)->n2d2.Tensor:
This method also convert the shape of the tensor to follow N2D2 convention.
"""
n2d2_tensor = None
if torch_tensor.is_cuda:
if torch_tensor.is_cuda:
if not torch_tensor.is_contiguous():
# If torch_tensor is not contiguous then we need to do a copy !
numpy_tensor = torch_tensor.cpu().detach().numpy()
Expand All @@ -45,7 +45,7 @@ def _to_n2d2(torch_tensor:torch.Tensor)->n2d2.Tensor:
n2d2_tensor = n2d2_tensor.cuda()
n2d2_tensor.htod()
if n2d2_tensor.nb_dims() == 4:
n2d2_tensor.reshape(_switching_convention(n2d2_tensor.dims()))
n2d2_tensor.reshape(_switching_convention(n2d2_tensor.dims()))
else:
# We avoid a copy on GPU !
dtype = torch_tensor.dtype
Expand Down Expand Up @@ -135,7 +135,7 @@ def get_block(self) -> n2d2.cells.Block:
return self._block

def summary(self)->None:
"""Print model information.
"""Print model information.
"""
self._block.summary()

Expand Down Expand Up @@ -165,7 +165,7 @@ def forward(ctx, inputs):
if self.current_batch_size != self.batch_size:
# Pad incomplete batch with 0 as N2D2 doesn't support incomplete batch.
n2d2_input_shape[0] = self.batch_size

n2d2_tensor.resize(n2d2_input_shape)


Expand All @@ -174,7 +174,7 @@ def forward(ctx, inputs):
n2d2_tensor.htod()

# training is a torch.nn.Module attribute (cf. https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module)
if self.training:
if self.training:
self._block.learn()
else:
self._block.test()
Expand Down Expand Up @@ -240,11 +240,11 @@ def backward(ctx, grad_output):

if self.current_batch_size < self.batch_size:
# Making sure we have a full batch
new_shape = list(grad_output.shape)
new_shape = list(grad_output.shape)
new_shape[0] = self.batch_size
tmp_numpy = t_grad_output.to_numpy(copy=True)
tmp_numpy.resize(new_shape)
t_grad_output = n2d2.Tensor.from_numpy(tmp_numpy)
t_grad_output = n2d2.Tensor.from_numpy(tmp_numpy)

t_grad_output=t_grad_output.N2D2()
if len(self.deepnet.getLayers()[-1]) > 1:
Expand All @@ -262,7 +262,7 @@ def backward(ctx, grad_output):
diffOutput = self.deepnet.getCell_Frame_Top(self.deepnet.getLayers()[1][0]).getDiffOutputs()

outputs = _to_torch(diffOutput)

outputs = outputs.resize_(self.input_shape) # in place operation
outputs = torch.mul(outputs, -1/self.batch_size)

Expand Down Expand Up @@ -313,9 +313,10 @@ def fake_forward(inputs,
# Save copy of values before propagation
saved_run_mean = current_bn.running_mean.detach().clone()
saved_run_var = current_bn.running_var.detach().clone()
saved_bias = current_bn.bias.detach().clone()
saved_weight = current_bn.weight.detach().clone()

if current_bn.affine: # Bias and Weights only created if affine=True
saved_bias = current_bn.bias.detach().clone()
saved_weight = current_bn.weight.detach().clone()
print(current_bn.running_mean.shape)
# Real batchnorm forward
output_tensor = torch.nn.functional.batch_norm(
inputs,
Expand All @@ -330,8 +331,9 @@ def fake_forward(inputs,

current_bn.running_mean.copy_(torch.nn.Parameter(saved_run_mean).requires_grad_(False))
current_bn.running_var.copy_(torch.nn.Parameter(saved_run_var).requires_grad_(False))
current_bn.bias = (torch.nn.Parameter(saved_bias))
current_bn.weight = (torch.nn.Parameter(saved_weight))
if current_bn.affine:
current_bn.bias = (torch.nn.Parameter(saved_bias))
current_bn.weight = (torch.nn.Parameter(saved_weight))

return output_tensor
# Update Batchnorm forward with the fake forward !
Expand Down

0 comments on commit b0bdfa5

Please sign in to comment.