-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing Bug with HPO #345
Fixing Bug with HPO #345
Conversation
Coverage reportThe coverage rate went from
Diff Coverage details (click to unfold)src/renate/updaters/experimental/er.py
src/renate/updaters/learner_components/reinitialization.py
src/renate/updaters/learner_components/losses.py
src/renate/updaters/learner_components/component.py
|
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | ||
"""Add plastic and stable model to checkpoint.""" | ||
super().on_save_checkpoint(checkpoint) | ||
checkpoint["component-cls-plastic-model"] = self._plastic_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should save and load self._plastic_model.state_dict()
instead of the model object itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -5,6 +5,6 @@ | |||
"dataset": "mnist.json", | |||
"backend": "local", | |||
"job_name": "class-incremental-mlp-cls-er", | |||
"expected_accuracy_linux": [[0.9839243292808533, 0.9740450382232666]], | |||
"expected_accuracy_linux": [[0.9834515452384949, 0.9740450382232666]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the results remain unaffected by the changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See PR description for more details. In short, working with torch.tensor(1.0) and float(1.0) seem to cause differences.
Hyperparameter defined as part of the components were overwritten by values in
state_dict
. To address this issue,Component
no longer is ann.Module
.Moving from torch.tensor to Python floats had led to changes for CLS (and therefore Super-ER) at the 8th digit. These errors accumulate and therefore give some different numbers. To account for this, I've changed the expected values for CLS and Super-ER.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.