Skip to content

refactor(aggregation): Add _NonDifferentiable mixin#677

Open
ValerianRey wants to merge 4 commits intomainfrom
non-differentiable-mixin
Open

refactor(aggregation): Add _NonDifferentiable mixin#677
ValerianRey wants to merge 4 commits intomainfrom
non-differentiable-mixin

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

Summary

  • Adds a _NonDifferentiable mixin (in _mixins.py) that wraps __call__ in torch.no_grad(), preventing autograd graph construction entirely
  • Applies it to all 9 non-differentiable aggregators and their paired weighting classes: UPGrad, DualProj, PCGrad, GradVac, IMTLG, GradDrop, ConFIG, CAGrad, NashMTL
  • Removes _utils/non_differentiable.py (NonDifferentiableError + raise_non_differentiable_error) and all register_full_backward_pre_hook calls, which are now redundant
  • Updates assert_non_differentiable in tests to check that the output has no grad_fn (graph was never built) instead of catching NonDifferentiableError on backward

Motivation

The old approach registered a full_backward_pre_hook to raise NonDifferentiableError. This only caught the problem after a graph had already been built through the module — calling a non-differentiable aggregator on a requires_grad=True tensor would silently produce a result with grad_fn = BackwardHookFunctionBackward. The new approach is both stricter (no graph is ever built) and simpler (no error class, no hook registration).

Design notes

  • _NonDifferentiable inherits from nn.Module. This avoids a cast and makes the inheritance constraint explicit: it only makes sense for modules.
  • It must be listed before any other nn.Module base class to take effect (documented with a warning in the docstring). If placed after, nn.Module.__call__ is resolved first and the mixin is silently bypassed.
  • _NonDifferentiable is applied to both aggregators and their paired weightings, so the invariant holds whether a class is used via the Aggregator interface (autojac) or the Weighting interface (autogram).

🤖 Generated with Claude Code

ValerianRey and others added 2 commits May 10, 2026 03:04
…mixin

Non-differentiable aggregators and weightings previously registered a
full_backward_pre_hook to raise NonDifferentiableError. This only caught
the problem after a graph had already been built through the module.

The new _NonDifferentiable mixin (in _mixins.py) wraps __call__ in
torch.no_grad(), so no graph is ever constructed in the first place,
making the hook and NonDifferentiableError entirely redundant.

The mixin is applied to both the aggregator and its paired weighting
class for each non-differentiable method (UPGrad, DualProj, PCGrad,
GradVac, IMTLG, GradDrop, ConFIG, CAGrad, NashMTL).

The test helper assert_non_differentiable is updated to assert that the
output has no grad_fn rather than catching NonDifferentiableError.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Mixins should be listed before the main base class so they are resolved
first in the MRO. Move Stateful before GramianWeightedAggregator /
WeightedAggregator / _MatrixWeighting / _GramianWeighting in GradVac,
GradVacWeighting, NashMTL, and _NashMTLWeighting.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@ValerianRey ValerianRey added package: aggregation cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements labels May 10, 2026
@ValerianRey ValerianRey changed the title refactor(aggregation): Replace backward hook with _NonDifferentiable mixin refactor(aggregation): Add _NonDifferentiable mixin May 10, 2026
@ValerianRey ValerianRey marked this pull request as ready for review May 10, 2026 01:23
@ValerianRey ValerianRey requested review from a team and PierreQuinton as code owners May 10, 2026 01:23
Comment thread CHANGELOG.md Outdated
ValerianRey and others added 2 commits May 10, 2026 03:26
…-differentiable

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool. If you are certain this is run then I'm happy with this. BTW, since this needs to be the first in the objects we inherit from, then we cannot have two such classes. If we do then we'll need to think about it.

.. warning::
This mixin must appear **before** any :class:`torch.nn.Module` base class in the inheritance
list. Placing it after will silently have no effect, because :meth:`__call__` would be
resolved to :class:`torch.nn.Module` before reaching this mixin.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this Python specific MRO to solve diamonds? I guess the reason is to prevent diamonds?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you certain this is run BTW? did you try to put a raise or something?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants