-
Couldn't load subscription status.
- Fork 537
Open
Labels
Description
Describe the bug
TorchBackend.sqrtm relies on torch.linalg.eigh which has undefined gradients when eigvals are repeated (PyTorch's doc explains the issue).
To Reproduce
import torch
from ot.backend import TorchBackend
torch.set_default_dtype(torch.float64)
torch.autograd.set_detect_anomaly(True)
nx = TorchBackend()
A = torch.eye(3, dtype=torch.float64, requires_grad=True)
nx.sqrtm(A)[0, 1].backward()
print('OK')Output: RuntimeError: Function 'LinalgEighBackward0' returned nan values in its 0th output.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): MacOS
- Python version: 9
- How was POT installed (source,
pip,conda):pip