Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `CRMOGMWeighting` from [On the Convergence of Stochastic Multi-Objective Gradient
Manipulation and Beyond](https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf)
(NeurIPS 2022). It wraps an existing `Weighting` and stabilises its weights with an exponential
moving average across calls.
- Made `WeightedAggregator`, `GramianWeightedAggregator`, `MatrixWeighting`, and `GramianWeighting`
public. These abstract base classes are now importable from `torchjd.aggregation` and documented.
They can be extended to easily implement custom `Weighting`s and `Aggregator`s.
Expand All @@ -20,8 +24,7 @@ changelog does not include internal changes that do not affect the user.
`CAGrad` and `CAGradWeighting`; `pref_vector` in `ConFIG`; `leak` in `GradDrop`, `n_byzantine` and
`n_selected` in `Krum` and `KrumWeighting`; `epsilon` and `max_iters` in `MGDA` and
`MGDAWeighting`; `n_tasks`, `max_norm`, `update_weights_every` and `optim_niter` in `NashMTL`;
`trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor
checks. Note that setters for `GradVac` and `GradVacWeighting` already existed.
`trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor checks. Note that setters for `GradVac` and `GradVacWeighting` already existed.
Comment thread
PierreQuinton marked this conversation as resolved.

## [0.10.0] - 2026-04-16

Expand Down
15 changes: 15 additions & 0 deletions docs/source/docs/aggregation/cr_mogm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
:hide-toc:

CR-MOGM
=======

.. autoclass:: torchjd.aggregation.CRMOGMWeighting
:members: __call__, reset

.. note::
The usage example in the docstring above imports
``WeightedAggregator`` / ``GramianWeightedAggregator`` from
``torchjd.aggregation._aggregator_bases``, which is a private module. These two
aggregator base classes are not currently part of the public ``torchjd.aggregation``
namespace, so this private-module import is the only path that works today. Promoting
them to the public namespace is a separate decision left to the maintainers.
Comment on lines +9 to +15
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot about that, but it's a major issue: those classes should become public if we want users to be able to comfortably use CRMOGMWeighting.

I'll make these classes public in another PR, together with all the required changes. It's like a prerequisite to your PR IMO.

Thanks for being specific about this!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed when we merge #670

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. We can now remove this note and update every import of those classes to the new version now that they are public.

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like there is still a problem with this, because now WeightedAggregator takes a MatrixWeighting (and not a Weighting[Matrix]), and CRMOGMWeighting[Matrix] is not a subtype of Weighting[Matrix]. See the type checking action that fails. Idk how to fix this...

I'll think about that tomorrow probably.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what we can do is revert the init annotations back to the broader types in _aggregator_bases.py.

def init(self, weighting: Weighting[Matrix]) -> None:
def init(self, gramian_weighting: Weighting[PSDMatrix]) -> None:

Do the aggregators really need to require these specific subtypes that are there currently:

class WeightedAggregator(Aggregator):
def init(self, weighting: MatrixWeighting) -> None:
class GramianWeightedAggregator(WeightedAggregator):
def init(self, gramian_weighting: GramianWeighting) -> None:

1 change: 1 addition & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Abstract base classes
cagrad.rst
config.rst
constant.rst
cr_mogm.rst
dualproj.rst
flattening.rst
graddrop.rst
Expand Down
2 changes: 2 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from ._aligned_mtl import AlignedMTL, AlignedMTLWeighting
from ._config import ConFIG
from ._constant import Constant, ConstantWeighting
from ._cr_mogm import CRMOGMWeighting
from ._dualproj import DualProj, DualProjWeighting
from ._flattening import Flattening
from ._graddrop import GradDrop
Expand All @@ -89,6 +90,7 @@
"ConFIG",
"Constant",
"ConstantWeighting",
"CRMOGMWeighting",
"DualProj",
"DualProjWeighting",
"Flattening",
Expand Down
117 changes: 117 additions & 0 deletions src/torchjd/aggregation/_cr_mogm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

from typing import TypeVar

import torch
from torch import Tensor

from torchjd.aggregation._mixins import Stateful

from ._weighting_bases import Weighting

_T = TypeVar("_T", contravariant=True, bound=Tensor)


class CRMOGMWeighting(Weighting[_T], Stateful):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name could be improved to reflect that this is purely a weighting wrapper, maybe CRMOGMWeightingWrapper?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bit too long, so I'd rather not specify that in general. Same applies to normalizers which, if implemented as wrappers, would always have Wrapper in their name.

Same for mixins, we don't use Mixin in their name.

r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another
:class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it
produces with an exponential moving average (EMA) across calls. This is the weight-smoothing
modifier from `On the Convergence of Stochastic Multi-Objective Gradient Manipulation and
Beyond <https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf>`_
(NeurIPS 2022).

Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step
:math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are:

.. math::

\lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k

with :math:`\lambda_0 = \begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^\top
\in \mathbb{R}^m`. The state :math:`\lambda_{k-1}` is initialised lazily on the first
forward call once :math:`m` is known and is reset automatically when ``m`` changes.

Because ``CRMOGMWeighting`` is generic in the input type ``_T``, it can wrap either a
``MatrixWeighting`` or a ``GramianWeighting``. Creating a corresponding :class:`~torchjd.aggregation.Aggregator` can be done by composing it with the appropriate
aggregator base:

.. code-block:: python

from torchjd.aggregation import MeanWeighting, UPGradWeighting
from torchjd.aggregation._aggregator_bases import (
GramianWeightedAggregator, WeightedAggregator,
)
from torchjd.aggregation._cr_mogm import CRMOGMWeighting

matrix_aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting()))
gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting()))

This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset`
when restarting the smoothing from uniform weights. Note that calling :meth:`reset` will also
reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`.

:param weighting: The wrapped weighting whose output is smoothed.
:param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing
(``CRMOGMWeighting`` returns ``weighting``'s output verbatim) and ``alpha=1`` freezes
the weights at their initial uniform value. The default of ``0.9`` follows the usual
EMA convention (analogous to Adam's :math:`\beta_1`).

.. note::
``alpha`` is a fixed ``float`` for simplicity. Corollary 1 of the paper recommends a
schedule where :math:`\alpha_k` starts near 0 and increases toward 1 as the learning
rate decays. Update ``alpha`` between forward calls via the public attribute on the
wrapping aggregator:

.. code-block:: python

# With WeightedAggregator
aggregator.weighting.alpha = 1 - current_lr / initial_lr

# With GramianWeightedAggregator
aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr
"""

def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the initialization of lambda is debattable.

For now, we have 1/m all the time.

Maybe sometimes a user wants to provide their own starting weights (btw they don't need to be in the simplex, even though it's stated like that in the paper I think it's a mistake).

So we could have a initial_weights parameter, of type Tensor | None, so that the user can provide their weights or we use 1/m if they don't.

The alternative would be to have still type Tensor | None, but if the user gives None, we use lambda_0 = lambda_1_hat.

This means that the first weights output by the CRMOGMWeighting will be lambda_1 = lambda_1_hat * alpha + (1 - alpha) * lambda_1_hat = lambda_1_hat.

I don't know which option we should go for. @PierreQuinton maybe need your insight on this.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weights definitely can leave the simplex (maybe not for MGDA, but for most other weightings it is not the case) So I would not limit ourselves to the simplex.

I don't know about the second question, but I would use the default value they use int he paper, which seems to be what is currently the implementation.

super().__init__()
self.weighting = weighting
self.alpha = alpha
self._lambda: Tensor | None = None

@property
def alpha(self) -> float:
return self._alpha

@alpha.setter
def alpha(self, value: float) -> None:
if not (0.0 <= value <= 1.0):
raise ValueError(f"Attribute `alpha` must be in [0, 1]. Found alpha={value!r}.")
self._alpha = value

def reset(self) -> None:
"""Clears the EMA state so the next forward starts from uniform weights."""

if isinstance(self.weighting, Stateful):
self.weighting.reset()
self._lambda = None

def forward(self, stat: _T, /) -> Tensor:
lambda_hat = self.weighting(stat)

lambda_prev = self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device)

lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat

self._lambda = lambda_k.detach()
return lambda_k

def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> Tensor:
if self._lambda is None:
self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device)
elif self._lambda.shape[0] != m:
raise ValueError(
f"The number of objectives changed from {self._lambda.shape[0]} to {m}. Call "
f"`reset()` before changing the number of objectives."
)
return self._lambda
165 changes: 165 additions & 0 deletions tests/unit/aggregation/test_cr_mogm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from pytest import mark, raises
from torch import Tensor
from torch.testing import assert_close
from utils.tensors import randn_, tensor_

from torchjd.aggregation import GradVacWeighting, MeanWeighting, UPGradWeighting
from torchjd.aggregation._aggregator_bases import (
GramianWeightedAggregator,
WeightedAggregator,
)
from torchjd.aggregation._cr_mogm import CRMOGMWeighting

from ._asserts import assert_expected_structure
from ._inputs import scaled_matrices, typical_matrices

# UPGradWeighting uses a QP solver that can fail on the extreme scales (0.0, 1e15) found in
# scaled_matrices, so the gramian-path structural test only uses typical_matrices.
matrix_pairs = [
(WeightedAggregator(CRMOGMWeighting(MeanWeighting())), m)
for m in typical_matrices + scaled_matrices
]
gramian_pairs = [
(GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())), m) for m in typical_matrices
]


@mark.parametrize(["aggregator", "matrix"], matrix_pairs)
def test_expected_structure_matrix_weighting(
aggregator: WeightedAggregator, matrix: Tensor
) -> None:
assert_expected_structure(aggregator, matrix)


@mark.parametrize(["aggregator", "matrix"], gramian_pairs)
def test_expected_structure_gramian_weighting(
aggregator: GramianWeightedAggregator, matrix: Tensor
) -> None:
assert_expected_structure(aggregator, matrix)


def test_reset_restores_first_step_behavior() -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a test of CRMOGMWeighting wrapping GradVacWeighting (a stateful weighting), to verify that the reset method correctly calls the underlying weighting's reset method.

"""
Use ``UPGradWeighting`` so the weights actually depend on the input — with
``MeanWeighting`` the EMA would be a fixed point at the uniform weights and the test would
be trivial.
"""

J = randn_((3, 8))
G = J @ J.T
W = CRMOGMWeighting(UPGradWeighting(), alpha=0.5)
first = W(G)
W(G)
W.reset()
assert_close(first, W(G))


def test_reset_propagates_to_stateful_weighting() -> None:
"""
Verify that ``reset()`` calls the wrapped weighting's ``reset()`` when it is
:class:`~torchjd.aggregation.Stateful`. Checks that ``GradVacWeighting``'s internal
state is cleared after ``reset()``.
"""

inner = GradVacWeighting()
W = CRMOGMWeighting(inner, alpha=0.5)
J = randn_((3, 8))
W(J @ J.T)
assert inner._phi_t is not None
W.reset()
assert inner._phi_t is None


def test_changing_m_raises() -> None:
"""Verify that changing the number of objectives after the first call raises a ValueError."""

W = CRMOGMWeighting(MeanWeighting())
W(randn_((3, 8)) @ randn_((3, 8)).T)
with raises(ValueError, match="number of objectives"):
W(randn_((2, 8)) @ randn_((2, 8)).T)


def test_alpha_setter_accepts_valid() -> None:
W = CRMOGMWeighting(MeanWeighting())
W.alpha = 0.0
assert W.alpha == 0.0
W.alpha = 0.5
assert W.alpha == 0.5
W.alpha = 1.0
assert W.alpha == 1.0


def test_alpha_setter_rejects_out_of_range() -> None:
W = CRMOGMWeighting(MeanWeighting())
with raises(ValueError, match="alpha"):
W.alpha = -0.1
with raises(ValueError, match="alpha"):
W.alpha = 1.1


def test_alpha_zero_reduces_to_bare_weighting() -> None:
"""
With ``alpha=0`` the previous state is always multiplied by zero, so the smoothed weights
equal the bare weighting's output on every call — not just the first.
"""

J = randn_((3, 8))
G = J @ J.T
bare = UPGradWeighting()
smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=0.0)

expected = bare(G)
assert_close(smoothed(G), expected)
assert_close(smoothed(G), expected)


def test_alpha_one_freezes_weights() -> None:
"""
With ``alpha=1`` the fresh weights are multiplied by zero, so the smoothed weights stay at
their initial uniform value forever. Note: the equality with uniform weights is a
consequence of the uniform initialisation, not a general property of CR-MOGM.
"""

J = randn_((3, 8))
m = J.shape[0]
W = CRMOGMWeighting(UPGradWeighting(), alpha=1.0)
uniform = tensor_([1.0 / m] * m)

assert_close(W(J @ J.T), uniform)
assert_close(W(J @ J.T), uniform)


def test_ema_is_applied() -> None:
"""Run two steps with ``alpha=0.9`` and check the EMA recurrence by hand."""

alpha = 0.9
J1 = randn_((3, 8))
J2 = randn_((3, 8))
G1 = J1 @ J1.T
G2 = J2 @ J2.T
m = J1.shape[0]

bare = UPGradWeighting()
smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=alpha)

lambda_hat_1 = bare(G1)
lambda_hat_2 = bare(G2)
uniform = tensor_([1.0 / m] * m)

expected_1 = alpha * uniform + (1.0 - alpha) * lambda_hat_1
expected_2 = alpha * expected_1 + (1.0 - alpha) * lambda_hat_2

assert_close(smoothed(G1), expected_1)
assert_close(smoothed(G2), expected_2)


def test_zero_columns() -> None:
"""
A ``(2, 0)`` matrix has no columns to combine, so the aggregation must be empty. Zero-row
inputs are intentionally not tested: ``MeanWeighting`` does ``1/m`` in Python and would
raise ``ZeroDivisionError`` at ``m=0``, which is the wrapped weighting's responsibility.
"""

aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting()))
out = aggregator(tensor_([]).reshape(2, 0))
assert out.shape == (0,)