Skip to content

feat(aggregation): Add CRMOGMWeighting#669

Open
KhusPatel4450 wants to merge 10 commits intoSimplexLab:mainfrom
KhusPatel4450:feat/cr-mogm-weighting
Open

feat(aggregation): Add CRMOGMWeighting#669
KhusPatel4450 wants to merge 10 commits intoSimplexLab:mainfrom
KhusPatel4450:feat/cr-mogm-weighting

Conversation

@KhusPatel4450
Copy link
Copy Markdown

  • Adds CRMOGMWeighting, a stateful Weighting modifier from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022)
  • Wraps any existing Weighting and smooths its output with an EMA: λk = α·λ{k-1} + (1−α)·λ̂_k
  • Generic over the input type so it composes correctly with both WeightedAggregator and GramianWeightedAggregator
  • Stateful via the Stateful mixin; reset() restores uniform initial weights

Tests:

  • uv run pytest tests/unit/aggregation/test_cr_mogm.py -v, 92 tests covering EMA recurrence, alpha boundaries, reset, structural checks on both aggregator paths
  • uv run pytest tests/unit -q, full regression (2889 passed)
  • uv run ty check src/torchjd/aggregation/_cr_mogm.py, passes
  • Sphinx doctest, 94 tests, 0 failures

@KhusPatel4450 KhusPatel4450 changed the title Add CRMOGMWeighting from NeurIPS 2022 (Aggregation Feature) feat(aggregation): Add CRMOGMWeighting from NeurIPS 2022 May 7, 2026
@ValerianRey ValerianRey changed the title feat(aggregation): Add CRMOGMWeighting from NeurIPS 2022 feat(aggregation): Add CRMOGMWeighting May 7, 2026
@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: aggregation labels May 7, 2026
@ValerianRey
Copy link
Copy Markdown
Contributor

Thanks a lot for the PR! I'm gonna review soon! In the meantime, you can try to get the CI the pass

@ValerianRey ValerianRey mentioned this pull request May 7, 2026
@KhusPatel4450
Copy link
Copy Markdown
Author

Hello,

Happy to say, all checks have been passed! Glad to have got my first PR as well. Looking forward to feedback

Comment on lines +9 to +15
.. note::
The usage example in the docstring above imports
``WeightedAggregator`` / ``GramianWeightedAggregator`` from
``torchjd.aggregation._aggregator_bases``, which is a private module. These two
aggregator base classes are not currently part of the public ``torchjd.aggregation``
namespace, so this private-module import is the only path that works today. Promoting
them to the public namespace is a separate decision left to the maintainers.
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot about that, but it's a major issue: those classes should become public if we want users to be able to comfortably use CRMOGMWeighting.

I'll make these classes public in another PR, together with all the required changes. It's like a prerequisite to your PR IMO.

Thanks for being specific about this!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed when we merge #670

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. We can now remove this note and update every import of those classes to the new version now that they are public.

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like there is still a problem with this, because now WeightedAggregator takes a MatrixWeighting (and not a Weighting[Matrix]), and CRMOGMWeighting[Matrix] is not a subtype of Weighting[Matrix]. See the type checking action that fails. Idk how to fix this...

I'll think about that tomorrow probably.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what we can do is revert the init annotations back to the broader types in _aggregator_bases.py.

def init(self, weighting: Weighting[Matrix]) -> None:
def init(self, gramian_weighting: Weighting[PSDMatrix]) -> None:

Do the aggregators really need to require these specific subtypes that are there currently:

class WeightedAggregator(Aggregator):
def init(self, weighting: MatrixWeighting) -> None:
class GramianWeightedAggregator(WeightedAggregator):
def init(self, gramian_weighting: GramianWeighting) -> None:

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment on lines +91 to +95
def reset(self) -> None:
"""Clears the EMA state so the next forward starts from uniform weights."""

self._lambda = None
self._state_key = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also call the wrapped weighting's reset method if it has one.

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment on lines +98 to +100
device = stat.device
dtype = stat.dtype
m = stat.shape[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of getting this information from stat, which is currently a _T (bound to be a Tensor), but which may change in the future to be any stat computed from a Jacobian matrix (not just a tensor).

So I would rather first call self.weighting(stat) to obtain lambda_hat, and deduce device, dtype and m from this lambda_hat.

Comment thread tests/unit/aggregation/test_cr_mogm.py Outdated
Comment on lines +27 to +32
def test_representations() -> None:
W = CRMOGMWeighting(MeanWeighting(), alpha=0.9)
expected = "CRMOGMWeighting(weighting=MeanWeighting(), alpha=0.9)"
# Weighting does not define __str__, so it falls back to __repr__.
assert repr(W) == expected
assert str(W) == expected
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't currently define __repr__ or __str__ for the weightings, so we can get rid of that until we do.

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment on lines +120 to +121
def __repr__(self) -> str:
return f"CRMOGMWeighting(weighting={self.weighting!r}, alpha={self.alpha!r})"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't currently define __repr__ or __str__ for the weightings, so we can get rid of that until we do.

assert_expected_structure(aggregator, matrix)


def test_reset_restores_first_step_behavior() -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a test of CRMOGMWeighting wrapping GradVacWeighting (a stateful weighting), to verify that the reset method correctly calls the underlying weighting's reset method.

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
:class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another
:class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it
produces with an exponential moving average (EMA) across calls. This is the weight-smoothing
modifier from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can add the link to https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf

Also the name of the paper is: On the Convergence of Stochastic Multi-Objective Gradient Manipulation and Beyond

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
self.weighting = weighting
self.alpha = alpha
self._lambda: Tensor | None = None
self._state_key: tuple[int, torch.dtype, torch.device] | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even need a state_key? I think we don't use it.

aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr
"""

def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the initialization of lambda is debattable.

For now, we have 1/m all the time.

Maybe sometimes a user wants to provide their own starting weights (btw they don't need to be in the simplex, even though it's stated like that in the paper I think it's a mistake).

So we could have a initial_weights parameter, of type Tensor | None, so that the user can provide their weights or we use 1/m if they don't.

The alternative would be to have still type Tensor | None, but if the user gives None, we use lambda_0 = lambda_1_hat.

This means that the first weights output by the CRMOGMWeighting will be lambda_1 = lambda_1_hat * alpha + (1 - alpha) * lambda_1_hat = lambda_1_hat.

I don't know which option we should go for. @PierreQuinton maybe need your insight on this.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weights definitely can leave the simplex (maybe not for MGDA, but for most other weightings it is not the case) So I would not limit ourselves to the simplex.

I don't know about the second question, but I would use the default value they use int he paper, which seems to be what is currently the implementation.

Comment thread CHANGELOG.md Outdated
@KhusPatel4450
Copy link
Copy Markdown
Author

All the things addressed:

  • Reset propagation: reset() now calls self.weighting.reset() if the wrapped weighting is Stateful.

  • device/dtype/m from weighting output: forward() now calls self.weighting(stat) first and reads everything from lambda_hat, not from stat.

  • Removed repr and the corresponding test_representations test.

  • Removed _state_key: _ensure_state now checks shape/dtype/device directly off _lambda.

  • Added test for reset() propagation using GradVacWeighting as the inner stateful weighting.

Still open:

  • Initial weight strategy (uniform 1/m vs first weighting output)

  • Type checking failure

_T = TypeVar("_T", contravariant=True, bound=Tensor)


class CRMOGMWeighting(Weighting[_T], Stateful):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name could be improved to reflect that this is purely a weighting wrapper, maybe CRMOGMWeightingWrapper?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bit too long, so I'd rather not specify that in general. Same applies to normalizers which, if implemented as wrappers, would always have Wrapper in their name.

Same for mixins, we don't use Mixin in their name.

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment on lines +38 to +39
``MatrixWeighting`` or a ``GramianWeighting``. The user composes it with the appropriate
aggregator base:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe?

Suggested change
``MatrixWeighting`` or a ``GramianWeighting``. The user composes it with the appropriate
aggregator base:
``MatrixWeighting`` or a ``GramianWeighting``. Creating a corresponding
:class:`~torchjd.aggregation._weighting_bases.Aggregator` can be done by composing
it with the appropriate aggregator base:

aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr
"""

def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weights definitely can leave the simplex (maybe not for MGDA, but for most other weightings it is not the case) So I would not limit ourselves to the simplex.

I don't know about the second question, but I would use the default value they use int he paper, which seems to be what is currently the implementation.

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment on lines +52 to +53
This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset`
when restarting the smoothing from uniform weights.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should add a note that it will also reset the wrapped Weighting on call to reset so that there are no surprises?

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment on lines +102 to +103
self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device)
lambda_prev = cast(Tensor, self._lambda)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ValerianRey Not sure this is a good idea, but if _ensure_state where to return self._lambda after, then we could merge those two lines and remove the cast. This would give several responsibilities to _ensure_state though.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could do that yeah.

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment on lines +111 to +120
if (
self._lambda is None
or self._lambda.shape[0] != m
or self._lambda.dtype != dtype
or self._lambda.device != device
):
if m > 0:
self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device)
else:
self._lambda = torch.zeros(0, dtype=dtype, device=device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would remove the checks on dtype and device and let torch fail if there is a problem. I would also raise an error if self._lambda.shape[0] != m because it looks like something users want to avoid (or maybe not, if we think about IWRM and the last batch being smaller).

I would also remove the inner condition, I think it is safe to assume that m != 0, I hope something would fail before we reach this. If we want to test this, I would raise and error if m==0.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use CRMOGM for ssjd IMO.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I agree that we can remove most of these checks from this function. Its role is just to initialize lambda if it is None.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, if it is not the same shape then it doesn't make sense, so I would actually raise. If a user comes with a use for this, we can remove later on.

Comment thread CHANGELOG.md
@KhusPatel4450
Copy link
Copy Markdown
Author

Hello I updated the code with the changes that were requested with these two commits, its just that 2nd commit has the similified version and the raise on shape change in CRMOGMWeighting._ensure_state

@ValerianRey
Copy link
Copy Markdown
Contributor

ValerianRey commented May 8, 2026

I have a plan to fix the typing issue. Coming today or tomorrow.

Edit: should be fixed with #673.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants