From 96119df4ceb66ef86dc9efd9580ce2a15713ef88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 24 May 2025 14:22:24 +0200 Subject: [PATCH 1/9] Make aggregator modules protected and update their rst files --- docs/source/docs/aggregation/aligned_mtl.rst | 2 +- docs/source/docs/aggregation/bases.rst | 2 +- docs/source/docs/aggregation/cagrad.rst | 2 +- docs/source/docs/aggregation/config.rst | 2 +- docs/source/docs/aggregation/constant.rst | 2 +- docs/source/docs/aggregation/dualproj.rst | 2 +- docs/source/docs/aggregation/graddrop.rst | 2 +- docs/source/docs/aggregation/imtl_g.rst | 2 +- docs/source/docs/aggregation/krum.rst | 2 +- docs/source/docs/aggregation/mean.rst | 2 +- docs/source/docs/aggregation/mgda.rst | 2 +- docs/source/docs/aggregation/nash_mtl.rst | 2 +- docs/source/docs/aggregation/pcgrad.rst | 2 +- docs/source/docs/aggregation/random.rst | 2 +- docs/source/docs/aggregation/sum.rst | 2 +- docs/source/docs/aggregation/trimmed_mean.rst | 2 +- docs/source/docs/aggregation/upgrad.rst | 2 +- src/torchjd/aggregation/__init__.py | 34 +++++++++---------- ...gregator_bases.py => _aggregator_bases.py} | 0 .../{aligned_mtl.py => _aligned_mtl.py} | 6 ++-- .../aggregation/{cagrad.py => _cagrad.py} | 4 +-- .../aggregation/{config.py => _config.py} | 6 ++-- .../aggregation/{constant.py => _constant.py} | 6 ++-- .../aggregation/{dualproj.py => _dualproj.py} | 6 ++-- .../aggregation/{graddrop.py => _graddrop.py} | 4 +-- .../aggregation/{imtl_g.py => _imtl_g.py} | 4 +-- src/torchjd/aggregation/{krum.py => _krum.py} | 6 ++-- src/torchjd/aggregation/{mean.py => _mean.py} | 4 +-- src/torchjd/aggregation/{mgda.py => _mgda.py} | 4 +-- .../aggregation/{nash_mtl.py => _nash_mtl.py} | 4 +-- .../aggregation/{pcgrad.py => _pcgrad.py} | 6 ++-- .../aggregation/{random.py => _random.py} | 4 +-- src/torchjd/aggregation/{sum.py => _sum.py} | 4 +-- .../{trimmed_mean.py => _trimmed_mean.py} | 4 +-- .../aggregation/{upgrad.py => _upgrad.py} | 6 ++-- src/torchjd/aggregation/_utils/pref_vector.py | 2 +- src/torchjd/autojac/_transform/aggregate.py | 2 +- .../aggregation/_utils/test_pref_vector.py | 2 +- tests/unit/aggregation/test_mgda.py | 2 +- tests/unit/aggregation/test_pcgrad.py | 6 ++-- 40 files changed, 80 insertions(+), 80 deletions(-) rename src/torchjd/aggregation/{aggregator_bases.py => _aggregator_bases.py} (100%) rename src/torchjd/aggregation/{aligned_mtl.py => _aligned_mtl.py} (95%) rename src/torchjd/aggregation/{cagrad.py => _cagrad.py} (96%) rename src/torchjd/aggregation/{config.py => _config.py} (94%) rename src/torchjd/aggregation/{constant.py => _constant.py} (89%) rename src/torchjd/aggregation/{dualproj.py => _dualproj.py} (95%) rename src/torchjd/aggregation/{graddrop.py => _graddrop.py} (95%) rename src/torchjd/aggregation/{imtl_g.py => _imtl_g.py} (91%) rename src/torchjd/aggregation/{krum.py => _krum.py} (94%) rename src/torchjd/aggregation/{mean.py => _mean.py} (87%) rename src/torchjd/aggregation/{mgda.py => _mgda.py} (95%) rename src/torchjd/aggregation/{nash_mtl.py => _nash_mtl.py} (98%) rename src/torchjd/aggregation/{pcgrad.py => _pcgrad.py} (90%) rename src/torchjd/aggregation/{random.py => _random.py} (90%) rename src/torchjd/aggregation/{sum.py => _sum.py} (86%) rename src/torchjd/aggregation/{trimmed_mean.py => _trimmed_mean.py} (94%) rename src/torchjd/aggregation/{upgrad.py => _upgrad.py} (95%) diff --git a/docs/source/docs/aggregation/aligned_mtl.rst b/docs/source/docs/aggregation/aligned_mtl.rst index 3ea8de97..36ec8b44 100644 --- a/docs/source/docs/aggregation/aligned_mtl.rst +++ b/docs/source/docs/aggregation/aligned_mtl.rst @@ -3,7 +3,7 @@ Aligned-MTL =========== -.. automodule:: torchjd.aggregation.aligned_mtl +.. autoclass:: torchjd.aggregation.AlignedMTL :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/bases.rst b/docs/source/docs/aggregation/bases.rst index 529d7f3d..1b890776 100644 --- a/docs/source/docs/aggregation/bases.rst +++ b/docs/source/docs/aggregation/bases.rst @@ -3,7 +3,7 @@ Aggregator (abstract) ===================== -.. automodule:: torchjd.aggregation.aggregator_bases +.. autoclass:: torchjd.aggregation.Aggregator :members: :undoc-members: :show-inheritance: diff --git a/docs/source/docs/aggregation/cagrad.rst b/docs/source/docs/aggregation/cagrad.rst index 7e97a39a..bef38f07 100644 --- a/docs/source/docs/aggregation/cagrad.rst +++ b/docs/source/docs/aggregation/cagrad.rst @@ -3,7 +3,7 @@ CAGrad ====== -.. automodule:: torchjd.aggregation.cagrad +.. autoclass:: torchjd.aggregation.CAGrad :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/config.rst b/docs/source/docs/aggregation/config.rst index d25d9f05..94ab3f4a 100644 --- a/docs/source/docs/aggregation/config.rst +++ b/docs/source/docs/aggregation/config.rst @@ -3,7 +3,7 @@ ConFIG ====== -.. automodule:: torchjd.aggregation.config +.. autoclass:: torchjd.aggregation.ConFIG :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/constant.rst b/docs/source/docs/aggregation/constant.rst index d34fc59e..776c5210 100644 --- a/docs/source/docs/aggregation/constant.rst +++ b/docs/source/docs/aggregation/constant.rst @@ -3,7 +3,7 @@ Constant ======== -.. automodule:: torchjd.aggregation.constant +.. autoclass:: torchjd.aggregation.Constant :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/dualproj.rst b/docs/source/docs/aggregation/dualproj.rst index f8784e3e..af34d362 100644 --- a/docs/source/docs/aggregation/dualproj.rst +++ b/docs/source/docs/aggregation/dualproj.rst @@ -3,7 +3,7 @@ DualProj ======== -.. automodule:: torchjd.aggregation.dualproj +.. autoclass:: torchjd.aggregation.DualProj :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/graddrop.rst b/docs/source/docs/aggregation/graddrop.rst index f3a2e30c..f27f3612 100644 --- a/docs/source/docs/aggregation/graddrop.rst +++ b/docs/source/docs/aggregation/graddrop.rst @@ -3,7 +3,7 @@ GradDrop ======== -.. automodule:: torchjd.aggregation.graddrop +.. autoclass:: torchjd.aggregation.GradDrop :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/imtl_g.rst b/docs/source/docs/aggregation/imtl_g.rst index 8360f0b7..93bce564 100644 --- a/docs/source/docs/aggregation/imtl_g.rst +++ b/docs/source/docs/aggregation/imtl_g.rst @@ -3,7 +3,7 @@ IMTL-G ====== -.. automodule:: torchjd.aggregation.imtl_g +.. autoclass:: torchjd.aggregation.IMTLG :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/krum.rst b/docs/source/docs/aggregation/krum.rst index 030fbdb5..25a3cd30 100644 --- a/docs/source/docs/aggregation/krum.rst +++ b/docs/source/docs/aggregation/krum.rst @@ -3,7 +3,7 @@ Krum ==== -.. automodule:: torchjd.aggregation.krum +.. autoclass:: torchjd.aggregation.Krum :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/mean.rst b/docs/source/docs/aggregation/mean.rst index d0815aae..8e6e4a89 100644 --- a/docs/source/docs/aggregation/mean.rst +++ b/docs/source/docs/aggregation/mean.rst @@ -3,7 +3,7 @@ Mean ==== -.. automodule:: torchjd.aggregation.mean +.. autoclass:: torchjd.aggregation.Mean :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/mgda.rst b/docs/source/docs/aggregation/mgda.rst index 32219846..9cd68319 100644 --- a/docs/source/docs/aggregation/mgda.rst +++ b/docs/source/docs/aggregation/mgda.rst @@ -3,7 +3,7 @@ MGDA ==== -.. automodule:: torchjd.aggregation.mgda +.. autoclass:: torchjd.aggregation.MGDA :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/nash_mtl.rst b/docs/source/docs/aggregation/nash_mtl.rst index 3da37445..e95ddea9 100644 --- a/docs/source/docs/aggregation/nash_mtl.rst +++ b/docs/source/docs/aggregation/nash_mtl.rst @@ -3,7 +3,7 @@ Nash-MTL ======== -.. automodule:: torchjd.aggregation.nash_mtl +.. autoclass:: torchjd.aggregation.NashMTL :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/pcgrad.rst b/docs/source/docs/aggregation/pcgrad.rst index 2dfa0ca1..64627947 100644 --- a/docs/source/docs/aggregation/pcgrad.rst +++ b/docs/source/docs/aggregation/pcgrad.rst @@ -3,7 +3,7 @@ PCGrad ====== -.. automodule:: torchjd.aggregation.pcgrad +.. autoclass:: torchjd.aggregation.PCGrad :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/random.rst b/docs/source/docs/aggregation/random.rst index 5d204044..5fd64ac8 100644 --- a/docs/source/docs/aggregation/random.rst +++ b/docs/source/docs/aggregation/random.rst @@ -3,7 +3,7 @@ Random ====== -.. automodule:: torchjd.aggregation.random +.. autoclass:: torchjd.aggregation.Random :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/sum.rst b/docs/source/docs/aggregation/sum.rst index e6273b54..79c31afc 100644 --- a/docs/source/docs/aggregation/sum.rst +++ b/docs/source/docs/aggregation/sum.rst @@ -3,7 +3,7 @@ Sum === -.. automodule:: torchjd.aggregation.sum +.. autoclass:: torchjd.aggregation.Sum :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/trimmed_mean.rst b/docs/source/docs/aggregation/trimmed_mean.rst index a5063e5e..e7f5abd8 100644 --- a/docs/source/docs/aggregation/trimmed_mean.rst +++ b/docs/source/docs/aggregation/trimmed_mean.rst @@ -3,7 +3,7 @@ Trimmed Mean ============ -.. automodule:: torchjd.aggregation.trimmed_mean +.. autoclass:: torchjd.aggregation.TrimmedMean :members: :undoc-members: :exclude-members: forward diff --git a/docs/source/docs/aggregation/upgrad.rst b/docs/source/docs/aggregation/upgrad.rst index 5b206417..0e2df8a0 100644 --- a/docs/source/docs/aggregation/upgrad.rst +++ b/docs/source/docs/aggregation/upgrad.rst @@ -3,7 +3,7 @@ UPGrad ====== -.. automodule:: torchjd.aggregation.upgrad +.. autoclass:: torchjd.aggregation.UPGrad :members: :undoc-members: :exclude-members: forward diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index f38fe091..1af20e37 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -1,28 +1,28 @@ +from ._aggregator_bases import Aggregator +from ._aligned_mtl import AlignedMTL +from ._config import ConFIG +from ._constant import Constant +from ._dualproj import DualProj +from ._graddrop import GradDrop +from ._imtl_g import IMTLG +from ._krum import Krum +from ._mean import Mean +from ._mgda import MGDA +from ._pcgrad import PCGrad +from ._random import Random +from ._sum import Sum +from ._trimmed_mean import TrimmedMean +from ._upgrad import UPGrad from ._utils.check_dependencies import ( OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError, ) -from .aggregator_bases import Aggregator -from .aligned_mtl import AlignedMTL -from .config import ConFIG -from .constant import Constant -from .dualproj import DualProj -from .graddrop import GradDrop -from .imtl_g import IMTLG -from .krum import Krum -from .mean import Mean -from .mgda import MGDA -from .pcgrad import PCGrad -from .random import Random -from .sum import Sum -from .trimmed_mean import TrimmedMean -from .upgrad import UPGrad try: - from .cagrad import CAGrad + from ._cagrad import CAGrad except _OptionalDepsNotInstalledError: # The required dependencies are not installed pass try: - from .nash_mtl import NashMTL + from ._nash_mtl import NashMTL except _OptionalDepsNotInstalledError: # The required dependencies are not installed pass diff --git a/src/torchjd/aggregation/aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py similarity index 100% rename from src/torchjd/aggregation/aggregator_bases.py rename to src/torchjd/aggregation/_aggregator_bases.py diff --git a/src/torchjd/aggregation/aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py similarity index 95% rename from src/torchjd/aggregation/aligned_mtl.py rename to src/torchjd/aggregation/_aligned_mtl.py index bc6878ff..57ed62e7 100644 --- a/src/torchjd/aggregation/aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -28,15 +28,15 @@ import torch from torch import Tensor +from ._aggregator_bases import _GramianWeightedAggregator +from ._mean import _MeanWeighting from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import PSDMatrix, Weighting -from .aggregator_bases import _GramianWeightedAggregator -from .mean import _MeanWeighting class AlignedMTL(_GramianWeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` as defined in Algorithm 1 of + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of `Independent Component Alignment for Multi-Task Learning `_. diff --git a/src/torchjd/aggregation/cagrad.py b/src/torchjd/aggregation/_cagrad.py similarity index 96% rename from src/torchjd/aggregation/cagrad.py rename to src/torchjd/aggregation/_cagrad.py index 01a800f6..fb5ec6f1 100644 --- a/src/torchjd/aggregation/cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -8,14 +8,14 @@ import torch from torch import Tensor +from ._aggregator_bases import _GramianWeightedAggregator from ._utils.gramian import normalize from ._utils.non_differentiable import raise_non_differentiable_error -from .aggregator_bases import _GramianWeightedAggregator class CAGrad(_GramianWeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` as defined in Algorithm 1 of + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of `Conflict-Averse Gradient Descent for Multi-task Learning `_. diff --git a/src/torchjd/aggregation/config.py b/src/torchjd/aggregation/_config.py similarity index 94% rename from src/torchjd/aggregation/config.py rename to src/torchjd/aggregation/_config.py index bdd98968..754a4481 100644 --- a/src/torchjd/aggregation/config.py +++ b/src/torchjd/aggregation/_config.py @@ -28,15 +28,15 @@ import torch from torch import Tensor +from ._aggregator_bases import Aggregator +from ._sum import _SumWeighting from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -from .aggregator_bases import Aggregator -from .sum import _SumWeighting class ConFIG(Aggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` as defined in Equation 2 of `ConFIG: + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Equation 2 of `ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks `_. diff --git a/src/torchjd/aggregation/constant.py b/src/torchjd/aggregation/_constant.py similarity index 89% rename from src/torchjd/aggregation/constant.py rename to src/torchjd/aggregation/_constant.py index 7faa0318..27e2fcc2 100644 --- a/src/torchjd/aggregation/constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,14 +1,14 @@ from torch import Tensor +from ._aggregator_bases import _WeightedAggregator from ._utils.str import vector_to_str from ._weighting_bases import Matrix, Weighting -from .aggregator_bases import _WeightedAggregator class Constant(_WeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` that makes a linear combination of the - rows of the provided matrix, with constant, pre-determined weights. + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of + the rows of the provided matrix, with constant, pre-determined weights. :param weights: The weights associated to the rows of the input matrices. diff --git a/src/torchjd/aggregation/dualproj.py b/src/torchjd/aggregation/_dualproj.py similarity index 95% rename from src/torchjd/aggregation/dualproj.py rename to src/torchjd/aggregation/_dualproj.py index 1d5ddbfa..848e8219 100644 --- a/src/torchjd/aggregation/dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -2,18 +2,18 @@ from torch import Tensor +from ._aggregator_bases import _GramianWeightedAggregator +from ._mean import _MeanWeighting from ._utils.dual_cone import project_weights from ._utils.gramian import normalize, regularize from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import PSDMatrix, Weighting -from .aggregator_bases import _GramianWeightedAggregator -from .mean import _MeanWeighting class DualProj(_GramianWeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` that averages the rows of the input + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds to the solution to Equation 11 of `Gradient Episodic Memory for Continual Learning `_. diff --git a/src/torchjd/aggregation/graddrop.py b/src/torchjd/aggregation/_graddrop.py similarity index 95% rename from src/torchjd/aggregation/graddrop.py rename to src/torchjd/aggregation/_graddrop.py index 662ccaa2..826d6cc8 100644 --- a/src/torchjd/aggregation/graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -3,8 +3,8 @@ import torch from torch import Tensor +from ._aggregator_bases import Aggregator from ._utils.non_differentiable import raise_non_differentiable_error -from .aggregator_bases import Aggregator def _identity(P: Tensor) -> Tensor: @@ -13,7 +13,7 @@ def _identity(P: Tensor) -> Tensor: class GradDrop(Aggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` that applies the gradient combination + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that applies the gradient combination steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout `_. diff --git a/src/torchjd/aggregation/imtl_g.py b/src/torchjd/aggregation/_imtl_g.py similarity index 91% rename from src/torchjd/aggregation/imtl_g.py rename to src/torchjd/aggregation/_imtl_g.py index 2d0798e1..66e9da35 100644 --- a/src/torchjd/aggregation/imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -1,14 +1,14 @@ import torch from torch import Tensor +from ._aggregator_bases import _GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import PSDMatrix, Weighting -from .aggregator_bases import _GramianWeightedAggregator class IMTLG(_GramianWeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` generalizing the method described in + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` generalizing the method described in `Towards Impartial Multi-task Learning `_. This generalization, defined formally in `Jacobian Descent For Multi-Objective Optimization `_, supports matrices with some linearly dependant rows. diff --git a/src/torchjd/aggregation/krum.py b/src/torchjd/aggregation/_krum.py similarity index 94% rename from src/torchjd/aggregation/krum.py rename to src/torchjd/aggregation/_krum.py index b727e9d8..b3e1dcf7 100644 --- a/src/torchjd/aggregation/krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -2,14 +2,14 @@ from torch import Tensor from torch.nn import functional as F +from ._aggregator_bases import _GramianWeightedAggregator from ._weighting_bases import PSDMatrix, Weighting -from .aggregator_bases import _GramianWeightedAggregator class Krum(_GramianWeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` for adversarial federated learning, as - defined in `Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning, + as defined in `Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent `_. :param n_byzantine: The number of rows of the input matrix that can come from an adversarial diff --git a/src/torchjd/aggregation/mean.py b/src/torchjd/aggregation/_mean.py similarity index 87% rename from src/torchjd/aggregation/mean.py rename to src/torchjd/aggregation/_mean.py index 76cd4c22..35cd46f9 100644 --- a/src/torchjd/aggregation/mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -1,13 +1,13 @@ import torch from torch import Tensor +from ._aggregator_bases import _WeightedAggregator from ._weighting_bases import Matrix, Weighting -from .aggregator_bases import _WeightedAggregator class Mean(_WeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` that averages the rows of the input + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input matrices. .. admonition:: diff --git a/src/torchjd/aggregation/mgda.py b/src/torchjd/aggregation/_mgda.py similarity index 95% rename from src/torchjd/aggregation/mgda.py rename to src/torchjd/aggregation/_mgda.py index 5d6ddd38..9cb5b709 100644 --- a/src/torchjd/aggregation/mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -1,13 +1,13 @@ import torch from torch import Tensor +from ._aggregator_bases import _GramianWeightedAggregator from ._weighting_bases import PSDMatrix, Weighting -from .aggregator_bases import _GramianWeightedAggregator class MGDA(_GramianWeightedAggregator): r""" - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` performing the gradient aggregation + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` performing the gradient aggregation step of `Multiple-gradient descent algorithm (MGDA) for multiobjective optimization `_. The implementation is based on Algorithm 2 of `Multi-Task Learning as Multi-Objective Optimization diff --git a/src/torchjd/aggregation/nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py similarity index 98% rename from src/torchjd/aggregation/nash_mtl.py rename to src/torchjd/aggregation/_nash_mtl.py index 3c0b6210..91a20e62 100644 --- a/src/torchjd/aggregation/nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -34,13 +34,13 @@ from cvxpy import Expression from torch import Tensor +from ._aggregator_bases import _WeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -from .aggregator_bases import _WeightedAggregator class NashMTL(_WeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` as proposed in Algorithm 1 of + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. :param n_tasks: The number of tasks, corresponding to the number of rows in the provided diff --git a/src/torchjd/aggregation/pcgrad.py b/src/torchjd/aggregation/_pcgrad.py similarity index 90% rename from src/torchjd/aggregation/pcgrad.py rename to src/torchjd/aggregation/_pcgrad.py index 0cf6a97f..8bc85d2d 100644 --- a/src/torchjd/aggregation/pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -1,15 +1,15 @@ import torch from torch import Tensor +from ._aggregator_bases import _GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import PSDMatrix, Weighting -from .aggregator_bases import _GramianWeightedAggregator class PCGrad(_GramianWeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` as defined in algorithm 1 of `Gradient - Surgery for Multi-Task Learning `_. + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of + `Gradient Surgery for Multi-Task Learning `_. .. admonition:: Example diff --git a/src/torchjd/aggregation/random.py b/src/torchjd/aggregation/_random.py similarity index 90% rename from src/torchjd/aggregation/random.py rename to src/torchjd/aggregation/_random.py index d1f7ba64..b5a92b1c 100644 --- a/src/torchjd/aggregation/random.py +++ b/src/torchjd/aggregation/_random.py @@ -2,13 +2,13 @@ from torch import Tensor from torch.nn import functional as F +from ._aggregator_bases import _WeightedAggregator from ._weighting_bases import Matrix, Weighting -from .aggregator_bases import _WeightedAggregator class Random(_WeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` that computes a random combination of + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning `_. diff --git a/src/torchjd/aggregation/sum.py b/src/torchjd/aggregation/_sum.py similarity index 86% rename from src/torchjd/aggregation/sum.py rename to src/torchjd/aggregation/_sum.py index d6658eab..434cbe09 100644 --- a/src/torchjd/aggregation/sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -1,13 +1,13 @@ import torch from torch import Tensor +from ._aggregator_bases import _WeightedAggregator from ._weighting_bases import Matrix, Weighting -from .aggregator_bases import _WeightedAggregator class Sum(_WeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` that sums of the rows of the input + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that sums of the rows of the input matrices. .. admonition:: diff --git a/src/torchjd/aggregation/trimmed_mean.py b/src/torchjd/aggregation/_trimmed_mean.py similarity index 94% rename from src/torchjd/aggregation/trimmed_mean.py rename to src/torchjd/aggregation/_trimmed_mean.py index 25d80b1a..750dab74 100644 --- a/src/torchjd/aggregation/trimmed_mean.py +++ b/src/torchjd/aggregation/_trimmed_mean.py @@ -1,12 +1,12 @@ import torch from torch import Tensor -from .aggregator_bases import Aggregator +from ._aggregator_bases import Aggregator class TrimmedMean(Aggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` for adversarial federated learning, + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning, that trims the most extreme values of the input matrix, before averaging its rows, as defined in `Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates `_. diff --git a/src/torchjd/aggregation/upgrad.py b/src/torchjd/aggregation/_upgrad.py similarity index 95% rename from src/torchjd/aggregation/upgrad.py rename to src/torchjd/aggregation/_upgrad.py index 2abd02f3..b8939544 100644 --- a/src/torchjd/aggregation/upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -3,18 +3,18 @@ import torch from torch import Tensor +from ._aggregator_bases import _GramianWeightedAggregator +from ._mean import _MeanWeighting from ._utils.dual_cone import project_weights from ._utils.gramian import normalize, regularize from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import PSDMatrix, Weighting -from .aggregator_bases import _GramianWeightedAggregator -from .mean import _MeanWeighting class UPGrad(_GramianWeightedAggregator): """ - :class:`~torchjd.aggregation.aggregator_bases.Aggregator` that projects each row of the input + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that projects each row of the input matrix onto the dual cone of all rows of this matrix, and that combines the result, as proposed in `Jacobian Descent For Multi-Objective Optimization `_. diff --git a/src/torchjd/aggregation/_utils/pref_vector.py b/src/torchjd/aggregation/_utils/pref_vector.py index 4517cc16..980ea263 100644 --- a/src/torchjd/aggregation/_utils/pref_vector.py +++ b/src/torchjd/aggregation/_utils/pref_vector.py @@ -1,7 +1,7 @@ from torch import Tensor +from torchjd.aggregation._constant import _ConstantWeighting from torchjd.aggregation._weighting_bases import Matrix, Weighting -from torchjd.aggregation.constant import _ConstantWeighting from .str import vector_to_str diff --git a/src/torchjd/autojac/_transform/aggregate.py b/src/torchjd/autojac/_transform/aggregate.py index deed0ce9..cf56113f 100644 --- a/src/torchjd/autojac/_transform/aggregate.py +++ b/src/torchjd/autojac/_transform/aggregate.py @@ -94,7 +94,7 @@ def _aggregate_group( ) -> GradientVectors: """ Unites the jacobian matrices and aggregates them using an - :class:`~torchjd.aggregation.aggregator_bases.Aggregator`. Returns the obtained gradient + :class:`~torchjd.aggregation._aggregator_bases.Aggregator`. Returns the obtained gradient vectors. """ diff --git a/tests/unit/aggregation/_utils/test_pref_vector.py b/tests/unit/aggregation/_utils/test_pref_vector.py index 6db951f5..a1f81d06 100644 --- a/tests/unit/aggregation/_utils/test_pref_vector.py +++ b/tests/unit/aggregation/_utils/test_pref_vector.py @@ -5,8 +5,8 @@ from torch import Tensor from unit._utils import ExceptionContext +from torchjd.aggregation._mean import _MeanWeighting from torchjd.aggregation._utils.pref_vector import pref_vector_to_weighting -from torchjd.aggregation.mean import _MeanWeighting @mark.parametrize( diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index 336a72cc..b98c9c56 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -4,8 +4,8 @@ from torch.testing import assert_close from torchjd.aggregation import MGDA +from torchjd.aggregation._mgda import _MGDAWeighting from torchjd.aggregation._utils.gramian import compute_gramian -from torchjd.aggregation.mgda import _MGDAWeighting from ._asserts import ( assert_expected_structure, diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index 19c1e795..823a9ce1 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -4,10 +4,10 @@ from torch.testing import assert_close from torchjd.aggregation import PCGrad +from torchjd.aggregation._pcgrad import _PCGradWeighting +from torchjd.aggregation._sum import _SumWeighting +from torchjd.aggregation._upgrad import _UPGradWrapper from torchjd.aggregation._utils.gramian import compute_gramian -from torchjd.aggregation.pcgrad import _PCGradWeighting -from torchjd.aggregation.sum import _SumWeighting -from torchjd.aggregation.upgrad import _UPGradWrapper from ._asserts import assert_expected_structure, assert_non_differentiable from ._inputs import scaled_matrices, typical_matrices From 4caf5f34c3b987899d0605abaa99fc6b8f4166c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 24 May 2025 14:27:46 +0200 Subject: [PATCH 2/9] Make autojac function modules protected and update their rst files --- docs/source/docs/autojac/backward.rst | 5 +---- docs/source/docs/autojac/mtl_backward.rst | 5 +---- src/torchjd/autojac/__init__.py | 4 ++-- src/torchjd/autojac/{backward.py => _backward.py} | 0 src/torchjd/autojac/{mtl_backward.py => _mtl_backward.py} | 0 tests/unit/autojac/test_backward.py | 2 +- tests/unit/autojac/test_mtl_backward.py | 2 +- 7 files changed, 6 insertions(+), 12 deletions(-) rename src/torchjd/autojac/{backward.py => _backward.py} (100%) rename src/torchjd/autojac/{mtl_backward.py => _mtl_backward.py} (100%) diff --git a/docs/source/docs/autojac/backward.rst b/docs/source/docs/autojac/backward.rst index 0399961d..f7cc8ddd 100644 --- a/docs/source/docs/autojac/backward.rst +++ b/docs/source/docs/autojac/backward.rst @@ -3,7 +3,4 @@ backward ======== -.. automodule:: torchjd.autojac.backward - :members: - :undoc-members: - :exclude-members: +.. autofunction:: torchjd.backward diff --git a/docs/source/docs/autojac/mtl_backward.rst b/docs/source/docs/autojac/mtl_backward.rst index 963231c8..3ae6b0e9 100644 --- a/docs/source/docs/autojac/mtl_backward.rst +++ b/docs/source/docs/autojac/mtl_backward.rst @@ -3,7 +3,4 @@ mtl_backward ============ -.. automodule:: torchjd.autojac.mtl_backward - :members: - :undoc-members: - :exclude-members: +.. autofunction:: torchjd.mtl_backward diff --git a/src/torchjd/autojac/__init__.py b/src/torchjd/autojac/__init__.py index 54ecd87e..e2175c16 100644 --- a/src/torchjd/autojac/__init__.py +++ b/src/torchjd/autojac/__init__.py @@ -1,2 +1,2 @@ -from .backward import backward -from .mtl_backward import mtl_backward +from ._backward import backward +from ._mtl_backward import mtl_backward diff --git a/src/torchjd/autojac/backward.py b/src/torchjd/autojac/_backward.py similarity index 100% rename from src/torchjd/autojac/backward.py rename to src/torchjd/autojac/_backward.py diff --git a/src/torchjd/autojac/mtl_backward.py b/src/torchjd/autojac/_mtl_backward.py similarity index 100% rename from src/torchjd/autojac/mtl_backward.py rename to src/torchjd/autojac/_mtl_backward.py diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 63ad6b93..083e9ce7 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -5,8 +5,8 @@ from torchjd import backward from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad +from torchjd.autojac._backward import _create_transform from torchjd.autojac._transform.ordered_set import OrderedSet -from torchjd.autojac.backward import _create_transform def test_check_create_transform(): diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index f89febd5..8938fefc 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -5,8 +5,8 @@ from torchjd import mtl_backward from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad +from torchjd.autojac._mtl_backward import _create_transform from torchjd.autojac._transform.ordered_set import OrderedSet -from torchjd.autojac.mtl_backward import _create_transform def test_check_create_transform(): From cfb3752e7f26e0f6f2c0357f167afe2bec0c433f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 24 May 2025 14:32:05 +0200 Subject: [PATCH 3/9] Add changelog entry --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f9bd82b8..11cd122a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,12 @@ changes that do not affect the user. TorchJD. Users of these aggregators will have to use `pip install torchjd[cagrad]`, `pip install torchjd[nash_mtl]` or `pip install torchjd[full]` to install TorchJD alongside those dependencies. This should make TorchJD more lightweight. +- **BREAKING**: Made the aggregator modules and the modules of `backward` and `mtl_backward` + protected. They should now always be imported via their package (e.g. + `from torchjd.aggregation.upgrad import UPGrad` should become + `from torchjd.aggregation import UPGrad`, and + `from torchjd.autojac.mtl_backward import mtl_backward` should become + `from torchjd import mtl_backward`). ### Fixed From 5a023bb4ec6228650e80565876126bdd69bcd292 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 24 May 2025 15:51:09 +0200 Subject: [PATCH 4/9] Make autojac package protected --- src/torchjd/__init__.py | 2 +- src/torchjd/{autojac => _autojac}/__init__.py | 0 src/torchjd/{autojac => _autojac}/_backward.py | 0 src/torchjd/{autojac => _autojac}/_mtl_backward.py | 0 src/torchjd/{autojac => _autojac}/_transform/__init__.py | 0 .../{autojac => _autojac}/_transform/_differentiate.py | 0 .../{autojac => _autojac}/_transform/_materialize.py | 0 .../{autojac => _autojac}/_transform/accumulate.py | 0 src/torchjd/{autojac => _autojac}/_transform/aggregate.py | 0 src/torchjd/{autojac => _autojac}/_transform/base.py | 0 .../{autojac => _autojac}/_transform/diagonalize.py | 0 src/torchjd/{autojac => _autojac}/_transform/grad.py | 0 src/torchjd/{autojac => _autojac}/_transform/init.py | 0 src/torchjd/{autojac => _autojac}/_transform/jac.py | 0 .../{autojac => _autojac}/_transform/ordered_set.py | 0 src/torchjd/{autojac => _autojac}/_transform/select.py | 0 src/torchjd/{autojac => _autojac}/_transform/stack.py | 0 .../{autojac => _autojac}/_transform/tensor_dict.py | 0 src/torchjd/{autojac => _autojac}/_utils.py | 0 tests/unit/autojac/_transform/test_accumulate.py | 2 +- tests/unit/autojac/_transform/test_aggregate.py | 8 ++++---- tests/unit/autojac/_transform/test_base.py | 4 ++-- tests/unit/autojac/_transform/test_diagonalize.py | 4 ++-- tests/unit/autojac/_transform/test_grad.py | 4 ++-- tests/unit/autojac/_transform/test_init.py | 2 +- tests/unit/autojac/_transform/test_interactions.py | 4 ++-- tests/unit/autojac/_transform/test_jac.py | 4 ++-- tests/unit/autojac/_transform/test_select.py | 2 +- tests/unit/autojac/_transform/test_stack.py | 2 +- tests/unit/autojac/_transform/test_tensor_dict.py | 4 ++-- tests/unit/autojac/test_backward.py | 4 ++-- tests/unit/autojac/test_mtl_backward.py | 4 ++-- tests/unit/autojac/test_utils.py | 2 +- 33 files changed, 26 insertions(+), 26 deletions(-) rename src/torchjd/{autojac => _autojac}/__init__.py (100%) rename src/torchjd/{autojac => _autojac}/_backward.py (100%) rename src/torchjd/{autojac => _autojac}/_mtl_backward.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/__init__.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/_differentiate.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/_materialize.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/accumulate.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/aggregate.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/base.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/diagonalize.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/grad.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/init.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/jac.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/ordered_set.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/select.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/stack.py (100%) rename src/torchjd/{autojac => _autojac}/_transform/tensor_dict.py (100%) rename src/torchjd/{autojac => _autojac}/_utils.py (100%) diff --git a/src/torchjd/__init__.py b/src/torchjd/__init__.py index ffb74ef6..5f4a6cc9 100644 --- a/src/torchjd/__init__.py +++ b/src/torchjd/__init__.py @@ -1 +1 @@ -from torchjd.autojac import backward, mtl_backward +from torchjd._autojac import backward, mtl_backward diff --git a/src/torchjd/autojac/__init__.py b/src/torchjd/_autojac/__init__.py similarity index 100% rename from src/torchjd/autojac/__init__.py rename to src/torchjd/_autojac/__init__.py diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/_autojac/_backward.py similarity index 100% rename from src/torchjd/autojac/_backward.py rename to src/torchjd/_autojac/_backward.py diff --git a/src/torchjd/autojac/_mtl_backward.py b/src/torchjd/_autojac/_mtl_backward.py similarity index 100% rename from src/torchjd/autojac/_mtl_backward.py rename to src/torchjd/_autojac/_mtl_backward.py diff --git a/src/torchjd/autojac/_transform/__init__.py b/src/torchjd/_autojac/_transform/__init__.py similarity index 100% rename from src/torchjd/autojac/_transform/__init__.py rename to src/torchjd/_autojac/_transform/__init__.py diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/_autojac/_transform/_differentiate.py similarity index 100% rename from src/torchjd/autojac/_transform/_differentiate.py rename to src/torchjd/_autojac/_transform/_differentiate.py diff --git a/src/torchjd/autojac/_transform/_materialize.py b/src/torchjd/_autojac/_transform/_materialize.py similarity index 100% rename from src/torchjd/autojac/_transform/_materialize.py rename to src/torchjd/_autojac/_transform/_materialize.py diff --git a/src/torchjd/autojac/_transform/accumulate.py b/src/torchjd/_autojac/_transform/accumulate.py similarity index 100% rename from src/torchjd/autojac/_transform/accumulate.py rename to src/torchjd/_autojac/_transform/accumulate.py diff --git a/src/torchjd/autojac/_transform/aggregate.py b/src/torchjd/_autojac/_transform/aggregate.py similarity index 100% rename from src/torchjd/autojac/_transform/aggregate.py rename to src/torchjd/_autojac/_transform/aggregate.py diff --git a/src/torchjd/autojac/_transform/base.py b/src/torchjd/_autojac/_transform/base.py similarity index 100% rename from src/torchjd/autojac/_transform/base.py rename to src/torchjd/_autojac/_transform/base.py diff --git a/src/torchjd/autojac/_transform/diagonalize.py b/src/torchjd/_autojac/_transform/diagonalize.py similarity index 100% rename from src/torchjd/autojac/_transform/diagonalize.py rename to src/torchjd/_autojac/_transform/diagonalize.py diff --git a/src/torchjd/autojac/_transform/grad.py b/src/torchjd/_autojac/_transform/grad.py similarity index 100% rename from src/torchjd/autojac/_transform/grad.py rename to src/torchjd/_autojac/_transform/grad.py diff --git a/src/torchjd/autojac/_transform/init.py b/src/torchjd/_autojac/_transform/init.py similarity index 100% rename from src/torchjd/autojac/_transform/init.py rename to src/torchjd/_autojac/_transform/init.py diff --git a/src/torchjd/autojac/_transform/jac.py b/src/torchjd/_autojac/_transform/jac.py similarity index 100% rename from src/torchjd/autojac/_transform/jac.py rename to src/torchjd/_autojac/_transform/jac.py diff --git a/src/torchjd/autojac/_transform/ordered_set.py b/src/torchjd/_autojac/_transform/ordered_set.py similarity index 100% rename from src/torchjd/autojac/_transform/ordered_set.py rename to src/torchjd/_autojac/_transform/ordered_set.py diff --git a/src/torchjd/autojac/_transform/select.py b/src/torchjd/_autojac/_transform/select.py similarity index 100% rename from src/torchjd/autojac/_transform/select.py rename to src/torchjd/_autojac/_transform/select.py diff --git a/src/torchjd/autojac/_transform/stack.py b/src/torchjd/_autojac/_transform/stack.py similarity index 100% rename from src/torchjd/autojac/_transform/stack.py rename to src/torchjd/_autojac/_transform/stack.py diff --git a/src/torchjd/autojac/_transform/tensor_dict.py b/src/torchjd/_autojac/_transform/tensor_dict.py similarity index 100% rename from src/torchjd/autojac/_transform/tensor_dict.py rename to src/torchjd/_autojac/_transform/tensor_dict.py diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/_autojac/_utils.py similarity index 100% rename from src/torchjd/autojac/_utils.py rename to src/torchjd/_autojac/_utils.py diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index 4fd3a303..d6bf5537 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -1,7 +1,7 @@ import torch from pytest import mark, raises -from torchjd.autojac._transform import Accumulate, Gradients +from torchjd._autojac._transform import Accumulate, Gradients from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_aggregate.py b/tests/unit/autojac/_transform/test_aggregate.py index 8a102209..79b0ad52 100644 --- a/tests/unit/autojac/_transform/test_aggregate.py +++ b/tests/unit/autojac/_transform/test_aggregate.py @@ -6,15 +6,15 @@ from torch import Tensor from unit.conftest import DEVICE -from torchjd.aggregation import Random -from torchjd.autojac._transform import ( +from torchjd._autojac._transform import ( GradientVectors, JacobianMatrices, Jacobians, RequirementError, ) -from torchjd.autojac._transform.aggregate import _AggregateMatrices, _Matrixify, _Reshape -from torchjd.autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform.aggregate import _AggregateMatrices, _Matrixify, _Reshape +from torchjd._autojac._transform.ordered_set import OrderedSet +from torchjd.aggregation import Random from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_base.py b/tests/unit/autojac/_transform/test_base.py index a855a5c3..c48bb4ba 100644 --- a/tests/unit/autojac/_transform/test_base.py +++ b/tests/unit/autojac/_transform/test_base.py @@ -4,8 +4,8 @@ from pytest import raises from torch import Tensor -from torchjd.autojac._transform.base import Conjunction, RequirementError, Transform -from torchjd.autojac._transform.tensor_dict import _B, _C, TensorDict +from torchjd._autojac._transform.base import Conjunction, RequirementError, Transform +from torchjd._autojac._transform.tensor_dict import _B, _C, TensorDict class FakeTransform(Transform[_B, _C]): diff --git a/tests/unit/autojac/_transform/test_diagonalize.py b/tests/unit/autojac/_transform/test_diagonalize.py index 95986948..7f5c4df2 100644 --- a/tests/unit/autojac/_transform/test_diagonalize.py +++ b/tests/unit/autojac/_transform/test_diagonalize.py @@ -1,8 +1,8 @@ import torch from pytest import raises -from torchjd.autojac._transform import Diagonalize, Gradients, RequirementError -from torchjd.autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform import Diagonalize, Gradients, RequirementError +from torchjd._autojac._transform.ordered_set import OrderedSet from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_grad.py b/tests/unit/autojac/_transform/test_grad.py index 5bf4f4bb..378e2988 100644 --- a/tests/unit/autojac/_transform/test_grad.py +++ b/tests/unit/autojac/_transform/test_grad.py @@ -1,8 +1,8 @@ import torch from pytest import raises -from torchjd.autojac._transform import Grad, Gradients, RequirementError -from torchjd.autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform import Grad, Gradients, RequirementError +from torchjd._autojac._transform.ordered_set import OrderedSet from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_init.py b/tests/unit/autojac/_transform/test_init.py index 0d5eda30..b9154c11 100644 --- a/tests/unit/autojac/_transform/test_init.py +++ b/tests/unit/autojac/_transform/test_init.py @@ -1,7 +1,7 @@ import torch from pytest import raises -from torchjd.autojac._transform import EmptyTensorDict, Init, RequirementError +from torchjd._autojac._transform import EmptyTensorDict, Init, RequirementError from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index 65f70605..3d4a604a 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -2,7 +2,7 @@ from pytest import raises from torch.testing import assert_close -from torchjd.autojac._transform import ( +from torchjd._autojac._transform import ( Accumulate, Conjunction, Diagonalize, @@ -17,7 +17,7 @@ Stack, TensorDict, ) -from torchjd.autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform.ordered_set import OrderedSet from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_jac.py b/tests/unit/autojac/_transform/test_jac.py index 274ebe1d..1988cab7 100644 --- a/tests/unit/autojac/_transform/test_jac.py +++ b/tests/unit/autojac/_transform/test_jac.py @@ -1,8 +1,8 @@ import torch from pytest import mark, raises -from torchjd.autojac._transform import Jac, Jacobians, RequirementError -from torchjd.autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform import Jac, Jacobians, RequirementError +from torchjd._autojac._transform.ordered_set import OrderedSet from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_select.py b/tests/unit/autojac/_transform/test_select.py index 051f31e1..60a476f1 100644 --- a/tests/unit/autojac/_transform/test_select.py +++ b/tests/unit/autojac/_transform/test_select.py @@ -1,7 +1,7 @@ import torch from pytest import raises -from torchjd.autojac._transform import RequirementError, Select, TensorDict +from torchjd._autojac._transform import RequirementError, Select, TensorDict from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_stack.py b/tests/unit/autojac/_transform/test_stack.py index 2fa94316..efdfeaa7 100644 --- a/tests/unit/autojac/_transform/test_stack.py +++ b/tests/unit/autojac/_transform/test_stack.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from torchjd.autojac._transform import EmptyTensorDict, Gradients, Stack, Transform +from torchjd._autojac._transform import EmptyTensorDict, Gradients, Stack, Transform from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_tensor_dict.py b/tests/unit/autojac/_transform/test_tensor_dict.py index 70707075..36fb7a62 100644 --- a/tests/unit/autojac/_transform/test_tensor_dict.py +++ b/tests/unit/autojac/_transform/test_tensor_dict.py @@ -5,7 +5,7 @@ from torch import Tensor from unit._utils import ExceptionContext -from torchjd.autojac._transform import ( +from torchjd._autojac._transform import ( EmptyTensorDict, Gradients, GradientVectors, @@ -13,7 +13,7 @@ Jacobians, TensorDict, ) -from torchjd.autojac._transform.tensor_dict import _least_common_ancestor +from torchjd._autojac._transform.tensor_dict import _least_common_ancestor _key_shapes = [[], [1], [2, 3]] diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 083e9ce7..a6c13635 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -4,9 +4,9 @@ from torch.testing import assert_close from torchjd import backward +from torchjd._autojac._backward import _create_transform +from torchjd._autojac._transform.ordered_set import OrderedSet from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad -from torchjd.autojac._backward import _create_transform -from torchjd.autojac._transform.ordered_set import OrderedSet def test_check_create_transform(): diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 8938fefc..4bc158e6 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -4,9 +4,9 @@ from torch.testing import assert_close from torchjd import mtl_backward +from torchjd._autojac._mtl_backward import _create_transform +from torchjd._autojac._transform.ordered_set import OrderedSet from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad -from torchjd.autojac._mtl_backward import _create_transform -from torchjd.autojac._transform.ordered_set import OrderedSet def test_check_create_transform(): diff --git a/tests/unit/autojac/test_utils.py b/tests/unit/autojac/test_utils.py index 24443d60..38e08128 100644 --- a/tests/unit/autojac/test_utils.py +++ b/tests/unit/autojac/test_utils.py @@ -2,7 +2,7 @@ from pytest import mark, raises from torch.nn import Linear, MSELoss, ReLU, Sequential -from torchjd.autojac._utils import get_leaf_tensors +from torchjd._autojac._utils import get_leaf_tensors def test_simple_get_leaf_tensors(): From a35c6e232d555ddddec6542757f1796eeae1ad29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 24 May 2025 15:52:38 +0200 Subject: [PATCH 5/9] Make relative import in torchjd/__init__.py --- src/torchjd/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/__init__.py b/src/torchjd/__init__.py index 5f4a6cc9..78f1528d 100644 --- a/src/torchjd/__init__.py +++ b/src/torchjd/__init__.py @@ -1 +1 @@ -from torchjd._autojac import backward, mtl_backward +from ._autojac import backward, mtl_backward From 3fe83616e72e9f9f0742d4cf1839e412b2d56d71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 24 May 2025 15:54:18 +0200 Subject: [PATCH 6/9] Update changelog --- CHANGELOG.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 11cd122a..2cd0256e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,10 +14,9 @@ changes that do not affect the user. TorchJD. Users of these aggregators will have to use `pip install torchjd[cagrad]`, `pip install torchjd[nash_mtl]` or `pip install torchjd[full]` to install TorchJD alongside those dependencies. This should make TorchJD more lightweight. -- **BREAKING**: Made the aggregator modules and the modules of `backward` and `mtl_backward` - protected. They should now always be imported via their package (e.g. - `from torchjd.aggregation.upgrad import UPGrad` should become - `from torchjd.aggregation import UPGrad`, and +- **BREAKING**: Made the aggregator modules and the autojac package protected. They should now + always be imported via their package (e.g. `from torchjd.aggregation.upgrad import UPGrad` should + become `from torchjd.aggregation import UPGrad`, and `from torchjd.autojac.mtl_backward import mtl_backward` should become `from torchjd import mtl_backward`). From 014e47aeb02c5ae71c167c1da3141f951df196df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 24 May 2025 16:07:12 +0200 Subject: [PATCH 7/9] Make WeightedAggregator and GramianWeightedAggregator public --- src/torchjd/aggregation/_aggregator_bases.py | 4 ++-- src/torchjd/aggregation/_aligned_mtl.py | 4 ++-- src/torchjd/aggregation/_cagrad.py | 4 ++-- src/torchjd/aggregation/_constant.py | 4 ++-- src/torchjd/aggregation/_dualproj.py | 4 ++-- src/torchjd/aggregation/_imtl_g.py | 4 ++-- src/torchjd/aggregation/_krum.py | 4 ++-- src/torchjd/aggregation/_mean.py | 4 ++-- src/torchjd/aggregation/_mgda.py | 4 ++-- src/torchjd/aggregation/_nash_mtl.py | 4 ++-- src/torchjd/aggregation/_pcgrad.py | 4 ++-- src/torchjd/aggregation/_random.py | 4 ++-- src/torchjd/aggregation/_sum.py | 4 ++-- src/torchjd/aggregation/_upgrad.py | 4 ++-- 14 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 1a7f43ae..8daebeb6 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -45,7 +45,7 @@ def __str__(self) -> str: return f"{self.__class__.__name__}" -class _WeightedAggregator(Aggregator): +class WeightedAggregator(Aggregator): """ Aggregator that combines the rows of the input jacobian matrix with weights given by applying a Weighting to it. @@ -76,7 +76,7 @@ def forward(self, matrix: Tensor) -> Tensor: return vector -class _GramianWeightedAggregator(_WeightedAggregator): +class GramianWeightedAggregator(WeightedAggregator): """ WeightedAggregator that computes the gramian of the input jacobian matrix before applying a Weighting to it. diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index 57ed62e7..4d554f11 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -28,13 +28,13 @@ import torch from torch import Tensor -from ._aggregator_bases import _GramianWeightedAggregator +from ._aggregator_bases import GramianWeightedAggregator from ._mean import _MeanWeighting from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import PSDMatrix, Weighting -class AlignedMTL(_GramianWeightedAggregator): +class AlignedMTL(GramianWeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of `Independent Component Alignment for Multi-Task Learning diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index fb5ec6f1..87f65d81 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -8,12 +8,12 @@ import torch from torch import Tensor -from ._aggregator_bases import _GramianWeightedAggregator +from ._aggregator_bases import GramianWeightedAggregator from ._utils.gramian import normalize from ._utils.non_differentiable import raise_non_differentiable_error -class CAGrad(_GramianWeightedAggregator): +class CAGrad(GramianWeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of `Conflict-Averse Gradient Descent for Multi-task Learning diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 27e2fcc2..8658e1b6 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,11 +1,11 @@ from torch import Tensor -from ._aggregator_bases import _WeightedAggregator +from ._aggregator_bases import WeightedAggregator from ._utils.str import vector_to_str from ._weighting_bases import Matrix, Weighting -class Constant(_WeightedAggregator): +class Constant(WeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of the rows of the provided matrix, with constant, pre-determined weights. diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 848e8219..58a71676 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -2,7 +2,7 @@ from torch import Tensor -from ._aggregator_bases import _GramianWeightedAggregator +from ._aggregator_bases import GramianWeightedAggregator from ._mean import _MeanWeighting from ._utils.dual_cone import project_weights from ._utils.gramian import normalize, regularize @@ -11,7 +11,7 @@ from ._weighting_bases import PSDMatrix, Weighting -class DualProj(_GramianWeightedAggregator): +class DualProj(GramianWeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 66e9da35..8cb8e774 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -1,12 +1,12 @@ import torch from torch import Tensor -from ._aggregator_bases import _GramianWeightedAggregator +from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import PSDMatrix, Weighting -class IMTLG(_GramianWeightedAggregator): +class IMTLG(GramianWeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` generalizing the method described in `Towards Impartial Multi-task Learning `_. diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index b3e1dcf7..97f5675a 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -2,11 +2,11 @@ from torch import Tensor from torch.nn import functional as F -from ._aggregator_bases import _GramianWeightedAggregator +from ._aggregator_bases import GramianWeightedAggregator from ._weighting_bases import PSDMatrix, Weighting -class Krum(_GramianWeightedAggregator): +class Krum(GramianWeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning, as defined in `Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 35cd46f9..e1aeef84 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -1,11 +1,11 @@ import torch from torch import Tensor -from ._aggregator_bases import _WeightedAggregator +from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Matrix, Weighting -class Mean(_WeightedAggregator): +class Mean(WeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input matrices. diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index 9cb5b709..edac91e5 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -1,11 +1,11 @@ import torch from torch import Tensor -from ._aggregator_bases import _GramianWeightedAggregator +from ._aggregator_bases import GramianWeightedAggregator from ._weighting_bases import PSDMatrix, Weighting -class MGDA(_GramianWeightedAggregator): +class MGDA(GramianWeightedAggregator): r""" :class:`~torchjd.aggregation._aggregator_bases.Aggregator` performing the gradient aggregation step of `Multiple-gradient descent algorithm (MGDA) for multiobjective optimization diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 91a20e62..d1f7aeea 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -34,11 +34,11 @@ from cvxpy import Expression from torch import Tensor -from ._aggregator_bases import _WeightedAggregator +from ._aggregator_bases import WeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error -class NashMTL(_WeightedAggregator): +class NashMTL(WeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 8bc85d2d..509e298f 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -1,12 +1,12 @@ import torch from torch import Tensor -from ._aggregator_bases import _GramianWeightedAggregator +from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import PSDMatrix, Weighting -class PCGrad(_GramianWeightedAggregator): +class PCGrad(GramianWeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of `Gradient Surgery for Multi-Task Learning `_. diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index b5a92b1c..dc8c113e 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -2,11 +2,11 @@ from torch import Tensor from torch.nn import functional as F -from ._aggregator_bases import _WeightedAggregator +from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Matrix, Weighting -class Random(_WeightedAggregator): +class Random(WeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 434cbe09..92ce53e5 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -1,11 +1,11 @@ import torch from torch import Tensor -from ._aggregator_bases import _WeightedAggregator +from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Matrix, Weighting -class Sum(_WeightedAggregator): +class Sum(WeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that sums of the rows of the input matrices. diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index b8939544..96a3cccc 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from ._aggregator_bases import _GramianWeightedAggregator +from ._aggregator_bases import GramianWeightedAggregator from ._mean import _MeanWeighting from ._utils.dual_cone import project_weights from ._utils.gramian import normalize, regularize @@ -12,7 +12,7 @@ from ._weighting_bases import PSDMatrix, Weighting -class UPGrad(_GramianWeightedAggregator): +class UPGrad(GramianWeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that projects each row of the input matrix onto the dual cone of all rows of this matrix, and that combines the result, as proposed From a63c822821958c35ca7c3fd27f5658d556ca7050 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 24 May 2025 16:14:08 +0200 Subject: [PATCH 8/9] Make transforms protected --- src/torchjd/_autojac/_backward.py | 2 +- src/torchjd/_autojac/_mtl_backward.py | 2 +- src/torchjd/_autojac/_transform/__init__.py | 20 +++++++++---------- .../{accumulate.py => _accumulate.py} | 4 ++-- .../{aggregate.py => _aggregate.py} | 6 +++--- .../_autojac/_transform/{base.py => _base.py} | 2 +- .../{diagonalize.py => _diagonalize.py} | 6 +++--- .../_autojac/_transform/_differentiate.py | 6 +++--- .../_autojac/_transform/{grad.py => _grad.py} | 4 ++-- .../_autojac/_transform/{init.py => _init.py} | 4 ++-- .../_autojac/_transform/{jac.py => _jac.py} | 4 ++-- .../{ordered_set.py => _ordered_set.py} | 0 .../_transform/{select.py => _select.py} | 4 ++-- .../_transform/{stack.py => _stack.py} | 4 ++-- .../{tensor_dict.py => _tensor_dict.py} | 0 src/torchjd/_autojac/_utils.py | 2 +- .../unit/autojac/_transform/test_aggregate.py | 4 ++-- tests/unit/autojac/_transform/test_base.py | 4 ++-- .../autojac/_transform/test_diagonalize.py | 2 +- tests/unit/autojac/_transform/test_grad.py | 2 +- .../autojac/_transform/test_interactions.py | 2 +- tests/unit/autojac/_transform/test_jac.py | 2 +- .../autojac/_transform/test_tensor_dict.py | 2 +- tests/unit/autojac/test_backward.py | 2 +- tests/unit/autojac/test_mtl_backward.py | 2 +- 25 files changed, 46 insertions(+), 46 deletions(-) rename src/torchjd/_autojac/_transform/{accumulate.py => _accumulate.py} (95%) rename src/torchjd/_autojac/_transform/{aggregate.py => _aggregate.py} (97%) rename src/torchjd/_autojac/_transform/{base.py => _base.py} (97%) rename src/torchjd/_autojac/_transform/{diagonalize.py => _diagonalize.py} (95%) rename src/torchjd/_autojac/_transform/{grad.py => _grad.py} (97%) rename src/torchjd/_autojac/_transform/{init.py => _init.py} (88%) rename src/torchjd/_autojac/_transform/{jac.py => _jac.py} (98%) rename src/torchjd/_autojac/_transform/{ordered_set.py => _ordered_set.py} (100%) rename src/torchjd/_autojac/_transform/{select.py => _select.py} (91%) rename src/torchjd/_autojac/_transform/{stack.py => _stack.py} (96%) rename src/torchjd/_autojac/_transform/{tensor_dict.py => _tensor_dict.py} (100%) diff --git a/src/torchjd/_autojac/_backward.py b/src/torchjd/_autojac/_backward.py index 2fb7ee75..b15f148f 100644 --- a/src/torchjd/_autojac/_backward.py +++ b/src/torchjd/_autojac/_backward.py @@ -5,7 +5,7 @@ from torchjd.aggregation import Aggregator from ._transform import Accumulate, Aggregate, Diagonalize, EmptyTensorDict, Init, Jac, Transform -from ._transform.ordered_set import OrderedSet +from ._transform._ordered_set import OrderedSet from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors diff --git a/src/torchjd/_autojac/_mtl_backward.py b/src/torchjd/_autojac/_mtl_backward.py index ae50164f..b74297aa 100644 --- a/src/torchjd/_autojac/_mtl_backward.py +++ b/src/torchjd/_autojac/_mtl_backward.py @@ -16,7 +16,7 @@ Stack, Transform, ) -from ._transform.ordered_set import OrderedSet +from ._transform._ordered_set import OrderedSet from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors diff --git a/src/torchjd/_autojac/_transform/__init__.py b/src/torchjd/_autojac/_transform/__init__.py index 1721675b..ecab986e 100644 --- a/src/torchjd/_autojac/_transform/__init__.py +++ b/src/torchjd/_autojac/_transform/__init__.py @@ -1,13 +1,13 @@ -from .accumulate import Accumulate -from .aggregate import Aggregate -from .base import Composition, Conjunction, RequirementError, Transform -from .diagonalize import Diagonalize -from .grad import Grad -from .init import Init -from .jac import Jac -from .select import Select -from .stack import Stack -from .tensor_dict import ( +from ._accumulate import Accumulate +from ._aggregate import Aggregate +from ._base import Composition, Conjunction, RequirementError, Transform +from ._diagonalize import Diagonalize +from ._grad import Grad +from ._init import Init +from ._jac import Jac +from ._select import Select +from ._stack import Stack +from ._tensor_dict import ( EmptyTensorDict, Gradients, GradientVectors, diff --git a/src/torchjd/_autojac/_transform/accumulate.py b/src/torchjd/_autojac/_transform/_accumulate.py similarity index 95% rename from src/torchjd/_autojac/_transform/accumulate.py rename to src/torchjd/_autojac/_transform/_accumulate.py index 6950971e..32e45b0d 100644 --- a/src/torchjd/_autojac/_transform/accumulate.py +++ b/src/torchjd/_autojac/_transform/_accumulate.py @@ -1,7 +1,7 @@ from torch import Tensor -from .base import Transform -from .tensor_dict import EmptyTensorDict, Gradients +from ._base import Transform +from ._tensor_dict import EmptyTensorDict, Gradients class Accumulate(Transform[Gradients, EmptyTensorDict]): diff --git a/src/torchjd/_autojac/_transform/aggregate.py b/src/torchjd/_autojac/_transform/_aggregate.py similarity index 97% rename from src/torchjd/_autojac/_transform/aggregate.py rename to src/torchjd/_autojac/_transform/_aggregate.py index cf56113f..f9719c76 100644 --- a/src/torchjd/_autojac/_transform/aggregate.py +++ b/src/torchjd/_autojac/_transform/_aggregate.py @@ -6,9 +6,9 @@ from torchjd.aggregation import Aggregator -from .base import RequirementError, Transform -from .ordered_set import OrderedSet -from .tensor_dict import EmptyTensorDict, Gradients, GradientVectors, JacobianMatrices, Jacobians +from ._base import RequirementError, Transform +from ._ordered_set import OrderedSet +from ._tensor_dict import EmptyTensorDict, Gradients, GradientVectors, JacobianMatrices, Jacobians _KeyType = TypeVar("_KeyType", bound=Hashable) _ValueType = TypeVar("_ValueType") diff --git a/src/torchjd/_autojac/_transform/base.py b/src/torchjd/_autojac/_transform/_base.py similarity index 97% rename from src/torchjd/_autojac/_transform/base.py rename to src/torchjd/_autojac/_transform/_base.py index c3973fe3..e1c8fd7e 100644 --- a/src/torchjd/_autojac/_transform/base.py +++ b/src/torchjd/_autojac/_transform/_base.py @@ -5,7 +5,7 @@ from torch import Tensor -from .tensor_dict import _A, _B, _C, EmptyTensorDict, _least_common_ancestor +from ._tensor_dict import _A, _B, _C, EmptyTensorDict, _least_common_ancestor class RequirementError(ValueError): diff --git a/src/torchjd/_autojac/_transform/diagonalize.py b/src/torchjd/_autojac/_transform/_diagonalize.py similarity index 95% rename from src/torchjd/_autojac/_transform/diagonalize.py rename to src/torchjd/_autojac/_transform/_diagonalize.py index 067d39c3..40ff4c7e 100644 --- a/src/torchjd/_autojac/_transform/diagonalize.py +++ b/src/torchjd/_autojac/_transform/_diagonalize.py @@ -1,9 +1,9 @@ import torch from torch import Tensor -from .base import RequirementError, Transform -from .ordered_set import OrderedSet -from .tensor_dict import Gradients, Jacobians +from ._base import RequirementError, Transform +from ._ordered_set import OrderedSet +from ._tensor_dict import Gradients, Jacobians class Diagonalize(Transform[Gradients, Jacobians]): diff --git a/src/torchjd/_autojac/_transform/_differentiate.py b/src/torchjd/_autojac/_transform/_differentiate.py index c7d783b7..ca3ca644 100644 --- a/src/torchjd/_autojac/_transform/_differentiate.py +++ b/src/torchjd/_autojac/_transform/_differentiate.py @@ -3,9 +3,9 @@ from torch import Tensor -from .base import RequirementError, Transform -from .ordered_set import OrderedSet -from .tensor_dict import _A +from ._base import RequirementError, Transform +from ._ordered_set import OrderedSet +from ._tensor_dict import _A class Differentiate(Transform[_A, _A], ABC): diff --git a/src/torchjd/_autojac/_transform/grad.py b/src/torchjd/_autojac/_transform/_grad.py similarity index 97% rename from src/torchjd/_autojac/_transform/grad.py rename to src/torchjd/_autojac/_transform/_grad.py index 2bcb96cb..0600e65e 100644 --- a/src/torchjd/_autojac/_transform/grad.py +++ b/src/torchjd/_autojac/_transform/_grad.py @@ -5,8 +5,8 @@ from ._differentiate import Differentiate from ._materialize import materialize -from .ordered_set import OrderedSet -from .tensor_dict import Gradients +from ._ordered_set import OrderedSet +from ._tensor_dict import Gradients class Grad(Differentiate[Gradients]): diff --git a/src/torchjd/_autojac/_transform/init.py b/src/torchjd/_autojac/_transform/_init.py similarity index 88% rename from src/torchjd/_autojac/_transform/init.py rename to src/torchjd/_autojac/_transform/_init.py index 9f1b3d44..86520011 100644 --- a/src/torchjd/_autojac/_transform/init.py +++ b/src/torchjd/_autojac/_transform/_init.py @@ -3,8 +3,8 @@ import torch from torch import Tensor -from .base import RequirementError, Transform -from .tensor_dict import EmptyTensorDict, Gradients +from ._base import RequirementError, Transform +from ._tensor_dict import EmptyTensorDict, Gradients class Init(Transform[EmptyTensorDict, Gradients]): diff --git a/src/torchjd/_autojac/_transform/jac.py b/src/torchjd/_autojac/_transform/_jac.py similarity index 98% rename from src/torchjd/_autojac/_transform/jac.py rename to src/torchjd/_autojac/_transform/_jac.py index 14a444c1..82f1d856 100644 --- a/src/torchjd/_autojac/_transform/jac.py +++ b/src/torchjd/_autojac/_transform/_jac.py @@ -8,8 +8,8 @@ from ._differentiate import Differentiate from ._materialize import materialize -from .ordered_set import OrderedSet -from .tensor_dict import Jacobians +from ._ordered_set import OrderedSet +from ._tensor_dict import Jacobians class Jac(Differentiate[Jacobians]): diff --git a/src/torchjd/_autojac/_transform/ordered_set.py b/src/torchjd/_autojac/_transform/_ordered_set.py similarity index 100% rename from src/torchjd/_autojac/_transform/ordered_set.py rename to src/torchjd/_autojac/_transform/_ordered_set.py diff --git a/src/torchjd/_autojac/_transform/select.py b/src/torchjd/_autojac/_transform/_select.py similarity index 91% rename from src/torchjd/_autojac/_transform/select.py rename to src/torchjd/_autojac/_transform/_select.py index be532bf2..424353fe 100644 --- a/src/torchjd/_autojac/_transform/select.py +++ b/src/torchjd/_autojac/_transform/_select.py @@ -2,8 +2,8 @@ from torch import Tensor -from .base import RequirementError, Transform -from .tensor_dict import _A +from ._base import RequirementError, Transform +from ._tensor_dict import _A class Select(Transform[_A, _A]): diff --git a/src/torchjd/_autojac/_transform/stack.py b/src/torchjd/_autojac/_transform/_stack.py similarity index 96% rename from src/torchjd/_autojac/_transform/stack.py rename to src/torchjd/_autojac/_transform/_stack.py index 6b67cace..cce06ad0 100644 --- a/src/torchjd/_autojac/_transform/stack.py +++ b/src/torchjd/_autojac/_transform/_stack.py @@ -3,9 +3,9 @@ import torch from torch import Tensor +from ._base import Transform from ._materialize import materialize -from .base import Transform -from .tensor_dict import _A, Gradients, Jacobians +from ._tensor_dict import _A, Gradients, Jacobians class Stack(Transform[_A, Jacobians]): diff --git a/src/torchjd/_autojac/_transform/tensor_dict.py b/src/torchjd/_autojac/_transform/_tensor_dict.py similarity index 100% rename from src/torchjd/_autojac/_transform/tensor_dict.py rename to src/torchjd/_autojac/_transform/_tensor_dict.py diff --git a/src/torchjd/_autojac/_utils.py b/src/torchjd/_autojac/_utils.py index 90dc6838..0b198775 100644 --- a/src/torchjd/_autojac/_utils.py +++ b/src/torchjd/_autojac/_utils.py @@ -4,7 +4,7 @@ from torch import Tensor from torch.autograd.graph import Node -from ._transform.ordered_set import OrderedSet +from ._transform._ordered_set import OrderedSet def check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None: diff --git a/tests/unit/autojac/_transform/test_aggregate.py b/tests/unit/autojac/_transform/test_aggregate.py index 79b0ad52..c9e68d8e 100644 --- a/tests/unit/autojac/_transform/test_aggregate.py +++ b/tests/unit/autojac/_transform/test_aggregate.py @@ -12,8 +12,8 @@ Jacobians, RequirementError, ) -from torchjd._autojac._transform.aggregate import _AggregateMatrices, _Matrixify, _Reshape -from torchjd._autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform._aggregate import _AggregateMatrices, _Matrixify, _Reshape +from torchjd._autojac._transform._ordered_set import OrderedSet from torchjd.aggregation import Random from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_base.py b/tests/unit/autojac/_transform/test_base.py index c48bb4ba..7426e302 100644 --- a/tests/unit/autojac/_transform/test_base.py +++ b/tests/unit/autojac/_transform/test_base.py @@ -4,8 +4,8 @@ from pytest import raises from torch import Tensor -from torchjd._autojac._transform.base import Conjunction, RequirementError, Transform -from torchjd._autojac._transform.tensor_dict import _B, _C, TensorDict +from torchjd._autojac._transform._base import Conjunction, RequirementError, Transform +from torchjd._autojac._transform._tensor_dict import _B, _C, TensorDict class FakeTransform(Transform[_B, _C]): diff --git a/tests/unit/autojac/_transform/test_diagonalize.py b/tests/unit/autojac/_transform/test_diagonalize.py index 7f5c4df2..0b4bb077 100644 --- a/tests/unit/autojac/_transform/test_diagonalize.py +++ b/tests/unit/autojac/_transform/test_diagonalize.py @@ -2,7 +2,7 @@ from pytest import raises from torchjd._autojac._transform import Diagonalize, Gradients, RequirementError -from torchjd._autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform._ordered_set import OrderedSet from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_grad.py b/tests/unit/autojac/_transform/test_grad.py index 378e2988..b9cd3659 100644 --- a/tests/unit/autojac/_transform/test_grad.py +++ b/tests/unit/autojac/_transform/test_grad.py @@ -2,7 +2,7 @@ from pytest import raises from torchjd._autojac._transform import Grad, Gradients, RequirementError -from torchjd._autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform._ordered_set import OrderedSet from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index 3d4a604a..0b1eb796 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -17,7 +17,7 @@ Stack, TensorDict, ) -from torchjd._autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform._ordered_set import OrderedSet from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_jac.py b/tests/unit/autojac/_transform/test_jac.py index 1988cab7..a552b3fd 100644 --- a/tests/unit/autojac/_transform/test_jac.py +++ b/tests/unit/autojac/_transform/test_jac.py @@ -2,7 +2,7 @@ from pytest import mark, raises from torchjd._autojac._transform import Jac, Jacobians, RequirementError -from torchjd._autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform._ordered_set import OrderedSet from ._dict_assertions import assert_tensor_dicts_are_close diff --git a/tests/unit/autojac/_transform/test_tensor_dict.py b/tests/unit/autojac/_transform/test_tensor_dict.py index 36fb7a62..0d9160b3 100644 --- a/tests/unit/autojac/_transform/test_tensor_dict.py +++ b/tests/unit/autojac/_transform/test_tensor_dict.py @@ -13,7 +13,7 @@ Jacobians, TensorDict, ) -from torchjd._autojac._transform.tensor_dict import _least_common_ancestor +from torchjd._autojac._transform._tensor_dict import _least_common_ancestor _key_shapes = [[], [1], [2, 3]] diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index a6c13635..48db9b6c 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -5,7 +5,7 @@ from torchjd import backward from torchjd._autojac._backward import _create_transform -from torchjd._autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform._ordered_set import OrderedSet from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 4bc158e6..f4d39161 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -5,7 +5,7 @@ from torchjd import mtl_backward from torchjd._autojac._mtl_backward import _create_transform -from torchjd._autojac._transform.ordered_set import OrderedSet +from torchjd._autojac._transform._ordered_set import OrderedSet from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad From 59b2ccf9519218500c3b8f4a5fff355106d72cee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 24 May 2025 16:20:34 +0200 Subject: [PATCH 9/9] Improve changelog --- CHANGELOG.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cd0256e..cbdfa634 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,10 +14,12 @@ changes that do not affect the user. TorchJD. Users of these aggregators will have to use `pip install torchjd[cagrad]`, `pip install torchjd[nash_mtl]` or `pip install torchjd[full]` to install TorchJD alongside those dependencies. This should make TorchJD more lightweight. -- **BREAKING**: Made the aggregator modules and the autojac package protected. They should now - always be imported via their package (e.g. `from torchjd.aggregation.upgrad import UPGrad` should - become `from torchjd.aggregation import UPGrad`, and - `from torchjd.autojac.mtl_backward import mtl_backward` should become +- **BREAKING**: Made the aggregator modules and the `autojac` package protected. The aggregators + must now always be imported via their package (e.g. + `from torchjd.aggregation.upgrad import UPGrad` must be changed to + `from torchjd.aggregation import UPGrad`). The `backward` and `mtl_backward` functions must now + always be imported directly from the `torchjd` package (e.g. + `from torchjd.autojac.mtl_backward import mtl_backward` must be changed to `from torchjd import mtl_backward`). ### Fixed