Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/torchjd/aggregation/_pref_vector_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch import Tensor

from ._str_utils import _vector_to_str
from .bases import _Weighting
from .constant import _ConstantWeighting

Expand All @@ -25,3 +26,12 @@ def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -
return default
else:
return _ConstantWeighting(pref_vector)


def _pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
"""Returns a suffix string containing the representation of the optional preference vector."""

if pref_vector is None:
return ""
else:
return f"([{_vector_to_str(pref_vector)}])"
13 changes: 6 additions & 7 deletions src/torchjd/aggregation/aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@
from torch import Tensor
from torch.linalg import LinAlgError

from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting
from ._str_utils import _vector_to_str
from ._pref_vector_utils import (
_check_pref_vector,
_pref_vector_to_str_suffix,
_pref_vector_to_weighting,
)
from .bases import _WeightedAggregator, _Weighting
from .mean import _MeanWeighting

Expand Down Expand Up @@ -73,11 +76,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"

def __str__(self) -> str:
if self._pref_vector is None:
suffix = ""
else:
suffix = f"([{_vector_to_str(self._pref_vector)}])"
return f"AlignedMTL{suffix}"
return f"AlignedMTL{_pref_vector_to_str_suffix(self._pref_vector)}"


class _AlignedMTLWrapper(_Weighting):
Expand Down
13 changes: 6 additions & 7 deletions src/torchjd/aggregation/dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from torch import Tensor

from ._gramian_utils import _compute_normalized_gramian
from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting
from ._str_utils import _vector_to_str
from ._pref_vector_utils import (
_check_pref_vector,
_pref_vector_to_str_suffix,
_pref_vector_to_weighting,
)
from .bases import _WeightedAggregator, _Weighting
from .mean import _MeanWeighting

Expand Down Expand Up @@ -68,11 +71,7 @@ def __repr__(self) -> str:
)

def __str__(self) -> str:
if self._pref_vector is None:
suffix = ""
else:
suffix = f"([{_vector_to_str(self._pref_vector)}])"
return f"DualProj{suffix}"
return f"DualProj{_pref_vector_to_str_suffix(self._pref_vector)}"


class _DualProjWrapper(_Weighting):
Expand Down
13 changes: 6 additions & 7 deletions src/torchjd/aggregation/upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from torch import Tensor

from ._gramian_utils import _compute_normalized_gramian
from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting
from ._str_utils import _vector_to_str
from ._pref_vector_utils import (
_check_pref_vector,
_pref_vector_to_str_suffix,
_pref_vector_to_weighting,
)
from .bases import _WeightedAggregator, _Weighting
from .mean import _MeanWeighting

Expand Down Expand Up @@ -67,11 +70,7 @@ def __repr__(self) -> str:
)

def __str__(self) -> str:
if self._pref_vector is None:
suffix = ""
else:
suffix = f"([{_vector_to_str(self._pref_vector)}])"
return f"UPGrad{suffix}"
return f"UPGrad{_pref_vector_to_str_suffix(self._pref_vector)}"


class _UPGradWrapper(_Weighting):
Expand Down