Skip to content

Torch device error in example "plot_sliced_wass_grad_flow_pytorch.py" #371

@eloitanguy

Description

@eloitanguy

Describe the bug

Running the "plot_sliced_wass_grad_flow_pytorch.py" raises a torch device-based RuntimeError

To Reproduce

Steps to reproduce the behavior:

  1. From the POT source folder, navigate to examples/backends
  2. Run python plot_sliced_wass_grad_flow_pytorch.py

Terminal output (with edited paths)

2022-05-05 11:08:24.082850: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
Traceback (most recent call last):
File "POT/examples/backends/plot_sliced_wass_grad_flow_pytorch.py", line 82, in
loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen)
File "POT/ot/sliced.py", line 149, in sliced_wasserstein_distance
projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
File "POT/ot/sliced.py", line 58, in get_random_projections
projections = nx.randn(d, n_projections, type_as=type_as)
File "POT/ot/backend.py", line 1777, in randn
return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

Expected behavior

The script should run entirely on gpu and never expect cpu data, since torch.cuda.is_available() == True in this case

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Linux
  • Python version: 3.9.12
  • How was POT installed (source, pip, conda): source
  • Build command you used (if compiling from source): python setup.py build_ext --inplace
  • Only for GPU related bugs:
    • CUDA version: 10.1.24
    • GPU models and configuration: RTX 2070MQ
    • Any other relevant information: N/A

Output of the following code snippet:

import platform; print(platform.platform())

Linux-5.13.0-40-generic-x86_64-with-glibc2.31

import sys; print("Python", sys.version)

Python 3.9.12 (main, Apr 5 2022, 06:56:58)
[GCC 7.5.0]

import numpy; print("NumPy", numpy.__version__)

NumPy 1.22.3

import scipy; print("SciPy", scipy.__version__)

SciPy 1.8.0

import ot; print("POT", ot.__version__)

2022-05-05 11:25:05.637911: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
POT 0.8.3dev

import torch;print("torch", torch.__version__)

torch 1.11.0+cu102

(yes, my CUDA version is old as dirt, but this should be irrelevant)

Additional Context

I prepare my conda env as follows:

conda create -n ot_dev
conda activate ot_dev
conda install pip
pip install -r requirements.txt
cd docs/
pip install -r requirements.txt

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions