diff --git a/CHANGELOG.md b/CHANGELOG.md index bfd051c07..5104aa77e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Added + +- Added `GradVac` and `GradVacWeighting` from + [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874). + ### Fixed - Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example diff --git a/README.md b/README.md index ccf443fc0..05df2b1e1 100644 --- a/README.md +++ b/README.md @@ -281,6 +281,7 @@ TorchJD provides many existing aggregators from the literature, listed in the fo | [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - | | [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) | | [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) | +| [GradVac](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVac) | [GradVacWeighting](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVacWeighting) | [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874) | | [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) | | [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) | | [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - | diff --git a/docs/source/docs/aggregation/gradvac.rst b/docs/source/docs/aggregation/gradvac.rst new file mode 100644 index 000000000..1c2a7d0ae --- /dev/null +++ b/docs/source/docs/aggregation/gradvac.rst @@ -0,0 +1,14 @@ +:hide-toc: + +GradVac +======= + +.. autoclass:: torchjd.aggregation.GradVac + :members: + :undoc-members: + :exclude-members: forward, eps, beta + +.. autoclass:: torchjd.aggregation.GradVacWeighting + :members: + :undoc-members: + :exclude-members: forward, eps, beta diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index c15d5980f..64ba6f639 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -35,6 +35,7 @@ Abstract base classes dualproj.rst flattening.rst graddrop.rst + gradvac.rst imtl_g.rst krum.rst mean.rst diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 9eed9bf7e..93f824e35 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -66,6 +66,7 @@ from ._dualproj import DualProj, DualProjWeighting from ._flattening import Flattening from ._graddrop import GradDrop +from ._gradvac import GradVac, GradVacWeighting from ._imtl_g import IMTLG, IMTLGWeighting from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting @@ -92,6 +93,8 @@ "Flattening", "GeneralizedWeighting", "GradDrop", + "GradVac", + "GradVacWeighting", "IMTLG", "IMTLGWeighting", "Krum", diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py new file mode 100644 index 000000000..57a089644 --- /dev/null +++ b/src/torchjd/aggregation/_gradvac.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd._linalg import PSDMatrix + +from ._aggregator_bases import GramianWeightedAggregator +from ._utils.non_differentiable import raise_non_differentiable_error +from ._weighting_bases import Weighting + + +class GradVac(GramianWeightedAggregator): + r""" + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of + Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task + Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) + `_. + + For each task :math:`i`, the order in which other tasks :math:`j` are visited is drawn at + random. For each pair :math:`(i, j)`, the cosine similarity :math:`\phi_{ij}` between the + (possibly already modified) gradient of task :math:`i` and the original gradient of task + :math:`j` is compared to an EMA target :math:`\hat{\phi}_{ij}`. When + :math:`\phi_{ij} < \hat{\phi}_{ij}`, a closed-form correction adds a scaled copy of + :math:`g_j` to :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with + :math:`\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}`. The aggregated + vector is the sum of the modified rows. + + This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when + the number of tasks or dtype changes. + + :param beta: EMA decay for :math:`\hat{\phi}`. + :param eps: Small non-negative constant added to denominators. + + .. note:: + For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently + using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if + you need reproducibility. + """ + + def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: + weighting = GradVacWeighting(beta=beta, eps=eps) + super().__init__(weighting) + self._gradvac_weighting = weighting + self.register_full_backward_pre_hook(raise_non_differentiable_error) + + @property + def beta(self) -> float: + return self._gradvac_weighting.beta + + @beta.setter + def beta(self, value: float) -> None: + self._gradvac_weighting.beta = value + + @property + def eps(self) -> float: + return self._gradvac_weighting.eps + + @eps.setter + def eps(self, value: float) -> None: + self._gradvac_weighting.eps = value + + def reset(self) -> None: + """Clears EMA state so the next forward starts from zero targets.""" + + self._gradvac_weighting.reset() + + def __repr__(self) -> str: + return f"GradVac(beta={self.beta!r}, eps={self.eps!r})" + + +class GradVacWeighting(Weighting[PSDMatrix]): + r""" + :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GradVac`. + + All required quantities (gradient norms, cosine similarities, and their updates after the + vaccine correction) are derived purely from the Gramian, without needing the full Jacobian. + If :math:`g_i^{(\mathrm{PC})} = \sum_k c_{ik} g_k`, then: + + .. math:: + + \|g_i^{(\mathrm{PC})}\|^2 = \mathbf{c}_i G \mathbf{c}_i^\top,\qquad + g_i^{(\mathrm{PC})} \cdot g_j = \mathbf{c}_i G_{:,j} + + where :math:`G` is the Gramian. The correction :math:`g_i^{(\mathrm{PC})} \mathrel{+}= w + g_j` then becomes :math:`c_{ij} \mathrel{+}= w`, and the updated dot products follow + immediately. + + This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when + the number of tasks or dtype changes. + + :param beta: EMA decay for :math:`\hat{\phi}`. + :param eps: Small non-negative constant added to denominators. + """ + + def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: + super().__init__() + if not (0.0 <= beta <= 1.0): + raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") + if eps < 0.0: + raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.") + + self._beta = beta + self._eps = eps + self._phi_t: Tensor | None = None + self._state_key: tuple[int, torch.dtype] | None = None + + @property + def beta(self) -> float: + return self._beta + + @beta.setter + def beta(self, value: float) -> None: + if not (0.0 <= value <= 1.0): + raise ValueError(f"Attribute `beta` must be in [0, 1]. Found beta={value!r}.") + self._beta = value + + @property + def eps(self) -> float: + return self._eps + + @eps.setter + def eps(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `eps` must be non-negative. Found eps={value!r}.") + self._eps = value + + def reset(self) -> None: + """Clears EMA state so the next forward starts from zero targets.""" + + self._phi_t = None + self._state_key = None + + def forward(self, gramian: PSDMatrix, /) -> Tensor: + # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration + device = gramian.device + dtype = gramian.dtype + cpu = torch.device("cpu") + + G = cast(PSDMatrix, gramian.to(device=cpu)) + m = G.shape[0] + + self._ensure_state(m, dtype) + phi_t = cast(Tensor, self._phi_t) + + beta = self._beta + eps = self._eps + + # C[i, :] holds coefficients such that g_i^PC = sum_k C[i,k] * g_k (original gradients). + # Initially each modified gradient equals the original, so C = I. + C = torch.eye(m, device=cpu, dtype=dtype) + + for i in range(m): + # Dot products of g_i^PC with every original g_j, shape (m,). + cG = C[i] @ G + + others = [j for j in range(m) if j != i] + perm = torch.randperm(len(others)) + shuffled_js = [others[idx] for idx in perm.tolist()] + + for j in shuffled_js: + dot_ij = cG[j] + norm_i_sq = (cG * C[i]).sum() + norm_i = norm_i_sq.clamp(min=0.0).sqrt() + norm_j = G[j, j].clamp(min=0.0).sqrt() + denom = norm_i * norm_j + eps + phi_ijk = dot_ij / denom + + phi_hat = phi_t[i, j] + if phi_ijk < phi_hat: + sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt() + sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt() + denom_w = norm_j * sqrt_1_hat2 + eps + w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w + C[i, j] = C[i, j] + w + cG = cG + w * G[j] + + phi_t[i, j] = (1.0 - beta) * phi_hat + beta * phi_ijk + + weights = C.sum(dim=0) + return weights.to(device) + + def _ensure_state(self, m: int, dtype: torch.dtype) -> None: + key = (m, dtype) + if self._state_key != key or self._phi_t is None: + self._phi_t = torch.zeros(m, m, dtype=dtype) + self._state_key = key diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 2a945f939..2411e4c33 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -17,6 +17,7 @@ ConFIG, DualProj, GradDrop, + GradVac, Mean, NashMTL, PCGrad, @@ -48,6 +49,7 @@ def main() -> None: ConFIG(), DualProj(), GradDrop(), + GradVac(), IMTLG(), Mean(), MGDA(), diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py new file mode 100644 index 000000000..bde2e8fd6 --- /dev/null +++ b/tests/unit/aggregation/test_gradvac.py @@ -0,0 +1,146 @@ +import torch +from pytest import mark, raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import ones_, randn_, tensor_ + +from torchjd.aggregation import GradVac, GradVacWeighting + +from ._asserts import assert_expected_structure, assert_non_differentiable +from ._inputs import scaled_matrices, typical_matrices, typical_matrices_2_plus_rows + +scaled_pairs = [(GradVac(), m) for m in scaled_matrices] +typical_pairs = [(GradVac(), m) for m in typical_matrices] +requires_grad_pairs = [(GradVac(), ones_(3, 5, requires_grad=True))] + + +def test_representations() -> None: + A = GradVac() + assert repr(A) == "GradVac(beta=0.5, eps=1e-08)" + assert str(A) == "GradVac" + + +def test_beta_out_of_range() -> None: + with raises(ValueError, match="beta"): + GradVac(beta=-0.1) + with raises(ValueError, match="beta"): + GradVac(beta=1.1) + + +def test_beta_setter_out_of_range() -> None: + A = GradVac() + with raises(ValueError, match="beta"): + A.beta = -0.1 + with raises(ValueError, match="beta"): + A.beta = 1.1 + + +def test_beta_setter_updates_value() -> None: + A = GradVac() + A.beta = 0.25 + assert A.beta == 0.25 + + +def test_eps_rejects_negative() -> None: + with raises(ValueError, match="eps"): + GradVac(eps=-1e-9) + + +def test_eps_setter_rejects_negative() -> None: + A = GradVac() + with raises(ValueError, match="eps"): + A.eps = -1e-9 + + +def test_eps_can_be_changed_between_steps() -> None: + J = tensor_([[1.0, 0.0], [0.0, 1.0]]) + A = GradVac() + A.eps = 1e-6 + assert A(J).isfinite().all() + A.reset() + A.eps = 1e-10 + assert A(J).isfinite().all() + + +def test_zero_rows_returns_zero_vector() -> None: + out = GradVac()(tensor_([]).reshape(0, 3)) + assert_close(out, tensor_([0.0, 0.0, 0.0])) + + +def test_zero_columns_returns_zero_vector() -> None: + out = GradVac()(tensor_([]).reshape(2, 0)) + assert out.shape == (0,) + + +def test_reproducible_with_manual_seed() -> None: + J = randn_((3, 8)) + torch.manual_seed(12345) + A1 = GradVac(beta=0.3) + out1 = A1(J) + torch.manual_seed(12345) + A2 = GradVac(beta=0.3) + out2 = A2(J) + assert_close(out1, out2) + + +@mark.parametrize("matrix", typical_matrices_2_plus_rows) +def test_reset_restores_first_step_behavior(matrix: Tensor) -> None: + torch.manual_seed(7) + A = GradVac(beta=0.5) + first = A(matrix) + A(matrix) + A.reset() + torch.manual_seed(7) + assert_close(first, A(matrix)) + + +@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) +def test_expected_structure(aggregator: GradVac, matrix: Tensor) -> None: + assert_expected_structure(aggregator, matrix) + + +@mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) +def test_non_differentiable(aggregator: GradVac, matrix: Tensor) -> None: + assert_non_differentiable(aggregator, matrix) + + +def test_weighting_beta_out_of_range() -> None: + with raises(ValueError, match="beta"): + GradVacWeighting(beta=-0.1) + with raises(ValueError, match="beta"): + GradVacWeighting(beta=1.1) + + +def test_weighting_eps_rejects_negative() -> None: + with raises(ValueError, match="eps"): + GradVacWeighting(eps=-1e-9) + + +def test_weighting_reset_restores_first_step_behavior() -> None: + J = randn_((3, 8)) + G = J @ J.T + torch.manual_seed(7) + w = GradVacWeighting(beta=0.5) + first = w(G) + w(G) + w.reset() + torch.manual_seed(7) + assert_close(first, w(G)) + + +def test_aggregator_and_weighting_agree() -> None: + """GradVac()(J) == GradVacWeighting()(J @ J.T) @ J for any matrix J.""" + + J = randn_((3, 8)) + G = J @ J.T + + torch.manual_seed(42) + A = GradVac(beta=0.3) + expected = A(J) + + torch.manual_seed(42) + W = GradVacWeighting(beta=0.3) + weights = W(G) + result = weights @ J + + assert_close(result, expected, rtol=1e-4, atol=1e-4) diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index 42faca91c..f468dc447 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -14,6 +14,8 @@ DualProj, DualProjWeighting, GradDrop, + GradVac, + GradVacWeighting, IMTLGWeighting, Krum, KrumWeighting, @@ -57,6 +59,7 @@ (Constant(tensor([1.0, 2.0])), J_base, tensor([8.0, 3.0, 3.0])), (DualProj(), J_base, tensor([0.5563, 1.1109, 1.1109])), (GradDrop(), J_base, tensor([6.0, 2.0, 2.0])), + (GradVac(), J_base, tensor([0.5848, 3.8012, 3.8012])), (IMTLG(), J_base, tensor([0.0767, 1.0000, 1.0000])), (Krum(n_byzantine=1, n_selected=4), J_Krum, tensor([1.2500, 0.7500, 1.5000])), (Mean(), J_base, tensor([1.0, 1.0, 1.0])), @@ -77,6 +80,7 @@ (DualProjWeighting(), G_base, tensor([0.6109, 0.5000])), (IMTLGWeighting(), G_base, tensor([0.5923, 0.4077])), (KrumWeighting(1, 4), G_Krum, tensor([0.2500, 0.2500, 0.0000, 0.2500, 0.2500])), + (GradVacWeighting(), G_base, tensor([2.2222, 1.5789])), (MeanWeighting(), G_base, tensor([0.5000, 0.5000])), (MGDAWeighting(), G_base, tensor([0.6000, 0.4000])), (PCGradWeighting(), G_base, tensor([2.2222, 1.5789])),