Skip to content

Commit

Permalink
Feat: Optional RunningBatchNorm Affine
Browse files Browse the repository at this point in the history
  • Loading branch information
GilesStrong committed Apr 20, 2023
1 parent 5995ab7 commit 73700f8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Expand Up @@ -21,6 +21,7 @@
- `fold2foldfile`, `df2foldfile`, and `add_meta_data` can now deal with targets in the form of multi dimensional tensors, and convert them to sparse COO format
- `df2foldfile` now has the option to not shuffle data into folds and instead split it into contiguous folds
- Limited handling of PyTorch Geometric data: `TorchGeometricFoldYielder`, `TorchGeometricBatchYielder`, `TorchGeometricEvalMetric`
- Make `RunningBatchNorm` affine transformation optional

## Removals

Expand Down
27 changes: 20 additions & 7 deletions lumin/nn/models/layers/batchnorms.py
Expand Up @@ -31,16 +31,21 @@ class RunningBatchNorm1d(nn.Module):
mom: momentum (fraction to add to running averages)
n_warmup: number of warmup iterations (during which variance is clamped)
eps: epsilon to prevent division by zero
affine: whether to apply a learnable linear transformation to incoming data
'''

def __init__(self, nf:int, mom:float=0.1, n_warmup:int=20, eps:float=1e-5):
def __init__(self, nf:int, mom:float=0.1, n_warmup:int=20, eps:float=1e-5, affine:bool=True):
super().__init__()
store_attr()
self._set_params()

def _set_params(self) -> None:
self.weight = nn.Parameter(torch.ones(self.nf,1))
self.bias = nn.Parameter(torch.zeros(self.nf,1))
if self.affine:
self.weight = nn.Parameter(torch.ones(self.nf,1))
self.bias = nn.Parameter(torch.zeros(self.nf,1))
else:
self.weight = 1
self.bias = 0
self.register_buffer('sums', torch.zeros(1,self.nf,1))
self.register_buffer('sqrs', torch.zeros(1,self.nf,1))
self.register_buffer('batch', tensor(0.))
Expand Down Expand Up @@ -89,8 +94,12 @@ class RunningBatchNorm2d(RunningBatchNorm1d):
'''

def _set_params(self) -> None:
self.weight = nn.Parameter(torch.ones(self.nf,1,1))
self.bias = nn.Parameter(torch.zeros(self.nf,1,1))
if self.affine:
self.weight = nn.Parameter(torch.ones(self.nf,1,1))
self.bias = nn.Parameter(torch.zeros(self.nf,1,1))
else:
self.weight = 1
self.bias = 0
self.register_buffer('sums', torch.zeros(1,self.nf,1,1))
self.register_buffer('sqrs', torch.zeros(1,self.nf,1,1))
self.register_buffer('batch', tensor(0.))
Expand Down Expand Up @@ -120,8 +129,12 @@ class RunningBatchNorm3d(RunningBatchNorm2d):
'''

def _set_params(self) -> None:
self.weight = nn.Parameter(torch.ones(self.nf,1,1,1))
self.bias = nn.Parameter(torch.zeros(self.nf,1,1,1))
if self.affine:
self.weight = nn.Parameter(torch.ones(self.nf,1,1,1))
self.bias = nn.Parameter(torch.zeros(self.nf,1,1,1))
else:
self.weight = 1
self.bias = 0
self.register_buffer('sums', torch.zeros(1,self.nf,1,1,1))
self.register_buffer('sqrs', torch.zeros(1,self.nf,1,1,1))
self.register_buffer('batch', tensor(0.))
Expand Down

0 comments on commit 73700f8

Please sign in to comment.