-
Notifications
You must be signed in to change notification settings - Fork 16
feat(aggregation): Add CRMOGMWeighting #669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0aa1c8b
53f3eb3
23c0f62
daf59f9
f3c21a0
65df561
e846a9c
e16cf48
1b74974
1cb9953
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| :hide-toc: | ||
|
|
||
| CR-MOGM | ||
| ======= | ||
|
|
||
| .. autoclass:: torchjd.aggregation.CRMOGMWeighting | ||
| :members: __call__, reset | ||
|
|
||
| .. note:: | ||
| The usage example in the docstring above imports | ||
| ``WeightedAggregator`` / ``GramianWeightedAggregator`` from | ||
| ``torchjd.aggregation._aggregator_bases``, which is a private module. These two | ||
| aggregator base classes are not currently part of the public ``torchjd.aggregation`` | ||
| namespace, so this private-module import is the only path that works today. Promoting | ||
| them to the public namespace is a separate decision left to the maintainers. | ||
|
Comment on lines
+9
to
+15
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I forgot about that, but it's a major issue: those classes should become public if we want users to be able to comfortably use CRMOGMWeighting. I'll make these classes public in another PR, together with all the required changes. It's like a prerequisite to your PR IMO. Thanks for being specific about this!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be fixed when we merge #670
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. We can now remove this note and update every import of those classes to the new version now that they are public.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like there is still a problem with this, because now WeightedAggregator takes a MatrixWeighting (and not a Weighting[Matrix]), and CRMOGMWeighting[Matrix] is not a subtype of Weighting[Matrix]. See the type checking action that fails. Idk how to fix this... I'll think about that tomorrow probably.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think what we can do is revert the init annotations back to the broader types in _aggregator_bases.py. def init(self, weighting: Weighting[Matrix]) -> None: Do the aggregators really need to require these specific subtypes that are there currently: class WeightedAggregator(Aggregator): |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TypeVar | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
| from torchjd.aggregation._mixins import Stateful | ||
|
|
||
| from ._weighting_bases import Weighting | ||
|
|
||
| _T = TypeVar("_T", contravariant=True, bound=Tensor) | ||
|
|
||
|
|
||
| class CRMOGMWeighting(Weighting[_T], Stateful): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the name could be improved to reflect that this is purely a weighting wrapper, maybe
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's a bit too long, so I'd rather not specify that in general. Same applies to normalizers which, if implemented as wrappers, would always have Wrapper in their name. Same for mixins, we don't use Mixin in their name. |
||
| r""" | ||
| :class:`~torchjd.aggregation._mixins.Stateful` | ||
| :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another | ||
| :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it | ||
| produces with an exponential moving average (EMA) across calls. This is the weight-smoothing | ||
| modifier from `On the Convergence of Stochastic Multi-Objective Gradient Manipulation and | ||
| Beyond <https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf>`_ | ||
| (NeurIPS 2022). | ||
|
|
||
| Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step | ||
| :math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are: | ||
|
|
||
| .. math:: | ||
|
|
||
| \lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k | ||
|
|
||
| with :math:`\lambda_0 = \begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^\top | ||
| \in \mathbb{R}^m`. The state :math:`\lambda_{k-1}` is initialised lazily on the first | ||
| forward call once :math:`m` is known and is reset automatically when ``m`` changes. | ||
|
|
||
| Because ``CRMOGMWeighting`` is generic in the input type ``_T``, it can wrap either a | ||
| ``MatrixWeighting`` or a ``GramianWeighting``. Creating a corresponding :class:`~torchjd.aggregation.Aggregator` can be done by composing it with the appropriate | ||
| aggregator base: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| from torchjd.aggregation import MeanWeighting, UPGradWeighting | ||
| from torchjd.aggregation._aggregator_bases import ( | ||
| GramianWeightedAggregator, WeightedAggregator, | ||
| ) | ||
| from torchjd.aggregation._cr_mogm import CRMOGMWeighting | ||
|
|
||
| matrix_aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) | ||
| gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) | ||
|
|
||
| This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` | ||
| when restarting the smoothing from uniform weights. Note that calling :meth:`reset` will also | ||
| reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`. | ||
|
|
||
| :param weighting: The wrapped weighting whose output is smoothed. | ||
| :param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing | ||
| (``CRMOGMWeighting`` returns ``weighting``'s output verbatim) and ``alpha=1`` freezes | ||
| the weights at their initial uniform value. The default of ``0.9`` follows the usual | ||
| EMA convention (analogous to Adam's :math:`\beta_1`). | ||
|
|
||
| .. note:: | ||
| ``alpha`` is a fixed ``float`` for simplicity. Corollary 1 of the paper recommends a | ||
| schedule where :math:`\alpha_k` starts near 0 and increases toward 1 as the learning | ||
| rate decays. Update ``alpha`` between forward calls via the public attribute on the | ||
| wrapping aggregator: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| # With WeightedAggregator | ||
| aggregator.weighting.alpha = 1 - current_lr / initial_lr | ||
|
|
||
| # With GramianWeightedAggregator | ||
| aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr | ||
| """ | ||
|
|
||
| def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the initialization of lambda is debattable. For now, we have 1/m all the time. Maybe sometimes a user wants to provide their own starting weights (btw they don't need to be in the simplex, even though it's stated like that in the paper I think it's a mistake). So we could have a initial_weights parameter, of type Tensor | None, so that the user can provide their weights or we use 1/m if they don't. The alternative would be to have still type Tensor | None, but if the user gives None, we use lambda_0 = lambda_1_hat. This means that the first weights output by the CRMOGMWeighting will be lambda_1 = lambda_1_hat * alpha + (1 - alpha) * lambda_1_hat = lambda_1_hat. I don't know which option we should go for. @PierreQuinton maybe need your insight on this.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The weights definitely can leave the simplex (maybe not for MGDA, but for most other weightings it is not the case) So I would not limit ourselves to the simplex. I don't know about the second question, but I would use the default value they use int he paper, which seems to be what is currently the implementation. |
||
| super().__init__() | ||
| self.weighting = weighting | ||
| self.alpha = alpha | ||
| self._lambda: Tensor | None = None | ||
|
|
||
| @property | ||
| def alpha(self) -> float: | ||
| return self._alpha | ||
|
|
||
| @alpha.setter | ||
| def alpha(self, value: float) -> None: | ||
| if not (0.0 <= value <= 1.0): | ||
| raise ValueError(f"Attribute `alpha` must be in [0, 1]. Found alpha={value!r}.") | ||
| self._alpha = value | ||
|
|
||
| def reset(self) -> None: | ||
| """Clears the EMA state so the next forward starts from uniform weights.""" | ||
|
|
||
| if isinstance(self.weighting, Stateful): | ||
| self.weighting.reset() | ||
| self._lambda = None | ||
|
|
||
| def forward(self, stat: _T, /) -> Tensor: | ||
| lambda_hat = self.weighting(stat) | ||
|
|
||
| lambda_prev = self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device) | ||
|
|
||
| lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat | ||
|
|
||
| self._lambda = lambda_k.detach() | ||
| return lambda_k | ||
|
|
||
| def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> Tensor: | ||
| if self._lambda is None: | ||
| self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) | ||
| elif self._lambda.shape[0] != m: | ||
| raise ValueError( | ||
| f"The number of objectives changed from {self._lambda.shape[0]} to {m}. Call " | ||
| f"`reset()` before changing the number of objectives." | ||
| ) | ||
| return self._lambda | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| from pytest import mark, raises | ||
| from torch import Tensor | ||
| from torch.testing import assert_close | ||
| from utils.tensors import randn_, tensor_ | ||
|
|
||
| from torchjd.aggregation import GradVacWeighting, MeanWeighting, UPGradWeighting | ||
| from torchjd.aggregation._aggregator_bases import ( | ||
| GramianWeightedAggregator, | ||
| WeightedAggregator, | ||
| ) | ||
| from torchjd.aggregation._cr_mogm import CRMOGMWeighting | ||
|
|
||
| from ._asserts import assert_expected_structure | ||
| from ._inputs import scaled_matrices, typical_matrices | ||
|
|
||
| # UPGradWeighting uses a QP solver that can fail on the extreme scales (0.0, 1e15) found in | ||
| # scaled_matrices, so the gramian-path structural test only uses typical_matrices. | ||
| matrix_pairs = [ | ||
| (WeightedAggregator(CRMOGMWeighting(MeanWeighting())), m) | ||
| for m in typical_matrices + scaled_matrices | ||
| ] | ||
| gramian_pairs = [ | ||
| (GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())), m) for m in typical_matrices | ||
| ] | ||
|
|
||
|
|
||
| @mark.parametrize(["aggregator", "matrix"], matrix_pairs) | ||
| def test_expected_structure_matrix_weighting( | ||
| aggregator: WeightedAggregator, matrix: Tensor | ||
| ) -> None: | ||
| assert_expected_structure(aggregator, matrix) | ||
|
|
||
|
|
||
| @mark.parametrize(["aggregator", "matrix"], gramian_pairs) | ||
| def test_expected_structure_gramian_weighting( | ||
| aggregator: GramianWeightedAggregator, matrix: Tensor | ||
| ) -> None: | ||
| assert_expected_structure(aggregator, matrix) | ||
|
|
||
|
|
||
| def test_reset_restores_first_step_behavior() -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need a test of CRMOGMWeighting wrapping GradVacWeighting (a stateful weighting), to verify that the reset method correctly calls the underlying weighting's reset method. |
||
| """ | ||
| Use ``UPGradWeighting`` so the weights actually depend on the input — with | ||
| ``MeanWeighting`` the EMA would be a fixed point at the uniform weights and the test would | ||
| be trivial. | ||
| """ | ||
|
|
||
| J = randn_((3, 8)) | ||
| G = J @ J.T | ||
| W = CRMOGMWeighting(UPGradWeighting(), alpha=0.5) | ||
| first = W(G) | ||
| W(G) | ||
| W.reset() | ||
| assert_close(first, W(G)) | ||
|
|
||
|
|
||
| def test_reset_propagates_to_stateful_weighting() -> None: | ||
| """ | ||
| Verify that ``reset()`` calls the wrapped weighting's ``reset()`` when it is | ||
| :class:`~torchjd.aggregation.Stateful`. Checks that ``GradVacWeighting``'s internal | ||
| state is cleared after ``reset()``. | ||
| """ | ||
|
|
||
| inner = GradVacWeighting() | ||
| W = CRMOGMWeighting(inner, alpha=0.5) | ||
| J = randn_((3, 8)) | ||
| W(J @ J.T) | ||
| assert inner._phi_t is not None | ||
| W.reset() | ||
| assert inner._phi_t is None | ||
|
|
||
|
|
||
| def test_changing_m_raises() -> None: | ||
| """Verify that changing the number of objectives after the first call raises a ValueError.""" | ||
|
|
||
| W = CRMOGMWeighting(MeanWeighting()) | ||
| W(randn_((3, 8)) @ randn_((3, 8)).T) | ||
| with raises(ValueError, match="number of objectives"): | ||
| W(randn_((2, 8)) @ randn_((2, 8)).T) | ||
|
|
||
|
|
||
| def test_alpha_setter_accepts_valid() -> None: | ||
| W = CRMOGMWeighting(MeanWeighting()) | ||
| W.alpha = 0.0 | ||
| assert W.alpha == 0.0 | ||
| W.alpha = 0.5 | ||
| assert W.alpha == 0.5 | ||
| W.alpha = 1.0 | ||
| assert W.alpha == 1.0 | ||
|
|
||
|
|
||
| def test_alpha_setter_rejects_out_of_range() -> None: | ||
| W = CRMOGMWeighting(MeanWeighting()) | ||
| with raises(ValueError, match="alpha"): | ||
| W.alpha = -0.1 | ||
| with raises(ValueError, match="alpha"): | ||
| W.alpha = 1.1 | ||
|
|
||
|
|
||
| def test_alpha_zero_reduces_to_bare_weighting() -> None: | ||
| """ | ||
| With ``alpha=0`` the previous state is always multiplied by zero, so the smoothed weights | ||
| equal the bare weighting's output on every call — not just the first. | ||
| """ | ||
|
|
||
| J = randn_((3, 8)) | ||
| G = J @ J.T | ||
| bare = UPGradWeighting() | ||
| smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=0.0) | ||
|
|
||
| expected = bare(G) | ||
| assert_close(smoothed(G), expected) | ||
| assert_close(smoothed(G), expected) | ||
|
|
||
|
|
||
| def test_alpha_one_freezes_weights() -> None: | ||
| """ | ||
| With ``alpha=1`` the fresh weights are multiplied by zero, so the smoothed weights stay at | ||
| their initial uniform value forever. Note: the equality with uniform weights is a | ||
| consequence of the uniform initialisation, not a general property of CR-MOGM. | ||
| """ | ||
|
|
||
| J = randn_((3, 8)) | ||
| m = J.shape[0] | ||
| W = CRMOGMWeighting(UPGradWeighting(), alpha=1.0) | ||
| uniform = tensor_([1.0 / m] * m) | ||
|
|
||
| assert_close(W(J @ J.T), uniform) | ||
| assert_close(W(J @ J.T), uniform) | ||
|
|
||
|
|
||
| def test_ema_is_applied() -> None: | ||
| """Run two steps with ``alpha=0.9`` and check the EMA recurrence by hand.""" | ||
|
|
||
| alpha = 0.9 | ||
| J1 = randn_((3, 8)) | ||
| J2 = randn_((3, 8)) | ||
| G1 = J1 @ J1.T | ||
| G2 = J2 @ J2.T | ||
| m = J1.shape[0] | ||
|
|
||
| bare = UPGradWeighting() | ||
| smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=alpha) | ||
|
|
||
| lambda_hat_1 = bare(G1) | ||
| lambda_hat_2 = bare(G2) | ||
| uniform = tensor_([1.0 / m] * m) | ||
|
|
||
| expected_1 = alpha * uniform + (1.0 - alpha) * lambda_hat_1 | ||
| expected_2 = alpha * expected_1 + (1.0 - alpha) * lambda_hat_2 | ||
|
|
||
| assert_close(smoothed(G1), expected_1) | ||
| assert_close(smoothed(G2), expected_2) | ||
|
|
||
|
|
||
| def test_zero_columns() -> None: | ||
| """ | ||
| A ``(2, 0)`` matrix has no columns to combine, so the aggregation must be empty. Zero-row | ||
| inputs are intentionally not tested: ``MeanWeighting`` does ``1/m`` in Python and would | ||
| raise ``ZeroDivisionError`` at ``m=0``, which is the wrapped weighting's responsibility. | ||
| """ | ||
|
|
||
| aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) | ||
| out = aggregator(tensor_([]).reshape(2, 0)) | ||
| assert out.shape == (0,) | ||
Uh oh!
There was an error while loading. Please reload this page.