diff --git a/CHANGELOG.md b/CHANGELOG.md index fcc34c937..7d913493a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added a new `torchjd.scalarization` package providing the abstract `Scalarizer` base class and + the concrete implementations `Constant`, `Mean`, `Random`, and `Sum`. These baselines simply + combine losses into a scalar that can be optimized with a standard backward pass, making them + useful for comparison with JD-based methods. - Added `FairGrad` and `FairGradWeighting` from [Fair Resource Allocation in Multi-Task Learning](https://arxiv.org/pdf/2402.15638). diff --git a/docs/source/docs/scalarization/constant.rst b/docs/source/docs/scalarization/constant.rst new file mode 100644 index 000000000..bcbf0217c --- /dev/null +++ b/docs/source/docs/scalarization/constant.rst @@ -0,0 +1,7 @@ +:hide-toc: + +Constant +======== + +.. autoclass:: torchjd.scalarization.Constant + :members: __call__ diff --git a/docs/source/docs/scalarization/index.rst b/docs/source/docs/scalarization/index.rst new file mode 100644 index 000000000..11381bfa4 --- /dev/null +++ b/docs/source/docs/scalarization/index.rst @@ -0,0 +1,21 @@ +scalarization +============= + +.. automodule:: torchjd.scalarization + :no-members: + +Abstract base class +------------------- + +.. autoclass:: torchjd.scalarization.Scalarizer + :members: __call__ + + +.. toctree:: + :hidden: + :maxdepth: 1 + + constant.rst + mean.rst + random.rst + sum.rst diff --git a/docs/source/docs/scalarization/mean.rst b/docs/source/docs/scalarization/mean.rst new file mode 100644 index 000000000..5a435b985 --- /dev/null +++ b/docs/source/docs/scalarization/mean.rst @@ -0,0 +1,7 @@ +:hide-toc: + +Mean +==== + +.. autoclass:: torchjd.scalarization.Mean + :members: __call__ diff --git a/docs/source/docs/scalarization/random.rst b/docs/source/docs/scalarization/random.rst new file mode 100644 index 000000000..0fffdc0e2 --- /dev/null +++ b/docs/source/docs/scalarization/random.rst @@ -0,0 +1,7 @@ +:hide-toc: + +Random +====== + +.. autoclass:: torchjd.scalarization.Random + :members: __call__ diff --git a/docs/source/docs/scalarization/sum.rst b/docs/source/docs/scalarization/sum.rst new file mode 100644 index 000000000..8f89702cd --- /dev/null +++ b/docs/source/docs/scalarization/sum.rst @@ -0,0 +1,7 @@ +:hide-toc: + +Sum +=== + +.. autoclass:: torchjd.scalarization.Sum + :members: __call__ diff --git a/docs/source/index.rst b/docs/source/index.rst index d8b14f830..20d0b6db8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -31,6 +31,10 @@ Jacobian descent is the aggregator, which maps the Jacobian to an optimization s :doc:`Aggregation `, we provide an overview of the various aggregators available in TorchJD, and their corresponding weightings. +For comparison against simple baselines, the :doc:`Scalarization ` +package provides scalarizers that combine a tensor of losses into a single scalar loss, allowing +standard gradient descent to be used. + A straightforward application of Jacobian descent is multi-task learning, in which the vector of per-task losses has to be minimized. To start using TorchJD for multi-task learning, follow our :doc:`MTL example `. @@ -70,4 +74,5 @@ TorchJD is open-source, under MIT License. The source code is available on docs/autogram/index.rst docs/autojac/index.rst docs/aggregation/index.rst + docs/scalarization/index.rst docs/linalg/index.rst diff --git a/src/torchjd/aggregation/_utils/str.py b/src/torchjd/_vector_str.py similarity index 53% rename from src/torchjd/aggregation/_utils/str.py rename to src/torchjd/_vector_str.py index 82a045406..2f821ebbc 100644 --- a/src/torchjd/aggregation/_utils/str.py +++ b/src/torchjd/_vector_str.py @@ -9,3 +9,11 @@ def vector_to_str(vector: Tensor) -> str: weights_str = ", ".join([f"{value:.2f}".rstrip("0") for value in vector]) return weights_str + + +def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str: + """Returns a suffix string containing the representation of the optional preference vector.""" + + if pref_vector is None: + return "" + return f"([{vector_to_str(pref_vector)}])" diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 961102631..da994d853 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -7,11 +7,12 @@ import torch from torch import Tensor +from torchjd._vector_str import pref_vector_to_str_suffix from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting -from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting +from ._utils.pref_vector import pref_vector_to_weighting from ._weighting_bases import _GramianWeighting SUPPORTED_SCALE_MODE: TypeAlias = Literal["min", "median", "rmse"] diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 7ca654b73..8485a83d4 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -5,12 +5,13 @@ import torch from torch import Tensor +from torchjd._vector_str import pref_vector_to_str_suffix from torchjd.linalg import Matrix from ._aggregator_bases import Aggregator from ._mixins import _NonDifferentiable from ._sum import SumWeighting -from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting +from ._utils.pref_vector import pref_vector_to_weighting # Non-differentiable: the pseudoinverse and the normalization are not differentiable in this context. diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 54f973a22..c849155e9 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,7 +1,8 @@ from torch import Tensor +from torchjd._vector_str import vector_to_str + from ._aggregator_bases import WeightedAggregator -from ._utils.str import vector_to_str from ._weighting_bases import _MatrixWeighting diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 15b6aa873..84206b476 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -1,12 +1,13 @@ from torch import Tensor from torchjd._linalg import DualConeProjector, projector_or_default +from torchjd._vector_str import pref_vector_to_str_suffix from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._mixins import _NonDifferentiable -from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting +from ._utils.pref_vector import pref_vector_to_weighting from ._weighting_bases import _GramianWeighting diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index e85b15f45..133bb0145 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -2,12 +2,13 @@ from torch import Tensor from torchjd._linalg import DualConeProjector, projector_or_default +from torchjd._vector_str import pref_vector_to_str_suffix from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting from ._mixins import _NonDifferentiable -from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting +from ._utils.pref_vector import pref_vector_to_weighting from ._weighting_bases import _GramianWeighting diff --git a/src/torchjd/aggregation/_utils/pref_vector.py b/src/torchjd/aggregation/_utils/pref_vector.py index be87c3530..254559f40 100644 --- a/src/torchjd/aggregation/_utils/pref_vector.py +++ b/src/torchjd/aggregation/_utils/pref_vector.py @@ -4,8 +4,6 @@ from torchjd.aggregation._weighting_bases import Weighting from torchjd.linalg import Matrix -from .str import vector_to_str - def pref_vector_to_weighting( pref_vector: Tensor | None, @@ -24,11 +22,3 @@ def pref_vector_to_weighting( f"{pref_vector.ndim}`.", ) return ConstantWeighting(pref_vector) - - -def pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str: - """Returns a suffix string containing the representation of the optional preference vector.""" - - if pref_vector is None: - return "" - return f"([{vector_to_str(pref_vector)}])" diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py new file mode 100644 index 000000000..bf82aa115 --- /dev/null +++ b/src/torchjd/scalarization/__init__.py @@ -0,0 +1,34 @@ +""" +A :class:`~torchjd.scalarization.Scalarizer` reduces a tensor of values of any shape into a single +scalar value. This is the simple baseline +against which :class:`Aggregators ` are compared: instead of +combining the per-loss gradients via the Jacobian or its Gramian, a +:class:`~torchjd.scalarization.Scalarizer` combines the losses directly, and a standard call to +:meth:`~torch.Tensor.backward` produces the gradient. + +The following example shows how to use :class:`~torchjd.scalarization.Mean` to combine a vector of +losses into a single scalar loss. + +>>> from torch import tensor +>>> from torchjd.scalarization import Mean +>>> +>>> scalarizer = Mean() +>>> losses = tensor([1.0, 2.0, 3.0]) +>>> loss = scalarizer(losses) +>>> loss +tensor(2.) +""" + +from ._constant import Constant +from ._mean import Mean +from ._random import Random +from ._scalarizer_base import Scalarizer +from ._sum import Sum + +__all__ = [ + "Constant", + "Mean", + "Random", + "Scalarizer", + "Sum", +] diff --git a/src/torchjd/scalarization/_constant.py b/src/torchjd/scalarization/_constant.py new file mode 100644 index 000000000..5223f289d --- /dev/null +++ b/src/torchjd/scalarization/_constant.py @@ -0,0 +1,35 @@ +from torch import Tensor + +from torchjd._vector_str import pref_vector_to_str_suffix + +from ._scalarizer_base import Scalarizer + + +class Constant(Scalarizer): + """ + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values with + constant, pre-determined weights. + + :param weights: The weights to apply to the values. Must have the same shape as the values + passed at call time. + """ + + def __init__(self, weights: Tensor) -> None: + super().__init__() + self.weights = weights + + def forward(self, values: Tensor, /) -> Tensor: + if values.shape != self.weights.shape: + raise ValueError( + f"Parameter `values` should have shape {tuple(self.weights.shape)} (matching the " + f"shape of the weights). Found `values.shape = {tuple(values.shape)}`.", + ) + return (self.weights * values).sum() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(weights={repr(self.weights)})" + + def __str__(self) -> str: + if self.weights.ndim == 1: + return f"{self.__class__.__name__}{pref_vector_to_str_suffix(self.weights)}" + return f"{self.__class__.__name__}(weights of shape {tuple(self.weights.shape)})" diff --git a/src/torchjd/scalarization/_mean.py b/src/torchjd/scalarization/_mean.py new file mode 100644 index 000000000..a03065bb3 --- /dev/null +++ b/src/torchjd/scalarization/_mean.py @@ -0,0 +1,12 @@ +from torch import Tensor + +from ._scalarizer_base import Scalarizer + + +class Mean(Scalarizer): + """ + :class:`~torchjd.scalarization.Scalarizer` that returns the mean of the input tensor of values. + """ + + def forward(self, values: Tensor, /) -> Tensor: + return values.mean() diff --git a/src/torchjd/scalarization/_random.py b/src/torchjd/scalarization/_random.py new file mode 100644 index 000000000..1b454b727 --- /dev/null +++ b/src/torchjd/scalarization/_random.py @@ -0,0 +1,19 @@ +import torch +from torch import Tensor +from torch.nn.functional import softmax + +from ._scalarizer_base import Scalarizer + + +class Random(Scalarizer): + """ + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values with + positive random weights summing to 1, as defined in Algorithm 2 of `Reasonable Effectiveness of + Random Weighting: A Litmus Test for Multi-Task Learning + `_. + """ + + def forward(self, values: Tensor, /) -> Tensor: + flat = torch.randn(values.numel(), device=values.device, dtype=values.dtype) + weights = softmax(flat, dim=-1).reshape(values.shape) + return (weights * values).sum() diff --git a/src/torchjd/scalarization/_scalarizer_base.py b/src/torchjd/scalarization/_scalarizer_base.py new file mode 100644 index 000000000..e83a25d69 --- /dev/null +++ b/src/torchjd/scalarization/_scalarizer_base.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod + +from torch import Tensor, nn + + +class Scalarizer(nn.Module, ABC): + """ + Abstract base class for all scalarizers. Reduces a tensor of values of any shape into a single + scalar value. + """ + + def __init__(self) -> None: + super().__init__() + + @abstractmethod + def forward(self, values: Tensor, /) -> Tensor: + """Computes the scalarization from input tensor.""" + + def __call__(self, values: Tensor, /) -> Tensor: + """ + Computes the scalar value from the input tensor of values and applies all registered hooks. + + :param values: The tensor of values to scalarize. May be of any shape. + """ + return super().__call__(values) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + def __str__(self) -> str: + return f"{self.__class__.__name__}" diff --git a/src/torchjd/scalarization/_sum.py b/src/torchjd/scalarization/_sum.py new file mode 100644 index 000000000..34f9c924d --- /dev/null +++ b/src/torchjd/scalarization/_sum.py @@ -0,0 +1,12 @@ +from torch import Tensor + +from ._scalarizer_base import Scalarizer + + +class Sum(Scalarizer): + """ + :class:`~torchjd.scalarization.Scalarizer` that returns the sum of the input tensor of values. + """ + + def forward(self, values: Tensor, /) -> Tensor: + return values.sum() diff --git a/tests/unit/scalarization/__init__.py b/tests/unit/scalarization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/scalarization/_asserts.py b/tests/unit/scalarization/_asserts.py new file mode 100644 index 000000000..8b6d5fc7d --- /dev/null +++ b/tests/unit/scalarization/_asserts.py @@ -0,0 +1,27 @@ +import torch +from torch import Tensor +from utils.tensors import randperm_ + +from torchjd.scalarization import Scalarizer + + +def assert_returns_scalar(scalarizer: Scalarizer, losses: Tensor) -> None: + out = scalarizer(losses) + assert out.dim() == 0 + assert out.isfinite() + + +def assert_grad_flow(scalarizer: Scalarizer, losses: Tensor) -> None: + leaf = losses.detach().requires_grad_() + out = scalarizer(leaf) + out.backward() + assert leaf.grad is not None + assert leaf.grad.isfinite().all() + + +def assert_permutation_invariant(scalarizer: Scalarizer, losses: Tensor) -> None: + out = scalarizer(losses) + flat = losses.flatten() + permuted = flat[randperm_(flat.numel())].reshape(losses.shape) + out_permuted = scalarizer(permuted) + torch.testing.assert_close(out, out_permuted) diff --git a/tests/unit/scalarization/_inputs.py b/tests/unit/scalarization/_inputs.py new file mode 100644 index 000000000..5933cfba9 --- /dev/null +++ b/tests/unit/scalarization/_inputs.py @@ -0,0 +1,5 @@ +from torch import Tensor +from utils.tensors import randn_ + +shapes: list[list[int]] = [[], [5], [3, 4], [2, 3, 4]] +all_inputs: list[Tensor] = [randn_(shape) for shape in shapes] diff --git a/tests/unit/scalarization/test_constant.py b/tests/unit/scalarization/test_constant.py new file mode 100644 index 000000000..126b91c01 --- /dev/null +++ b/tests/unit/scalarization/test_constant.py @@ -0,0 +1,63 @@ +from contextlib import nullcontext as does_not_raise + +import torch +from pytest import mark, raises +from torch import Tensor +from utils.contexts import ExceptionContext +from utils.tensors import ones_, tensor_ + +from torchjd.scalarization import Constant + +from ._asserts import assert_grad_flow, assert_returns_scalar +from ._inputs import all_inputs + + +def test_value() -> None: + losses = tensor_([1.0, 2.0, 3.0, 4.0]) + weights = tensor_([0.1, 0.2, 0.3, 0.4]) + torch.testing.assert_close(Constant(weights)(losses), tensor_(3.0)) + + +@mark.parametrize("losses", all_inputs) +def test_expected_structure(losses: Tensor) -> None: + weights = ones_(losses.shape) + assert_returns_scalar(Constant(weights), losses) + + +@mark.parametrize("losses", all_inputs) +def test_grad_flow(losses: Tensor) -> None: + weights = ones_(losses.shape) + assert_grad_flow(Constant(weights), losses) + + +@mark.parametrize( + ["weights_shape", "losses_shape", "expectation"], + [ + ((5,), (5,), does_not_raise()), + ((3, 4), (3, 4), does_not_raise()), + ((), (), does_not_raise()), + ((5,), (4,), raises(ValueError)), + ((5,), (5, 1), raises(ValueError)), + ((3, 4), (4, 3), raises(ValueError)), + ], +) +def test_shape_check( + weights_shape: tuple[int, ...], + losses_shape: tuple[int, ...], + expectation: ExceptionContext, +) -> None: + weights = ones_(weights_shape) + losses = ones_(losses_shape) + with expectation: + _ = Constant(weights)(losses) + + +def test_representations() -> None: + s = Constant(weights=torch.tensor([1.0, 2.0], device="cpu")) + assert repr(s) == "Constant(weights=tensor([1., 2.]))" + assert str(s) == "Constant([1., 2.])" + + +def test_str_with_non_vector_weights() -> None: + assert str(Constant(weights=ones_((3, 4)))) == "Constant(weights of shape (3, 4))" + assert str(Constant(weights=ones_(()))) == "Constant(weights of shape ())" diff --git a/tests/unit/scalarization/test_mean.py b/tests/unit/scalarization/test_mean.py new file mode 100644 index 000000000..566546755 --- /dev/null +++ b/tests/unit/scalarization/test_mean.py @@ -0,0 +1,39 @@ +import torch +from pytest import mark +from torch import Tensor +from utils.tensors import tensor_ + +from torchjd.scalarization import Mean + +from ._asserts import ( + assert_grad_flow, + assert_permutation_invariant, + assert_returns_scalar, +) +from ._inputs import all_inputs + + +def test_value() -> None: + losses = tensor_([1.0, 2.0, 3.0]) + torch.testing.assert_close(Mean()(losses), tensor_(2.0)) + + +@mark.parametrize("losses", all_inputs) +def test_expected_structure(losses: Tensor) -> None: + assert_returns_scalar(Mean(), losses) + + +@mark.parametrize("losses", all_inputs) +def test_grad_flow(losses: Tensor) -> None: + assert_grad_flow(Mean(), losses) + + +@mark.parametrize("losses", all_inputs) +def test_permutation_invariant(losses: Tensor) -> None: + assert_permutation_invariant(Mean(), losses) + + +def test_representations() -> None: + s = Mean() + assert repr(s) == "Mean()" + assert str(s) == "Mean" diff --git a/tests/unit/scalarization/test_random.py b/tests/unit/scalarization/test_random.py new file mode 100644 index 000000000..e11c71af9 --- /dev/null +++ b/tests/unit/scalarization/test_random.py @@ -0,0 +1,42 @@ +import torch +from pytest import mark +from torch import Tensor +from utils.contexts import fork_rng +from utils.tensors import ones_, tensor_ + +from torchjd.scalarization import Random + +from ._asserts import assert_grad_flow, assert_returns_scalar +from ._inputs import all_inputs + + +@mark.parametrize("losses", all_inputs) +def test_expected_structure(losses: Tensor) -> None: + assert_returns_scalar(Random(), losses) + + +@mark.parametrize("losses", all_inputs) +def test_grad_flow(losses: Tensor) -> None: + assert_grad_flow(Random(), losses) + + +def test_deterministic_under_seed() -> None: + losses = tensor_([1.0, 2.0, 3.0, 4.0]) + scalarizer = Random() + with fork_rng(seed=0): + a = scalarizer(losses) + with fork_rng(seed=0): + b = scalarizer(losses) + torch.testing.assert_close(a, b) + + +def test_weights_sum_to_one() -> None: + # If all losses equal 1, then sum(weights * losses) == 1 when weights sum to 1. + losses = ones_((5,)) + torch.testing.assert_close(Random()(losses), tensor_(1.0)) + + +def test_representations() -> None: + s = Random() + assert repr(s) == "Random()" + assert str(s) == "Random" diff --git a/tests/unit/scalarization/test_sum.py b/tests/unit/scalarization/test_sum.py new file mode 100644 index 000000000..9973197d8 --- /dev/null +++ b/tests/unit/scalarization/test_sum.py @@ -0,0 +1,39 @@ +import torch +from pytest import mark +from torch import Tensor +from utils.tensors import tensor_ + +from torchjd.scalarization import Sum + +from ._asserts import ( + assert_grad_flow, + assert_permutation_invariant, + assert_returns_scalar, +) +from ._inputs import all_inputs + + +def test_value() -> None: + losses = tensor_([1.0, 2.0, 3.0]) + torch.testing.assert_close(Sum()(losses), tensor_(6.0)) + + +@mark.parametrize("losses", all_inputs) +def test_expected_structure(losses: Tensor) -> None: + assert_returns_scalar(Sum(), losses) + + +@mark.parametrize("losses", all_inputs) +def test_grad_flow(losses: Tensor) -> None: + assert_grad_flow(Sum(), losses) + + +@mark.parametrize("losses", all_inputs) +def test_permutation_invariant(losses: Tensor) -> None: + assert_permutation_invariant(Sum(), losses) + + +def test_representations() -> None: + s = Sum() + assert repr(s) == "Sum()" + assert str(s) == "Sum"