Skip to content

Commit

Permalink
Merge pull request #22 from usernameandme/fix_for_issue_21
Browse files Browse the repository at this point in the history
Changes for Pytorch 1.5.0 (adapted from mapillary inplace_abn repo)
  • Loading branch information
GoGoDuck912 committed Jul 9, 2020
2 parents 7cc9e64 + e182ac2 commit 9b9b382
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
6 changes: 4 additions & 2 deletions modules/bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation
super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)

def forward(self, x):
return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
x, _, _ = inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
self.training, self.momentum, self.eps, self.activation, self.slope)
return x


class InPlaceABNSync(ABN):
Expand All @@ -115,8 +116,9 @@ class InPlaceABNSync(ABN):
"""

def forward(self, x):
return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
x, _, _ = inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
self.training, self.momentum, self.eps, self.activation, self.slope)
return x

def __repr__(self):
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
Expand Down
10 changes: 6 additions & 4 deletions modules/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,12 @@ def forward(ctx, x, weight, bias, running_mean, running_var,
# Output
ctx.var = var
ctx.save_for_backward(x, var, weight, bias)
return x
ctx.mark_non_differentiable(running_mean, running_var)
return x, running_mean, running_var

@staticmethod
@once_differentiable
def backward(ctx, dz):
def backward(ctx, dz, _drunning_mean, _drunning_var):
z, var, weight, bias = ctx.saved_tensors
dz = dz.contiguous()

Expand Down Expand Up @@ -200,11 +201,12 @@ def forward(cls, ctx, x, weight, bias, running_mean, running_var,
# Output
ctx.var = var
ctx.save_for_backward(x, var, weight, bias)
return x
ctx.mark_non_differentiable(running_mean, running_var)
return x, running_mean, running_var

@staticmethod
@once_differentiable
def backward(ctx, dz):
def backward(ctx, dz, _drunning_mean, _drunning_var):
z, var, weight, bias = ctx.saved_tensors
dz = dz.contiguous()

Expand Down

0 comments on commit 9b9b382

Please sign in to comment.