You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I found some numerical instabilities of mdn_loss in mdn.py.
(1) When I tested Conv1dResnetMDN(Conv1dResnet + MDN)[1], the back propagation of pow in torch.distributions.Normal returned nan[2]. The mechanism of this I guess is;
i. Because we clipped the minimum end of log_prob, log_sigma went bigger to prevent the probability from going smaller.
ii. "scale=exp(log_sigma)" went +inf and the back propagation of "var = (self.scale ** 2)" went nan.
This is fixed by using the centered target instead of target and clipping it within +/-5SD[3] as you recommended in the PR of MDN[4].
(2) After changing as above, logsumexp in mdn_loss still returns nan occasionally.
20%|## | 10/50 [01:40<06:32, 9.81s/it][W ..\torch\csrc\autograd\python_anomaly_mode.cpp:104] Warning: Error detected in LogsumexpBackward. Traceback of
forward call that caused the error:
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\runpy.py", line 85, in _run_code
exec(code, run_globals)
File "D:\cygwin64\opt\miniconda3\envs\nnsvs\Scripts\nnsvs-train.exe\__main__.py", line 7, in <module>
sys.exit(entry())
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\bin\train.py", line 275, in entry
my_app()
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\main.py",line 37, in decorated_main
strict=strict,
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\_internal\utils.py", line 356, in _run_hydra
lambda: hydra.run(
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\_internal\utils.py", line 207, in run_and_report
return func()
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\_internal\utils.py", line 359, in <lambda>
overrides=args.overrides,
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\_internal\hydra.py", line 112, in run
configure_logging=with_log_configuration,
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\hydra\core\utils.py", line 125, in run_job
ret.return_value = task_function(task_cfg)
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\bin\train.py", line 271, in my_app
train_loop(config, device, model, optimizer, lr_scheduler, data_loaders)
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\bin\train.py", line 175, in train_loop
loss = mdn_loss(pi, sigma, mu, y, reduce=False).masked_select(mask).mean()
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\mdn.py", line 104, in mdn_loss
loss = -torch.logsumexp(loss, dim=2)
(function _print_stack)
20%|## | 10/50 [01:47<07:09, 10.74s/it]
Traceback (most recent call last):
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\bin\train.py", line 271, in my_app
train_loop(config, device, model, optimizer, lr_scheduler, data_loaders)
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\nnsvs\bin\train.py", line 199, in train_loop
loss.backward()
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\torch\tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "d:\cygwin64\opt\miniconda3\envs\nnsvs\lib\site-packages\torch\autograd\__init__.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Function 'LogsumexpBackward' returned nan values in its 0th output.
I assumed the diagonal covariance for convenience and mdn_loss contains the code to sum log_prob along the axis of target variables(D_out)[5]. Even though each log_prob is small, the summation of them can be large enough in the negative direction. This may result in 0 of exponential in logsumexp and logsumexp may return nan.
Sorry for the late reply, and thank you very much for the detailed report as always! I will look into it soon and let you know if I find a fix for the numerical instability.
For the record, I encountered the following error:
RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.
File "/home/ryuichi/sp/nnsvs/nnsvs/bin/train.py", line 77, in train_step
loss = mdn_loss(pi, sigma, mu, out_feats, reduce=False)
File "/home/ryuichi/sp/nnsvs/nnsvs/mdn.py", line 122, in mdn_loss
log_prob = dist.log_prob(centered_target)
File "/home/ryuichi/anaconda3/envs/py38/lib/python3.8/site-packages/torch/distributions/normal.py", line 75, in log_prob
var = (self.scale ** 2)
File "/home/ryuichi/anaconda3/envs/py38/lib/python3.8/site-packages/torch/_tensor.py", line 31, in wrapped
return f(*args, **kwargs)
(Triggered internally at /opt/conda/conda-bld/pytorch_1646755903507/work/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Hello, I found some numerical instabilities of mdn_loss in mdn.py.
(1) When I tested Conv1dResnetMDN(Conv1dResnet + MDN)[1], the back propagation of pow in torch.distributions.Normal returned nan[2]. The mechanism of this I guess is;
i. Because we clipped the minimum end of log_prob, log_sigma went bigger to prevent the probability from going smaller.
ii. "scale=exp(log_sigma)" went +inf and the back propagation of "var = (self.scale ** 2)" went nan.
This is fixed by using the centered target instead of target and clipping it within +/-5SD[3] as you recommended in the PR of MDN[4].
(2) After changing as above, logsumexp in mdn_loss still returns nan occasionally.
I assumed the diagonal covariance for convenience and mdn_loss contains the code to sum log_prob along the axis of target variables(D_out)[5]. Even though each log_prob is small, the summation of them can be large enough in the negative direction. This may result in 0 of exponential in logsumexp and logsumexp may return nan.
I struggled to solve the case (2) but i could not find any good and mathematically correct solutions. Please advise me.
The text was updated successfully, but these errors were encountered: