Skip to content

Commit

Permalink
Metrics docs (#2184)
Browse files Browse the repository at this point in the history
* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* Update docs/source/metrics.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update docs/source/metrics.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update docs/source/metrics.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update docs/source/metrics.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* add workers fix

* add workers fix

* add workers fix

* doctests

* add workers fix

* add workers fix

* fixes

* fix docs

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* add workers fix

* Update docs/source/metrics.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* doctests

* add workers fix

* fix docs

* fixes

* fixes

* fix doctests

* Apply suggestions from code review

* fix doctests

* fix examples

* bug

* Update docs/source/metrics.rst

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update docs/source/metrics.rst

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update docs/source/metrics.rst

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* fixes

* fixes

* fixes

* fixes

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Nicki Skafte <nugginea@gmail.com>
  • Loading branch information
5 people committed Jun 16, 2020
1 parent e289e45 commit 55fbcc0
Show file tree
Hide file tree
Showing 7 changed files with 696 additions and 44 deletions.
322 changes: 318 additions & 4 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,318 @@
.. automodule:: pytorch_lightning.metrics
:members:
:noindex:
:exclude-members:
.. testsetup:: *

from torch.nn import Module
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.metrics import TensorMetric, NumpyMetric

Metrics
=======
This is a general package for PyTorch Metrics. These can also be used with regular non-lightning PyTorch code.
Metrics are used to monitor model performance.

In this package we provide two major pieces of functionality.

1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic.
2. A collection of popular metrics already implemented for you.

Example::

from pytorch_lightning.metrics.functional import accuracy

pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 1, 2, 2])

# calculates accuracy across all GPUs and all Nodes used in training
accuracy(pred, target)

Out::

tensor(0.7500)

--------------

Implement a metric
------------------
You can implement metrics as either a PyTorch metric or a Numpy metric. Numpy metrics
will slow down training, use PyTorch metrics when possible.

Use :class:`TensorMetric` to implement native PyTorch metrics. This class
handles automated DDP syncing and converts all inputs and outputs to tensors.

Use :class:`NumpyMetric` to implement numpy metrics. This class
handles automated DDP syncing and converts all inputs and outputs to tensors.

.. warning::
Numpy metrics might slow down your training substantially,
since every metric computation requires a GPU sync to convert tensors to numpy.

TensorMetric
^^^^^^^^^^^^
Here's an example showing how to implement a TensorMetric

.. testcode::

class RMSE(TensorMetric):
def forward(self, x, y):
return torch.sqrt(torch.mean(torch.pow(x-y, 2.0)))

.. autoclass:: pytorch_lightning.metrics.metric.TensorMetric
:noindex:

NumpyMetric
^^^^^^^^^^^
Here's an example showing how to implement a NumpyMetric

.. testcode::

class RMSE(NumpyMetric):
def forward(self, x, y):
return np.sqrt(np.mean(np.power(x-y, 2.0)))


.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric
:noindex:

--------------

Class Metrics
-------------
The following are metrics which can be instantiated as part of a module definition (even with just
plain PyTorch).

.. testcode::

from pytorch_lightning.metrics import Accuracy

# Plain PyTorch
class MyModule(Module):
def __init__(self):
super().__init__()
self.metric = Accuracy()

def forward(self, x, y):
y_hat = ...
acc = self.metric(y_hat, y)

# PyTorch Lightning
class MyModule(LightningModule):
def __init__(self):
super().__init__()
self.metric = Accuracy()

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = ...
acc = self.metric(y_hat, y)

These metrics even work when using distributed training:

.. code-block:: python
model = MyModule()
trainer = Trainer(gpus=8, num_nodes=2)
# any metric automatically reduces across GPUs (even the ones you implement using Lightning)
trainer.fit(model)
Accuracy
^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
:noindex:

AveragePrecision
^^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
:noindex:

AUROC
^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.AUROC
:noindex:

ConfusionMatrix
^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:

DiceCoefficient
^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.DiceCoefficient
:noindex:

F1
^^

.. autoclass:: pytorch_lightning.metrics.classification.F1
:noindex:

FBeta
^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.FBeta
:noindex:

PrecisionRecall
^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecall
:noindex:

Precision
^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.Precision
:noindex:

Recall
^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.Recall
:noindex:

ROC
^^^

.. autoclass:: pytorch_lightning.metrics.classification.ROC
:noindex:

MulticlassROC
^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.MulticlassROC
:noindex:

MulticlassPrecisionRecall
^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecall
:noindex:

--------------

Functional Metrics
------------------

accuracy (F)
^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.accuracy
:noindex:

auc (F)
^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.auc
:noindex:

auroc (F)
^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.auroc
:noindex:

average_precision (F)
^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.average_precision
:noindex:

confusion_matrix (F)
^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix
:noindex:

dice_score (F)
^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.dice_score
:noindex:

f1_score (F)
^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.f1_score
:noindex:

fbeta_score (F)
^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.fbeta_score
:noindex:

multiclass_precision_recall_curve (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.multiclass_precision_recall_curve
:noindex:

multiclass_roc (F)
^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.multiclass_roc
:noindex:

precision (F)
^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.precision
:noindex:

precision_recall (F)
^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.precision_recall
:noindex:

precision_recall_curve (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve
:noindex:

recall (F)
^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.recall
:noindex:

roc (F)
^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.roc
:noindex:

stat_scores (F)
^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.stat_scores
:noindex:

stat_scores_multiple_classes (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes
:noindex:

----------------

Metric pre-processing
---------------------
Metric

to_categorical (F)
^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.to_categorical
:noindex:

to_onehot (F)
^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.to_onehot
:noindex:
39 changes: 12 additions & 27 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,15 @@
"""
Metrics
=======
Metrics are generally used to monitor model performance.
The following package aims to provide the most convenient ones as well
as a structure to implement your custom metrics for all the fancy research
you want to do.
For native PyTorch implementations of metrics, it is recommended to use
the :class:`TensorMetric` which handles automated DDP syncing and conversions
to tensors for all inputs and outputs.
If your metrics implementation works on numpy, just use the
:class:`NumpyMetric`, which handles the automated conversion of
inputs to and outputs from numpy as well as automated ddp syncing.
.. warning:: Employing numpy in your metric calculation might slow
down your training substantially, since every metric computation
requires a GPU sync to convert tensors to numpy.
"""

from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning.metrics.sklearn import (
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
SklearnMetric,
Accuracy,
AveragePrecision,
AUC,
ConfusionMatrix,
F1,
FBeta,
Precision,
Recall,
PrecisionRecallCurve,
ROC,
AUROC)

0 comments on commit 55fbcc0

Please sign in to comment.