From 1c3f25dee149949e7ace0b2ed81fb1d6a99d9253 Mon Sep 17 00:00:00 2001 From: Venky Iyer Date: Tue, 9 Mar 2021 02:59:55 -0800 Subject: [PATCH] Handle batchnorms in BatchGradientVerification (#569) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Handle batchnorms in BatchGradientVerification * address feedback * rename and document context manager * isort * added tests * remove unused imports Co-authored-by: Adrian Wälchli --- .../callbacks/verification/batch_gradient.py | 38 +++++++++++++++- .../verification/test_batch_gradient.py | 44 ++++++++++++++++++- 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/pl_bolts/callbacks/verification/batch_gradient.py b/pl_bolts/callbacks/verification/batch_gradient.py index b8ec9963af..cea50262fe 100644 --- a/pl_bolts/callbacks/verification/batch_gradient.py +++ b/pl_bolts/callbacks/verification/batch_gradient.py @@ -1,7 +1,9 @@ # type: ignore -from typing import Any, Callable, List, Optional +from contextlib import contextmanager +from typing import Any, Callable, Iterable, List, Optional, Type import torch +import torch.nn as nn from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -16,6 +18,14 @@ class BatchGradientVerification(VerificationBase): on the wrong tensor dimensions. """ + NORM_LAYER_CLASSES = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + torch.nn.GroupNorm, + ) + def check( self, input_array: Any, @@ -58,7 +68,8 @@ def check( input_batch.requires_grad = True self.model.zero_grad() - output = self._model_forward(input_array) + with selective_eval(self.model, self.NORM_LAYER_CLASSES): + output = self._model_forward(input_array) # backward on the i-th sample should lead to gradient only in i-th input slice output_mapping(output)[sample_idx].sum().backward() @@ -190,3 +201,26 @@ def collect_batches(tensor: torch.Tensor) -> torch.Tensor: apply_to_collection(data, dtype=torch.Tensor, function=collect_batches) return tensors + + +@contextmanager +def selective_eval(model: nn.Module, layer_types: Iterable[Type[nn.Module]]) -> None: + """ + A context manager that sets all requested types of layers to eval mode. This method uses an ``isinstance`` + check, so all subclasses are also affected. + + Args: + model: A model which has layers that need to be set to eval mode. + layer_types: The list of class objects for which all layers of that type will be set to eval mode. + """ + to_revert = [] + try: + for module in model.modules(): + if isinstance(module, tuple(layer_types)): + if module.training: + module.eval() + to_revert.append(module) + yield + finally: + for module in to_revert: + module.train() diff --git a/tests/callbacks/verification/test_batch_gradient.py b/tests/callbacks/verification/test_batch_gradient.py index 0f9e10405e..329a3f82bf 100644 --- a/tests/callbacks/verification/test_batch_gradient.py +++ b/tests/callbacks/verification/test_batch_gradient.py @@ -7,7 +7,7 @@ from torch import nn as nn from pl_bolts.callbacks import BatchGradientVerificationCallback -from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping +from pl_bolts.callbacks.verification.batch_gradient import default_input_mapping, default_output_mapping, selective_eval from pl_bolts.utils import BatchGradientVerification @@ -18,6 +18,7 @@ def __init__(self, mix_data=False): super().__init__() self.mix_data = mix_data self.linear = nn.Linear(10, 5) + self.bn = nn.BatchNorm1d(10) self.input_array = torch.rand(10, 5, 2) def forward(self, *args, **kwargs): @@ -29,7 +30,7 @@ def forward__standard(self, x): x = x.view(10, -1).permute(1, 0).view(-1, 10) # oops! else: x = x.view(-1, 10) # good! - return self.linear(x) + return self.linear(self.bn(x)) class MultipleInputModel(TemplateModel): @@ -255,3 +256,42 @@ def test_default_output_mapping(): ) output = default_output_mapping(data) assert torch.all(output == expected) + + +class BatchNormModel(nn.Module): + + def __init__(self): + super().__init__() + self.batch_norm0 = nn.BatchNorm1d(2) + self.batch_norm1 = nn.BatchNorm1d(3) + self.instance_norm = nn.InstanceNorm1d(4) + + +def test_selective_eval(): + """ Test that the selective_eval context manager only applies to selected layer types. """ + model = BatchNormModel() + model.train() + with selective_eval(model, [nn.BatchNorm1d]): + assert not model.batch_norm0.training + assert not model.batch_norm1.training + assert model.instance_norm.training + + assert model.batch_norm0.training + assert model.batch_norm1.training + assert model.instance_norm.training + + +def test_selective_eval_invariant(): + """ Test that the selective_eval context manager does not undo layers that were already in eval mode. """ + model = BatchNormModel() + model.train() + model.batch_norm1.eval() + assert model.batch_norm0.training + assert not model.batch_norm1.training + + with selective_eval(model, [nn.BatchNorm1d]): + assert not model.batch_norm0.training + assert not model.batch_norm1.training + + assert model.batch_norm0.training + assert not model.batch_norm1.training