Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Base class for metrics #1293

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## Metrics Package (will be added to unreleased, once it's finished)
- Added base-metric ([#1293](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))

## [unreleased] - YYYY-MM-DD

### Added
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
Metrics
=======

TODO
"""
122 changes: 122 additions & 0 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import numbers
from abc import ABC, abstractmethod
from collections import Mapping, Sequence
from functools import partial
from typing import Union, Any, Optional

import torch
import torch.distributed
from torch.utils.data._utils.collate import np_str_obj_array_pattern

__all__ = ['BaseMetric']


class BaseMetric(torch.nn.Module, ABC):
def __init__(self, name: str,
reduce_group: Optional[Any] = torch.distributed.group.WORLD,
reduce_op: Optional[Any] = torch.distributed.ReduceOp.SUM):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inst mean more frequently used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
Abstract Base Class for metric implementation.

Automatically handles the computation
Args:
name: the metric's name
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
"""
super().__init__()
self.name = name
self.reduce_op = reduce_op
self.reduce_group = reduce_group

@abstractmethod
def forward(self, *args, **kwargs) -> torch.Tensor:
"""
Implements the actual metric computation.

Returns:
metric value

"""
raise NotImplementedError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for abstract no need for this raise

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, but it doesn't matter if you put pass or raise NotImplementedError there and I found this to be more intuitive, but we can still change this, if you want...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line won't be ever tested so lover test coverage count... lol


def __call__(self, *args, **kwargs) -> torch.Tensor:
return _sync_collections(super().__call__(*args, **kwargs),
group=self.reduce_group,
reduce_op=self.reduce_op)


def _sync_ddp(result: Union[torch.Tensor, numbers.Number],
group: Any = torch.distributed.group.WORLD,
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process

Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum

Returns:
reduced value

"""

# convert to tensor if necessary
if not isinstance(result, torch.Tensor):
result = torch.tensor(result)

if torch.distributed.is_available() and torch.distributed.is_initialized():
# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group,
async_op=False)

return result


def _sync_collections(result: Union[torch.Tensor, numbers.Number,
Mapping, Sequence],
group: Any = torch.distributed.group.WORLD,
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM
) -> Union[torch.Tensor, numbers.Number,
Mapping, Sequence]:
"""
Recursively applies sync_ddp to collections

Args:
result: Tensor or Number or Mapping or Sequence holding the values to be reduced
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum

Returns:
the reduced collection

"""
# function adapted from torch.utils.data._utils.collate
elem_type = type(result)

func = partial(_sync_collections, group=group, reduce_op=reduce_op)

# convert numpy to tensor if possible
if elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
# array not of string classes and object
if elem_type.__name__ != 'ndarray' \
or np_str_obj_array_pattern.search(result.dtype.str) is None:
result = torch.as_tensor(result)

if isinstance(result, (torch.Tensor, numbers.Number)):
return _sync_ddp(result, group=group, reduce_op=reduce_op)

elif isinstance(result, Mapping):
return elem_type({key: func(result[key]) for key in result})
elif isinstance(result, tuple) and hasattr(result, '_fields'): # namedtuple
return elem_type(*(func(r) for r in result))
elif isinstance(result, Sequence) and not isinstance(result, str):
return elem_type([func(r) for r in result])
# not possible to reduce this type
else:
return result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not safe since all result assignment is under some conditions...
rather instead each return do result assignment and always return result

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be safe, since we just return the input variable as is, if there is nothing to sync...

Empty file added tests/metrics/__init__.py
Empty file.
162 changes: 162 additions & 0 deletions tests/metrics/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from collections import namedtuple

import numpy as np
import pytest
import torch
import torch.distributed as dist

import tests.base.utils as tutils
from pytorch_lightning.metrics.metric import _sync_ddp, _sync_collections, BaseMetric


@pytest.mark.skipif(torch.cuda.device_count() < 2,
'Not enough GPUs to test sync reduce')
justusschock marked this conversation as resolved.
Show resolved Hide resolved
def test_sync_reduce_ddp():
"""Make sure sync-reduce works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()

dist.init_process_group('gloo')

tensor = torch.tensor([1.], device='cuda:0')

reduced_tensor = _sync_ddp(tensor)
assert reduced_tensor.item() == dist.get_world_size(), \
'Sync-Reduce does not work properly with DDP and Tensors'

number = 1.
reduced_number = _sync_ddp(number)
assert isinstance(reduced_number, torch.Tensor), 'When reducing a number we should get a tensor out'
assert reduced_number.item() == dist.get_world_size(), \
'Sync-Reduce does not work properly with DDP and Numbers'

dist.destroy_process_group()


def test_sync_reduce_no_ddp():
justusschock marked this conversation as resolved.
Show resolved Hide resolved
"""Make sure sync-reduce works without DDP"""
tensor = torch.tensor([1.], device='cpu')

reduced_tensor = _sync_ddp(tensor)

assert torch.allclose(tensor,
reduced_tensor), 'Sync-Reduce does not work properly without DDP and Tensors'

number = 1.
reduced_number = _sync_ddp(number)
assert isinstance(reduced_number, torch.Tensor), 'When reducing a number we should get a tensor out'
assert reduced_number.item() == number, 'Sync-Reduce does not work properly without DDP and Numbers'


def _sync_collections_test(is_ddp: bool):
ntc = namedtuple('Foo', ['bar'])
to_reduce = {
'a': torch.tensor([1.]), # Tensor
'b': [torch.tensor([2.])], # list
'c': (torch.tensor([100.]),), # tuple
'd': ntc(bar=5.), # named tuple
'e': np.array([10.]), # numpy array
'f': 'this_is_a_dummy_str', # string
'g': 12. # number
}

if is_ddp:
factor = dist.get_world_size()
else:
factor = 1.

expected_result = {
'a': torch.tensor([1. * factor]),
'b': [torch.tensor([2. * factor])],
'c': (torch.tensor([100. * factor]),),
'd': ntc(bar=torch.tensor([5. * factor])),
'e': torch.tensor([10. * factor]),
'f': 'this_is_a_dummy_str',
'g': torch.tensor([12. * factor])

justusschock marked this conversation as resolved.
Show resolved Hide resolved
}

reduced = _sync_collections(to_reduce)

assert isinstance(reduced, dict), ' Type Consistency of dict not preserved'
assert all([x in reduced for x in to_reduce.keys()]), 'Not all entries of the dict were preserved'
assert all([isinstance(reduced[k], type(expected_result[k])) for k in to_reduce.keys()]), \
'At least one type was not correctly preserved'

assert isinstance(reduced['a'], torch.Tensor), 'Reduction Result of a Tensor should be a Tensor'
assert torch.allclose(expected_result['a'],
reduced['a']), 'Reduction of a tensor does not yield the expected value'

assert isinstance(reduced['b'], list), 'Reduction Result of a list should be a list'
assert all([torch.allclose(x, y) for x, y in zip(reduced['b'], expected_result['b'])]), \
'At least one value of list reduction did not come out as expected'

assert isinstance(reduced['c'], tuple), 'Reduction Result of a tuple should be a tuple'
assert all([torch.allclose(x, y) for x, y in zip(reduced['c'], expected_result['c'])]), \
'At least one value of tuple reduction did not come out as expected'

assert isinstance(reduced['d'], ntc), 'Type Consistency for named tuple not given'
assert isinstance(reduced['d'].bar,
torch.Tensor), 'Failure in type promotion while reducing fields of named tuples'
assert torch.allclose(reduced['d'].bar, expected_result['d'].bar)

assert isinstance(reduced['e'], torch.Tensor), 'Type Promotion in reduction of numpy arrays failed'
assert torch.allclose(reduced['e'], expected_result['e']), \
'Reduction of numpy array did not yield the expected result'

assert isinstance(reduced['f'], str), 'A string should not be reduced'
assert reduced['f'] == expected_result['f'], 'String not preserved during reduction'

assert isinstance(reduced['g'], torch.Tensor), 'Reduction of a number should result in a tensor'
assert torch.allclose(reduced['g'],
expected_result['g']), 'Reduction of a number did not yield the desired result'


@pytest.mark.skipif(torch.cuda.device_count() < 2,
'Not enough GPUs to test sync reduce')
def test_sync_collections_ddp():
tutils.reset_seed()
tutils.set_random_master_port()

dist.init_process_group('gloo')

_sync_collections_test(True)

dist.destroy_process_group()


def test_sync_collections_no_ddp():
justusschock marked this conversation as resolved.
Show resolved Hide resolved
_sync_collections_test(False)


def _test_base_metric(is_ddp):
class DummyMetric(BaseMetric):
def __init__(self):
super().__init__(name='Dummy')

def forward(self):
return 1.

dummy_metric = DummyMetric()

assert dummy_metric.name == 'Dummy'
metric_val = dummy_metric()

if is_ddp:
expected = dist.get_world_size()
else:
expected = 1.

assert isinstance(metric_val, torch.Tensor), \
'The result value should be synced and reduced which would promote the type from number to tensor'
assert metric_val.item() == expected, 'Invalid Value for reduction'


@pytest.mark.skipif(torch.cuda.device_count() < 2,
'Not enough GPUs to test with ddp')
def test_base_metric_ddp():
_test_base_metric(True)


def test_base_metric_no_ddp():
justusschock marked this conversation as resolved.
Show resolved Hide resolved
_test_base_metric(False)