Skip to content

Commit

Permalink
gradient verification callback (#465)
Browse files Browse the repository at this point in the history
* initial commit

* docs cleanup

* isort

* black

* top level imports

* rst docs

* update chlog

* isort again

* format

* Apply suggestions from code review

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>

* fix import

* increase coverage

* don't skip tests that can partially run on cpu

* black format

* make bots happy

* cleanup

* more tests for full coverage

* isort, black

* mypy complaining

* remove unused import

* stop complain

* try type ignore

* try ignore

* try ignore

* try ignore

* try ignore

* stupid mypy

* stupid mypy

* stupid mypy

* stupid mypi

* stupid mypy

* ugly yapf

* yapf :(

* yapffffffff

* chlog

* Apply suggestions from code review

* yapf

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
4 people committed Jan 18, 2021
1 parent 58aa93a commit 42cfa8f
Show file tree
Hide file tree
Showing 11 changed files with 736 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added metric GIoU ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347))
- Added Intersection over Union Metric/Loss ([#469](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/469))
- Added SimSiam model ([#407](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/407))
- Added gradient verification callback ([#465](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/465))

### Changed

Expand Down
60 changes: 60 additions & 0 deletions docs/source/info_callbacks.rst
Expand Up @@ -64,3 +64,63 @@ You can track all or just a selection of submodules:
This is especially useful for debugging the data flow in complex models and to identify
numerical instabilities.


---------------

Model Verification
------------------


Gradient-Check for Batch-Optimization
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Gradient descent over a batch of samples can not only benefit the optimization but also leverages data parallelism.
However, one has to be careful not to mix data across the batch dimension.
Only a small error in a reshape or permutation operation results in the optimization getting stuck and you won't
even get a runtime error. How can one tell if the model mixes data in the batch?
A simple trick is to do the following:

1. run the model on an example batch (can be random data)
2. get the output batch and select the n-th sample (choose n)
3. compute a dummy loss value of only that sample and compute the gradient w.r.t the entire input batch
4. observe that only the i-th sample in the input batch has non-zero gradient

|
If the gradient is non-zero for the other samples in the batch, it means the forward pass of the model is mixing data!
The :class:`~pl_bolts.callbacks.verification.batch_gradient.BatchGradientVerificationCallback`
does all of that for you before training begins.

.. code-block:: python
from pytorch_lightning import Trainer
from pl_bolts.callbacks import BatchGradientVerificationCallback
model = YourLightningModule()
verification = BatchGradientVerificationCallback()
trainer = Trainer(callbacks=[verification])
trainer.fit(model)
This Callback will warn the user with the following message in case data mixing inside the batch is detected:

.. code-block::
Your model is mixing data across the batch dimension.
This can lead to wrong gradient updates in the optimizer.
Check the operations that reshape and permute tensor dimensions in your model.
A non-Callback version
:class:`~pl_bolts.callbacks.verification.batch_gradient.BatchGradientVerification`
that works with any PyTorch :class:`~torch.nn.Module` is also available:

.. code-block:: python
from pl_bolts.utils import BatchGradientVerification
model = YourPyTorchModel()
verification = BatchGradientVerification(model)
valid = verification.check(input_array=torch.rand(2, 3, 4), sample_idx=1)
In this example we run the test on a batch size 2 by inspecting gradients on the second sample.
2 changes: 2 additions & 0 deletions pl_bolts/callbacks/__init__.py
Expand Up @@ -6,10 +6,12 @@
from pl_bolts.callbacks.printing import PrintTableMetricsCallback # noqa: F401
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator # noqa: F401
from pl_bolts.callbacks.variational import LatentDimInterpolator # noqa: F401
from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # type: ignore
from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback # noqa: F401
from pl_bolts.callbacks.vision.image_generation import TensorboardGenerativeModelImageSampler # noqa: F401

__all__ = [
"BatchGradientVerificationCallback",
"BYOLMAWeightUpdate",
"ModuleDataMonitor",
"TrainingDataMonitor",
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/callbacks/data_monitor.py
Expand Up @@ -18,7 +18,7 @@
import wandb
else: # pragma: no cover
warn_missing_pkg("wandb")
wandb = None
wandb = None # type: ignore


class DataMonitorBase(Callback):
Expand Down
Empty file.
123 changes: 123 additions & 0 deletions pl_bolts/callbacks/verification/base.py
@@ -0,0 +1,123 @@
# type: ignore
from abc import abstractmethod
from copy import deepcopy
from typing import Any, Optional

import torch.nn as nn
from pytorch_lightning import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn


class VerificationBase:
"""
Base class for model verification.
All verifications should run with any :class:`torch.nn.Module` unless otherwise stated.
"""

def __init__(self, model: nn.Module):
"""
Arguments:
model: The model to run verification for.
"""
super().__init__()
self.model = model

@abstractmethod
def check(self, *args: Any, **kwargs: Any) -> bool:
""" Runs the actual test on the model. All verification classes must implement this.
Arguments:
*args: Any positional arguments that are needed to run the test
*kwargs: Keyword arguments that are needed to run the test
Returns:
`True` if the test passes, and `False` otherwise. Some verifications can only be performed
with a heuristic accuracy, thus the return value may not always reflect the true state of
the system in these cases.
"""

def _get_input_array_copy(self, input_array: Optional[Any] = None) -> Any:
"""
Returns a deep copy of the example input array in cases where it is expected that the
input changes during the verification process.
Arguments:
input_array: The input to clone.
"""
if input_array is None and isinstance(self.model, LightningModule):
input_array = self.model.example_input_array
input_array = deepcopy(input_array)

if isinstance(self.model, LightningModule):
input_array = self.model.transfer_batch_to_device(input_array, self.model.device)
else:
input_array = move_data_to_device(input_array, device=next(self.model.parameters()).device)

return input_array

def _model_forward(self, input_array: Any) -> Any:
"""
Feeds the input array to the model via the ``__call__`` method.
Arguments:
input_array: The input that goes into the model. If it is a tuple, it gets
interpreted as the sequence of positional arguments and is passed in by tuple unpacking.
If it is a dict, the contents get passed in as named parameters by unpacking the dict.
Otherwise, the input array gets passed in as a single argument.
Returns:
The output of the model.
"""
if isinstance(input_array, tuple):
return self.model(*input_array)
if isinstance(input_array, dict):
return self.model(**input_array)
return self.model(input_array)


class VerificationCallbackBase(Callback):
"""
Base class for model verification in form of a callback.
This type of verification is expected to only work with
:class:`~pytorch_lightning.core.lightning.LightningModule` and will take the input array
from :attr:`~pytorch_lightning.core.lightning.LightningModule.example_input_array` if needed.
"""

def __init__(self, warn: bool = True, error: bool = False) -> None:
"""
Arguments:
warn: If ``True``, prints a warning message when verification fails. Default: ``True``.
error: If ``True``, prints an error message when verification fails. Default: ``False``.
"""
self._raise_warning = warn
self._raise_error = error

def message(self, *args: Any, **kwargs: Any) -> str:
"""
The message to be printed when the model does not pass the verification.
If the message for warning and error differ, override the
:meth:`warning_message` and :meth:`error_message`
methods directly.
Arguments:
*args: Any positional arguments that are needed to construct the message.
**kwargs: Any keyword arguments that are needed to construct the message.
Returns:
The message as a string.
"""

def warning_message(self, *args: Any, **kwargs: Any) -> str:
""" The warning message printed when the model does not pass the verification. """
return self.message(*args, **kwargs)

def error_message(self, *args: Any, **kwargs: Any) -> str:
""" The error message printed when the model does not pass the verification. """
return self.message(*args, **kwargs)

def _raise(self, *args: Any, **kwargs: Any) -> None:
if self._raise_error:
raise RuntimeError(self.error_message(*args, **kwargs))
if self._raise_warning:
rank_zero_warn(self.warning_message(*args, **kwargs))

0 comments on commit 42cfa8f

Please sign in to comment.