From 4c6c4e02d49c02891a310643f0e1506c5b44bfc0 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 3 Apr 2020 21:10:40 +0200 Subject: [PATCH] New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec * Apply suggestions from code review Co-Authored-By: Jirka Borovec * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 + docs/source/index.rst | 4 +- docs/source/metrics.rst | 4 + pytorch_lightning/metrics/__init__.py | 5 + pytorch_lightning/metrics/converters.py | 223 +++++++++++++++++++ pytorch_lightning/metrics/metric.py | 260 ++++++++++++++++++++++ pytorch_lightning/utilities/apply_func.py | 36 +++ tests/metrics/__init__.py | 0 tests/metrics/test_converters.py | 214 ++++++++++++++++++ tests/metrics/test_metrics.py | 85 +++++++ tests/utilities/__init__.py | 0 tests/utilities/test_apply_func.py | 66 ++++++ 12 files changed, 899 insertions(+), 1 deletion(-) create mode 100644 docs/source/metrics.rst create mode 100644 pytorch_lightning/metrics/__init__.py create mode 100644 pytorch_lightning/metrics/converters.py create mode 100644 pytorch_lightning/metrics/metric.py create mode 100644 pytorch_lightning/utilities/apply_func.py create mode 100644 tests/metrics/__init__.py create mode 100644 tests/metrics/test_converters.py create mode 100644 tests/metrics/test_metrics.py create mode 100644 tests/utilities/__init__.py create mode 100644 tests/utilities/test_apply_func.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d7af014edecb..9a1802fa67809 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## Metrics (will be added to unreleased once the metric branch was finished) +- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326)) + ## [unreleased] - YYYY-MM-DD ### Added diff --git a/docs/source/index.rst b/docs/source/index.rst index 1e11f7a0e9487..68b9b2abcb263 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,6 +23,7 @@ PyTorch Lightning Documentation hooks lightning-module loggers + metrics trainer .. toctree:: @@ -105,7 +106,8 @@ Indices and tables pytorch_lightning.core pytorch_lightning.callbacks pytorch_lightning.loggers + pytorch_lightning.metrics pytorch_lightning.overrides pytorch_lightning.profiler pytorch_lightning.trainer - pytorch_lightning.utilities \ No newline at end of file + pytorch_lightning.utilities diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst new file mode 100644 index 0000000000000..6f70a3c73f2d0 --- /dev/null +++ b/docs/source/metrics.rst @@ -0,0 +1,4 @@ +.. automodule:: pytorch_lightning.metrics + :members: + :noindex: + :exclude-members: diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py new file mode 100644 index 0000000000000..18522e0dda94b --- /dev/null +++ b/pytorch_lightning/metrics/__init__.py @@ -0,0 +1,5 @@ +""" +Metrics +======= +TODO +""" diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py new file mode 100644 index 0000000000000..8162876fc3b00 --- /dev/null +++ b/pytorch_lightning/metrics/converters.py @@ -0,0 +1,223 @@ +""" +This file provides functions and decorators for automated input and output +conversion to/from numpy.ndarray and torch.Tensor as well as utilities to +sync tensors between different processes in a DDP scenario, when needed. +""" + +import numbers +from typing import Union, Any, Callable + +import numpy as np +import torch +from torch.utils.data._utils.collate import np_str_obj_array_pattern + +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable: + """ + Decorator function to apply a function to all inputs of a function. + Args: + func_to_apply: the function to apply to the inputs + *dec_args: positional arguments for the function to be applied + **dec_kwargs: keyword arguments for the function to be applied + + Returns: + the decorated function + """ + + def decorator_fn(func_to_decorate): + # actual function applying the give function to inputs + def new_func(*args, **kwargs): + args = func_to_apply(args, *dec_args, **dec_kwargs) + kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs) + return func_to_decorate(*args, **kwargs) + + return new_func + + return decorator_fn + + +def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable: + """ + Decorator function to apply a function to all outputs of a function. + Args: + func_to_apply: the function to apply to the outputs + *dec_args: positional arguments for the function to be applied + **dec_kwargs: keyword arguments for the function to be applied + + Returns: + the decorated function + """ + + def decorator_fn(function_to_decorate): + # actual function applying the give function to outputs + def new_func(*args, **kwargs): + result = function_to_decorate(*args, **kwargs) + return func_to_apply(result, *dec_args, **dec_kwargs) + + return new_func + + return decorator_fn + + +def _convert_to_tensor(data: Any) -> Any: + """ + Maps all kind of collections and numbers to tensors. + + Args: + data: the data to convert to tensor + + Returns: + the converted data + + """ + if isinstance(data, numbers.Number): + return torch.tensor([data]) + # is not array of object + elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None: + return torch.from_numpy(data) + elif isinstance(data, torch.Tensor): + return data + + raise TypeError("The given type ('%s') cannot be converted to a tensor!" % type(data).__name__) + + +def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray: + """Convert all tensors and numpy arrays to numpy arrays. + Args: + data: the tensor or array to convert to numpy + + Returns: + the resulting numpy array + + """ + if isinstance(data, torch.Tensor): + return data.cpu().detach().numpy() + elif isinstance(data, numbers.Number): + return np.array([data]) + elif isinstance(data, np.ndarray): + return data + + raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__) + + +def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable: + """ + Decorator Handling the argument conversion for metrics working on numpy. + All inputs of the decorated function will be converted to numpy and all + outputs will be converted to Tensors + + Args: + func_to_decorate: the function whose inputs and outputs shall be converted + + Returns: + the decorated function + + """ + # applies collection conversion from tensor to numpy to all inputs + # we need to include numpy arrays here, since otherwise they will also be treated as sequences + func_convert_inputs = _apply_to_inputs( + apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate) + # converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric) + func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs) + return func_convert_in_out + + +def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable: + """ + Decorator Handling the argument conversion for metrics working on tensors. + All inputs and outputs of the decorated function will be converted to tensors + + Args: + func_to_decorate: the function whose inputs and outputs shall be converted + + Returns: + the decorated function + + """ + # converts all inputs to tensor if possible + # we need to include tensors here, since otherwise they will also be treated as sequences + func_convert_inputs = _apply_to_inputs( + apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate) + # convert all outputs to tensor if possible + return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs) + + +def _sync_ddp_if_available(result: Union[torch.Tensor], + group: Any = torch.distributed.group.WORLD, + reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM, + ) -> torch.Tensor: + """ + Function to reduce the tensors from several ddp processes to one master process + + Args: + result: the value to sync and reduce (typically tensor or number) + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum + + Returns: + reduced value + + """ + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + # sync all processes before reduction + torch.distributed.barrier(group=group) + torch.distributed.all_reduce(result, op=reduce_op, group=group, + async_op=False) + + return result + + +def numpy_metric(group: Any = torch.distributed.group.WORLD, + reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM) -> Callable: + """ + This decorator shall be used on all function metrics working on numpy arrays. + + It handles the argument conversion and DDP reduction for metrics working on numpy. + All inputs of the decorated function will be converted to numpy and all + outputs will be converted to Tensors. + In DDP Training all output tensors will be reduced according to the given rules. + + Args: + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum + + Returns: + the decorated function + + """ + + def decorator_fn(func_to_decorate): + return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available, + group=group, + reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate)) + + return decorator_fn + + +def tensor_metric(group: Any = torch.distributed.group.WORLD, + reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM) -> Callable: + """ + This decorator shall be used on all function metrics working on tensors. + + It handles the argument conversion and DDP reduction for metrics working on tensors. + All inputs and outputs of the decorated function will be converted to tensors . + In DDP Training all output tensors will be reduced according to the given rules. + + Args: + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum + + Returns: + the decorated function + + """ + + def decorator_fn(func_to_decorate): + return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available, + group=group, + reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate)) + + return decorator_fn diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py new file mode 100644 index 0000000000000..50853105f94f9 --- /dev/null +++ b/pytorch_lightning/metrics/metric.py @@ -0,0 +1,260 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Union + +import torch +import torch.distributed + +from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric +from pytorch_lightning.utilities.apply_func import apply_to_collection + +__all__ = ['Metric', 'TensorMetric', 'NumpyMetric'] + + +class Metric(torch.nn.Module, ABC): + """ + Abstract Base Class for metric implementation. + + Should be used to implement metrics that + 1. Return multiple Outputs + 2. Handle their own DDP sync + """ + def __init__(self, name: str): + """ + Args: + name: the metric's name + + """ + super().__init__() + self.name = name + self._dtype = torch.get_default_dtype() + self._device = torch.device('cpu') + + @property + def dtype(self) -> Union[str, torch.dtype]: + return self._dtype + + @dtype.setter + def dtype(self, new_dtype: Union[str, torch.dtype]): + # necessary to avoid infinite recursion + raise RuntimeError('Cannot set the dtype explicitly. Please use metric.to(new_dtype).') + + @property + def device(self) -> Union[str, torch.device]: + return self._device + + @device.setter + def device(self, new_device: Union[str, torch.device]): + # Necessary to avoid infinite recursion + raise RuntimeError('Cannot set the device explicitly. Please use metric.to(new_device).') + + @abstractmethod + def forward(self, *args, **kwargs) -> torch.Tensor: + """ + Implements the actual metric computation. + + Returns: + metric value + + """ + raise NotImplementedError + + def to(self, *args, **kwargs) -> torch.nn.Module: + """Moves and/or casts the parameters and buffers. + + This can be called as + + .. function:: to(device=None, dtype=None, non_blocking=False) + + .. function:: to(dtype, non_blocking=False) + + .. function:: to(tensor, non_blocking=False) + + Its signature is similar to :meth:`torch.Tensor.to`, but only accepts + floating point desired :attr:`dtype` s. In addition, this method will + only cast the floating point parameters and buffers to :attr:`dtype` + (if given). The integral parameters and buffers will be moved + :attr:`device`, if that is given, but with dtypes unchanged. When + :attr:`non_blocking` is set, it tries to convert/move asynchronously + with respect to the host if possible, e.g., moving CPU Tensors with + pinned memory to CUDA devices. + + See below for examples. + + Note: + This method modifies the module in-place. + + Args: + device: the desired device of the parameters + and buffers in this module + dtype: the desired floating point type of + the floating point parameters and buffers in this module + tensor: Tensor whose dtype and device are the desired + dtype and device for all parameters and buffers in this module + + Returns: + Module: self + + Example:: + >>> class ExampleMetric(Metric): + ... def __init__(self, weight: torch.Tensor): + ... super().__init__('example') + ... self.register_buffer('weight', weight) + ... def forward(self, pred, target) -> torch.Tensor: + ... return (pred - target) * self.weight + >>> _ = torch.manual_seed(0) + >>> metric = ExampleMetric(torch.rand(3, 4)) + >>> metric.weight + tensor([[0.4963, 0.7682, 0.0885, 0.1320], + [0.3074, 0.6341, 0.4901, 0.8964], + [0.4556, 0.6323, 0.3489, 0.4017]]) + >>> metric.to(torch.double) #doctest: +ELLIPSIS + ExampleMetric() + >>> metric.weight + tensor([[...]], dtype=torch.float64) + >>> cpu = torch.device('cpu') + >>> metric.to(cpu, dtype=torch.half, non_blocking=True) + ExampleMetric() + >>> metric.weight #doctest: +ELLIPSIS + tensor([[...]], dtype=torch.float16) + >>> metric.to(cpu) + ExampleMetric() + >>> metric.weight #doctest: +ELLIPSIS + tensor([[...]], dtype=torch.float16) + + + """ + device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + self._device = device + + if dtype is not None: + self._dtype = dtype + + return super().to(*args, **kwargs) + + def cuda(self, device: Optional[int] = None) -> torch.nn.Module: + """Moves all model parameters and buffers to the GPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on GPU while being optimized. + + Arguments: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: + """ + + self._device = torch.device('cuda', index=device) + return super().cuda(device=device) + + def cpu(self) -> torch.nn.Module: + """Moves all model parameters and buffers to the CPU. + + Returns: + Module: self + """ + self._device = torch.device('cpu') + return super().cpu() + + def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module: + """Casts all parameters and buffers to :attr:`dst_type`. + + Arguments: + dst_type (type or string): the desired type + + Returns: + Module: self + """ + self._dtype = dst_type + return super().type(dst_type=dst_type) + + def float(self) -> torch.nn.Module: + """Casts all floating point parameters and buffers to float datatype. + + Returns: + Module: self + """ + self._dtype = torch.float + return super().float() + + def double(self) -> torch.nn.Module: + """Casts all floating point parameters and buffers to ``double`` datatype. + + Returns: + Module: self + """ + self._dtype = torch.double + return super().double() + + def half(self) -> torch.nn.Module: + """Casts all floating point parameters and buffers to ``half`` datatype. + + Returns: + Module: self + """ + self._dtype = torch.half + return super().half() + + +class TensorMetric(Metric): + """ + Base class for metric implementation operating directly on tensors. + All inputs and outputs will be casted to tensors if necessary. + Already handles DDP sync and input/output conversions. + """ + def __init__(self, name: str, + reduce_group: Optional[Any] = torch.distributed.group.WORLD, + reduce_op: Optional[Any] = torch.distributed.ReduceOp.SUM): + """ + + Args: + name: the metric's name + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__(name) + self._orig_call = tensor_metric(group=reduce_group, + reduce_op=reduce_op)(super().__call__) + + def __call__(self, *args, **kwargs) -> torch.Tensor: + def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: + return x.to(device=self.device, dtype=self.dtype) + + return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, + _to_device_dtype) + + +class NumpyMetric(Metric): + """ + Base class for metric implementation operating on numpy arrays. + All inputs will be casted to numpy if necessary and all outputs will + be casted to tensors if necessary. + Already handles DDP sync and input/output conversions. + """ + def __init__(self, name: str, + reduce_group: Optional[Any] = torch.distributed.group.WORLD, + reduce_op: Optional[Any] = torch.distributed.ReduceOp.SUM): + """ + + Args: + name: the metric's name + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__(name) + self._orig_call = numpy_metric(group=reduce_group, + reduce_op=reduce_op)(super().__call__) + + def __call__(self, *args, **kwargs) -> torch.Tensor: + def _to_device_dtype(x: torch.Tensor) -> torch.Tensor: + return x.to(device=self.device, dtype=self.dtype) + + return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor, + _to_device_dtype) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py new file mode 100644 index 0000000000000..724715c3d8607 --- /dev/null +++ b/pytorch_lightning/utilities/apply_func.py @@ -0,0 +1,36 @@ +from collections import Mapping, Sequence +from typing import Any, Callable, Union + + +def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: + """ + Recursively applies a function to all elements of a certain dtype. + + Args: + data: the collection to apply the function to + dtype: the given function will be applied to all elements of this dtype + function: the function to apply + *args: positional arguments (will be forwarded to calls of ``function``) + **kwargs: keyword arguments (will be forwarded to calls of ``function``) + + Returns: + the resulting collection + + """ + elem_type = type(data) + + # Breaking condition + if isinstance(data, dtype): + return function(data, *args, **kwargs) + + # Recursively apply to collection items + elif isinstance(data, Mapping): + return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) + for k, v in data.items()}) + elif isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple + return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data)) + elif isinstance(data, Sequence) and not isinstance(data, str): + return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data]) + + # data is neither of dtype, nor a collection + return data diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/metrics/test_converters.py b/tests/metrics/test_converters.py new file mode 100644 index 0000000000000..9abc11d4b07ad --- /dev/null +++ b/tests/metrics/test_converters.py @@ -0,0 +1,214 @@ +import numpy as np +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import tests.base.utils as tutils +from pytorch_lightning.metrics.converters import ( + _apply_to_inputs, _apply_to_outputs, _convert_to_tensor, _convert_to_numpy, + _numpy_metric_conversion, _tensor_metric_conversion, _sync_ddp_if_available, tensor_metric, numpy_metric) + + +@pytest.mark.parametrize(['args', 'kwargs'], + [pytest.param([], {}), + pytest.param([1., 2.], {}), + pytest.param([], {'a': 1., 'b': 2.}), + pytest.param([1., 2.], {'a': 1., 'b': 2.})]) +def test_apply_to_inputs(args, kwargs): + def apply_fn(inputs, factor): + if isinstance(inputs, (float, int)): + return inputs * factor + elif isinstance(inputs, dict): + return {k: apply_fn(v, factor) for k, v in inputs.items()} + elif isinstance(inputs, (tuple, list)): + return [apply_fn(x, factor) for x in inputs] + + @_apply_to_inputs(apply_fn, factor=2.) + def test_fn(*func_args, **func_kwargs): + return func_args, func_kwargs + + result_args, result_kwargs = test_fn(*args, **kwargs) + assert isinstance(result_args, (list, tuple)) + assert isinstance(result_kwargs, dict) + assert len(result_args) == len(args) + assert len(result_kwargs) == len(kwargs) + assert all([k in result_kwargs for k in kwargs.keys()]) + for arg, result_arg in zip(args, result_args): + assert arg * 2. == result_arg + + for key in kwargs.keys(): + arg = kwargs[key] + result_arg = result_kwargs[key] + assert arg * 2. == result_arg + + +def test_apply_to_outputs(): + def apply_fn(inputs, additional_str): + return str(inputs) + additional_str + + @_apply_to_outputs(apply_fn, additional_str='_str') + def test_fn(*args, **kwargs): + return 'dummy' + + assert test_fn() == 'dummy_str' + + +def test_convert_to_tensor(): + for test_item in [1., np.array([1.])]: + result_tensor = _convert_to_tensor(test_item) + assert isinstance(result_tensor, torch.Tensor) + assert result_tensor.item() == 1. + + +def test_convert_to_numpy(): + for test_item in [1., torch.tensor([1.])]: + result = _convert_to_numpy(test_item) + assert isinstance(result, np.ndarray) + assert result.item() == 1. + + +def test_numpy_metric_conversion(): + @_numpy_metric_conversion + def numpy_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, np.ndarray) + + for v in kwargs.values(): + assert isinstance(v, np.ndarray) + + return 5. + + result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. + + +def test_tensor_metric_conversion(): + @_tensor_metric_conversion + def tensor_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, torch.Tensor) + + for v in kwargs.values(): + assert isinstance(v, torch.Tensor) + + return 5. + + result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. + + +def setup_ddp(rank, worldsize, ): + import os + + os.environ['MASTER_ADDR'] = 'localhost' + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=worldsize) + + +def ddp_test_fn(rank, worldsize): + setup_ddp(rank, worldsize) + tensor = torch.tensor([1.], device='cuda:0') + + reduced_tensor = _sync_ddp_if_available(tensor) + + assert reduced_tensor.item() == dist.get_world_size(), \ + 'Sync-Reduce does not work properly with DDP and Tensors' + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_sync_reduce_ddp(): + """Make sure sync-reduce works with DDP""" + tutils.reset_seed() + tutils.set_random_master_port() + + worldsize = 2 + mp.spawn(ddp_test_fn, args=(worldsize,), nprocs=worldsize) + + +def test_sync_reduce_simple(): + """Make sure sync-reduce works without DDP""" + tensor = torch.tensor([1.], device='cpu') + + reduced_tensor = _sync_ddp_if_available(tensor) + + assert torch.allclose(tensor, reduced_tensor), \ + 'Sync-Reduce does not work properly without DDP and Tensors' + + +def _test_tensor_metric(is_ddp: bool): + @tensor_metric() + def tensor_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, torch.Tensor) + + for v in kwargs.values(): + assert isinstance(v, torch.Tensor) + + return 5. + + if is_ddp: + factor = dist.get_world_size() + else: + factor = 1. + + result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. * factor + + +def _ddp_test_tensor_metric(rank, worldsize): + setup_ddp(rank, worldsize) + _test_tensor_metric(True) + + +def test_tensor_metric_ddp(): + tutils.reset_seed() + tutils.set_random_master_port() + + world_size = 2 + mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size) + + +def test_tensor_metric_simple(): + _test_tensor_metric(False) + + +def _test_numpy_metric(is_ddp: bool): + @numpy_metric() + def numpy_test_metric(*args, **kwargs): + for arg in args: + assert isinstance(arg, np.ndarray) + + for v in kwargs.values(): + assert isinstance(v, np.ndarray) + + return 5. + + if is_ddp: + factor = dist.get_world_size() + else: + factor = 1. + + result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) + assert isinstance(result, torch.Tensor) + assert result.item() == 5. * factor + + +def _ddp_test_numpy_metric(rank, worldsize): + setup_ddp(rank, worldsize) + _test_numpy_metric(True) + + +def test_numpy_metric_ddp(): + tutils.reset_seed() + tutils.set_random_master_port() + world_size = 2 + mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size) + + +def test_numpy_metric_simple(): + _test_tensor_metric(False) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py new file mode 100644 index 0000000000000..e83a9d97b7a6c --- /dev/null +++ b/tests/metrics/test_metrics.py @@ -0,0 +1,85 @@ +import numpy as np +import torch + +from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric + + +class DummyTensorMetric(TensorMetric): + def __init__(self): + super().__init__('dummy') + + def forward(self, input1, input2): + assert isinstance(input1, torch.Tensor) + assert isinstance(input2, torch.Tensor) + return 1. + + +class DummyNumpyMetric(NumpyMetric): + def __init__(self): + super().__init__('dummy') + + def forward(self, input1, input2): + assert isinstance(input1, np.ndarray) + assert isinstance(input2, np.ndarray) + return 1. + + +def _test_metric(metric: Metric): + input1, input2 = torch.tensor([1.]), torch.tensor([2.]) + + def change_and_check_device_dtype(device, dtype): + metric.to(device=device, dtype=dtype) + + metric_val = metric(input1, input2) + assert isinstance(metric_val, torch.Tensor) + + if device is not None: + assert metric.device in [device, torch.device(device)] + assert metric_val.device in [device, torch.device(device)] + + if dtype is not None: + assert metric.dtype == dtype + assert metric_val.dtype == dtype + + devices = [None, 'cpu'] + if torch.cuda.is_available(): + devices += ['cuda:0'] + + for device in devices: + for dtype in [None, torch.float32, torch.float64]: + change_and_check_device_dtype(device=device, dtype=dtype) + + if torch.cuda.is_available(): + metric.cuda(0) + assert metric.device == torch.device('cuda', index=0) + assert metric(input1, input2).device == torch.device('cuda', index=0) + + metric.cpu() + assert metric.device == torch.device('cpu') + assert metric(input1, input2).device == torch.device('cpu') + + metric.type(torch.int8) + assert metric.dtype == torch.int8 + assert metric(input1, input2).dtype == torch.int8 + + metric.float() + assert metric.dtype == torch.float32 + assert metric(input1, input2).dtype == torch.float32 + + metric.double() + assert metric.dtype == torch.float64 + assert metric(input1, input2).dtype == torch.float64 + + if torch.cuda.is_available(): + metric.cuda() + metric.half() + assert metric.dtype == torch.float16 + assert metric(input1, input2).dtype == torch.float16 + + +def test_tensor_metric(): + _test_metric(DummyTensorMetric()) + + +def test_numpy_metric(): + _test_metric(DummyNumpyMetric()) diff --git a/tests/utilities/__init__.py b/tests/utilities/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py new file mode 100644 index 0000000000000..dce1e56e2b332 --- /dev/null +++ b/tests/utilities/test_apply_func.py @@ -0,0 +1,66 @@ +import numbers +from collections import namedtuple + +import numpy as np +import torch + +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +def test_recursive_application_to_collection(): + ntc = namedtuple('Foo', ['bar']) + + to_reduce = { + 'a': torch.tensor([1.]), # Tensor + 'b': [torch.tensor([2.])], # list + 'c': (torch.tensor([100.]),), # tuple + 'd': ntc(bar=5.), # named tuple + 'e': np.array([10.]), # numpy array + 'f': 'this_is_a_dummy_str', # string + 'g': 12. # number + } + + expected_result = { + 'a': torch.tensor([2.]), + 'b': [torch.tensor([4.])], + 'c': (torch.tensor([200.]),), + 'd': ntc(bar=torch.tensor([10.])), + 'e': np.array([20.]), + 'f': 'this_is_a_dummy_str', + 'g': 24. + } + + reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), + lambda x: x * 2) + + assert isinstance(reduced, dict), ' Type Consistency of dict not preserved' + assert all([x in reduced for x in to_reduce.keys()]), 'Not all entries of the dict were preserved' + assert all([isinstance(reduced[k], type(expected_result[k])) for k in to_reduce.keys()]), \ + 'At least one type was not correctly preserved' + + assert isinstance(reduced['a'], torch.Tensor), 'Reduction Result of a Tensor should be a Tensor' + assert torch.allclose(expected_result['a'], reduced['a']), \ + 'Reduction of a tensor does not yield the expected value' + + assert isinstance(reduced['b'], list), 'Reduction Result of a list should be a list' + assert all([torch.allclose(x, y) for x, y in zip(reduced['b'], expected_result['b'])]), \ + 'At least one value of list reduction did not come out as expected' + + assert isinstance(reduced['c'], tuple), 'Reduction Result of a tuple should be a tuple' + assert all([torch.allclose(x, y) for x, y in zip(reduced['c'], expected_result['c'])]), \ + 'At least one value of tuple reduction did not come out as expected' + + assert isinstance(reduced['d'], ntc), 'Type Consistency for named tuple not given' + assert isinstance(reduced['d'].bar, numbers.Number), \ + 'Failure in type promotion while reducing fields of named tuples' + assert reduced['d'].bar == expected_result['d'].bar + + assert isinstance(reduced['e'], np.ndarray), 'Type Promotion in reduction of numpy arrays failed' + assert reduced['e'] == expected_result['e'], \ + 'Reduction of numpy array did not yield the expected result' + + assert isinstance(reduced['f'], str), 'A string should not be reduced' + assert reduced['f'] == expected_result['f'], 'String not preserved during reduction' + + assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor' + assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'