Skip to content

Commit

Permalink
Add more pytree tests (#16825)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
carmocca and awaelchli committed Feb 21, 2023
1 parent 0009cde commit 2bd54e4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/tests_pytorch/utilities/test_combined_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,13 @@ def test_sequential_mode_limits_raises():


def test_combined_loader_flattened_setter():
combined_loader = CombinedLoader([0, [1, [2]]])
iterables = [[0], [[1], [[2]]]]
combined_loader = CombinedLoader(iterables)
with pytest.raises(ValueError, match=r"Mismatch in flattened length \(1\) and existing length \(3\)"):
combined_loader.flattened = [2]
combined_loader.flattened = [3, 2, 1]
# TODO(carmocca): this should be [3, [2, [1]]]
assert combined_loader.iterables == [3, [2, 1]]
assert combined_loader.flattened == [[0], [1], [2]]
combined_loader.flattened = [[3], [2], [1]]
assert combined_loader.iterables == [[3], [[2], [[1]]]]


@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]])
Expand Down
21 changes: 21 additions & 0 deletions tests/tests_pytorch/utilities/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,24 @@ def test_flatten_unflatten():
dataset1, dataset2 = ["a", "b"], ["c"]
datasets = [[dataset1, dataset2]]
assert_tree_flatten_unflatten(datasets, [dataset1, dataset2])


def test_flatten_unflatten_depth_2_or_more():
datasets = [range(1), [range(2), [range(3)]]]
flat, spec = _tree_flatten(datasets)
assert flat == [range(1), range(2), range(3)]
unflattened = tree_unflatten(flat, spec)
assert unflattened == datasets

datasets = [[1], [[2], [[3]]]]
flat, spec = _tree_flatten(datasets)
assert flat == [[1], [2], [3]]
unflattened = tree_unflatten(flat, spec)
assert unflattened == datasets

datasets = [1, [2, [3]]]
flat, spec = _tree_flatten(datasets)
# [3] is a container of all primitives so it is treated as a leaf
assert flat == [1, 2, [3]]
unflattened = tree_unflatten(flat, spec)
assert unflattened == datasets

0 comments on commit 2bd54e4

Please sign in to comment.