Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ changes that do not affect the user.
TorchJD. Users of these aggregators will have to use `pip install torchjd[cagrad]`, `pip install
torchjd[nash_mtl]` or `pip install torchjd[full]` to install TorchJD alongside those dependencies.
This should make TorchJD more lightweight.
- **BREAKING**: Made the aggregator modules and the `autojac` package protected. The aggregators
must now always be imported via their package (e.g.
`from torchjd.aggregation.upgrad import UPGrad` must be changed to
`from torchjd.aggregation import UPGrad`). The `backward` and `mtl_backward` functions must now
always be imported directly from the `torchjd` package (e.g.
`from torchjd.autojac.mtl_backward import mtl_backward` must be changed to
`from torchjd import mtl_backward`).

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/aligned_mtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Aligned-MTL
===========

.. automodule:: torchjd.aggregation.aligned_mtl
.. autoclass:: torchjd.aggregation.AlignedMTL
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/bases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Aggregator (abstract)
=====================

.. automodule:: torchjd.aggregation.aggregator_bases
.. autoclass:: torchjd.aggregation.Aggregator
:members:
:undoc-members:
:show-inheritance:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/cagrad.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
CAGrad
======

.. automodule:: torchjd.aggregation.cagrad
.. autoclass:: torchjd.aggregation.CAGrad
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
ConFIG
======

.. automodule:: torchjd.aggregation.config
.. autoclass:: torchjd.aggregation.ConFIG
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/constant.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Constant
========

.. automodule:: torchjd.aggregation.constant
.. autoclass:: torchjd.aggregation.Constant
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/dualproj.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
DualProj
========

.. automodule:: torchjd.aggregation.dualproj
.. autoclass:: torchjd.aggregation.DualProj
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/graddrop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
GradDrop
========

.. automodule:: torchjd.aggregation.graddrop
.. autoclass:: torchjd.aggregation.GradDrop
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/imtl_g.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
IMTL-G
======

.. automodule:: torchjd.aggregation.imtl_g
.. autoclass:: torchjd.aggregation.IMTLG
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/krum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Krum
====

.. automodule:: torchjd.aggregation.krum
.. autoclass:: torchjd.aggregation.Krum
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/mean.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Mean
====

.. automodule:: torchjd.aggregation.mean
.. autoclass:: torchjd.aggregation.Mean
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/mgda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
MGDA
====

.. automodule:: torchjd.aggregation.mgda
.. autoclass:: torchjd.aggregation.MGDA
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/nash_mtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Nash-MTL
========

.. automodule:: torchjd.aggregation.nash_mtl
.. autoclass:: torchjd.aggregation.NashMTL
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/pcgrad.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
PCGrad
======

.. automodule:: torchjd.aggregation.pcgrad
.. autoclass:: torchjd.aggregation.PCGrad
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/random.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Random
======

.. automodule:: torchjd.aggregation.random
.. autoclass:: torchjd.aggregation.Random
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/sum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Sum
===

.. automodule:: torchjd.aggregation.sum
.. autoclass:: torchjd.aggregation.Sum
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/trimmed_mean.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Trimmed Mean
============

.. automodule:: torchjd.aggregation.trimmed_mean
.. autoclass:: torchjd.aggregation.TrimmedMean
:members:
:undoc-members:
:exclude-members: forward
2 changes: 1 addition & 1 deletion docs/source/docs/aggregation/upgrad.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
UPGrad
======

.. automodule:: torchjd.aggregation.upgrad
.. autoclass:: torchjd.aggregation.UPGrad
:members:
:undoc-members:
:exclude-members: forward
5 changes: 1 addition & 4 deletions docs/source/docs/autojac/backward.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,4 @@
backward
========

.. automodule:: torchjd.autojac.backward
:members:
:undoc-members:
:exclude-members:
.. autofunction:: torchjd.backward
5 changes: 1 addition & 4 deletions docs/source/docs/autojac/mtl_backward.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,4 @@
mtl_backward
============

.. automodule:: torchjd.autojac.mtl_backward
:members:
:undoc-members:
:exclude-members:
.. autofunction:: torchjd.mtl_backward
2 changes: 1 addition & 1 deletion src/torchjd/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from torchjd.autojac import backward, mtl_backward
from ._autojac import backward, mtl_backward
2 changes: 2 additions & 0 deletions src/torchjd/_autojac/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ._backward import backward
from ._mtl_backward import mtl_backward
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torchjd.aggregation import Aggregator

from ._transform import Accumulate, Aggregate, Diagonalize, EmptyTensorDict, Init, Jac, Transform
from ._transform.ordered_set import OrderedSet
from ._transform._ordered_set import OrderedSet
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Stack,
Transform,
)
from ._transform.ordered_set import OrderedSet
from ._transform._ordered_set import OrderedSet
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors


Expand Down
17 changes: 17 additions & 0 deletions src/torchjd/_autojac/_transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from ._accumulate import Accumulate
from ._aggregate import Aggregate
from ._base import Composition, Conjunction, RequirementError, Transform
from ._diagonalize import Diagonalize
from ._grad import Grad
from ._init import Init
from ._jac import Jac
from ._select import Select
from ._stack import Stack
from ._tensor_dict import (
EmptyTensorDict,
Gradients,
GradientVectors,
JacobianMatrices,
Jacobians,
TensorDict,
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import Tensor

from .base import Transform
from .tensor_dict import EmptyTensorDict, Gradients
from ._base import Transform
from ._tensor_dict import EmptyTensorDict, Gradients


class Accumulate(Transform[Gradients, EmptyTensorDict]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from torchjd.aggregation import Aggregator

from .base import RequirementError, Transform
from .ordered_set import OrderedSet
from .tensor_dict import EmptyTensorDict, Gradients, GradientVectors, JacobianMatrices, Jacobians
from ._base import RequirementError, Transform
from ._ordered_set import OrderedSet
from ._tensor_dict import EmptyTensorDict, Gradients, GradientVectors, JacobianMatrices, Jacobians

_KeyType = TypeVar("_KeyType", bound=Hashable)
_ValueType = TypeVar("_ValueType")
Expand Down Expand Up @@ -94,7 +94,7 @@ def _aggregate_group(
) -> GradientVectors:
"""
Unites the jacobian matrices and aggregates them using an
:class:`~torchjd.aggregation.aggregator_bases.Aggregator`. Returns the obtained gradient
:class:`~torchjd.aggregation._aggregator_bases.Aggregator`. Returns the obtained gradient
vectors.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch import Tensor

from .tensor_dict import _A, _B, _C, EmptyTensorDict, _least_common_ancestor
from ._tensor_dict import _A, _B, _C, EmptyTensorDict, _least_common_ancestor


class RequirementError(ValueError):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
from torch import Tensor

from .base import RequirementError, Transform
from .ordered_set import OrderedSet
from .tensor_dict import Gradients, Jacobians
from ._base import RequirementError, Transform
from ._ordered_set import OrderedSet
from ._tensor_dict import Gradients, Jacobians


class Diagonalize(Transform[Gradients, Jacobians]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from torch import Tensor

from .base import RequirementError, Transform
from .ordered_set import OrderedSet
from .tensor_dict import _A
from ._base import RequirementError, Transform
from ._ordered_set import OrderedSet
from ._tensor_dict import _A


class Differentiate(Transform[_A, _A], ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from ._differentiate import Differentiate
from ._materialize import materialize
from .ordered_set import OrderedSet
from .tensor_dict import Gradients
from ._ordered_set import OrderedSet
from ._tensor_dict import Gradients


class Grad(Differentiate[Gradients]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from torch import Tensor

from .base import RequirementError, Transform
from .tensor_dict import EmptyTensorDict, Gradients
from ._base import RequirementError, Transform
from ._tensor_dict import EmptyTensorDict, Gradients


class Init(Transform[EmptyTensorDict, Gradients]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from ._differentiate import Differentiate
from ._materialize import materialize
from .ordered_set import OrderedSet
from .tensor_dict import Jacobians
from ._ordered_set import OrderedSet
from ._tensor_dict import Jacobians


class Jac(Differentiate[Jacobians]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from torch import Tensor

from .base import RequirementError, Transform
from .tensor_dict import _A
from ._base import RequirementError, Transform
from ._tensor_dict import _A


class Select(Transform[_A, _A]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch
from torch import Tensor

from ._base import Transform
from ._materialize import materialize
from .base import Transform
from .tensor_dict import _A, Gradients, Jacobians
from ._tensor_dict import _A, Gradients, Jacobians


class Stack(Transform[_A, Jacobians]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import Tensor
from torch.autograd.graph import Node

from ._transform.ordered_set import OrderedSet
from ._transform._ordered_set import OrderedSet


def check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None:
Expand Down
Loading
Loading