Skip to content

Fix type mismatch bug in NashMTL#317

Merged
ValerianRey merged 2 commits intomainfrom
fix-nash-mtl
Apr 19, 2025
Merged

Fix type mismatch bug in NashMTL#317
ValerianRey merged 2 commits intomainfrom
fix-nash-mtl

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey commented Apr 19, 2025

  • Move cast of alpha to torch.Tensor out of the condition on the step value
  • Add changelog entry

Notes:

  • self.prvs_alpha is always a numpy array, so in both cases (if (self.step % self.update_weights_every) == 0: and else), we have to cast alpha to a Tensor.
  • In the original implementation of https://github.com/AvivNavon/nash-mtl/blob/main/methods/weight_methods.py#L238, there was already a mismatch of type, with alpha being a tensor when entering the condition (if (self.step % self.update_weights_every) == 0), and being a numpy array otherwise, but the following line (weighted_loss = sum([losses[i] * alpha[i] for i in range(len(alpha))])) made it work regardless.

* self.prvs_alpha is always a numpy array, so in both cases (if (self.step % self.update_weights_every) == 0: and else), we have to cast alpha to a Tensor.
* In the original implementation of https://github.com/AvivNavon/nash-mtl/blob/main/methods/weight_methods.py#L238, there was already a mismatch of type, with alpha being a tensor when entering if (self.step % self.update_weights_every) == 0, and being a numpy array otherwise, but the following line weighted_loss = sum([losses[i] * alpha[i] for i in range(len(alpha))]) made it work regardless.
@ValerianRey ValerianRey added package: aggregation cc: fix Conventional commit type for bug fixes of the actual library (changes to src). labels Apr 19, 2025
@ValerianRey ValerianRey self-assigned this Apr 19, 2025
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 19, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Files with missing lines Coverage Δ
src/torchjd/aggregation/nash_mtl.py 88.23% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ValerianRey ValerianRey merged commit e42d677 into main Apr 19, 2025
15 checks passed
@ValerianRey ValerianRey deleted the fix-nash-mtl branch April 19, 2025 15:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: fix Conventional commit type for bug fixes of the actual library (changes to src). package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant