-
Notifications
You must be signed in to change notification settings - Fork 528
Description
Describe the bug
Running the "plot_sliced_wass_grad_flow_pytorch.py" raises a torch device-based RuntimeError
To Reproduce
Steps to reproduce the behavior:
- From the POT source folder, navigate to
examples/backends
- Run
python plot_sliced_wass_grad_flow_pytorch.py
Terminal output (with edited paths)
2022-05-05 11:08:24.082850: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
Traceback (most recent call last):
File "POT/examples/backends/plot_sliced_wass_grad_flow_pytorch.py", line 82, in
loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen)
File "POT/ot/sliced.py", line 149, in sliced_wasserstein_distance
projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
File "POT/ot/sliced.py", line 58, in get_random_projections
projections = nx.randn(d, n_projections, type_as=type_as)
File "POT/ot/backend.py", line 1777, in randn
return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
Expected behavior
The script should run entirely on gpu and never expect cpu data, since torch.cuda.is_available() == True
in this case
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Linux
- Python version: 3.9.12
- How was POT installed (source,
pip
,conda
): source - Build command you used (if compiling from source):
python setup.py build_ext --inplace
- Only for GPU related bugs:
- CUDA version: 10.1.24
- GPU models and configuration: RTX 2070MQ
- Any other relevant information: N/A
Output of the following code snippet:
import platform; print(platform.platform())
Linux-5.13.0-40-generic-x86_64-with-glibc2.31
import sys; print("Python", sys.version)
Python 3.9.12 (main, Apr 5 2022, 06:56:58)
[GCC 7.5.0]
import numpy; print("NumPy", numpy.__version__)
NumPy 1.22.3
import scipy; print("SciPy", scipy.__version__)
SciPy 1.8.0
import ot; print("POT", ot.__version__)
2022-05-05 11:25:05.637911: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
POT 0.8.3dev
import torch;print("torch", torch.__version__)
torch 1.11.0+cu102
(yes, my CUDA version is old as dirt, but this should be irrelevant)
Additional Context
I prepare my conda env as follows:
conda create -n ot_dev
conda activate ot_dev
conda install pip
pip install -r requirements.txt
cd docs/
pip install -r requirements.txt