Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical instabilities of mdn_loss #39

Open
taroushirani opened this issue Nov 8, 2020 · 2 comments
Open

Numerical instabilities of mdn_loss #39

taroushirani opened this issue Nov 8, 2020 · 2 comments
Labels
bug Something isn't working

Comments

@taroushirani
Copy link
Contributor

taroushirani commented Nov 8, 2020

Hello, I found some numerical instabilities of mdn_loss in mdn.py.

(1) When I tested Conv1dResnetMDN(Conv1dResnet + MDN)[1], the back propagation of pow in torch.distributions.Normal returned nan[2]. The mechanism of this I guess is;

i. Because we clipped the minimum end of log_prob, log_sigma went bigger to prevent the probability from going smaller.
ii. "scale=exp(log_sigma)" went +inf and the back propagation of "var = (self.scale ** 2)" went nan.

This is fixed by using the centered target instead of target and clipping it within +/-5SD[3] as you recommended in the PR of MDN[4].

  1. https://github.com/taroushirani/nnsvs/blob/f6b04a4a2e1a059b96dc156c1fcbbcc225618f3c/nnsvs/model.py#L216
  2. https://gist.github.com/taroushirani/e6f91ae272b90ca1dcd1e261044a14eb
  3. https://github.com/taroushirani/nnsvs/blob/83e85f030c68c703b151f97fd4da9c1fc31fc854/nnsvs/mdn.py#L81
  4. MDN implementation #20 (comment)

(2) After changing as above, logsumexp in mdn_loss still returns nan occasionally.

20%|##        | 10/50 [01:40<06:32,  9.81s/it][W ..\torch\csrc\autograd\python_anomaly_mode.cpp:104] Warning: Error detected in LogsumexpBackward. Traceback of
 forward call that caused the error:
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "D:\cygwin64\opt\miniconda3\envs\nnsvs\Scripts\nnsvs-train.exe\__main__.py", line 7, in <module>
    sys.exit(entry())
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\bin\train.py", line 275, in entry
    my_app()
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\main.py",line 37, in decorated_main
    strict=strict,
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\_internal\utils.py", line 356, in _run_hydra
    lambda: hydra.run(
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\_internal\utils.py", line 207, in run_and_report
    return func()
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\_internal\utils.py", line 359, in <lambda>
    overrides=args.overrides,
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\_internal\hydra.py", line 112, in run
    configure_logging=with_log_configuration,
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\core\utils.py", line 125, in run_job
    ret.return_value = task_function(task_cfg)
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\bin\train.py", line 271, in my_app
    train_loop(config, device, model, optimizer, lr_scheduler, data_loaders)
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\bin\train.py", line 175, in train_loop
    loss = mdn_loss(pi, sigma, mu, y, reduce=False).masked_select(mask).mean()
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\mdn.py", line 104, in mdn_loss
    loss = -torch.logsumexp(loss, dim=2)
 (function _print_stack)
 20%|##        | 10/50 [01:47<07:09, 10.74s/it]
Traceback (most recent call last):
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\bin\train.py", line 271, in my_app
    train_loop(config, device, model, optimizer, lr_scheduler, data_loaders)
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\bin\train.py", line 199, in train_loop
    loss.backward()
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\torch\tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\torch\autograd\__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'LogsumexpBackward' returned nan values in its 0th output.

I assumed the diagonal covariance for convenience and mdn_loss contains the code to sum log_prob along the axis of target variables(D_out)[5]. Even though each log_prob is small, the summation of them can be large enough in the negative direction. This may result in 0 of exponential in logsumexp and logsumexp may return nan.

  1. https://github.com/r9y9/nnsvs/blob/6f783d289b8d11d69d954211b8cea83c79b2a49f/nnsvs/mdn.py#L96

I struggled to solve the case (2) but i could not find any good and mathematically correct solutions. Please advise me.

@r9y9
Copy link
Collaborator

r9y9 commented Nov 15, 2020

Sorry for the late reply, and thank you very much for the detailed report as always! I will look into it soon and let you know if I find a fix for the numerical instability.

@r9y9
Copy link
Collaborator

r9y9 commented Apr 3, 2022

For the record, I encountered the following error:

RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.
File "/home/ryuichi/sp/nnsvs/nnsvs/bin/train.py", line 77, in train_step
loss = mdn_loss(pi, sigma, mu, out_feats, reduce=False)
File "/home/ryuichi/sp/nnsvs/nnsvs/mdn.py", line 122, in mdn_loss
log_prob = dist.log_prob(centered_target)
File "/home/ryuichi/anaconda3/envs/py38/lib/python3.8/site-packages/torch/distributions/normal.py", line 75, in log_prob
var = (self.scale ** 2)
File "/home/ryuichi/anaconda3/envs/py38/lib/python3.8/site-packages/torch/_tensor.py", line 31, in wrapped
return f(*args, **kwargs)
(Triggered internally at  /opt/conda/conda-bld/pytorch_1646755903507/work/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

@r9y9 r9y9 added the bug Something isn't working label Apr 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants