https://torchjd.org/stable/examples/monitoring/
The hook should be registered to aggregator.weighting.weighting instead of aggregator.weighting (due to a change in the architecture of aggregators a few months ago).
Also, I think the only reason why aggregators and weightings are nn.Module is to be able to register hooks. So I think we could maybe:
- Stop making aggregators and weightings be nn.Module
- Add our own function to register hooks, which would be better documented, accept hooks with a nicer signature, etc
- This would make it possible to have
aggregator.register_weighting_hook(hook) for GramianWeightedAggregator, which would be equivalent to what aggregator.weighting.weighting.register_forward_hook(hook) does today.
Not sure if this is worth the effort but maybe we'd drop a bit of technical debt by having aggregators and weightings not be nn.Modules.
https://torchjd.org/stable/examples/monitoring/
The hook should be registered to
aggregator.weighting.weightinginstead ofaggregator.weighting(due to a change in the architecture of aggregators a few months ago).Also, I think the only reason why aggregators and weightings are nn.Module is to be able to register hooks. So I think we could maybe:
aggregator.register_weighting_hook(hook)forGramianWeightedAggregator, which would be equivalent to whataggregator.weighting.weighting.register_forward_hook(hook)does today.Not sure if this is worth the effort but maybe we'd drop a bit of technical debt by having aggregators and weightings not be nn.Modules.