Skip to content

feat(aggregation): Add GradVac aggregator#638

Merged
ValerianRey merged 25 commits intoSimplexLab:mainfrom
rkhosrowshahi:feature/gradvac
Apr 14, 2026
Merged

feat(aggregation): Add GradVac aggregator#638
ValerianRey merged 25 commits intoSimplexLab:mainfrom
rkhosrowshahi:feature/gradvac

Conversation

@rkhosrowshahi
Copy link
Copy Markdown
Contributor

@rkhosrowshahi rkhosrowshahi commented Apr 9, 2026

Summary

Adds Gradient Vaccine (GradVac) from ICLR 2021 as a stateful Aggregator on the full task Jacobian.

Behavior

  • Per-block cosine statistics and EMA targets \bar{\rho}, with the closed-form vaccine update when \rho < \bar{\rho}.
  • group_type: 0 whole model (single block); 1 all_layer via encoder (leaf modules with parameters); 2 all_matrix via shared_params (one block per tensor, iteration order = Jacobian column order).
  • DEFAULT_GRADVAC_EPS and configurable eps (constructor + mutable attribute).
  • Autogram not supported (needs full rows and per-block inner products). Task shuffle uses torch.randperm; use torch.manual_seed for reproducibility.

Files

  • src/torchjd/aggregation/_gradvac.py, export in __init__.py
  • docs/source/docs/aggregation/gradvac.rst + index toctree
  • tests/unit/aggregation/test_gradvac.py

Verification

  • ruff format / ruff check on touched paths
  • ty check on _gradvac.py
  • pytest tests/unit/aggregation/test_gradvac.py tests/unit/aggregation/test_values.py -W error
  • Sphinx HTML build for docs

Implement Gradient Vaccine (ICLR 2021) as a stateful Jacobian aggregator.
Support group_type 0 (whole model), 1 (all_layer via encoder), and 2
(all_matrix via shared_params), with DEFAULT_GRADVAC_EPS and configurable
eps. Add Sphinx page and unit tests. Autogram is not supported; use
torch.manual_seed for reproducible task shuffle order.
@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: aggregation labels Apr 9, 2026
@rkhosrowshahi
Copy link
Copy Markdown
Contributor Author

Hey @ValerianRey :
Could you please add cc: feat and package: aggregation labels to this PR?

@github-actions github-actions bot changed the title feat(aggregation): add GradVac aggregator feat(aggregation): Add GradVac aggregator Apr 9, 2026
@ValerianRey
Copy link
Copy Markdown
Contributor

I think the link check action failed because of bad luck, re-running the job fixed it.

@ValerianRey
Copy link
Copy Markdown
Contributor

Thanks a lot for the PR, I'll look at it tonight or tomorrow!

@ValerianRey
Copy link
Copy Markdown
Contributor

/opencode:Plan can you please check that the actual implementation (from line 219 to line 251 of _gradvac.py) is equivalent to the implementation in https://github.com/median-research-group/LibMTL/blob/main/LibMTL/weighting/GradVac.py? Please report anything that differs.

@opencode-agent
Copy link
Copy Markdown
Contributor

Let me analyze the differences between the two implementations:

Comparison Summary

After comparing lines 219-251 of _gradvac.py with the LibMTL reference implementation, here are the findings:

✅ Equivalent Behaviors

  1. Core Algorithm: Both implement the same GradVac algorithm:

    • Clone the gradients to pc_grads
    • Iterate over task pairs (i, j) where i ≠ j
    • Random shuffle of j for each i
    • Compute cosine similarity rho_ijk per block
    • Apply vaccine update when rho_ijk < bar (EMA target)
    • Update EMA: bar = (1-beta)*bar + beta*rho_ijk
  2. Weight Formula: Mathematically identical:

    • LibMTL: ||g_i||*(ρ̄*√(1-ρ²)-ρ*√(1-ρ̄²))/(||g_j||*√(1-ρ̄²)+ε)
    • TorchJD: norm_i*(bar*sqrt_1_rho2-rho_ijk*sqrt_1_bar2)/(norm_j*sqrt_1_bar2+eps)
  3. Group Types: Both support the same three grouping strategies (whole_model, all_layer, all_matrix)

🔍 Differences Found

Aspect LibMTL TorchJD (this PR)
RNG for shuffling random.shuffle() (Python stdlib) torch.randperm() (PyTorch)
Epsilon Hardcoded 1e-8 Configurable eps attribute
Numerical safety No clamp before sqrt() .clamp(min=0.0) before sqrt()
Block indexing Runtime sum(self.k_idx[:k]) Precomputed offsets list
Whole model case Special case with k_idx=[-1] Natural via sizes=(n,)
Return value batch_weight (unused) pc_grads.sum(dim=0)
State management self.step counter _state_key tuple

📝 Notes

  1. The clamp operation in TorchJD ((1.0 - rho_ijk * rho_ijk).clamp(min=0.0).sqrt()) is a numerical safety improvement not present in LibMTL - it prevents potential NaN from negative values due to floating-point errors.

  2. The random number generator difference is documented in the TorchJD docstring: "Task-order shuffling uses the global PyTorch RNG".

  3. Both implementations correctly handle the asymmetric nature of the algorithm (updating pc_grads[i] based on comparisons with grads[j]).

Conclusion: The implementations are algorithmically equivalent. TorchJD includes minor improvements (configurable eps, numerical clamp) and follows the codebase's design patterns.

New%20session%20-%202026-04-11T12%3A55%3A53.983Z
opencode session  |  github run

@ValerianRey
Copy link
Copy Markdown
Contributor

/opencode:Plan Please review this PR in detail.

@opencode-agent

This comment was marked as low quality.

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really good work, tyvm! This seems almost ready to merge IMO. Just a bunch of nitpicks that should be really easy to fix, and 1 somewhat important implementation question (iterating over groups in outer loop vs inner loop).

I would also argue that it seems that this could be implemented as a Weighting internally (because we actually act on norms and cosine similarities between gradients, which is what the gramian contains). Also, it's possible to keep track of norms and cosine similarities between projected gradients even if we don't have those gradients, just by making some operations on the gramian. This is what we did to implement PCGrad as a Weighting.

For example, imagine you have g1 and g2 be two gradients. From the gramian, you know ||g1||, ||g2|| (the sqrt of the diag elements), and g1 . g2 (an off-diag element), so you can deduce cos(g1, g2) from that.

If you compute g1' = g1 + w * g2, you can also directly deduce the norm of g1':
||g1'||² = ||g1||² + w² ||g2||² + 2w g1 . g2 (all elements of the right handside are known).

Similarly, you can compute g1' . g2 = (g1 + w * g2) . g2 = g1 . g2 + w g1 . g2.

So even after projection, you still know the dot products between all of your gradients, meaning that you still know the "new" gramian.

I didn't think through it entirely but at a first glance it seems possible to adapt this as a weighting, because of that. The implementation may even be faster actually (because we have fewer norms to recompute). But it may be hard to implement, so IMO we should merge this without even trying to implement it as a Weighting, and we can always improve later. @PierreQuinton what do you think about that?

@ValerianRey

This comment was marked as resolved.

ValerianRey and others added 2 commits April 11, 2026 14:56
- Add GOVERNANCE.md documenting technical governance structure
- Add CODEOWNERS file defining project maintainers
- Add CODE_OF_CONDUCT.md referencing Linux Foundation CoC

These files are required for PyTorch Ecosystem membership.

---------

Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
@rkhosrowshahi
Copy link
Copy Markdown
Contributor Author

Opencode's review was quite low quality, but it mentioned something that I missed: we need a test for GradVac in tests/unit/aggregation/test_values.py.

Similarly, i'd like to have GradVac added to tests/plots/interactive_plotter.py.

Thanks. I added the GradVac to the code and improved the code a bit to be more user-friendly. See the PCGrad and GradVac in the plot, find the same aggregated gradient. If you liked the changes, I can add to the commit as well.
PCGrad vs. GradVac

@ValerianRey

This comment was marked as outdated.

@PierreQuinton

This comment was marked as resolved.

- Use group_type "whole_model" | "all_layer" | "all_matrix" instead of 0/1/2
- Remove DEFAULT_GRADVAC_EPS from the public API; keep default 1e-8; allow eps=0
- Validate beta via setter; tighten GradVac repr/str expectations
- Fix all_layer leaf sizing via children() and parameters() instead of private fields
- Trim redundant GradVac.rst prose; align docs with the new API
- Tests: GradVac cases, value regression with torch.manual_seed for GradVac
- Plotter: factory dict + fresh aggregator instances per update; legend from
  selected keys; MathJax labels and live angle/length readouts in the sidebar

This commit includes GradVac implementation with Aggregator class.
@rkhosrowshahi rkhosrowshahi requested a review from a team as a code owner April 12, 2026 16:54
…hting

GradVac only needs gradient norms and dot products, which are fully
determined by the Gramian. This makes GradVac compatible with the autogram path.

- Remove grouping parameters (group_type, encoder, shared_params) from GradVac
- Export GradVacWeighting publicly
ValerianRey and others added 2 commits April 12, 2026 21:38
Seed is already set to 0 because of the autoused fix_randomness fixture declared in conftest.py
@ValerianRey
Copy link
Copy Markdown
Contributor

ValerianRey commented Apr 12, 2026

I think this is ready to merge, except for some plotting things. Can we remove the changes to the plotter and make plotter improvements in a different PR (except adding GradVac to the list of aggregators in the plotter)? I see a few issues in the plotter changes, and I'd rather merge this PR now and make the rest of the changes in a different PR. @rkhosrowshahi

BTW the link check action will fail because the links I added in the readme point to some documentation that will only be built after we merge this.

ValerianRey and others added 6 commits April 13, 2026 13:51
Add a Grouping example page covering all four strategies from the GradVac
paper (whole_model, enc_dec, all_layer, all_matrix), with a runnable code
block for each. Update the GradVac docstring note to link to the new page
instead of the previous placeholder text. Fix trailing whitespace in
CHANGELOG.md.
@rkhosrowshahi rkhosrowshahi force-pushed the feature/gradvac branch 2 times, most recently from c4ed86f to e626475 Compare April 13, 2026 16:46
rkhosrowshahi and others added 5 commits April 13, 2026 12:55
- The plan is to add it back in another PR
- I think it's good to have properties for beta and eps, so that we can check the values before assigning them, but I also think that the documentation of those properties is a bit bloating the documentation of the aggregator itself (and it's quite a duplicate of the documentation of the parameters). So IMO the best thing to do is to keep those two as properties, but not display the documentation of those properties (it still shows in the documentation of these parameters)
- The goal here is to uniformize a bit with the rest of the library: it's not really needed to indicate default values, because they already appear in the built documentation (unless the default is some None that later gets transformed into something else, but it's not the case here). Also, I don't think it's needed to indicate that these parameters can be changed afterwards, because I don't think a lot of people will do that, and it's actually the case of the parameters of all aggregators (or it should be). Lastly, I made the description be the same between aggregator and weighting (for ease of maintainance).
@ValerianRey ValerianRey merged commit 012b1ba into SimplexLab:main Apr 14, 2026
14 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants