Skip to content

Commit

Permalink
New metric classes (#1326)
Browse files Browse the repository at this point in the history
* Create metrics package

* Create metric.py

* Create utils.py

* Create __init__.py

* add tests for metric utils

* add docstrings for metrics utils

* add function to recursively apply other function to collection

* add tests for this function

* update test

* Update pytorch_lightning/metrics/metric.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* update metric name

* remove example docs

* fix tests

* add metric tests

* fix to tensor conversion

* fix apply to collection

* Update CHANGELOG.md

* Update pytorch_lightning/metrics/metric.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* remove tests from init

* add missing type annotations

* rename utils to convertors

* Create metrics.rst

* Update index.rst

* Update index.rst

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/metric.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/utilities/test_apply_to_collection.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/utilities/test_apply_to_collection.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Apply suggestions from code review

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* add doctest example

* rename file and fix imports

* added parametrized test

* replace lambda with inlined function

* rename apply_to_collection to apply_func

* Separated class description from init args

* Apply suggestions from code review

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* adjust random values

* suppress output when seeding

* remove gpu from doctest

* Add requested changes and add ellipsis for doctest

* forgot to push these files...

* add explicit check for dtype to convert to

* fix ddp tests

* remove explicit ddp destruction

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
justusschock and Borda committed Apr 3, 2020
1 parent 868b172 commit 4c6c4e0
Show file tree
Hide file tree
Showing 12 changed files with 899 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## Metrics (will be added to unreleased once the metric branch was finished)
- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326))

## [unreleased] - YYYY-MM-DD

### Added
Expand Down
4 changes: 3 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ PyTorch Lightning Documentation
hooks
lightning-module
loggers
metrics
trainer

.. toctree::
Expand Down Expand Up @@ -105,7 +106,8 @@ Indices and tables
pytorch_lightning.core
pytorch_lightning.callbacks
pytorch_lightning.loggers
pytorch_lightning.metrics
pytorch_lightning.overrides
pytorch_lightning.profiler
pytorch_lightning.trainer
pytorch_lightning.utilities
pytorch_lightning.utilities
4 changes: 4 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. automodule:: pytorch_lightning.metrics
:members:
:noindex:
:exclude-members:
5 changes: 5 additions & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Metrics
=======
TODO
"""
223 changes: 223 additions & 0 deletions pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
"""
This file provides functions and decorators for automated input and output
conversion to/from numpy.ndarray and torch.Tensor as well as utilities to
sync tensors between different processes in a DDP scenario, when needed.
"""

import numbers
from typing import Union, Any, Callable

import numpy as np
import torch
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from pytorch_lightning.utilities.apply_func import apply_to_collection


def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
"""
Decorator function to apply a function to all inputs of a function.
Args:
func_to_apply: the function to apply to the inputs
*dec_args: positional arguments for the function to be applied
**dec_kwargs: keyword arguments for the function to be applied
Returns:
the decorated function
"""

def decorator_fn(func_to_decorate):
# actual function applying the give function to inputs
def new_func(*args, **kwargs):
args = func_to_apply(args, *dec_args, **dec_kwargs)
kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs)
return func_to_decorate(*args, **kwargs)

return new_func

return decorator_fn


def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
"""
Decorator function to apply a function to all outputs of a function.
Args:
func_to_apply: the function to apply to the outputs
*dec_args: positional arguments for the function to be applied
**dec_kwargs: keyword arguments for the function to be applied
Returns:
the decorated function
"""

def decorator_fn(function_to_decorate):
# actual function applying the give function to outputs
def new_func(*args, **kwargs):
result = function_to_decorate(*args, **kwargs)
return func_to_apply(result, *dec_args, **dec_kwargs)

return new_func

return decorator_fn


def _convert_to_tensor(data: Any) -> Any:
"""
Maps all kind of collections and numbers to tensors.
Args:
data: the data to convert to tensor
Returns:
the converted data
"""
if isinstance(data, numbers.Number):
return torch.tensor([data])
# is not array of object
elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
return torch.from_numpy(data)
elif isinstance(data, torch.Tensor):
return data

raise TypeError("The given type ('%s') cannot be converted to a tensor!" % type(data).__name__)


def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
"""Convert all tensors and numpy arrays to numpy arrays.
Args:
data: the tensor or array to convert to numpy
Returns:
the resulting numpy array
"""
if isinstance(data, torch.Tensor):
return data.cpu().detach().numpy()
elif isinstance(data, numbers.Number):
return np.array([data])
elif isinstance(data, np.ndarray):
return data

raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)


def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator Handling the argument conversion for metrics working on numpy.
All inputs of the decorated function will be converted to numpy and all
outputs will be converted to Tensors
Args:
func_to_decorate: the function whose inputs and outputs shall be converted
Returns:
the decorated function
"""
# applies collection conversion from tensor to numpy to all inputs
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
func_convert_inputs = _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
return func_convert_in_out


def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator Handling the argument conversion for metrics working on tensors.
All inputs and outputs of the decorated function will be converted to tensors
Args:
func_to_decorate: the function whose inputs and outputs shall be converted
Returns:
the decorated function
"""
# converts all inputs to tensor if possible
# we need to include tensors here, since otherwise they will also be treated as sequences
func_convert_inputs = _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
# convert all outputs to tensor if possible
return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)


def _sync_ddp_if_available(result: Union[torch.Tensor],
group: Any = torch.distributed.group.WORLD,
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM,
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process
Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Returns:
reduced value
"""

if torch.distributed.is_available() and torch.distributed.is_initialized():
# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group,
async_op=False)

return result


def numpy_metric(group: Any = torch.distributed.group.WORLD,
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM) -> Callable:
"""
This decorator shall be used on all function metrics working on numpy arrays.
It handles the argument conversion and DDP reduction for metrics working on numpy.
All inputs of the decorated function will be converted to numpy and all
outputs will be converted to Tensors.
In DDP Training all output tensors will be reduced according to the given rules.
Args:
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Returns:
the decorated function
"""

def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
group=group,
reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))

return decorator_fn


def tensor_metric(group: Any = torch.distributed.group.WORLD,
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors.
It handles the argument conversion and DDP reduction for metrics working on tensors.
All inputs and outputs of the decorated function will be converted to tensors .
In DDP Training all output tensors will be reduced according to the given rules.
Args:
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Returns:
the decorated function
"""

def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
group=group,
reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))

return decorator_fn

0 comments on commit 4c6c4e0

Please sign in to comment.