Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the gradient_clip_algorithm has no effect issue. #6928

Merged
merged 8 commits into from Apr 14, 2021

Conversation

ceshine
Copy link
Contributor

@ceshine ceshine commented Apr 9, 2021

What does this PR do?

It contains some necessary changes to make gradient_clip_algorithm actually work (fixes #6920).

Also, I added a temporary workaround to #6807 to make the test cases work. (I can remove it to make this PR does only one thing.) EDIT: removed this workaround for now since four errors that are outside of this PR had spun out.

Notes:

  1. TPUAccelerator.clip_gradients does not implement clipping by value. Passing gradient_clip_algorithm="value" should raise an exception.
  2. Updated the test cases test_gradient_clipping_by_value and test_gradient_clipping_by_value_fp16. They now clip the gradients to a maximum of 1e-5, and check if the maximum gradient value in the result is almost equal to 1e-5 (this threshold is small enough, so there should always be some gradients before clipping that are larger than this threshold).

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@pep8speaks
Copy link

pep8speaks commented Apr 9, 2021

Hello @ceshine! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-04-10 15:10:55 UTC

@codecov
Copy link

codecov bot commented Apr 9, 2021

Codecov Report

Merging #6928 (3d3f7cc) into master (3baac71) will decrease coverage by 5%.
The diff coverage is 57%.

@@           Coverage Diff           @@
##           master   #6928    +/-   ##
=======================================
- Coverage      92%     86%    -5%     
=======================================
  Files         194     194            
  Lines       12347   12553   +206     
=======================================
- Hits        11322   10856   -466     
- Misses       1025    1697   +672     

@ceshine
Copy link
Contributor Author

ceshine commented Apr 9, 2021

The main test that is failing now is tests/models/test_horovod.py::test_horovod_multi_optimizer, which I don't believe has something to do with the PR. And from my observation, this test fails somewhat randomly (for example, the first check was successful once but failed in the next run, and I don't think changing the TPU test case would affect the horovod tests).

The #6807 issue is really interfering with the testing inside Trainer.fit calls. I tried enabling a workaround and see if the errors raised were related to gradient clipping. They weren't, but there's no guarantee after disabling it. I would be happy to fix any errors raised from this changes once #6807 has been fixed.

Comment on lines 941 to 942
assert abs(round(grad_max.item(), 6) - grad_clip_val) < 1e-6, \
f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In rare cases, this test will be a problem if gradient values ​​are all smaller than the gradient clipping values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm aware. I'd argue that the possibility of that happens is really low (since the threshold is 1e-5). At least I haven't encountered one in my local testing. If you want to make it even lower, we can set the threshold to 1e-10 or 1e-13 (within the fp16 range).

This is the way I can think of to distinguish between clipping by norm and clipping by value. I'm open to better ideas, of course.

Copy link
Contributor Author

@ceshine ceshine Apr 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you really want to prevent that false positive case to happen, we can add an if statement before that assertion to make sure the minimum gradient is larger than the threshold (this might create some false negatives, though).

EDIT: this solution would need to get the gradients before clipping, which is not possible in the current test setup.

Comment on lines 1021 to 1022
assert abs(round(grad_max.item(), 6) - grad_clip_val) < 1e-6, \
f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@awaelchli awaelchli added the bug Something isn't working label Apr 10, 2021
@awaelchli awaelchli modified the milestones: 1.2.x, 1.3 Apr 10, 2021
@ceshine
Copy link
Contributor Author

ceshine commented Apr 10, 2021

I've changed the clipping threshold in the test cases to 1e-10 and max_step to 1 (matching test_gradient_clipping) to address @dhkim0225's concern.

I don't think false positives (test case failed where it shouldn't) would happen in this setup. Even if one does, I'd argue that the problem is the BoringModel because a test setup should not create a situation where all gradients are almost zero after just one backward pass.

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

@kaushikb11 kaushikb11 changed the title Fix the gradient_clip_algorithm has no effect issue. (#6920) Fix the gradient_clip_algorithm has no effect issue. Apr 14, 2021
@kaushikb11 kaushikb11 merged commit 24d0295 into Lightning-AI:master Apr 14, 2021
@carmocca
Copy link
Member

Hi @ceshine, quick question

TPUAccelerator.clip_gradients does not implement clipping by value. Passing gradient_clip_algorithm="value" should raise an exception.

Is there any reason why we can't just use torch.nn.utils.clip_grad_value_? Do you have more info?

@ceshine
Copy link
Contributor Author

ceshine commented Apr 15, 2021

Hi @ceshine, quick question

TPUAccelerator.clip_gradients does not implement clipping by value. Passing gradient_clip_algorithm="value" should raise an exception.

Is there any reason why we can't just use torch.nn.utils.clip_grad_value_? Do you have more info?

I'm not really familiar with XLA, but I think it is the same reason behind the use of xla_clip_grad_norm_ in TPUAccelerator.

The original #6123 implementation did not even have a gradient_clip_algorithm argument in TPUAccelerator, which will create problems when we use this argument in the training loop. I merely added that argument and made sure to let users know that only the "norm" algorithm has been implemented for TPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Trainer(gradient_clip_algorithm='value') has no effect (from #6123)
8 participants