Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ changes that do not affect the user.
should now leave the default value `retain_graph=False`, no matter what the value of
`parallel_chunk_size` is. This will reduce the memory overhead.

### Added

- RNN training usage example in the documentation.

## [0.3.1] - 2024-12-21

### Changed
Expand Down
3 changes: 3 additions & 0 deletions docs/source/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ This section contains some usage examples for TorchJD.
- :doc:`Multi-Task Learning (MTL) <mtl>` provides an example of multi-task learning where Jacobian
descent is used to optimize the vector of per-task losses of a multi-task model, using the
dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`.
- :doc:`Recurrent Neural Network (RNN) <rnn>` shows how to apply Jacobian descent to RNN training,
with one loss per output sequence element.
- :doc:`PyTorch Lightning Integration <lightning_integration>` showcases how to combine
TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task
``LightningModule`` optimized by Jacobian descent.
Expand All @@ -23,4 +25,5 @@ This section contains some usage examples for TorchJD.
basic_usage.rst
iwrm.rst
mtl.rst
rnn.rst
lightning_integration.rst
38 changes: 38 additions & 0 deletions docs/source/examples/rnn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Recurrent Neural Network (RNN)
==============================

When training recurrent neural networks for sequence modelling, we can easily obtain one loss per
element of the output sequences. If the gradients of these losses are likely to conflict, Jacobian
descent can be leveraged to enhance optimization.

.. code-block:: python
:emphasize-lines: 5-6, 10, 17, 20

import torch
from torch.nn import RNN
from torch.optim import SGD

from torchjd import backward
from torchjd.aggregation import UPGrad

rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
optimizer = SGD(rnn.parameters(), lr=0.1)
aggregator = UPGrad()

inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.

for input, target in zip(inputs, targets):
output, _ = rnn(input) # output is of shape [5, 3, 20].
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.

optimizer.zero_grad()
backward(losses, aggregator, parallel_chunk_size=1)
optimizer.step()

.. note::
At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and
``torch.nn.RNN`` when running on CUDA (see `this issue
<https://github.com/TorchJD/torchjd/issues/220>`_ for more info), so we advise to set the
``parallel_chunk_size`` to ``1`` to avoid using ``torch.vmap``. To improve performance, you can
check whether ``parallel_chunk_size=None`` (maximal parallelization) works on your side.
24 changes: 24 additions & 0 deletions tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,27 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
)

trainer.fit(model=model, train_dataloaders=train_loader)


def test_rnn():
import torch
from torch.nn import RNN
from torch.optim import SGD

from torchjd import backward
from torchjd.aggregation import UPGrad

rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
optimizer = SGD(rnn.parameters(), lr=0.1)
aggregator = UPGrad()

inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.

for input, target in zip(inputs, targets):
output, _ = rnn(input) # output is of shape [5, 3, 20].
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.

optimizer.zero_grad()
backward(losses, aggregator, parallel_chunk_size=1)
optimizer.step()