Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions grassmann_tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,8 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
return GrassmannTensor(_arrow=(), _edges=(), _tensor=tensor)

if new_shape == (1,) and int(self.tensor.numel()) == 1:
eo = self._calculate_even_odd()
new_shape = (eo,)
even_self, odd_self = self._calculate_even_odd()
new_shape = ((even_self, odd_self),)

cursor_plan: int = 0
cursor_self: int = 0
Expand All @@ -318,6 +318,19 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
f"edges={self.edges}, new_shape={new_shape}"
)

if cursor_plan != len(new_shape):
new_shape_check = new_shape[cursor_plan]
if (
isinstance(new_shape_check, int)
and new_shape_check == 1
and self.tensor.shape[cursor_self] != 1
):
arrow.append(False)
edges.append((1, 0))
shape.append(1)
cursor_plan += 1
continue

if cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1:
# Does not change
arrow.append(self.arrow[cursor_self])
Expand Down
56 changes: 56 additions & 0 deletions tests/reshape_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,59 @@ def test_reshape_trailing_nontrivial_dim_raises() -> None:
a = GrassmannTensor((True,), ((2, 2),), torch.randn([4]))
with pytest.raises(AssertionError, match="New shape exceeds after exhausting self dimensions"):
_ = a.reshape((-1, (2, 2)))


@pytest.mark.parametrize(
"tensor",
[
GrassmannTensor(
(True, True, True, True),
((1, 0), (1, 0), (2, 2), (8, 8)),
torch.randn(1, 1, 4, 16),
),
],
)
@pytest.mark.parametrize(
"shape",
[
(1, 64),
((1, 0), 64),
(-1, 64),
],
)
def test_reshape_trivial_head_equivalence(
tensor: GrassmannTensor,
shape: tuple[int, ...],
) -> None:
baseline_tensor = tensor.reshape((1, 64))
actual_tensor = tensor.reshape(shape)

assert actual_tensor.edges == ((1, 0), (32, 32))
assert torch.allclose(actual_tensor.tensor, baseline_tensor.tensor)

roundtrip_tensor = actual_tensor.reshape(tensor.edges)
assert torch.allclose(roundtrip_tensor.tensor, tensor.tensor)


def test_reshape_head_1_inserts_trivial_when_self_dim_not_one() -> None:
a = GrassmannTensor(
(True, True),
((2, 2), (8, 8)),
torch.randn(4, 16),
)
out = a.reshape((1, 64))
assert out.edges == ((1, 0), (32, 32))
assert out.tensor.shape == (1, 64)
assert out.arrow[0] is False


def test_reshape_plan_exhausted_then_skip_trivial_self_edges() -> None:
a = GrassmannTensor(
(False, False, False),
((2, 2), (1, 0), (1, 0)),
torch.randn(4, 1, 1),
)
out = a.reshape((4,))
assert out.edges == ((2, 2),)
assert out.tensor.shape == (4,)
assert out.arrow == (False,)