feat(aggregation): Add CRMOGMWeighting#669
feat(aggregation): Add CRMOGMWeighting#669KhusPatel4450 wants to merge 10 commits intoSimplexLab:mainfrom
Conversation
|
Thanks a lot for the PR! I'm gonna review soon! In the meantime, you can try to get the CI the pass |
|
Hello, Happy to say, all checks have been passed! Glad to have got my first PR as well. Looking forward to feedback |
| .. note:: | ||
| The usage example in the docstring above imports | ||
| ``WeightedAggregator`` / ``GramianWeightedAggregator`` from | ||
| ``torchjd.aggregation._aggregator_bases``, which is a private module. These two | ||
| aggregator base classes are not currently part of the public ``torchjd.aggregation`` | ||
| namespace, so this private-module import is the only path that works today. Promoting | ||
| them to the public namespace is a separate decision left to the maintainers. |
There was a problem hiding this comment.
I forgot about that, but it's a major issue: those classes should become public if we want users to be able to comfortably use CRMOGMWeighting.
I'll make these classes public in another PR, together with all the required changes. It's like a prerequisite to your PR IMO.
Thanks for being specific about this!
There was a problem hiding this comment.
Done. We can now remove this note and update every import of those classes to the new version now that they are public.
There was a problem hiding this comment.
Seems like there is still a problem with this, because now WeightedAggregator takes a MatrixWeighting (and not a Weighting[Matrix]), and CRMOGMWeighting[Matrix] is not a subtype of Weighting[Matrix]. See the type checking action that fails. Idk how to fix this...
I'll think about that tomorrow probably.
There was a problem hiding this comment.
I think what we can do is revert the init annotations back to the broader types in _aggregator_bases.py.
def init(self, weighting: Weighting[Matrix]) -> None:
def init(self, gramian_weighting: Weighting[PSDMatrix]) -> None:
Do the aggregators really need to require these specific subtypes that are there currently:
class WeightedAggregator(Aggregator):
def init(self, weighting: MatrixWeighting) -> None:
class GramianWeightedAggregator(WeightedAggregator):
def init(self, gramian_weighting: GramianWeighting) -> None:
| def reset(self) -> None: | ||
| """Clears the EMA state so the next forward starts from uniform weights.""" | ||
|
|
||
| self._lambda = None | ||
| self._state_key = None |
There was a problem hiding this comment.
This should also call the wrapped weighting's reset method if it has one.
| device = stat.device | ||
| dtype = stat.dtype | ||
| m = stat.shape[0] |
There was a problem hiding this comment.
I'm not a fan of getting this information from stat, which is currently a _T (bound to be a Tensor), but which may change in the future to be any stat computed from a Jacobian matrix (not just a tensor).
So I would rather first call self.weighting(stat) to obtain lambda_hat, and deduce device, dtype and m from this lambda_hat.
| def test_representations() -> None: | ||
| W = CRMOGMWeighting(MeanWeighting(), alpha=0.9) | ||
| expected = "CRMOGMWeighting(weighting=MeanWeighting(), alpha=0.9)" | ||
| # Weighting does not define __str__, so it falls back to __repr__. | ||
| assert repr(W) == expected | ||
| assert str(W) == expected |
There was a problem hiding this comment.
We don't currently define __repr__ or __str__ for the weightings, so we can get rid of that until we do.
| def __repr__(self) -> str: | ||
| return f"CRMOGMWeighting(weighting={self.weighting!r}, alpha={self.alpha!r})" |
There was a problem hiding this comment.
We don't currently define __repr__ or __str__ for the weightings, so we can get rid of that until we do.
| assert_expected_structure(aggregator, matrix) | ||
|
|
||
|
|
||
| def test_reset_restores_first_step_behavior() -> None: |
There was a problem hiding this comment.
Need a test of CRMOGMWeighting wrapping GradVacWeighting (a stateful weighting), to verify that the reset method correctly calls the underlying weighting's reset method.
| :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another | ||
| :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it | ||
| produces with an exponential moving average (EMA) across calls. This is the weight-smoothing | ||
| modifier from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022). |
There was a problem hiding this comment.
Can add the link to https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf
Also the name of the paper is: On the Convergence of Stochastic Multi-Objective Gradient Manipulation and Beyond
| self.weighting = weighting | ||
| self.alpha = alpha | ||
| self._lambda: Tensor | None = None | ||
| self._state_key: tuple[int, torch.dtype, torch.device] | None = None |
There was a problem hiding this comment.
Do we even need a state_key? I think we don't use it.
| aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr | ||
| """ | ||
|
|
||
| def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: |
There was a problem hiding this comment.
I think the initialization of lambda is debattable.
For now, we have 1/m all the time.
Maybe sometimes a user wants to provide their own starting weights (btw they don't need to be in the simplex, even though it's stated like that in the paper I think it's a mistake).
So we could have a initial_weights parameter, of type Tensor | None, so that the user can provide their weights or we use 1/m if they don't.
The alternative would be to have still type Tensor | None, but if the user gives None, we use lambda_0 = lambda_1_hat.
This means that the first weights output by the CRMOGMWeighting will be lambda_1 = lambda_1_hat * alpha + (1 - alpha) * lambda_1_hat = lambda_1_hat.
I don't know which option we should go for. @PierreQuinton maybe need your insight on this.
There was a problem hiding this comment.
The weights definitely can leave the simplex (maybe not for MGDA, but for most other weightings it is not the case) So I would not limit ourselves to the simplex.
I don't know about the second question, but I would use the default value they use int he paper, which seems to be what is currently the implementation.
|
All the things addressed:
Still open:
|
| _T = TypeVar("_T", contravariant=True, bound=Tensor) | ||
|
|
||
|
|
||
| class CRMOGMWeighting(Weighting[_T], Stateful): |
There was a problem hiding this comment.
I think the name could be improved to reflect that this is purely a weighting wrapper, maybe CRMOGMWeightingWrapper?
There was a problem hiding this comment.
I think it's a bit too long, so I'd rather not specify that in general. Same applies to normalizers which, if implemented as wrappers, would always have Wrapper in their name.
Same for mixins, we don't use Mixin in their name.
| ``MatrixWeighting`` or a ``GramianWeighting``. The user composes it with the appropriate | ||
| aggregator base: |
There was a problem hiding this comment.
Maybe?
| ``MatrixWeighting`` or a ``GramianWeighting``. The user composes it with the appropriate | |
| aggregator base: | |
| ``MatrixWeighting`` or a ``GramianWeighting``. Creating a corresponding | |
| :class:`~torchjd.aggregation._weighting_bases.Aggregator` can be done by composing | |
| it with the appropriate aggregator base: |
| aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr | ||
| """ | ||
|
|
||
| def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: |
There was a problem hiding this comment.
The weights definitely can leave the simplex (maybe not for MGDA, but for most other weightings it is not the case) So I would not limit ourselves to the simplex.
I don't know about the second question, but I would use the default value they use int he paper, which seems to be what is currently the implementation.
| This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` | ||
| when restarting the smoothing from uniform weights. |
There was a problem hiding this comment.
Maybe we should add a note that it will also reset the wrapped Weighting on call to reset so that there are no surprises?
| self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device) | ||
| lambda_prev = cast(Tensor, self._lambda) |
There was a problem hiding this comment.
@ValerianRey Not sure this is a good idea, but if _ensure_state where to return self._lambda after, then we could merge those two lines and remove the cast. This would give several responsibilities to _ensure_state though.
| if ( | ||
| self._lambda is None | ||
| or self._lambda.shape[0] != m | ||
| or self._lambda.dtype != dtype | ||
| or self._lambda.device != device | ||
| ): | ||
| if m > 0: | ||
| self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) | ||
| else: | ||
| self._lambda = torch.zeros(0, dtype=dtype, device=device) |
There was a problem hiding this comment.
I think I would remove the checks on dtype and device and let torch fail if there is a problem. I would also raise an error if self._lambda.shape[0] != m because it looks like something users want to avoid (or maybe not, if we think about IWRM and the last batch being smaller).
I would also remove the inner condition, I think it is safe to assume that m != 0, I hope something would fail before we reach this. If we want to test this, I would raise and error if m==0.
There was a problem hiding this comment.
We should not use CRMOGM for ssjd IMO.
There was a problem hiding this comment.
But I agree that we can remove most of these checks from this function. Its role is just to initialize lambda if it is None.
There was a problem hiding this comment.
you are right, if it is not the same shape then it doesn't make sense, so I would actually raise. If a user comes with a use for this, we can remove later on.
|
Hello I updated the code with the changes that were requested with these two commits, its just that 2nd commit has the similified version and the raise on shape change in CRMOGMWeighting._ensure_state |
|
I have a plan to fix the typing issue. Coming today or tomorrow. Edit: should be fixed with #673. |
Tests: