diff --git a/src/torchjd/autojac/mtl_backward.py b/src/torchjd/autojac/mtl_backward.py index 0c81342ad..e04570d2a 100644 --- a/src/torchjd/autojac/mtl_backward.py +++ b/src/torchjd/autojac/mtl_backward.py @@ -69,6 +69,9 @@ def mtl_backward( A usage example of ``mtl_backward`` is provided in :doc:`Multi-Task Learning (MTL) <../../examples/mtl>`. + .. note:: + `shared_params` and `tasks_params` must be disjoint. + .. warning:: ``mtl_backward`` relies on a usage of ``torch.vmap`` that is not compatible with compiled functions. The arguments of ``mtl_backward`` should thus not come from a compiled model. diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index af5234e00..d1727ea5a 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -566,3 +566,26 @@ def test_mtl_backward_shared_params_overlap_with_tasks_params(): shared_params=[p0], retain_graph=True, ) + + +def test_mtl_backward_default_shared_params_overlap_with_default_tasks_params(): + """ + Tests that mtl_backward raises an error when the set of shared params obtained by default + overlaps with the set of task-specific params obtained by default. + """ + + p0 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE) + p1 = torch.tensor(2.0, requires_grad=True, device=DEVICE) + p2 = torch.tensor(3.0, requires_grad=True, device=DEVICE) + + r = torch.tensor([-1.0, 1.0], device=DEVICE) @ p0 + y1 = r * p1 + y2 = p0.sum() * r * p2 + + with raises(ValueError): + mtl_backward( + losses=[y1, y2], + features=[r], + A=UPGrad(), + retain_graph=True, + )