Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient verification callback #465

Merged
merged 40 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
b0e76b8
initial commit
awaelchli Dec 20, 2020
44f8ade
docs cleanup
awaelchli Dec 20, 2020
52a5134
isort
awaelchli Dec 20, 2020
64c389b
black
awaelchli Dec 20, 2020
8b8548e
top level imports
awaelchli Dec 20, 2020
1b77ce5
rst docs
awaelchli Dec 20, 2020
21cb8a1
update chlog
awaelchli Dec 20, 2020
b97a65a
isort again
awaelchli Dec 20, 2020
9f2adcd
format
Borda Dec 20, 2020
15a3240
Merge branch 'master' into feature/gradient-verification
awaelchli Jan 9, 2021
c7b610c
Apply suggestions from code review
awaelchli Jan 9, 2021
028be19
fix import
awaelchli Jan 9, 2021
188bf09
Merge remote-tracking branch 'origin/feature/gradient-verification' i…
awaelchli Jan 9, 2021
e0dc1fb
increase coverage
awaelchli Jan 9, 2021
1b54855
don't skip tests that can partially run on cpu
awaelchli Jan 9, 2021
04eb2d7
black format
awaelchli Jan 9, 2021
6fa47a0
make bots happy
awaelchli Jan 9, 2021
79d36dc
cleanup
awaelchli Jan 9, 2021
1f2d250
more tests for full coverage
awaelchli Jan 9, 2021
4365e9a
isort, black
awaelchli Jan 9, 2021
7cb9f43
mypy complaining
awaelchli Jan 9, 2021
b8b2af2
remove unused import
awaelchli Jan 9, 2021
4ff9630
stop complain
awaelchli Jan 9, 2021
beba0b9
try type ignore
awaelchli Jan 14, 2021
af7db67
try ignore
awaelchli Jan 14, 2021
15e8991
try ignore
awaelchli Jan 14, 2021
6f52175
try ignore
awaelchli Jan 14, 2021
b8ff39e
try ignore
awaelchli Jan 14, 2021
590b722
stupid mypy
awaelchli Jan 14, 2021
df8c746
stupid mypy
awaelchli Jan 14, 2021
f4d9866
stupid mypy
awaelchli Jan 16, 2021
791f4fa
stupid mypi
awaelchli Jan 16, 2021
b6402e7
stupid mypy
awaelchli Jan 16, 2021
2638e3e
ugly yapf
awaelchli Jan 18, 2021
4a1defc
Merge branch 'master' into feature/gradient-verification
awaelchli Jan 18, 2021
3101e06
yapf :(
awaelchli Jan 18, 2021
84d98bb
yapffffffff
awaelchli Jan 18, 2021
224de45
chlog
Borda Jan 18, 2021
8439bd9
Apply suggestions from code review
Borda Jan 18, 2021
ae63dad
yapf
Borda Jan 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 60 additions & 0 deletions docs/source/info_callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,63 @@ You can track all or just a selection of submodules:

This is especially useful for debugging the data flow in complex models and to identify
numerical instabilities.


---------------

Model Verification
------------------


Gradient-Check for Batch-Optimization
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Gradient descent over a batch of samples can not only benefit the optimization but also leverages data parallelism.
However, one has to be careful not to mix data across the batch dimension.
Only a small error in a reshape or permutation operation results in the optimization getting stuck and you won't
even get a runtime error. How can one tell if the model mixes data in the batch?
A simple trick is to do the following:

1. run the model on an example batch (can be random data)
2. get the output batch and select the n-th sample (choose n)
3. compute a dummy loss value of only that sample and compute the gradient w.r.t the entire input batch
4. observe that only the i-th sample in the input batch has non-zero gradient

|

If the gradient is non-zero for the other samples in the batch, it means the forward pass of the model is mixing data!
The :class:`~pl_bolts.callbacks.verification.batch_gradient.BatchGradientVerificationCallback`
does all of that for you before training begins.

.. code-block:: python

from pytorch_lightning import Trainer
from pl_bolts.callbacks import BatchGradientVerificationCallback

model = YourLightningModule()
verification = BatchGradientVerificationCallback()
trainer = Trainer(callbacks=[verification])
trainer.fit(model)

This Callback will warn the user with the following message in case data mixing inside the batch is detected:

.. code-block::

Your model is mixing data across the batch dimension.
This can lead to wrong gradient updates in the optimizer.
Check the operations that reshape and permute tensor dimensions in your model.


A non-Callback version
:class:`~pl_bolts.callbacks.verification.batch_gradient.BatchGradientVerification`
that works with any PyTorch :class:`~torch.nn.Module` is also available:

.. code-block:: python

from verification.batch_gradient import BatchGradientVerification
akihironitta marked this conversation as resolved.
Show resolved Hide resolved

model = YourPyTorchModel()
verification = BatchGradientVerification(model)
valid = verification.check(input_array=torch.rand(2, 3, 4), sample_idx=1)

In this example we run the test on a batch size 2 by inspecting gradients on the second sample.
1 change: 1 addition & 0 deletions pl_bolts/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pl_bolts.callbacks.printing import PrintTableMetricsCallback # noqa: F401
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator # noqa: F401
from pl_bolts.callbacks.variational import LatentDimInterpolator # noqa: F401
from pl_bolts.callbacks.verification import BatchGradientVerificationCallback # noqa: F401
from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback # noqa: F401
from pl_bolts.callbacks.vision.image_generation import TensorboardGenerativeModelImageSampler # noqa: F401

Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/callbacks/verification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # noqa: F401
from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # noqa: F401
126 changes: 126 additions & 0 deletions pl_bolts/callbacks/verification/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from abc import abstractmethod
from copy import deepcopy
from typing import Any, Optional

import torch.nn as nn
from pytorch_lightning import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn


class VerificationBase:
"""
Base class for model verification.
All verifications should run with any :class:`torch.nn.Module` unless otherwise stated.
"""

def __init__(self, model: nn.Module):
"""
Arguments:
model: The model to run verification for.
"""
super().__init__()
self.model = model

@abstractmethod
def check(self, *args: Any, **kwargs: Any) -> bool:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akihironitta this mypy tool is not very smart :( It is not liking that the subclass has different args than here. I want this abstract method to be as general as possible and not specify concrete arguments and types. It should only act as an interface. Any suggestions how to proceed? I believe I have to add # type: ignore

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found a related issue in mypy repo, and it basically says that an easy workaround would be to add # type: ignore[override], so shall we just ignore it then?
python/mypy#1237 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no chance, it won't work. I tried to put it everywhere: at the top of method, on the same line as signature, below it, on top of the class, on top of the file, both in subclass and superclass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to spam #type: ignore everywhere to make it work.
This mypy tool, I don't understand it. I spent hours now studying the docs of this tool trying to figure out what the error messages mean. I tried everything, but the type: ignore are unavoidable, yet they pollute the code unnecessarily. It's unbelievably frustrating, and I can no longer work on this, sorry.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely understand your frustration. Let's ignore them all for now.

""" Runs the actual test on the model. All verification classes must implement this.

Arguments:
*args: Any positional arguments that are needed to run the test
*kwargs: Keyword arguments that are needed to run the test

Returns:
`True` if the test passes, and `False` otherwise. Some verifications can only be performed
with a heuristic accuracy, thus the return value may not always reflect the true state of
the system in these cases.
"""

def _get_input_array_copy(self, input_array: Optional[Any] = None) -> Any:
"""
Returns a deep copy of the example input array in cases where it is expected that the
input changes during the verification process.

Arguments:
input_array: The input to clone.
"""
if input_array is None and isinstance(self.model, LightningModule):
input_array = self.model.example_input_array
input_array = deepcopy(input_array)

if isinstance(self.model, LightningModule):
input_array = self.model.transfer_batch_to_device(
input_array, self.model.device
)
else:
input_array = move_data_to_device(
input_array, device=next(self.model.parameters()).device
)

return input_array

def _model_forward(self, input_array: Any) -> Any:
"""
Feeds the input array to the model via the ``__call__`` method.

Arguments:
input_array: The input that goes into the model. If it is a tuple, it gets
interpreted as the sequence of positional arguments and is passed in by tuple unpacking.
If it is a dict, the contents get passed in as named parameters by unpacking the dict.
Otherwise, the input array gets passed in as a single argument.

Returns:
The output of the model.
"""
if isinstance(input_array, tuple):
return self.model(*input_array)
if isinstance(input_array, dict):
return self.model(**input_array)
return self.model(input_array)


class VerificationCallbackBase(Callback):
"""
Base class for model verification in form of a callback.
This type of verification is expected to only work with
:class:`~pytorch_lightning.core.lightning.LightningModule` and will take the input array
from :attr:`~pytorch_lightning.core.lightning.LightningModule.example_input_array` if needed.
"""

def __init__(self, warn: bool = True, error: bool = False) -> None:
"""
Arguments:
warn: If ``True``, prints a warning message when verification fails. Default: ``True``.
error: If ``True``, prints an error message when verification fails. Default: ``False``.
"""
self._raise_warning = warn
self._raise_error = error

def message(self, *args: Any, **kwargs: Any) -> str:
"""
The message to be printed when the model does not pass the verification.
If the message for warning and error differ, override the
:meth:`warning_message` and :meth:`error_message`
methods directly.

Arguments:
*args: Any positional arguments that are needed to construct the message.
**kwargs: Any keyword arguments that are needed to construct the message.

Returns:
The message as a string.
"""

def warning_message(self, *args: Any, **kwargs: Any) -> str:
""" The warning message printed when the model does not pass the verification. """
return self.message(*args, **kwargs)

def error_message(self, *args: Any, **kwargs: Any) -> str:
""" The error message printed when the model does not pass the verification. """
return self.message(*args, **kwargs)

def _raise(self, *args: Any, **kwargs: Any) -> None:
if self._raise_error:
raise RuntimeError(self.error_message(*args, **kwargs))
if self._raise_warning:
rank_zero_warn(self.warning_message(*args, **kwargs))