diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index bb782ddadc3e4..e912ad323ab86 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -228,12 +228,13 @@ def test_sequential_mode_limits_raises(): def test_combined_loader_flattened_setter(): - combined_loader = CombinedLoader([0, [1, [2]]]) + iterables = [[0], [[1], [[2]]]] + combined_loader = CombinedLoader(iterables) with pytest.raises(ValueError, match=r"Mismatch in flattened length \(1\) and existing length \(3\)"): combined_loader.flattened = [2] - combined_loader.flattened = [3, 2, 1] - # TODO(carmocca): this should be [3, [2, [1]]] - assert combined_loader.iterables == [3, [2, 1]] + assert combined_loader.flattened == [[0], [1], [2]] + combined_loader.flattened = [[3], [2], [1]] + assert combined_loader.iterables == [[3], [[2], [[1]]]] @pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]]) diff --git a/tests/tests_pytorch/utilities/test_pytree.py b/tests/tests_pytorch/utilities/test_pytree.py index cdd76af1bc8b5..c87a83f85f6ea 100644 --- a/tests/tests_pytorch/utilities/test_pytree.py +++ b/tests/tests_pytorch/utilities/test_pytree.py @@ -43,3 +43,24 @@ def test_flatten_unflatten(): dataset1, dataset2 = ["a", "b"], ["c"] datasets = [[dataset1, dataset2]] assert_tree_flatten_unflatten(datasets, [dataset1, dataset2]) + + +def test_flatten_unflatten_depth_2_or_more(): + datasets = [range(1), [range(2), [range(3)]]] + flat, spec = _tree_flatten(datasets) + assert flat == [range(1), range(2), range(3)] + unflattened = tree_unflatten(flat, spec) + assert unflattened == datasets + + datasets = [[1], [[2], [[3]]]] + flat, spec = _tree_flatten(datasets) + assert flat == [[1], [2], [3]] + unflattened = tree_unflatten(flat, spec) + assert unflattened == datasets + + datasets = [1, [2, [3]]] + flat, spec = _tree_flatten(datasets) + # [3] is a container of all primitives so it is treated as a leaf + assert flat == [1, 2, [3]] + unflattened = tree_unflatten(flat, spec) + assert unflattened == datasets