Skip to content

Commit

Permalink
refactor: add comments for the code of RevIN;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Jul 2, 2024
1 parent 094579f commit dfcdfd1
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions pypots/nn/modules/revin/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,27 @@


class RevIN(nn.Module):
"""RevIN: Reversible Inference Network.
Parameters
----------
n_features :
the number of features or channels
eps :
a value added for numerical stability
affine :
if True, RevIN has learnable affine parameters
"""

def __init__(
self,
n_features: int,
eps: float = 1e-9,
affine: bool = True,
):
"""
Parameters
----------
n_features :
the number of features or channels
eps :
a value added for numerical stability
affine :
if True, RevIN has learnable affine parameters
"""
super().__init__()
self.n_features = n_features
self.eps = eps
Expand All @@ -54,36 +56,42 @@ def _init_params(self):
def _normalize(self, x, missing_mask=None):
dim2reduce = tuple(range(1, x.ndim - 1))

# calculate mean and stdev
if missing_mask is None:
# original implementation
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
self.stdev = torch.sqrt(
mean = torch.mean(x, dim=dim2reduce, keepdim=True)
stdev = torch.sqrt(
torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps
).detach()
)
else:
# pypots implementation for POTS data
missing_sum = (
torch.sum(missing_mask == 1, dim=dim2reduce, keepdim=True) + self.eps
)
self.mean = torch.sum(x, dim=dim2reduce, keepdim=True) / missing_sum
mean = torch.sum(x, dim=dim2reduce, keepdim=True) / missing_sum
x_enc = x.masked_fill(missing_mask == 0, 0)
variance = torch.sum(x_enc * x_enc, dim=dim2reduce, keepdim=True) + self.eps
self.stdev = torch.sqrt(variance / missing_sum)
self.mean = self.mean.detach()
self.stdev = self.stdev.detach()
stdev = torch.sqrt(variance / missing_sum)

# detach mean and stdev to avoid backpropagation
self.mean = mean.detach()
self.stdev = stdev.detach()
# normalize the input
x = x - self.mean
x = x / self.stdev

if self.affine:
# apply affine transformation
x = x * self.affine_weight
x = x + self.affine_bias
return x

def _denormalize(self, x):
# reverse affine transformation
if self.affine:
x = x - self.affine_bias
x = x / (self.affine_weight + self.eps)
# denormalize the input
x = x * self.stdev
x = x + self.mean
return x

0 comments on commit dfcdfd1

Please sign in to comment.