From 4163499ffeab9956d5ddeca64955daedb6a3a4e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 26 Mar 2025 01:49:05 +0100 Subject: [PATCH] Add test_stack_different_required_keys --- .../unit/autojac/_transform/test_interactions.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index cab0cab02..27db5b01f 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -1,4 +1,5 @@ import torch +from pytest import raises from torch.testing import assert_close from torchjd.autojac._transform import ( @@ -248,3 +249,17 @@ def test_equivalence_jac_grads(): assert_close(jac_A, torch.stack([grad_1_A, grad_2_A])) assert_close(jac_b, torch.stack([grad_1_b, grad_2_b])) assert_close(jac_c, torch.stack([grad_1_c, grad_2_c])) + + +def test_stack_different_required_keys(): + """Tests that the Stack transform fails on transforms with different required keys.""" + + a = torch.tensor(1.0, requires_grad=True) + y1 = a * 2.0 + y2 = a * 3.0 + + grad1 = Grad([y1], [a]) + grad2 = Grad([y2], [a]) + + with raises(ValueError): + _ = Stack([grad1, grad2])