Skip to content

fix(aggregation): Fix __call__ docs#693

Merged
ValerianRey merged 3 commits into
mainfrom
fix/nondifferentiable-mixin-mro-docs
May 20, 2026
Merged

fix(aggregation): Fix __call__ docs#693
ValerianRey merged 3 commits into
mainfrom
fix/nondifferentiable-mixin-mro-docs

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

Problem

Every non-differentiable aggregator and weighting had _NonDifferentiable
listed first in its base-class tuple, e.g.:

class PCGradWeighting(_NonDifferentiable, _GramianWeighting): ...
class CAGrad(_NonDifferentiable, GramianWeightedAggregator): ...

Python's MRO resolves __call__ to the first class in the MRO that
defines it. Since _NonDifferentiable.__call__ has the generic signature
(*args, **kwargs), Sphinx documented every affected method with that
unhelpful signature instead of the more specific one from
_GramianWeighting.__call__(gramian: Tensor, /) or
Aggregator.__call__(matrix: Tensor, /).

Fix

Swap the order so _NonDifferentiable comes after the primary base
class. Example:

class PCGradWeighting(_GramianWeighting, _NonDifferentiable): ...
class CAGrad(GramianWeightedAggregator, _NonDifferentiable): ...

The no_grad wrapping is fully preserved. The cooperative
super().__call__() chain now runs:

_GramianWeighting.__call__(gramian)   ← resolved first → correct docs
  → Weighting.__call__(gramian)
    → _NonDifferentiable.__call__(gramian)   ← applies no_grad
      → nn.Module.__call__(gramian)

_NonDifferentiable is still reached — just later in the chain.

Misunderstanding about the MRO requirement

The old docstring warning in _NonDifferentiable said:

This mixin must appear before any torch.nn.Module base class in the
inheritance list.

This was imprecise. What actually matters is that _NonDifferentiable
appears before nn.Module itself in the resolved MRO, not
necessarily before every nn.Module subclass in the inheritance list.
C3 linearization guarantees the former as long as every class in the chain
calls super().__call__(). The warning has been updated to reflect this.

Scope

All 11 affected files in src/torchjd/aggregation/ were updated:
_cagrad, _config, _dualproj, _fairgrad, _graddrop, _gradvac,
_imtl_g, _mixins, _nash_mtl, _pcgrad, _upgrad.

No other mixins (_WithOptionalDeps, Stateful) define __call__, so
they are unaffected.

Verification

  • All 2982 unit tests pass.
  • Docs build cleanly.
  • Generated HTML now shows __call__(gramian, /) / __call__(matrix, /)
    instead of __call__(*args, **kwargs) for every affected class.

…t in MRO

Before this change, every non-differentiable aggregator/weighting had
_NonDifferentiable listed first in its base-class tuple, e.g.
`class PCGradWeighting(_NonDifferentiable, _GramianWeighting)`. Because
Python's MRO resolves `__call__` to the first class that defines it,
Sphinx documented the method with _NonDifferentiable.__call__'s generic
`(*args, **kwargs)` signature instead of the more specific one from
_GramianWeighting or Aggregator.

The fix is to list _NonDifferentiable after the primary base class. The
cooperative super().__call__() chain then becomes:
  _GramianWeighting.__call__(gramian) →
  Weighting.__call__(gramian) →
  _NonDifferentiable.__call__(*args) [applies no_grad] →
  nn.Module.__call__(...)

The no_grad wrapping is fully preserved because every class in the chain
calls super().__call__(), so _NonDifferentiable is still reached — just
later in the chain. The old warning in _NonDifferentiable said it must
appear "before any nn.Module base class", which was imprecise: what
actually matters is that it appears before nn.Module *itself* in the
resolved MRO, which C3 linearization guarantees as long as super()
chains are cooperative.

All 2982 unit tests pass. Generated docs now show the correct parameter
names (gramian / matrix) for every affected __call__.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@ValerianRey ValerianRey added package: aggregation cc: fix Conventional commit type for bug fixes of the actual library (changes to src). labels May 20, 2026
@ValerianRey ValerianRey changed the title fix(aggregation): Fix __call__ docs by placing _NonDifferentiable last in MRO fix(aggregation): Fix __call__ docs May 20, 2026
@ValerianRey
Copy link
Copy Markdown
Contributor Author

The above is claude-generated. Not completely wrong but a bit off.

The reality is that we only need _NonDifferentiable to appear before nn.Module in the MRO. But since _NonDifferentiable is a subclass of nn.Module, we have a guarantee that it will appear before nn.Module (it's the first guarantee of the C3 linearization of MRO that Python uses). If we somehow had a class defined as:

class A(nn.Module, _NonDifferentiable):

then Python would raise: TypeError: Cannot create a consistent method resolution order (MRO) for bases B, C. In fact, when I try this, ty also catches it and says Cannot create a consistent method resolution order (MRO) for class A with bases list [<class 'Module'>, <class '_NonDifferentiable'>] ty[inconsistent-mro](https://ty.dev/rules#inconsistent-mro)

So in fact, there's no way this can even fail.

@ValerianRey
Copy link
Copy Markdown
Contributor Author

Even worse actually, before this, the chain was CAGradWeighting.call => NonDifferentiable.call => nn.Module.call => CAGradWeighting.forward, i.e. we never called GramianWeighting.call or Weighting.call (which just do a super.call, but still it's dangerous to skip calling them).

@ValerianRey
Copy link
Copy Markdown
Contributor Author

Example with CAGrad (but other classes, like UPGrad, were also affected):

Before:
image

After:
image

@ValerianRey
Copy link
Copy Markdown
Contributor Author

@PierreQuinton Please take a look at this. I'll merge because the current state has a broken documentation and I'd like it to be fixed at least on the latest tab in torchjd.org.

@ValerianRey ValerianRey merged commit 6c02f3b into main May 20, 2026
15 checks passed
@ValerianRey ValerianRey deleted the fix/nondifferentiable-mixin-mro-docs branch May 20, 2026 20:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: fix Conventional commit type for bug fixes of the actual library (changes to src). package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant