Skip to content

Unstable TorchBackend.sqrtm() around repeated eigenvalues #773

@samuelbx

Description

@samuelbx

Describe the bug

TorchBackend.sqrtm relies on torch.linalg.eigh which has undefined gradients when eigvals are repeated (PyTorch's doc explains the issue).

To Reproduce

import torch
from ot.backend import TorchBackend

torch.set_default_dtype(torch.float64)
torch.autograd.set_detect_anomaly(True)

nx = TorchBackend()
A = torch.eye(3, dtype=torch.float64, requires_grad=True)
nx.sqrtm(A)[0, 1].backward()
print('OK')

Output: RuntimeError: Function 'LinalgEighBackward0' returned nan values in its 0th output.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): MacOS
  • Python version: 9
  • How was POT installed (source, pip, conda): pip

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions