Add gradient accumulation logic to SupervisedTrainer#6101
Add gradient accumulation logic to SupervisedTrainer#6101jak0bw wants to merge 1 commit intoProject-MONAI:devfrom
Conversation
|
(Source code is strongly (and shamelessly) influenced by https://pytorch.org/ignite/generated/ignite.engine.supervised_training_step.html |
6bc4652 to
de2884a
Compare
|
Needs Feedback if this feature is desired and if yes probably test(s) for the new parameter. |
3c25c7c to
68177ac
Compare
|
Hi @jak0bw , Thanks for your idea and contribution here. Thanks. |
|
Hi @Nic-Ma, thank you for your answer. I am sorry if I misunderstand something critical here but doesn't the standard ignite logic include gradient accumulation similarly to my proposed changes (since ignite 0.4.7)? (My code is more or less copy pasted from the ignite source code) Therefore a change in the way outlined as in this pr would just restore feature parity between ignite and monai and not be considered customized logic. Link to ignite create_supervised_trainer: https://github.com/pytorch/ignite/blob/c7c0df0fbfdff2a86415476cf0e68f36a089c1d2/ignite/engine/__init__.py#L404 Link to the used step function(s): |
|
I think if ignite's def grad_accumulation_iteration(steps=...):
def iteration(engine, ...):
...
return engine.output
return iterationand the usage would be monai.engine.SupervisedTrainer(..., iteration_update=monai.engine.utils.grad_accumulation_iteration(steps), ...) |
|
As mentioned in #6100 it is possible to directly use Ignite's |
|
sure, please consider creating a function in MONAI/monai/engines/workflow.py Lines 125 to 128 in e375f2a this is how we create various iteration_update functions, for example: https://github.com/Project-MONAI/MONAI/blob/dev/monai/apps/deepedit/interaction.py#LL26C7-L26C18 |
5538ec7 to
da515e3
Compare
d48d18d to
3301889
Compare
303a1a8 to
1d8415e
Compare
Signed-off-by: Jakob Weigand <jakob.weigand@tum.de>
dfe24ef to
fda5cb6
Compare
|
@wyli I think I added a custom iteration update function as requested. The code still has a circular import error (the circular import of SupervisedTrainer) which I don't know how to resolve really as it kinda depends on how the monai project is structured (or deals with these problems in code). Additionally tests are still missing. @Nic-Ma Unfortunately, I don't really have time to further work on this pull request in the near future. Therefore this pull request (and/or the corresponding issue) can be marked for contribution wanted, closed or however the monai team wants to deal with it. |
#6100
Description
A few sentences describing the changes proposed in this pull request.
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.