Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

optimizer.zero_grad()
- loss = loss1 + loss2
- loss.backward()
+ mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
optimizer.step()
optimizer.zero_grad()
```

> [!NOTE]
Expand Down Expand Up @@ -150,12 +150,12 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr
- loss = loss_fn(output, target) # shape [1]
+ losses = loss_fn(output, target) # shape [16]

optimizer.zero_grad()
- loss.backward()
+ gramian = engine.compute_gramian(losses) # shape: [16, 16]
+ weights = weighting(gramian) # shape: [16]
+ losses.backward(weights)
optimizer.step()
optimizer.zero_grad()
```

Lastly, you can even combine the two approaches by considering multiple tasks and each element of
Expand Down Expand Up @@ -201,10 +201,10 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
# Obtain the weights that lead to no conflict between reweighted gradients
weights = weighting(gramian) # shape: [16, 2]

optimizer.zero_grad()
# Do the standard backward pass, but weighted using the obtained weights
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()
```

> [!NOTE]
Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ case, the losses) should preferably be scaled with a `GradScaler
following example shows the resulting code for a multi-task learning use-case.

.. code-block:: python
:emphasize-lines: 2, 17, 27, 34, 36-38
:emphasize-lines: 2, 17, 27, 34-37

import torch
from torch.amp import GradScaler
Expand Down Expand Up @@ -48,10 +48,10 @@ following example shows the resulting code for a multi-task learning use-case.
loss2 = loss_fn(output2, target2)

scaled_losses = scaler.scale([loss1, loss2])
optimizer.zero_grad()
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

.. hint::
Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For
Expand Down
12 changes: 6 additions & 6 deletions docs/source/examples/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ We can now compute the losses associated to each element of the batch.

The last steps are similar to gradient descent-based optimization, but using the two losses.

Reset the ``.grad`` field of each model parameter:

.. code-block:: python

optimizer.zero_grad()

Perform the Jacobian descent backward pass:

.. code-block:: python
Expand All @@ -81,3 +75,9 @@ Update each parameter based on its ``.grad`` field, using the ``optimizer``:
optimizer.step()

The model's parameters have been updated!

As usual, you should now reset the ``.grad`` field of each model parameter:

.. code-block:: python

optimizer.zero_grad()
4 changes: 2 additions & 2 deletions docs/source/examples/iwmtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ this Gramian to reweight the gradients and resolve conflict entirely.
The following example shows how to do that.

.. code-block:: python
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 41-42
:emphasize-lines: 5-6, 18-20, 31-32, 34-35, 37-38, 40-41

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
Expand Down Expand Up @@ -51,10 +51,10 @@ The following example shows how to do that.
# Obtain the weights that lead to no conflict between reweighted gradients
weights = weighting(gramian) # shape: [16, 2]

optimizer.zero_grad()
# Do the standard backward pass, but weighted using the obtained weights
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()

.. note::
In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a
Expand Down
10 changes: 5 additions & 5 deletions docs/source/examples/iwrm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
loss = loss_fn(y_hat, y) # shape: [] (scalar)
optimizer.zero_grad()
loss.backward()


optimizer.step()
optimizer.zero_grad()

In this baseline example, the update may negatively affect the loss of some elements of the
batch.

.. tab-item:: autojac

.. code-block:: python
:emphasize-lines: 5-6, 12, 16, 21, 23
:emphasize-lines: 5-6, 12, 16, 21-22

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
Expand All @@ -99,19 +99,19 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
optimizer.zero_grad()
backward(losses, aggregator)


optimizer.step()
optimizer.zero_grad()

Here, we compute the Jacobian of the per-sample losses with respect to the model parameters
and use it to update the model such that no loss from the batch is (locally) increased.

.. tab-item:: autogram (recommended)

.. code-block:: python
:emphasize-lines: 5-6, 12, 16-17, 21, 23-25
:emphasize-lines: 5-6, 12, 16-17, 21-24

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
Expand All @@ -134,11 +134,11 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
optimizer.zero_grad()
gramian = engine.compute_gramian(losses) # shape: [16, 16]
weights = weighting(gramian) # shape: [16]
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()

Here, the per-sample gradients are never fully stored in memory, leading to large
improvements in memory usage and speed compared to autojac, in most practical cases. The
Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/lightning_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The following code example demonstrates a basic multi-task learning setup using
<../docs/autojac/mtl_backward>` at each training iteration.

.. code-block:: python
:emphasize-lines: 9-10, 18, 32
:emphasize-lines: 9-10, 18, 31

import torch
from lightning import LightningModule, Trainer
Expand Down Expand Up @@ -43,9 +43,9 @@ The following code example demonstrates a basic multi-task learning setup using
loss2 = mse_loss(output2, target2)

opt = self.optimizers()
opt.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
opt.step()
opt.zero_grad()

def configure_optimizers(self) -> OptimizerLRScheduler:
optimizer = Adam(self.parameters(), lr=1e-3)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/monitoring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ they have a negative inner product).
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)
optimizer.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
optimizer.step()
optimizer.zero_grad()
4 changes: 2 additions & 2 deletions docs/source/examples/mtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.


.. code-block:: python
:emphasize-lines: 5-6, 19, 33
:emphasize-lines: 5-6, 19, 32

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
Expand Down Expand Up @@ -52,9 +52,9 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

optimizer.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
optimizer.step()
optimizer.zero_grad()

.. note::
In this example, the Jacobian is only with respect to the shared parameters. The task-specific
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/partial_jd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
optimizer.zero_grad()
gramian = engine.compute_gramian(losses)
weights = weighting(gramian)
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()
4 changes: 2 additions & 2 deletions docs/source/examples/rnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ element of the output sequences. If the gradients of these losses are likely to
descent can be leveraged to enhance optimization.

.. code-block:: python
:emphasize-lines: 5-6, 10, 17, 20
:emphasize-lines: 5-6, 10, 17, 19

import torch
from torch.nn import RNN
Expand All @@ -26,9 +26,9 @@ descent can be leveraged to enhance optimization.
output, _ = rnn(input) # output is of shape [5, 3, 20].
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.

optimizer.zero_grad()
backward(losses, aggregator, parallel_chunk_size=1)
optimizer.step()
optimizer.zero_grad()

.. note::
At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class Engine:
Train a model using Gramian-based Jacobian descent.

.. code-block:: python
:emphasize-lines: 5-6, 15-16, 18-19, 26-28
:emphasize-lines: 5-6, 15-16, 18-19, 26-29

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
Expand All @@ -103,11 +103,11 @@ class Engine:
output = model(input).squeeze(dim=1) # shape: [16]
losses = criterion(output, target) # shape: [16]

optimizer.zero_grad()
gramian = engine.compute_gramian(losses) # shape: [16, 16]
weights = weighting(gramian) # shape: [16]
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()

This is equivalent to just calling ``torchjd.autojac.backward(losses, UPGrad())``. However,
since the Jacobian never has to be entirely in memory, it is often much more
Expand Down
2 changes: 1 addition & 1 deletion tests/doc/test_autogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def test_engine():
output = model(input).squeeze(dim=1) # shape: [16]
losses = criterion(output, target) # shape: [16]

optimizer.zero_grad()
gramian = engine.compute_gramian(losses) # shape: [16, 16]
weights = weighting(gramian) # shape: [16]
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()
22 changes: 11 additions & 11 deletions tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def test_amp():
loss2 = loss_fn(output2, target2)

scaled_losses = scaler.scale([loss1, loss2])
optimizer.zero_grad()
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()


def test_basic_usage():
Expand All @@ -69,9 +69,9 @@ def test_basic_usage():
loss1 = loss_fn(output[:, 0], target1)
loss2 = loss_fn(output[:, 1], target2)

optimizer.zero_grad()
autojac.backward([loss1, loss2], aggregator)
optimizer.step()
optimizer.zero_grad()


def test_iwmtl():
Expand Down Expand Up @@ -114,10 +114,10 @@ def test_iwmtl():
# Obtain the weights that lead to no conflict between reweighted gradients
weights = weighting(gramian) # shape: [16, 2]

optimizer.zero_grad()
# Do the standard backward pass, but weighted using the obtained weights
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()


def test_iwrm():
Expand All @@ -138,9 +138,9 @@ def test_autograd():
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
loss = loss_fn(y_hat, y) # shape: [] (scalar)
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()

def test_autojac():
import torch
Expand All @@ -163,9 +163,9 @@ def test_autojac():
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
optimizer.zero_grad()
backward(losses, aggregator)
optimizer.step()
optimizer.zero_grad()

def test_autogram():
import torch
Expand All @@ -189,11 +189,11 @@ def test_autogram():
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
optimizer.zero_grad()
gramian = engine.compute_gramian(losses) # shape: [16, 16]
weights = weighting(gramian) # shape: [16]
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()

test_autograd()
test_autojac()
Expand Down Expand Up @@ -240,9 +240,9 @@ def training_step(self, batch, batch_idx) -> None:
loss2 = mse_loss(output2, target2)

opt = self.optimizers()
opt.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
opt.step()
opt.zero_grad()

def configure_optimizers(self) -> OptimizerLRScheduler:
optimizer = Adam(self.parameters(), lr=1e-3)
Expand Down Expand Up @@ -314,9 +314,9 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

optimizer.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
optimizer.step()
optimizer.zero_grad()


def test_mtl():
Expand Down Expand Up @@ -351,9 +351,9 @@ def test_mtl():
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

optimizer.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
optimizer.step()
optimizer.zero_grad()


def test_partial_jd():
Expand Down Expand Up @@ -382,11 +382,11 @@ def test_partial_jd():
for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
losses = loss_fn(y_hat, y) # shape: [16]
optimizer.zero_grad()
gramian = engine.compute_gramian(losses)
weights = weighting(gramian)
losses.backward(weights)
optimizer.step()
optimizer.zero_grad()


def test_rnn():
Expand All @@ -408,6 +408,6 @@ def test_rnn():
output, _ = rnn(input) # output is of shape [5, 3, 20].
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.

optimizer.zero_grad()
backward(losses, aggregator, parallel_chunk_size=1)
optimizer.step()
optimizer.zero_grad()
Loading
Loading