New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch device error in example "plot_sliced_wass_grad_flow_pytorch.py" #371
Comments
Hi, Thanks for your report, I can reproduce your issue. It comes from the fact that torch.randn does not allow a device where the generator is not located. It remained unspotted because on github, POT does not have access to a GPU so examples were computed on CPU. From what I understand from https://pytorch.org/docs/stable/generated/torch.randn.html, it seems to be a Pytorch bug ; I don't think this behaviour was intended. I can see two ways for fixing it. We could either
The first solution is more straightforward but it is a bit more time-consuming. See (Benchmark done on Pytorch 1.10 with a V100): rng = torch.Generator("cpu")
rng.seed()
rng_gpu = torch.Generator("cuda")
rng_gpu.seed()
%timeit torch.randn(100000, generator=rng).to("cuda")
%timeit torch.randn(100000, generator=rng_gpu, device="cuda") returns
So the second solution is more efficient, but requires more code. What do you think @rflamary ? |
I think that we need fast generators in a lot of potential applications so I'm sorry @ncassereau-idris but I prefer the second ;). But in this case it means that you need one generator per device because if you have two GPU then it will still has a problem no? |
No actually, I just tested it with 2 V100. About the correction of the bug, I will suggest a PR tomorrow, or this afternoon if I have time. |
great thanks for checking it was not obvious o my side |
Describe the bug
Running the "plot_sliced_wass_grad_flow_pytorch.py" raises a torch device-based
RuntimeError
To Reproduce
Steps to reproduce the behavior:
examples/backends
python plot_sliced_wass_grad_flow_pytorch.py
Terminal output (with edited paths)
Expected behavior
The script should run entirely on gpu and never expect cpu data, since
torch.cuda.is_available() == True
in this caseEnvironment (please complete the following information):
pip
,conda
): sourcepython setup.py build_ext --inplace
Output of the following code snippet:
(yes, my CUDA version is old as dirt, but this should be irrelevant)
Additional Context
I prepare my conda env as follows:
conda create -n ot_dev
conda activate ot_dev
conda install pip
pip install -r requirements.txt
cd docs/
pip install -r requirements.txt
The text was updated successfully, but these errors were encountered: