Skip to content

Commit

Permalink
Handle batchnorms in BatchGradientVerification (#569)
Browse files Browse the repository at this point in the history
* Handle batchnorms in BatchGradientVerification

* address feedback

* rename and document context manager

* isort

* added tests

* remove unused imports

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
indigoviolet and awaelchli committed Mar 9, 2021
1 parent 4d194d4 commit 1c3f25d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
38 changes: 36 additions & 2 deletions pl_bolts/callbacks/verification/batch_gradient.py
@@ -1,7 +1,9 @@
# type: ignore
from typing import Any, Callable, List, Optional
from contextlib import contextmanager
from typing import Any, Callable, Iterable, List, Optional, Type

import torch
import torch.nn as nn
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -16,6 +18,14 @@ class BatchGradientVerification(VerificationBase):
on the wrong tensor dimensions.
"""

NORM_LAYER_CLASSES = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
torch.nn.GroupNorm,
)

def check(
self,
input_array: Any,
Expand Down Expand Up @@ -58,7 +68,8 @@ def check(
input_batch.requires_grad = True

self.model.zero_grad()
output = self._model_forward(input_array)
with selective_eval(self.model, self.NORM_LAYER_CLASSES):
output = self._model_forward(input_array)

# backward on the i-th sample should lead to gradient only in i-th input slice
output_mapping(output)[sample_idx].sum().backward()
Expand Down Expand Up @@ -190,3 +201,26 @@ def collect_batches(tensor: torch.Tensor) -> torch.Tensor:

apply_to_collection(data, dtype=torch.Tensor, function=collect_batches)
return tensors


@contextmanager
def selective_eval(model: nn.Module, layer_types: Iterable[Type[nn.Module]]) -> None:
"""
A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance``
check, so all subclasses are also affected.
Args:
model: A model which has layers that need to be set to eval mode.
layer_types: The list of class objects for which all layers of that type will be set to eval mode.
"""
to_revert = []
try:
for module in model.modules():
if isinstance(module, tuple(layer_types)):
if module.training:
module.eval()
to_revert.append(module)
yield
finally:
for module in to_revert:
module.train()
44 changes: 42 additions & 2 deletions tests/callbacks/verification/test_batch_gradient.py
Expand Up @@ -7,7 +7,7 @@
from torch import nn as nn

from pl_bolts.callbacks import BatchGradientVerificationCallback
from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping
from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping, selective_eval
from pl_bolts.utils import BatchGradientVerification


Expand All @@ -18,6 +18,7 @@ def __init__(self, mix_data=False):
super().__init__()
self.mix_data = mix_data
self.linear = nn.Linear(10, 5)
self.bn = nn.BatchNorm1d(10)
self.input_array = torch.rand(10, 5, 2)

def forward(self, *args, **kwargs):
Expand All @@ -29,7 +30,7 @@ def forward__standard(self, x):
x = x.view(10, -1).permute(1, 0).view(-1, 10) # oops!
else:
x = x.view(-1, 10) # good!
return self.linear(x)
return self.linear(self.bn(x))


class MultipleInputModel(TemplateModel):
Expand Down Expand Up @@ -255,3 +256,42 @@ def test_default_output_mapping():
)
output = default_output_mapping(data)
assert torch.all(output == expected)


class BatchNormModel(nn.Module):

def __init__(self):
super().__init__()
self.batch_norm0 = nn.BatchNorm1d(2)
self.batch_norm1 = nn.BatchNorm1d(3)
self.instance_norm = nn.InstanceNorm1d(4)


def test_selective_eval():
""" Test that the selective_eval context manager only applies to selected layer types. """
model = BatchNormModel()
model.train()
with selective_eval(model, [nn.BatchNorm1d]):
assert not model.batch_norm0.training
assert not model.batch_norm1.training
assert model.instance_norm.training

assert model.batch_norm0.training
assert model.batch_norm1.training
assert model.instance_norm.training


def test_selective_eval_invariant():
""" Test that the selective_eval context manager does not undo layers that were already in eval mode. """
model = BatchNormModel()
model.train()
model.batch_norm1.eval()
assert model.batch_norm0.training
assert not model.batch_norm1.training

with selective_eval(model, [nn.BatchNorm1d]):
assert not model.batch_norm0.training
assert not model.batch_norm1.training

assert model.batch_norm0.training
assert not model.batch_norm1.training

0 comments on commit 1c3f25d

Please sign in to comment.