Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix list syncronization with partly empty lists #2468

Merged
merged 25 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
61c5195
implementation + tests
SkafteNicki Mar 24, 2024
fbdbd3b
changelog
SkafteNicki Mar 24, 2024
b94824b
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Mar 26, 2024
8d2202b
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Mar 27, 2024
5d917f3
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Mar 27, 2024
fada68f
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Mar 27, 2024
a4a2250
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Mar 28, 2024
b112ab9
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Mar 28, 2024
62ce1ac
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Mar 28, 2024
c535d17
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Mar 28, 2024
5533eed
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Mar 28, 2024
d05de12
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Apr 2, 2024
e5f3465
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Apr 2, 2024
354e94b
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Apr 10, 2024
08ff965
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Apr 10, 2024
5571103
Merge branch 'master' into bugfix/sync_empty_lists
Borda Apr 10, 2024
e4848b5
Merge branch 'master' into bugfix/sync_empty_lists
SkafteNicki Apr 14, 2024
1bfef7e
fix tests
SkafteNicki Apr 14, 2024
3ba2667
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Apr 14, 2024
4f0cbc7
Merge branch 'master' into bugfix/sync_empty_lists
SkafteNicki Apr 15, 2024
ec09917
Merge branch 'master' into bugfix/sync_empty_lists
Borda Apr 15, 2024
c236951
Merge branch 'master' into bugfix/sync_empty_lists
Borda Apr 17, 2024
92c1507
only for newer versions
SkafteNicki Apr 19, 2024
5f87656
Merge branch 'master' into bugfix/sync_empty_lists
SkafteNicki Apr 19, 2024
e4dd328
Merge branch 'master' into bugfix/sync_empty_lists
mergify[bot] Apr 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462))


- Fixed list synchronization with partly empty lists ([#2468](https://github.com/Lightning-AI/torchmetrics/pull/2468))


- Fixed memory leak in metrics using list states ([#2492](https://github.com/Lightning-AI/torchmetrics/pull/2492))


Expand Down
10 changes: 10 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val
from torchmetrics.utilities.prints import rank_zero_warn

Expand Down Expand Up @@ -438,6 +439,15 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group:
if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1:
input_dict[attr] = [dim_zero_cat(input_dict[attr])]

# cornor case in distributed settings where a rank have not received any data, create empty to concatenate
if (
_TORCH_GREATER_EQUAL_2_1
and reduction_fn == dim_zero_cat
and isinstance(input_dict[attr], list)
and len(input_dict[attr]) == 0
):
input_dict[attr] = [torch.tensor([], device=self.device, dtype=self.dtype)]

output_dict = apply_to_collection(
input_dict,
Tensor,
Expand Down
20 changes: 19 additions & 1 deletion tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchmetrics import Metric
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_PROCESSES
from unittests._helpers import seed_all
Expand Down Expand Up @@ -269,11 +270,28 @@ def test_sync_on_compute(sync_on_compute, test_func):
def _test_sync_with_empty_lists(rank):
dummy = DummyListMetric()
val = dummy.compute()
assert val == []
assert torch.allclose(val, tensor([]))


@pytest.mark.DDP()
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_sync_with_empty_lists():
"""Test that synchronization of states can be enabled and disabled for compute."""
pytest.pool.map(_test_sync_with_empty_lists, range(NUM_PROCESSES))


def _test_sync_with_unequal_size_lists(rank):
"""Test that synchronization of list states work even when some ranks have not received any data yet."""
dummy = DummyListMetric()
if rank == 0:
dummy.update(torch.zeros(2))
assert torch.all(dummy.compute() == tensor([0.0, 0.0]))


@pytest.mark.DDP()
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_sync_with_unequal_size_lists():
"""Test that synchronization of states can be enabled and disabled for compute."""
pytest.pool.map(_test_sync_with_unequal_size_lists, range(NUM_PROCESSES))
Loading