🪺 SDMGrad, DWA, FAMO
This release introduces a new weighting SDMGradWeighting and two new scalarizers DWA and FAMO. Thanks a lot to @KhusPatel4450 and @ppraneth for the contributions!
We're trying to grow the community and build even more features! To participate, you can join the Discord community!
Changelog
Added
- Added
SDMGradWeightingfrom
Direction-oriented Multi-objective Learning: Simple and Provable Stochastic
Algorithms
(NeurIPS 2023). It is a statefulWeightingthat solves for task weights via a simplex-projected
inner loop on a cross-batch matrixA = J_1 @ J_2.T(computed from two independent mini-batches
usingautojac.jac), with a direction-oriented regularizer pulling the descent direction toward
a preference direction. - Added
DWA(Dynamic Weight Average) from End-to-End Multi-Task Learning with
Attention
(CVPR 2019), a statefulScalarizerthat weights each value by the relative rate at which its
loss decreased over the two previous epochs. It has no learnable parameters; call itsstep()
method once per epoch to roll the loss history. - Added
FAMO(Fast Adaptive Multitask Optimization) from FAMO: Fast Adaptive Multitask
Optimization
(NeurIPS 2023), a statefulScalarizerthat decreases all task losses at an approximately equal
rate using only the loss values. It learns the task weights internally; after the model step,
call itsupdate()method with the losses recomputed on the same batch to adjust them.