From 0b782417ef802c96de08538db056ab81d2bcb03f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 5 Jan 2025 21:50:05 +0100 Subject: [PATCH] Add _pref_vector_to_suffix --- src/torchjd/aggregation/_pref_vector_utils.py | 10 ++++++++++ src/torchjd/aggregation/aligned_mtl.py | 13 ++++++------- src/torchjd/aggregation/dualproj.py | 13 ++++++------- src/torchjd/aggregation/upgrad.py | 13 ++++++------- 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/src/torchjd/aggregation/_pref_vector_utils.py b/src/torchjd/aggregation/_pref_vector_utils.py index f8b10794a..22068cfef 100644 --- a/src/torchjd/aggregation/_pref_vector_utils.py +++ b/src/torchjd/aggregation/_pref_vector_utils.py @@ -1,5 +1,6 @@ from torch import Tensor +from ._str_utils import _vector_to_str from .bases import _Weighting from .constant import _ConstantWeighting @@ -25,3 +26,12 @@ def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) - return default else: return _ConstantWeighting(pref_vector) + + +def _pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str: + """Returns a suffix string containing the representation of the optional preference vector.""" + + if pref_vector is None: + return "" + else: + return f"([{_vector_to_str(pref_vector)}])" diff --git a/src/torchjd/aggregation/aligned_mtl.py b/src/torchjd/aggregation/aligned_mtl.py index fdd52eabd..c9d76d89b 100644 --- a/src/torchjd/aggregation/aligned_mtl.py +++ b/src/torchjd/aggregation/aligned_mtl.py @@ -29,8 +29,11 @@ from torch import Tensor from torch.linalg import LinAlgError -from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting -from ._str_utils import _vector_to_str +from ._pref_vector_utils import ( + _check_pref_vector, + _pref_vector_to_str_suffix, + _pref_vector_to_weighting, +) from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -73,11 +76,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})" def __str__(self) -> str: - if self._pref_vector is None: - suffix = "" - else: - suffix = f"([{_vector_to_str(self._pref_vector)}])" - return f"AlignedMTL{suffix}" + return f"AlignedMTL{_pref_vector_to_str_suffix(self._pref_vector)}" class _AlignedMTLWrapper(_Weighting): diff --git a/src/torchjd/aggregation/dualproj.py b/src/torchjd/aggregation/dualproj.py index f81fe5149..070610264 100644 --- a/src/torchjd/aggregation/dualproj.py +++ b/src/torchjd/aggregation/dualproj.py @@ -6,8 +6,11 @@ from torch import Tensor from ._gramian_utils import _compute_normalized_gramian -from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting -from ._str_utils import _vector_to_str +from ._pref_vector_utils import ( + _check_pref_vector, + _pref_vector_to_str_suffix, + _pref_vector_to_weighting, +) from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -68,11 +71,7 @@ def __repr__(self) -> str: ) def __str__(self) -> str: - if self._pref_vector is None: - suffix = "" - else: - suffix = f"([{_vector_to_str(self._pref_vector)}])" - return f"DualProj{suffix}" + return f"DualProj{_pref_vector_to_str_suffix(self._pref_vector)}" class _DualProjWrapper(_Weighting): diff --git a/src/torchjd/aggregation/upgrad.py b/src/torchjd/aggregation/upgrad.py index 61f0f2b56..19f9f3adf 100644 --- a/src/torchjd/aggregation/upgrad.py +++ b/src/torchjd/aggregation/upgrad.py @@ -6,8 +6,11 @@ from torch import Tensor from ._gramian_utils import _compute_normalized_gramian -from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting -from ._str_utils import _vector_to_str +from ._pref_vector_utils import ( + _check_pref_vector, + _pref_vector_to_str_suffix, + _pref_vector_to_weighting, +) from .bases import _WeightedAggregator, _Weighting from .mean import _MeanWeighting @@ -67,11 +70,7 @@ def __repr__(self) -> str: ) def __str__(self) -> str: - if self._pref_vector is None: - suffix = "" - else: - suffix = f"([{_vector_to_str(self._pref_vector)}])" - return f"UPGrad{suffix}" + return f"UPGrad{_pref_vector_to_str_suffix(self._pref_vector)}" class _UPGradWrapper(_Weighting):