Skip to content

Commit

Permalink
Merge branch 'master' into refactor/minor-models
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Jan 18, 2021
2 parents eef3f5c + 4807ef3 commit 87a9722
Show file tree
Hide file tree
Showing 58 changed files with 1,520 additions and 124 deletions.
12 changes: 3 additions & 9 deletions .github/workflows/ci_test-full.yml
Expand Up @@ -16,15 +16,9 @@ jobs:
# max-parallel: 6
matrix:
# PyTorch 1.5 is failing on Win and bolts requires torchvision>=0.5
os: [ubuntu-20.04, macOS-10.15, windows-2019]
python-version: [3.6, 3.7, 3.8]
os: [ubuntu-20.04, macOS-10.15] #, windows-2019
python-version: [3.6, 3.8]
requires: ['minimal', 'latest']
exclude:
# TODO: temporary fix till hanging jobs on macOS for py38 is resolved
- python-version: 3.8
os: macOS-10.15
- python-version: 3.8
os: windows-2019

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 25
Expand Down Expand Up @@ -84,7 +78,6 @@ jobs:
run: |
# tox --sitepackages
coverage run --source pl_bolts -m py.test pl_bolts tests -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
coverage xml
- name: Upload pytest test results
uses: actions/upload-artifact@v2
Expand All @@ -98,6 +91,7 @@ jobs:
if: success()
run: |
coverage report
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#348](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/348),
[#323](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/323))
- Added data monitor callbacks `ModuleDataMonitor` and `TrainingDataMonitor` ([#285](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/285))
- Added DCGAN module ([#403](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/403))
- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`,
and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400))
- Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347))
Expand All @@ -23,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added metric GIoU ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347))
- Added Intersection over Union Metric/Loss ([#469](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/469))
- Added SimSiam model ([#407](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/407))
- Added gradient verification callback ([#465](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/465))
- Added Backbones to FRCNN ([#475](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/475))

### Changed

Expand All @@ -41,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Refactored `pl_bolts.callbacks` ([#477](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/477))
- Refactored the rest of `pl_bolts.models.self_supervised` ([#481](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/481),
[#479](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/479)
- Update [`torchvision.utils.make_grid`(https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid)] kwargs to `TensorboardGenerativeModelImageSampler` ([#494](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/494))

### Fixed

Expand Down
8 changes: 5 additions & 3 deletions README.md
Expand Up @@ -53,12 +53,14 @@

| System / PyTorch ver. | 1.6 (min. req.) | 1.7 (latest) |
| :---: | :---: | :---: |
| Linux py3.{6,7,8} | ![CI full testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20full%20testing/badge.svg?branch=master&event=push) | ![CI full testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20full%20testing/badge.svg?branch=master&event=push) |
| OSX py3.{6,7} | ![CI full testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20full%20testing/badge.svg?branch=master&event=push) | ![CI full testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20full%20testing/badge.svg?branch=master&event=push) |
| Windows py3.{6,7} | ![CI full testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20full%20testing/badge.svg?branch=master&event=push) | ![CI full testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20full%20testing/badge.svg?branch=master&event=push) |
| Linux py3.{6,8} | ![CI full testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20full%20testing/badge.svg?branch=master&event=push) | ![CI full testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20full%20testing/badge.svg?branch=master&event=push) |
| OSX py3.{6,8} | ![CI full testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20full%20testing/badge.svg?branch=master&event=push) | ![CI full testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20full%20testing/badge.svg?branch=master&event=push) |
| Windows py3.7* | ![CI base testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20base%20testing/badge.svg?branch=master&event=push) | ![CI base testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20base%20testing/badge.svg?branch=master&event=push) |

</center>

- _\* testing just the package itself, we skip full test suite - excluding `tests` folder_

## Install

Simple installation from PyPI
Expand Down
Binary file added docs/source/_images/gans/dcgan_lsun_dloss.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_lsun_gloss.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_lsun_outputs.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_mnist_dloss.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_mnist_gloss.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_images/gans/dcgan_mnist_outputs.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 46 additions & 1 deletion docs/source/gans.rst
Expand Up @@ -40,4 +40,49 @@ Loss curves:
.. autoclass:: pl_bolts.models.gans.GAN
:noindex:
:noindex:

DCGAN
---------
DCGAN implementation from the paper `Unsupervised Representation Learning with Deep Convolutional Generative
Adversarial Networks <https://arxiv.org/pdf/1511.06434.pdf>`_. The implementation is based on the version from
PyTorch's `examples <https://github.com/pytorch/examples/blob/master/dcgan/main.py>`_.

Implemented by:

- `Christoph Clement <https://github.com/chris-clem>`_

Example MNIST outputs:

.. image:: _images/gans/dcgan_mnist_outputs.png
:width: 400
:alt: DCGAN generated MNIST samples

Example LSUN bedroom outputs:

.. image:: _images/gans/dcgan_lsun_outputs.png
:width: 400
:alt: DCGAN generated LSUN bedroom samples

MNIST Loss curves:

.. image:: _images/gans/dcgan_mnist_dloss.png
:width: 200
:alt: DCGAN MNIST disc loss

.. image:: _images/gans/dcgan_mnist_gloss.png
:width: 200
:alt: DCGAN MNIST gen loss

LSUN Loss curves:

.. image:: _images/gans/dcgan_lsun_dloss.png
:width: 200
:alt: DCGAN LSUN disc loss

.. image:: _images/gans/dcgan_lsun_gloss.png
:width: 200
:alt: DCGAN LSUN gen loss

.. autoclass:: pl_bolts.models.gans.DCGAN
:noindex:
60 changes: 60 additions & 0 deletions docs/source/info_callbacks.rst
Expand Up @@ -64,3 +64,63 @@ You can track all or just a selection of submodules:
This is especially useful for debugging the data flow in complex models and to identify
numerical instabilities.


---------------

Model Verification
------------------


Gradient-Check for Batch-Optimization
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Gradient descent over a batch of samples can not only benefit the optimization but also leverages data parallelism.
However, one has to be careful not to mix data across the batch dimension.
Only a small error in a reshape or permutation operation results in the optimization getting stuck and you won't
even get a runtime error. How can one tell if the model mixes data in the batch?
A simple trick is to do the following:

1. run the model on an example batch (can be random data)
2. get the output batch and select the n-th sample (choose n)
3. compute a dummy loss value of only that sample and compute the gradient w.r.t the entire input batch
4. observe that only the i-th sample in the input batch has non-zero gradient

|
If the gradient is non-zero for the other samples in the batch, it means the forward pass of the model is mixing data!
The :class:`~pl_bolts.callbacks.verification.batch_gradient.BatchGradientVerificationCallback`
does all of that for you before training begins.

.. code-block:: python
from pytorch_lightning import Trainer
from pl_bolts.callbacks import BatchGradientVerificationCallback
model = YourLightningModule()
verification = BatchGradientVerificationCallback()
trainer = Trainer(callbacks=[verification])
trainer.fit(model)
This Callback will warn the user with the following message in case data mixing inside the batch is detected:

.. code-block::
Your model is mixing data across the batch dimension.
This can lead to wrong gradient updates in the optimizer.
Check the operations that reshape and permute tensor dimensions in your model.
A non-Callback version
:class:`~pl_bolts.callbacks.verification.batch_gradient.BatchGradientVerification`
that works with any PyTorch :class:`~torch.nn.Module` is also available:

.. code-block:: python
from pl_bolts.utils import BatchGradientVerification
model = YourPyTorchModel()
verification = BatchGradientVerification(model)
valid = verification.check(input_array=torch.rand(2, 3, 4), sample_idx=1)
In this example we run the test on a batch size 2 by inspecting gradients on the second sample.
2 changes: 2 additions & 0 deletions pl_bolts/callbacks/__init__.py
Expand Up @@ -6,10 +6,12 @@
from pl_bolts.callbacks.printing import PrintTableMetricsCallback # noqa: F401
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator # noqa: F401
from pl_bolts.callbacks.variational import LatentDimInterpolator # noqa: F401
from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # type: ignore
from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback # noqa: F401
from pl_bolts.callbacks.vision.image_generation import TensorboardGenerativeModelImageSampler # noqa: F401

__all__ = [
"BatchGradientVerificationCallback",
"BYOLMAWeightUpdate",
"ModuleDataMonitor",
"TrainingDataMonitor",
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/callbacks/data_monitor.py
Expand Up @@ -18,7 +18,7 @@
import wandb
else: # pragma: no cover
warn_missing_pkg("wandb")
wandb = None
wandb = None # type: ignore


class DataMonitorBase(Callback):
Expand Down
Empty file.
123 changes: 123 additions & 0 deletions pl_bolts/callbacks/verification/base.py
@@ -0,0 +1,123 @@
# type: ignore
from abc import abstractmethod
from copy import deepcopy
from typing import Any, Optional

import torch.nn as nn
from pytorch_lightning import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn


class VerificationBase:
"""
Base class for model verification.
All verifications should run with any :class:`torch.nn.Module` unless otherwise stated.
"""

def __init__(self, model: nn.Module):
"""
Arguments:
model: The model to run verification for.
"""
super().__init__()
self.model = model

@abstractmethod
def check(self, *args: Any, **kwargs: Any) -> bool:
""" Runs the actual test on the model. All verification classes must implement this.
Arguments:
*args: Any positional arguments that are needed to run the test
*kwargs: Keyword arguments that are needed to run the test
Returns:
`True` if the test passes, and `False` otherwise. Some verifications can only be performed
with a heuristic accuracy, thus the return value may not always reflect the true state of
the system in these cases.
"""

def _get_input_array_copy(self, input_array: Optional[Any] = None) -> Any:
"""
Returns a deep copy of the example input array in cases where it is expected that the
input changes during the verification process.
Arguments:
input_array: The input to clone.
"""
if input_array is None and isinstance(self.model, LightningModule):
input_array = self.model.example_input_array
input_array = deepcopy(input_array)

if isinstance(self.model, LightningModule):
input_array = self.model.transfer_batch_to_device(input_array, self.model.device)
else:
input_array = move_data_to_device(input_array, device=next(self.model.parameters()).device)

return input_array

def _model_forward(self, input_array: Any) -> Any:
"""
Feeds the input array to the model via the ``__call__`` method.
Arguments:
input_array: The input that goes into the model. If it is a tuple, it gets
interpreted as the sequence of positional arguments and is passed in by tuple unpacking.
If it is a dict, the contents get passed in as named parameters by unpacking the dict.
Otherwise, the input array gets passed in as a single argument.
Returns:
The output of the model.
"""
if isinstance(input_array, tuple):
return self.model(*input_array)
if isinstance(input_array, dict):
return self.model(**input_array)
return self.model(input_array)


class VerificationCallbackBase(Callback):
"""
Base class for model verification in form of a callback.
This type of verification is expected to only work with
:class:`~pytorch_lightning.core.lightning.LightningModule` and will take the input array
from :attr:`~pytorch_lightning.core.lightning.LightningModule.example_input_array` if needed.
"""

def __init__(self, warn: bool = True, error: bool = False) -> None:
"""
Arguments:
warn: If ``True``, prints a warning message when verification fails. Default: ``True``.
error: If ``True``, prints an error message when verification fails. Default: ``False``.
"""
self._raise_warning = warn
self._raise_error = error

def message(self, *args: Any, **kwargs: Any) -> str:
"""
The message to be printed when the model does not pass the verification.
If the message for warning and error differ, override the
:meth:`warning_message` and :meth:`error_message`
methods directly.
Arguments:
*args: Any positional arguments that are needed to construct the message.
**kwargs: Any keyword arguments that are needed to construct the message.
Returns:
The message as a string.
"""

def warning_message(self, *args: Any, **kwargs: Any) -> str:
""" The warning message printed when the model does not pass the verification. """
return self.message(*args, **kwargs)

def error_message(self, *args: Any, **kwargs: Any) -> str:
""" The error message printed when the model does not pass the verification. """
return self.message(*args, **kwargs)

def _raise(self, *args: Any, **kwargs: Any) -> None:
if self._raise_error:
raise RuntimeError(self.error_message(*args, **kwargs))
if self._raise_warning:
rank_zero_warn(self.warning_message(*args, **kwargs))

0 comments on commit 87a9722

Please sign in to comment.