Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ changelog does not include internal changes that do not affect the user.

## [Unreleased]

### Changed

- Non-differentiable aggregators and weightings (UPGrad, DualProj, PCGrad, GradVac, IMTLG,
GradDrop, ConFIG, CAGrad, NashMTL) no longer build a computation graph when called on tensors
that require gradients. Their forward pass is now wrapped in `torch.no_grad()`, so attempting to
differentiate through them is not possible anymore (while before, it raised a `NonDifferentiableError`).

### Added

- Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are
Expand Down
10 changes: 4 additions & 6 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from torchjd.linalg import PSDMatrix

from ._mixins import _NonDifferentiable
from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import _GramianWeighting

Expand All @@ -15,10 +16,10 @@
from torchjd._linalg import normalize

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error


class CAGradWeighting(_GramianWeighting):
# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
class CAGradWeighting(_NonDifferentiable, _GramianWeighting):
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.CAGrad`.
Expand Down Expand Up @@ -92,7 +93,7 @@ def norm_eps(self, value: float) -> None:
self._norm_eps = value


class CAGrad(GramianWeightedAggregator):
class CAGrad(_NonDifferentiable, GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
Expand All @@ -113,9 +114,6 @@ class CAGrad(GramianWeightedAggregator):
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def c(self) -> float:
return self.gramian_weighting.c
Expand Down
8 changes: 3 additions & 5 deletions src/torchjd/aggregation/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from torchjd.linalg import Matrix

from ._aggregator_bases import Aggregator
from ._mixins import _NonDifferentiable
from ._sum import SumWeighting
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting


class ConFIG(Aggregator):
# Non-differentiable: the pseudoinverse and the normalization are not differentiable in this context.
class ConFIG(_NonDifferentiable, Aggregator):
"""
:class:`~torchjd.aggregation.Aggregator` as defined in Equation 2 of `ConFIG:
Towards Conflict-free Training of Physics Informed Neural Networks
Expand All @@ -31,9 +32,6 @@ def __init__(self, pref_vector: Tensor | None = None) -> None:
super().__init__()
self.pref_vector = pref_vector

# This prevents computing gradients that can be very wrong.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def forward(self, matrix: Matrix, /) -> Tensor:
weights = self.weighting(matrix)
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
Expand Down
10 changes: 4 additions & 6 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._mixins import _NonDifferentiable
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import _GramianWeighting


class DualProjWeighting(_GramianWeighting):
# Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph.
class DualProjWeighting(_NonDifferentiable, _GramianWeighting):
r"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.DualProj`.
Expand Down Expand Up @@ -77,7 +78,7 @@ def reg_eps(self, value: float) -> None:
self._reg_eps = value


class DualProj(GramianWeightedAggregator):
class DualProj(_NonDifferentiable, GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` that averages the rows of the input
matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds
Expand Down Expand Up @@ -109,9 +110,6 @@ def __init__(
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
)

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def pref_vector(self) -> Tensor | None:
return self.gramian_weighting.pref_vector
Expand Down
8 changes: 3 additions & 5 deletions src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from torchjd.linalg import Matrix

from ._aggregator_bases import Aggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._mixins import _NonDifferentiable


def _identity(P: Tensor) -> Tensor:
return P


class GradDrop(Aggregator):
# Non-differentiable: the sign-based random masking is not differentiable.
class GradDrop(_NonDifferentiable, Aggregator):
"""
:class:`~torchjd.aggregation.Aggregator` that applies the gradient combination
steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign:
Expand All @@ -31,9 +32,6 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None:
self.f = f
self.leak = leak

# This prevents computing gradients that can be very wrong.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def forward(self, matrix: Matrix, /) -> Tensor:
self._check_matrix_has_enough_rows(matrix)

Expand Down
9 changes: 4 additions & 5 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import torch
from torch import Tensor

from torchjd.aggregation._mixins import Stateful
from torchjd.aggregation._mixins import Stateful, _NonDifferentiable
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import _GramianWeighting


class GradVacWeighting(_GramianWeighting, Stateful):
# Non-differentiable: weights are modified in-place during the gradient correction loop.
class GradVacWeighting(_NonDifferentiable, Stateful, _GramianWeighting):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
Expand Down Expand Up @@ -128,7 +128,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
self._state_key = key


class GradVac(GramianWeightedAggregator, Stateful):
class GradVac(_NonDifferentiable, Stateful, GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of
Expand Down Expand Up @@ -167,7 +167,6 @@ def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
weighting = GradVacWeighting(beta=beta, eps=eps)
super().__init__(weighting)
self._gradvac_weighting = weighting
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def beta(self) -> float:
Expand Down
10 changes: 4 additions & 6 deletions src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._mixins import _NonDifferentiable
from ._weighting_bases import _GramianWeighting


class IMTLGWeighting(_GramianWeighting):
# Non-differentiable: differentiating through pinv(gramian) would give incorrect gradients.
class IMTLGWeighting(_NonDifferentiable, _GramianWeighting):
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.IMTLG`.
Expand All @@ -24,7 +25,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
return weights


class IMTLG(GramianWeightedAggregator):
class IMTLG(_NonDifferentiable, GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` generalizing the method described in
`Towards Impartial Multi-task Learning <https://discovery.ucl.ac.uk/id/eprint/10120667/>`_.
Expand All @@ -36,6 +37,3 @@ class IMTLG(GramianWeightedAggregator):

def __init__(self) -> None:
super().__init__(IMTLGWeighting())

# This prevents computing gradients that can be very wrong.
self.register_full_backward_pre_hook(raise_non_differentiable_error)
20 changes: 20 additions & 0 deletions src/torchjd/aggregation/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any

import torch
from torch import nn


class Stateful(ABC):
Expand All @@ -7,3 +11,19 @@ class Stateful(ABC):
@abstractmethod
def reset(self) -> None:
"""Resets the internal state."""


class _NonDifferentiable(nn.Module):
"""
Mixin making a nn.Module non-differentiable, preventing autograd graph construction by wrapping
the call in :func:`torch.no_grad`.

.. warning::
This mixin must appear **before** any :class:`torch.nn.Module` base class in the inheritance
list. Placing it after will silently have no effect, because :meth:`__call__` would be
resolved to :class:`torch.nn.Module` before reaching this mixin.
Comment thread
ValerianRey marked this conversation as resolved.
"""

def __call__(self, *args: Any, **kwargs: Any) -> Any:
with torch.no_grad():
return super().__call__(*args, **kwargs)
11 changes: 4 additions & 7 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon.
# See NOTICES for the full license text.

from torchjd.aggregation._mixins import Stateful
from torchjd.aggregation._mixins import Stateful, _NonDifferentiable

from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import _MatrixWeighting
Expand All @@ -15,10 +15,10 @@
from torch import Tensor

from ._aggregator_bases import WeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error


class _NashMTLWeighting(_MatrixWeighting, Stateful):
# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
class _NashMTLWeighting(_NonDifferentiable, Stateful, _MatrixWeighting):
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that
Expand Down Expand Up @@ -199,7 +199,7 @@ def reset(self) -> None:
self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32)


class NashMTL(WeightedAggregator, Stateful):
class NashMTL(_NonDifferentiable, Stateful, WeightedAggregator):
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of
Expand Down Expand Up @@ -253,9 +253,6 @@ def __init__(
),
)

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def n_tasks(self) -> int:
return self.weighting.n_tasks
Expand Down
10 changes: 4 additions & 6 deletions src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from torchjd.linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._mixins import _NonDifferentiable
from ._weighting_bases import _GramianWeighting


class PCGradWeighting(_GramianWeighting):
# Non-differentiable: weights are modified in-place during the gradient projection loop.
class PCGradWeighting(_NonDifferentiable, _GramianWeighting):
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.PCGrad`.
Expand Down Expand Up @@ -46,7 +47,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
return weights.to(device)


class PCGrad(GramianWeightedAggregator):
class PCGrad(_NonDifferentiable, GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of
`Gradient Surgery for Multi-Task Learning <https://arxiv.org/pdf/2001.06782.pdf>`_.
Expand All @@ -56,6 +57,3 @@ class PCGrad(GramianWeightedAggregator):

def __init__(self) -> None:
super().__init__(PCGradWeighting())

# This prevents running into a RuntimeError due to modifying stored tensors in place.
self.register_full_backward_pre_hook(raise_non_differentiable_error)
10 changes: 4 additions & 6 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._mixins import _NonDifferentiable
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import _GramianWeighting


class UPGradWeighting(_GramianWeighting):
# Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph.
class UPGradWeighting(_NonDifferentiable, _GramianWeighting):
r"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.UPGrad`.
Expand Down Expand Up @@ -80,7 +81,7 @@ def reg_eps(self, value: float) -> None:
self._reg_eps = value


class UPGrad(GramianWeightedAggregator):
class UPGrad(_NonDifferentiable, GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` that projects each row of the input
matrix onto the dual cone of all rows of this matrix, and that combines the result, as proposed
Expand Down Expand Up @@ -112,9 +113,6 @@ def __init__(
UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
)

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def pref_vector(self) -> Tensor | None:
return self.gramian_weighting.pref_vector
Expand Down
10 changes: 0 additions & 10 deletions src/torchjd/aggregation/_utils/non_differentiable.py

This file was deleted.

12 changes: 5 additions & 7 deletions tests/unit/aggregation/_asserts.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import torch
from pytest import raises
from torch import Tensor
from torch.testing import assert_close
from utils.tensors import rand_, randperm_

from torchjd.aggregation import Aggregator
from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError


def assert_expected_structure(aggregator: Aggregator, matrix: Tensor) -> None:
Expand Down Expand Up @@ -103,10 +101,10 @@ def assert_strongly_stationary(

def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor) -> None:
"""
Tests empirically that a given non-differentiable `Aggregator` correctly raises a
NonDifferentiableError whenever we try to backward through it.
Tests that a non-differentiable `Aggregator` does not build a computation graph, even when the
input requires gradients.
"""

vector = aggregator(matrix)
with raises(NonDifferentiableError):
vector.backward(torch.ones_like(vector))
matrix_with_grad = matrix.clone().requires_grad_(True)
vector = aggregator(matrix_with_grad)
assert not vector.requires_grad
Loading