We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I get the following error:
my_conda_env/python3.11/site-packages/ot/backend.py", line 1822, in __init__ self.rng_cuda_ = torch.Generator("cuda") ^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: CUDA error: initialization error
I am using pot=0.9.4=py311h14de704_0 and pytorch=2.2.2=py3.11_cuda12.1_cudnn8.9.2_0.
pot=0.9.4=py311h14de704_0
pytorch=2.2.2=py3.11_cuda12.1_cudnn8.9.2_0
My trace looks like this (inside OT):
File "lib/python3.11/site-packages/ot/bregman/_sinkhorn.py", line 1037, in sinkhorn_stabilized a, b, M = list_to_array(a, b, M) ^^^^^^^^^^^^^^^^^^^^^^ File "python3.11/site-packages/ot/utils.py", line 68, in list_to_array nx = get_backend(*lst_not_empty) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "python3.11/site-packages/ot/backend.py", line 222, in get_backend return _get_backend_instance(backend_impl) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "python3.11/site-packages/ot/backend.py", line 170, in _get_backend_instance _BACKENDS[backend_impl.__name__] = backend_impl() ^^^^^^^^^^^^^^ File "python3.11/site-packages/ot/backend.py", line 1822, in __init__ self.rng_cuda_ = torch.Generator("cuda")
device = torch.device("cuda") n1 = 10 rank = 5 p1 = torch.ones(n1, device=device) / n1 Q = torch.rand(n1, rank, device=device) Q = ot.bregman.sinkhorn_stabilized( p1, torch.ones(rank, device=device) / rank, Q, reg=1e-1 )
pip
conda
Output of the following code snippet:
>>> import platform; print(platform.platform()) Linux-5.15.153.1-microsoft-standard-WSL2-x86_64-with-glibc2.35 >>> import sys; print("Python", sys.version) Python 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] >>> import numpy; print("NumPy", numpy.__version__) NumPy 1.24.3 >>> import scipy; print("SciPy", scipy.__version__) SciPy 1.14.0 >>> import ot; print("POT", ot.__version__) POT 0.9.4
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Describe the bug
I get the following error:
I am using
pot=0.9.4=py311h14de704_0
andpytorch=2.2.2=py3.11_cuda12.1_cudnn8.9.2_0
.To Reproduce
My trace looks like this (inside OT):
Code sample
Environment (please complete the following information):
pip
,conda
): condaOutput of the following code snippet:
The text was updated successfully, but these errors were encountered: