Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ changes that do not affect the user.
project onto the dual cone. This may minimally affect the output of these aggregators.

### Fixed
- Fixed the behavior of `backward` and `mtl_backward` when some tensors are repeated (i.e. when they
appear several times in a list of tensors provided as argument). Instead of raising an exception
in these cases, we are now aligned with the behavior of `torch.autograd.backward`. Repeated
tensors that we differentiate lead to repeated rows in the Jacobian, prior to aggregation, and
repeated tensors with respect to which we differentiate count only once.
- Removed arbitrary exception handling in `IMTLG` and `AlignedMTL` when the computation fails. In
practice, this fix should only affect some matrices with extremely large values, which should
not usually happen.
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_transform/_differentiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(
retain_graph: bool,
create_graph: bool,
):
self.outputs = ordered_set(outputs)
self.outputs = list(outputs)
self.inputs = ordered_set(inputs)
self.retain_graph = retain_graph
self.create_graph = create_graph
Expand Down
9 changes: 1 addition & 8 deletions src/torchjd/autojac/_transform/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,7 @@


def ordered_set(elements: Iterable[_KeyType]) -> _OrderedSet[_KeyType]:
elements = list(elements)
result = OrderedDict.fromkeys(elements, None)
if len(elements) != len(result):
raise ValueError(
f"Parameter `elements` should contain unique elements. Found `elements = {elements}`."
)

return result
return OrderedDict.fromkeys(elements, None)


def dicts_union(dicts: Iterable[dict[_KeyType, _ValueType]]) -> dict[_KeyType, _ValueType]:
Expand Down
47 changes: 46 additions & 1 deletion tests/unit/autojac/test_backward.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from pytest import mark, raises
from torch.autograd import grad
from torch.testing import assert_close

from torchjd import backward
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, UPGrad
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad


@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()])
Expand Down Expand Up @@ -214,3 +215,47 @@ def test_tensor_used_multiple_times(chunk_size: int | None):
)

assert_close(a.grad, aggregator(expected_jacobian).squeeze())


def test_repeated_tensors():
"""
Tests that backward correctly works when some tensors are repeated. In this case, since
torch.autograd.backward would sum the gradients of the repeated tensors, it is natural for
autojac to compute a Jacobian with one row per repeated tensor, and to aggregate it.
"""

a1 = torch.tensor([1.0, 2.0], requires_grad=True)
a2 = torch.tensor([3.0, 4.0], requires_grad=True)

y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
y2 = (a1**2).sum() + (a2**2).sum()

expected_grad_wrt_a1 = grad([y1, y1, y2], a1, retain_graph=True)[0]
expected_grad_wrt_a2 = grad([y1, y1, y2], a2, retain_graph=True)[0]

backward([y1, y1, y2], Sum())

assert_close(a1.grad, expected_grad_wrt_a1)
assert_close(a2.grad, expected_grad_wrt_a2)


def test_repeated_inputs():
"""
Tests that backward correctly works when some inputs are repeated. In this case, since
torch.autograd.backward ignores the repetition of the inputs, it is natural for autojac to
ignore that as well.
"""

a1 = torch.tensor([1.0, 2.0], requires_grad=True)
a2 = torch.tensor([3.0, 4.0], requires_grad=True)

y1 = torch.tensor([-1.0, 1.0]) @ a1 + a2.sum()
y2 = (a1**2).sum() + (a2**2).sum()

expected_grad_wrt_a1 = grad([y1, y2], a1, retain_graph=True)[0]
expected_grad_wrt_a2 = grad([y1, y2], a2, retain_graph=True)[0]

backward([y1, y2], Sum(), inputs=[a1, a1, a2])

assert_close(a1.grad, expected_grad_wrt_a1)
assert_close(a2.grad, expected_grad_wrt_a2)
118 changes: 117 additions & 1 deletion tests/unit/autojac/test_mtl_backward.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from pytest import mark, raises
from torch.autograd import grad
from torch.testing import assert_close

from torchjd import mtl_backward
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, UPGrad
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad


@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()])
Expand Down Expand Up @@ -557,3 +558,118 @@ def test_default_shared_params_overlapping_with_default_tasks_params_fails():
features=[f],
aggregator=UPGrad(),
)


def test_repeated_losses():
"""
Tests that mtl_backward correctly works when some losses are repeated. In this case, since
torch.autograd.backward would sum the gradients of the repeated losses, it is natural for
autojac to sum the task-specific gradients, and to compute and aggregate a Jacobian with one row
per repeated tensor, for shared gradients.
"""

p0 = torch.tensor([1.0, 2.0], requires_grad=True)
p1 = torch.tensor([1.0, 2.0], requires_grad=True)
p2 = torch.tensor([3.0, 4.0], requires_grad=True)

f1 = torch.tensor([-1.0, 1.0]) @ p0
f2 = (p0**2).sum() + p0.norm()
y1 = f1 * p1[0] + f2 * p1[1]
y2 = f1 * p2[0] + f2 * p2[1]

expected_grad_wrt_p0 = grad([y1, y1, y2], [p0], retain_graph=True)[0]
expected_grad_wrt_p1 = grad([y1, y1], [p1], retain_graph=True)[0]
expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0]

losses = [y1, y1, y2]
mtl_backward(losses=losses, features=[f1, f2], aggregator=Sum(), retain_graph=True)

assert_close(p0.grad, expected_grad_wrt_p0)
assert_close(p1.grad, expected_grad_wrt_p1)
assert_close(p2.grad, expected_grad_wrt_p2)


def test_repeated_features():
"""
Tests that mtl_backward correctly works when some features are repeated. Repeated features are
a bit more tricky, because we differentiate with respect to them (in which case it shouldn't
matter that they are repeated) and we also differentiate them (in which case it should lead to
extra rows in the Jacobian).
"""

p0 = torch.tensor([1.0, 2.0], requires_grad=True)
p1 = torch.tensor([1.0, 2.0], requires_grad=True)
p2 = torch.tensor([3.0, 4.0], requires_grad=True)

f1 = torch.tensor([-1.0, 1.0]) @ p0
f2 = (p0**2).sum() + p0.norm()
y1 = f1 * p1[0] + f2 * p1[1]
y2 = f1 * p2[0] + f2 * p2[1]

grad_outputs = grad([y1, y2], [f1, f1, f2], retain_graph=True)
expected_grad_wrt_p0 = grad([f1, f1, f2], [p0], grad_outputs, retain_graph=True)[0]
expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0]
expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0]

features = [f1, f1, f2]
mtl_backward(losses=[y1, y2], features=features, aggregator=Sum())

assert_close(p0.grad, expected_grad_wrt_p0)
assert_close(p1.grad, expected_grad_wrt_p1)
assert_close(p2.grad, expected_grad_wrt_p2)


def test_repeated_shared_params():
"""
Tests that mtl_backward correctly works when some shared are repeated. Since these are tensors
with respect to which we differentiate, to match the behavior of torch.autograd.backward, this
repetition should not affect the result.
"""

p0 = torch.tensor([1.0, 2.0], requires_grad=True)
p1 = torch.tensor([1.0, 2.0], requires_grad=True)
p2 = torch.tensor([3.0, 4.0], requires_grad=True)

f1 = torch.tensor([-1.0, 1.0]) @ p0
f2 = (p0**2).sum() + p0.norm()
y1 = f1 * p1[0] + f2 * p1[1]
y2 = f1 * p2[0] + f2 * p2[1]

expected_grad_wrt_p0 = grad([y1, y2], [p0], retain_graph=True)[0]
expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0]
expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0]

shared_params = [p0, p0]
mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=Sum(), shared_params=shared_params)

assert_close(p0.grad, expected_grad_wrt_p0)
assert_close(p1.grad, expected_grad_wrt_p1)
assert_close(p2.grad, expected_grad_wrt_p2)


def test_repeated_task_params():
"""
Tests that mtl_backward correctly works when some task-specific params are repeated for some
task. Since these are tensors with respect to which we differentiate, to match the behavior of
torch.autograd.backward, this repetition should not affect the result.
"""

p0 = torch.tensor([1.0, 2.0], requires_grad=True)
p1 = torch.tensor([1.0, 2.0], requires_grad=True)
p2 = torch.tensor([3.0, 4.0], requires_grad=True)

f1 = torch.tensor([-1.0, 1.0]) @ p0
f2 = (p0**2).sum() + p0.norm()
y1 = f1 * p1[0] + f2 * p1[1]
y2 = f1 * p2[0] + f2 * p2[1]

expected_grad_wrt_p0 = grad([y1, y2], [p0], retain_graph=True)[0]
expected_grad_wrt_p1 = grad([y1], [p1], retain_graph=True)[0]
expected_grad_wrt_p2 = grad([y2], [p2], retain_graph=True)[0]

tasks_params = [[p1, p1], [p2]]
mtl_backward(losses=[y1, y2], features=[f1, f2], aggregator=Sum(), tasks_params=tasks_params)

assert_close(p0.grad, expected_grad_wrt_p0)
assert_close(p1.grad, expected_grad_wrt_p1)
assert_close(p2.grad, expected_grad_wrt_p2)