Skip to content

Commit

Permalink
pep8, changelog
Browse files Browse the repository at this point in the history
Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
  • Loading branch information
ananyahjha93 and teddykoker committed Oct 6, 2020
1 parent 2a09ea2 commit 53cd2af
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added new Metrics API. ([#3868](https://github.com/PyTorchLightning/pytorch-lightning/pull/3868))

- Enable PyTorch 1.7 compatibility ([#3541](https://github.com/PyTorchLightning/pytorch-lightning/pull/3541))

- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528))
Expand Down Expand Up @@ -63,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

- Remove old Metrics API. ([#3868](https://github.com/PyTorchLightning/pytorch-lightning/pull/3868))

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Accuracy(Metric):
>>> preds = torch.tensor([0, 2, 1, 3])
>>> accuracy = Accuracy()
>>> accuracy(preds, target)
tensor(0.5)
tensor(0.5000)
"""

Expand Down
22 changes: 13 additions & 9 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.metrics.utils import _flatten, gather_all_tensors_if_available
from pytorch_lightning.metrics.utils import dim_zero_cat, dim_zero_mean, dim_zero_sum


class Metric(nn.Module, ABC):
Expand Down Expand Up @@ -71,9 +72,10 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call
name: The name of the state variable. The variable will then be accessible at ``self.name``.
default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be
reset to this value when ``self.reset()`` is called.
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. If value is ``"sum"``,
``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, and ``torch.cat`` respectively,
each with argument ``dim=0``. The user can also pass a custom function in this parameter.
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
function in this parameter.
Note:
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
Expand All @@ -99,11 +101,11 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call
)

if dist_reduce_fx == "sum":
dist_reduce_fx = lambda x: torch.sum(x, dim=0)
dist_reduce_fx = dim_zero_sum
elif dist_reduce_fx == "mean":
dist_reduce_fx = lambda x: torch.mean(x, dim=0)
dist_reduce_fx = dim_zero_mean
elif dist_reduce_fx == "cat":
dist_reduce_fx = lambda x: torch.cat(x, dim=0)
dist_reduce_fx = dim_zero_cat
elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable):
raise ValueError(
"`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]"
Expand Down Expand Up @@ -177,9 +179,11 @@ def wrapped_func(*args, **kwargs):
if self._computed is not None:
return self._computed

if self._to_sync \
and torch.distributed.is_available() \
and torch.distributed.is_initialized():
if (
self._to_sync
and torch.distributed.is_available()
and torch.distributed.is_initialized()
):
self._sync_dist()

self._computed = compute(*args, **kwargs)
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/metrics/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError
from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError
from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError

3 changes: 1 addition & 2 deletions pytorch_lightning/metrics/regression/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pytorch_lightning.metrics.metric import Metric



class MeanAbsoluteError(Metric):
"""
Computes mean absolute error.
Expand All @@ -16,7 +15,7 @@ class MeanAbsoluteError(Metric):
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> mean_absolute_error = MeanAbsoluteError()
>>> mean_absolute_error(preds, target)
tensor(0.5)
tensor(0.5000)
"""

def __init__(
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
from typing import Any, Callable, Optional, Union


def dim_zero_cat(x):
return torch.cat(x, dim=0)


def dim_zero_sum(x):
return torch.sum(x, dim=0)


def dim_zero_mean(x):
return torch.mean(x, dim=0)


def _flatten(x):
return [item for sublist in x for item in sublist]

Expand Down

0 comments on commit 53cd2af

Please sign in to comment.