Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

added on_backward trainer callback #5249

Merged
merged 5 commits into from
Jun 11, 2021

Conversation

ArjunSubramonian
Copy link
Contributor

@ArjunSubramonian ArjunSubramonian commented Jun 9, 2021

Additions proposed in this pull request:

  • Added on_backward training callback which allows for control over backpropagation and gradient manipulation.

@ArjunSubramonian ArjunSubramonian self-assigned this Jun 9, 2021
@ArjunSubramonian ArjunSubramonian requested review from dirkgr, epwalsh and AkshitaB and removed request for dirkgr June 9, 2021 22:23
Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this design is too heavy. I don't like putting a core functionality like loss.backward() into a callback. It makes it too hard to see what's going on in the trainer.

Instead, can we use the normal TrainerCallback, and give it some extra methods, like pre_backward() and post_backward()? Can you solve your problem with that?

@ArjunSubramonian
Copy link
Contributor Author

I think this design is too heavy. I don't like putting a core functionality like loss.backward() into a callback. It makes it too hard to see what's going on in the trainer.

Instead, can we use the normal TrainerCallback, and give it some extra methods, like pre_backward() and post_backward()? Can you solve your problem with that?

I can't do adversarial training without backward() being called in a callback because I need to call backward() two times and one call requires retain_graph=True.

@dirkgr
Copy link
Member

dirkgr commented Jun 9, 2021

You could leave one backward() call as it is, and issue the second one in the callback?

@ArjunSubramonian
Copy link
Contributor Author

You could leave one backward() call as it is, and issue the second one in the callback?

@dirkgr I have made the revisions you suggested :)

Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the changelog entry, this is great!

CHANGELOG.md Outdated
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `BackwardCallback`, a training callback which allows for control over backpropagation and gradient manipulation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment isn't accurate anymore, is it?

if not backward_called:
trainer._scaler.scale(loss).backward() # type: ignore
return True
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it an error if this gets called with backward_called == True? Should we throw an exception in that case?

if not backward_called:
loss.backward()
for param in trainer.model.parameters():
param.grad *= 0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that really the best way to do that?

Suggested change
param.grad *= 0.0
param.zero_()

I don't know for sure, but I would guess that zero_() is faster.

@ArjunSubramonian ArjunSubramonian changed the title added BackwardCallback added on_backward trainer callback Jun 10, 2021
Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Will make a great update for tomorrow's meeting, too!

@ArjunSubramonian ArjunSubramonian merged commit a6cfb12 into main Jun 11, 2021
@ArjunSubramonian ArjunSubramonian deleted the arjuns/during-backward-callback branch June 11, 2021 00:04
Abhishek-P pushed a commit to Abhishek-P/allennlp that referenced this pull request Aug 11, 2021
* added BackwardCallback

* finished tests

* fixed linting issue

* revised design per Dirk's suggestion

* added OnBackwardException, changed loss to batch_ouputs, etc.

Co-authored-by: Arjun Subramonian <arjuns@Arjuns-MacBook-Pro.local>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants