@@ -231,3 +231,59 @@ def test_reshape_trailing_nontrivial_dim_raises() -> None:
231231 a = GrassmannTensor ((True ,), ((2 , 2 ),), torch .randn ([4 ]))
232232 with pytest .raises (AssertionError , match = "New shape exceeds after exhausting self dimensions" ):
233233 _ = a .reshape ((- 1 , (2 , 2 )))
234+
235+
236+ @pytest .mark .parametrize (
237+ "tensor" ,
238+ [
239+ GrassmannTensor (
240+ (True , True , True , True ),
241+ ((1 , 0 ), (1 , 0 ), (2 , 2 ), (8 , 8 )),
242+ torch .randn (1 , 1 , 4 , 16 ),
243+ ),
244+ ],
245+ )
246+ @pytest .mark .parametrize (
247+ "shape" ,
248+ [
249+ (1 , 64 ),
250+ ((1 , 0 ), 64 ),
251+ (- 1 , 64 ),
252+ ],
253+ )
254+ def test_reshape_trivial_head_equivalence (
255+ tensor : GrassmannTensor ,
256+ shape : tuple [int , ...],
257+ ) -> None :
258+ baseline_tensor = tensor .reshape ((1 , 64 ))
259+ actual_tensor = tensor .reshape (shape )
260+
261+ assert actual_tensor .edges == ((1 , 0 ), (32 , 32 ))
262+ assert torch .allclose (actual_tensor .tensor , baseline_tensor .tensor )
263+
264+ roundtrip_tensor = actual_tensor .reshape (tensor .edges )
265+ assert torch .allclose (roundtrip_tensor .tensor , tensor .tensor )
266+
267+
268+ def test_reshape_head_1_inserts_trivial_when_self_dim_not_one () -> None :
269+ a = GrassmannTensor (
270+ (True , True ),
271+ ((2 , 2 ), (8 , 8 )),
272+ torch .randn (4 , 16 ),
273+ )
274+ out = a .reshape ((1 , 64 ))
275+ assert out .edges == ((1 , 0 ), (32 , 32 ))
276+ assert out .tensor .shape == (1 , 64 )
277+ assert out .arrow [0 ] is False
278+
279+
280+ def test_reshape_plan_exhausted_then_skip_trivial_self_edges () -> None :
281+ a = GrassmannTensor (
282+ (False , False , False ),
283+ ((2 , 2 ), (1 , 0 ), (1 , 0 )),
284+ torch .randn (4 , 1 , 1 ),
285+ )
286+ out = a .reshape ((4 ,))
287+ assert out .edges == ((2 , 2 ),)
288+ assert out .tensor .shape == (4 ,)
289+ assert out .arrow == (False ,)
0 commit comments