From 2bd54e460296b343f87480be4048e36b01ea5168 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 21 Feb 2023 17:27:00 +0100 Subject: [PATCH] Add more pytree tests (#16825) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- .../utilities/test_combined_loader.py | 9 ++++---- tests/tests_pytorch/utilities/test_pytree.py | 21 +++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index bb782ddadc3e4..e912ad323ab86 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -228,12 +228,13 @@ def test_sequential_mode_limits_raises(): def test_combined_loader_flattened_setter(): - combined_loader = CombinedLoader([0, [1, [2]]]) + iterables = [[0], [[1], [[2]]]] + combined_loader = CombinedLoader(iterables) with pytest.raises(ValueError, match=r"Mismatch in flattened length \(1\) and existing length \(3\)"): combined_loader.flattened = [2] - combined_loader.flattened = [3, 2, 1] - # TODO(carmocca): this should be [3, [2, [1]]] - assert combined_loader.iterables == [3, [2, 1]] + assert combined_loader.flattened == [[0], [1], [2]] + combined_loader.flattened = [[3], [2], [1]] + assert combined_loader.iterables == [[3], [[2], [[1]]]] @pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]]) diff --git a/tests/tests_pytorch/utilities/test_pytree.py b/tests/tests_pytorch/utilities/test_pytree.py index cdd76af1bc8b5..c87a83f85f6ea 100644 --- a/tests/tests_pytorch/utilities/test_pytree.py +++ b/tests/tests_pytorch/utilities/test_pytree.py @@ -43,3 +43,24 @@ def test_flatten_unflatten(): dataset1, dataset2 = ["a", "b"], ["c"] datasets = [[dataset1, dataset2]] assert_tree_flatten_unflatten(datasets, [dataset1, dataset2]) + + +def test_flatten_unflatten_depth_2_or_more(): + datasets = [range(1), [range(2), [range(3)]]] + flat, spec = _tree_flatten(datasets) + assert flat == [range(1), range(2), range(3)] + unflattened = tree_unflatten(flat, spec) + assert unflattened == datasets + + datasets = [[1], [[2], [[3]]]] + flat, spec = _tree_flatten(datasets) + assert flat == [[1], [2], [3]] + unflattened = tree_unflatten(flat, spec) + assert unflattened == datasets + + datasets = [1, [2, [3]]] + flat, spec = _tree_flatten(datasets) + # [3] is a container of all primitives so it is treated as a leaf + assert flat == [1, 2, [3]] + unflattened = tree_unflatten(flat, spec) + assert unflattened == datasets