Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install default and test dependencies
run: pdm install --group test --frozen-lockfile
run: pdm install --group full --group test --frozen-lockfile
- name: Run unit and doc tests with coverage report
run: pdm run pytest tests/unit tests/doc --cov=src --cov-report=xml
- name: Upload results to Codecov
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ changes that do not affect the user.

## [Unreleased]

### Changed

- **BREAKING**: Changed the dependencies of `CAGrad` and `NashMTL` to be optional when installing
TorchJD. Users of these aggregators will have to use `pip install torchjd[cagrad]`, `pip install
torchjd[nash_mtl]` or `pip install torchjd[full]` to install TorchJD alongside those dependencies.
This should make TorchJD more lightweight.

## [0.6.0] - 2025-04-19

### Added
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ TorchJD can be installed directly with pip:
pip install torchjd
```
<!-- end installation -->
Some aggregators may have additional dependencies. Please refer to the
[installation documentation](https://torchjd.org/stable/installation) for them.

## Usage
The main way to use TorchJD is to replace the usual call to `loss.backward()` by a call to
Expand Down
14 changes: 14 additions & 0 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,17 @@
```

Note that `torchjd` requires python 3.10, 3.11, 3.12 or 3.13 and `torch>=2.0`.

Some aggregators (CAGrad and Nash-MTL) have additional dependencies that are not included by default
when installing `torchjd`. To install them, you can use:
```
pip install torchjd[cagrad]
```
```
pip install torchjd[nash_mtl]
```

To install `torchjd` with all of its optional dependencies, you can also use:
```
pip install torchjd[full]
```
15 changes: 13 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ dependencies = [
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"numpy>=1.21.0", # Does not work before 1.21
"qpsolvers>=1.0.1", # Does not work before 1.0.1
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
"ecos>=2.0.14", # Does not work before 2.0.14
]
classifiers = [
"Development Status :: 4 - Beta",
Expand Down Expand Up @@ -67,3 +65,16 @@ plot = [
"dash>=2.16.0", # Recent version to avoid problems, could be relaxed
"kaleido==0.2.1", # Only works with locked version
]

[project.optional-dependencies]
nash_mtl = [
"cvxpy>=1.3.0", # Could be relaxed
"ecos>=2.0.14", # Does not work before 2.0.14
]
cagrad = [
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
]
full = [
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
"ecos>=2.0.14", # Does not work before 2.0.14
]
4 changes: 4 additions & 0 deletions src/torchjd/aggregation/cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class CAGrad(_WeightedAggregator):
>>>
>>> A(J)
tensor([0.1835, 1.2041, 1.2041])
.. note::
This aggregator has dependencies that are not included by default when installing
``torchjd``. To install them, use ``pip install torchjd[cagrad]``.
"""

def __init__(self, c: float, norm_eps: float = 0.0001):
Expand Down
6 changes: 5 additions & 1 deletion src/torchjd/aggregation/nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,16 @@ class NashMTL(_WeightedAggregator):
>>> A(J)
tensor([0.0542, 0.7061, 0.7061])

.. note::
This aggregator has dependencies that are not included by default when installing
``torchjd``. To install them, use ``pip install torchjd[nash_mtl]``.

.. warning::
This implementation was adapted from the `official implementation
<https://github.com/AvivNavon/nash-mtl/tree/main>`_, which has some flaws. Use with caution.

.. warning::
The aggregator is stateful. Its output will thus depend not only on the input matrix, but
This aggregator is stateful. Its output will thus depend not only on the input matrix, but
also on its state. It thus depends on previously seen matrices. It should be reset between
experiments.
"""
Expand Down