diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c5614a47..0adc6979 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,7 +22,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install default and test dependencies - run: pdm install --group test --frozen-lockfile + run: pdm install --group full --group test --frozen-lockfile - name: Run unit and doc tests with coverage report run: pdm run pytest tests/unit tests/doc --cov=src --cov-report=xml - name: Upload results to Codecov diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c130fb2..86493027 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ changes that do not affect the user. ## [Unreleased] +### Changed + +- **BREAKING**: Changed the dependencies of `CAGrad` and `NashMTL` to be optional when installing + TorchJD. Users of these aggregators will have to use `pip install torchjd[cagrad]`, `pip install + torchjd[nash_mtl]` or `pip install torchjd[full]` to install TorchJD alongside those dependencies. + This should make TorchJD more lightweight. + ## [0.6.0] - 2025-04-19 ### Added diff --git a/README.md b/README.md index 047b7071..b2775201 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,8 @@ TorchJD can be installed directly with pip: pip install torchjd ``` +Some aggregators may have additional dependencies. Please refer to the +[installation documentation](https://torchjd.org/stable/installation) for them. ## Usage The main way to use TorchJD is to replace the usual call to `loss.backward()` by a call to diff --git a/docs/source/installation.md b/docs/source/installation.md index 843c31a2..eb1307f8 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -6,3 +6,17 @@ ``` Note that `torchjd` requires python 3.10, 3.11, 3.12 or 3.13 and `torch>=2.0`. + +Some aggregators (CAGrad and Nash-MTL) have additional dependencies that are not included by default +when installing `torchjd`. To install them, you can use: +``` +pip install torchjd[cagrad] +``` +``` +pip install torchjd[nash_mtl] +``` + +To install `torchjd` with all of its optional dependencies, you can also use: +``` +pip install torchjd[full] +``` diff --git a/pyproject.toml b/pyproject.toml index 01e68561..8ab9d8a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,6 @@ dependencies = [ "quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked "numpy>=1.21.0", # Does not work before 1.21 "qpsolvers>=1.0.1", # Does not work before 1.0.1 - "cvxpy>=1.3.0", # No Clarabel solver before 1.3.0 - "ecos>=2.0.14", # Does not work before 2.0.14 ] classifiers = [ "Development Status :: 4 - Beta", @@ -67,3 +65,16 @@ plot = [ "dash>=2.16.0", # Recent version to avoid problems, could be relaxed "kaleido==0.2.1", # Only works with locked version ] + +[project.optional-dependencies] +nash_mtl = [ + "cvxpy>=1.3.0", # Could be relaxed + "ecos>=2.0.14", # Does not work before 2.0.14 +] +cagrad = [ + "cvxpy>=1.3.0", # No Clarabel solver before 1.3.0 +] +full = [ + "cvxpy>=1.3.0", # No Clarabel solver before 1.3.0 + "ecos>=2.0.14", # Does not work before 2.0.14 +] diff --git a/src/torchjd/aggregation/cagrad.py b/src/torchjd/aggregation/cagrad.py index 93075fc0..e1d4e117 100644 --- a/src/torchjd/aggregation/cagrad.py +++ b/src/torchjd/aggregation/cagrad.py @@ -29,6 +29,10 @@ class CAGrad(_WeightedAggregator): >>> >>> A(J) tensor([0.1835, 1.2041, 1.2041]) + + .. note:: + This aggregator has dependencies that are not included by default when installing + ``torchjd``. To install them, use ``pip install torchjd[cagrad]``. """ def __init__(self, c: float, norm_eps: float = 0.0001): diff --git a/src/torchjd/aggregation/nash_mtl.py b/src/torchjd/aggregation/nash_mtl.py index c9681fbb..a314630b 100644 --- a/src/torchjd/aggregation/nash_mtl.py +++ b/src/torchjd/aggregation/nash_mtl.py @@ -60,12 +60,16 @@ class NashMTL(_WeightedAggregator): >>> A(J) tensor([0.0542, 0.7061, 0.7061]) + .. note:: + This aggregator has dependencies that are not included by default when installing + ``torchjd``. To install them, use ``pip install torchjd[nash_mtl]``. + .. warning:: This implementation was adapted from the `official implementation `_, which has some flaws. Use with caution. .. warning:: - The aggregator is stateful. Its output will thus depend not only on the input matrix, but + This aggregator is stateful. Its output will thus depend not only on the input matrix, but also on its state. It thus depends on previously seen matrices. It should be reset between experiments. """