Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN and AssertionError #1

Closed
tdeboissiere opened this issue Jun 30, 2017 · 21 comments
Closed

NaN and AssertionError #1

tdeboissiere opened this issue Jun 30, 2017 · 21 comments

Comments

@tdeboissiere
Copy link

Thanks for open sourcing the code !

I've tried it on a simple MLP and could not find a set of parameters (lr and mu) that would not yield one of those two errors:

assert root.size == 1 AssertionError

and

numpy.linalg.linalg.LinAlgError: Array must not contain infs or NaNs

Any tips on best practices ?

@JianGoForIt
Copy link
Owner

Hi, thanks for trying our optimizer. From the information you provided, I think it is a result of all zero gradient in the first iteration. Could you verify and avoid all zero gradient?

@tdeboissiere
Copy link
Author

I'll make doubel sure. The error does not occur at the first optimization step though. Typically after ~5 steps.

@victor-shepardson
Copy link

victor-shepardson commented Jun 30, 2017

I am also having this problem with a simple MLP autoencoder. I see numpy.linalg.linalg.LinAlgError: Array must not contain infs or NaNs after a few batches. Using clip_thresh=1 allows training for ~35 epochs, then the same error. The training curves look normal: train/valid both converge, slight overfitting begins, then the LinAlgError

This is with torch 0.1.12_2, numpy 1.13.0, Anaconda Python 3.6

edit: after trying a few different depths/activations, I also see assert root.size == 1 AssertionError sometimes. Every run produces one of the two errors, sometimes immediately, sometimes after several epochs.

@sniklaus
Copy link

sniklaus commented Jun 30, 2017

I am having the same issue with a rather complex residual network. I eventually had to set the initial learning rate to 0.0000001 in order to avoid this, but it does not really converge anymore with such a low initial learning rate. I was also trying to clip the gradients, without success though.

On a side note, the readme states "like any tensorflow optimizer", which should probably be "like any PyTorch optimizer" instead. And thank you for your great work!

@JianGoForIt
Copy link
Owner

@tdeboissiere Could you please point me to a usable code base to reproduce the error? @victor-shepardson @sniklaus It seems @tdeboissiere 's case can produce the error most quickly, If it is solved, I hope they can generalize to your case also.

@JianGoForIt
Copy link
Owner

JianGoForIt commented Jun 30, 2017

@sniklaus @victor-shepardson it would also be great if you could point me to a codebase to reproduce. Thanks.

@skaae
Copy link

skaae commented Jun 30, 2017

I had nans because 'self._grad_var' was very something like 1e16in line https://github.com/JianGoForIt/YellowFin_Pytorch/blob/master/tuner_utils/yellowfin.py#L176
Lowering the learning rate and clamping the gradients helps.

@tdeboissiere
Copy link
Author

tdeboissiere commented Jul 1, 2017

I can't send you the exact script that gave the error as it relies on proprietary data.

I do have a dummy example: https://gist.github.com/tdeboissiere/27660736ca897025c80fd43674e69750
which sometimes crashes (assert root.size == 1) and sometimes runs perfectly well.

The example above seems to work perfectly fine when replacing the dummy data with random dummy data though.

Python 2.7.13 :: Continuum Analytics, Inc.
torch 0.1.12_2

Setting clip_thresh = 1.0 did not help. Lowering the learning rate did not help either.

@JianGoForIt
Copy link
Owner

@tdeboissiere Thanks for the example. We will take a look into it and get back to you soon.

@JianGoForIt
Copy link
Owner

Hi @tdeboissiere @victor-shepardson @sniklaus @skaae

I have committed the fix to the reported nan and inf issue.

Basically, it is because the clipping is not functioning for @tdeboissiere example (and maybe also other cases).

In the previous version, the _var_list of YFOptimizer is a generator. After it is used for the first time, the following use of the _var_list (for gradient clipping) is empty. Now we have convert the generator to list in the initialization of YFOptimizer.

The fix is in this commit

13b833e

If would be great if you could also keep us updated about the performance of YF on MSE style loss in your projects. We also have other ideas for curvature estimation specialized for MSE style loss, if necessary, we can discuss about it. But ideally, the currently implementation should also work for MSE style loss with some slight clipping.

@sniklaus
Copy link

sniklaus commented Jul 3, 2017

Thank you for having a look at this! I am afraid that I am still having the same issue. I will be sending you an email with my model once the related paper got accepted and my apologies that I am currently not able to share it.

@AngusGLChen
Copy link

I am also still having the same issue. For the model, I used the one described in the pytorch tutorial: http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

@JianGoForIt
Copy link
Owner

Hi @AngusGLChen,

I will take a look with the model you are using. Just to confirm that the default setting in the tutorial would trigger the issue, right? Does it take long to get the error? What initial value for learning rate and momentum did you use?

Cheers

@AngusGLChen
Copy link

@JianGoForIt Hello, thanks for your quick response. Yes, the default setting would trigger the issue. Usually, after about 30 to 60 minutes the error would occur. I used the default setting in YellowFin for learning rate.

@JianGoForIt
Copy link
Owner

JianGoForIt commented Jul 4, 2017

Hi @AngusGLChen, Thanks for the info.

As it happens after 30 to 60 minutes, that is most likely an sudden super strong exploding gradient (or zero gradient) driving the internal lr and momentum of YellowFin crazy.

Could you please try to use the gradient clipping feature in YellowFin? More specifically,

  1. first get a feeling of the normal magnitude of gradient for your model, monitor the error and record the magnitude that triggers the error.

  2. From the knowledge of 1, please use the clip_thresh argument (the default of YellowFin is no clipping) to clip gradient within YellowFin.

There is also another quick patch. You can try catch the exception and redo that iteration, as it is pytorch instead of tensorflow, it should be easy to do this.

If the error happens around the begining of your training (seems not your case), you could also slightly lower the initial learning rate in these corner cases. It maybe help avoid loss explosion and the influence of different initial learning rate typically diminish after a few thousand iterations.

Please let us know whether the above two solutions help you or not.

@AngusGLChen
Copy link

@JianGoForIt Hello, the first method works! Thanks!

@JianGoForIt
Copy link
Owner

@AngusGLChen, glad to hear it works.

@tdeboissiere
Copy link
Author

It's been fixed on my side as well. Thanks for the quick update.

@sniklaus
Copy link

sniklaus commented Jul 5, 2017

I tried 13b833e again and additionally used gradient clipping this time. The combination of both solved the issue for me as well. Thanks!

@JianGoForIt
Copy link
Owner

We have setup a new auto_clip branch for automatic gradient clipping, it might further help on the issue. It it works well we will merge it to master in the future.

@JianGoForIt
Copy link
Owner

Just as a general information, we have added an auto clipping feature in the auto_clip branch. It is designed for the sudden exploding gradient issue discussed in this thread. Could you please try it out and provide us feedback?

Thanks guys.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants