Skip to content

Commit

Permalink
metric docs
Browse files Browse the repository at this point in the history
Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
  • Loading branch information
ananyahjha93 and teddykoker committed Oct 6, 2020
1 parent 2e7187c commit 2a09ea2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 16 deletions.
30 changes: 20 additions & 10 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,25 @@
Metrics
=======

Using a metric with with PyTorch Lightning:
# TODO 1: write an intro for metrics, and lead the user into the lightning example
``pytorch_lightning.metrics`` is a Metrics API created for easy metric development and usage in
PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of
common metric implementations.

# expand a bit on this
These metrics work with DDP in PyTorch and PyTorch Lightning by default.
The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits
``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class
serves the dual purpose of calling ``update()`` on its input and simultanously returning the value of the metric over the
provided input.

These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
logic present in ``.compute()`` is applied to state information from all processes.

The example below shows how to use a metric in your ``LightningModule``:

.. note::

For v0.10.0 the user is expected to call ``.compute()`` on the metric at the end each epoch.
This has been shown in the example below. For v1.0 release after this, we will integrate metrics
This has been shown in the example below. For v1.0 release, we will integrate metrics
with logging and ``.compute()`` will be called automatically by PyTorch Lightning.

.. code-block:: python
Expand All @@ -40,7 +49,7 @@ These metrics work with DDP in PyTorch and PyTorch Lightning by default.
self.log('train_acc_epoch', self.accuracy.compute())
This metrics API is independent of PyTorch Lightning. If you please, they can be used with plain PyTorch like so:
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:

.. code-block:: python
Expand Down Expand Up @@ -69,16 +78,17 @@ This metrics API is independent of PyTorch Lightning. If you please, they can be
Implementing a Metric
---------------------

# TODO 3: finalize this!, explain reduction in detail

To implement a metric, subclass the ``Metric`` class and implement the following methods:
To implement your custom metric, subclass the base ``Metric`` class and implement the following methods:

- ``__init__()``: Each state variable should be called using ``self.add_state(...)``.
- ``update()``: Any code needed to update the state given any inputs to the metric.
- ``compute()``: Computes a final value from the state of the metric.

All you need to do is call add_state correctly to implement a custom metric with DDP.
``reset()`` is called on its own on variables added using ``add_state()``.
``reset()`` is called on metric state variables added using ``add_state()``.

To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs
from the base ``Metric`` class.

Example implementation:

Expand Down
23 changes: 17 additions & 6 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Metric(nn.Module, ABC):
Note:
Different metrics only override ``update()`` and not ``forward()``. A call to ``update()``
is valid, but it won't return the metric value at the current step. A call to ``forward()``
calls ``update()`` behind the scenes and also return the metric value at the current step.
automatically calls ``update()`` and also return the metric value at the current step.
Args:
compute_on_step:
Expand Down Expand Up @@ -71,16 +71,27 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call
name: The name of the state variable. The variable will then be accessible at ``self.name``.
default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be
reset to this value when ``self.reset()`` is called.
dist_reduce_fx (Optional): Function to reduce state accross mutliple GPUs. If value is ``"sum"``,
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. If value is ``"sum"``,
``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, and ``torch.cat`` respectively,
each with argument ``dim=0``.
each with argument ``dim=0``. The user can also pass a custom function in this parameter.
Note:
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
It will be stacked ``torch.Tensor`` across the process dimension if the metric state was a ``torch.Tensor``.
However, there won't be any reduction function applied to the synchronized metric state.
The metric states would be synced as follows
- If the metric state is ``torch.Tensor``, the synced value will be a stacked ``torch.Tensor`` across
the process dimension if the metric state was a ``torch.Tensor``. The original ``torch.Tensor`` metric
state retains dimension and hence the synchronized output will be of shape ``(num_process, ...)``.
- If the metric state is a ``list``, the synced value will be a ``list`` containing the
combined elements from all processes.
Note:
When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow
the format discussed in the above note.
For the list metric state, passing None to ``dist_reduce_fx`` will return a combined list ``torch.Tensor``
elements from across all processes.
"""
if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0):
raise ValueError(
Expand Down

0 comments on commit 2a09ea2

Please sign in to comment.