-
Notifications
You must be signed in to change notification settings - Fork 528
Closed
Labels
Description
Describe the bug
ot.gpu.da.sinkhorn_lpl1 crashes with error "TypeError: Expected list, got tuple"
To Reproduce
Run test_gpu_sinkhorn_lpl1()
in file https://github.com/PythonOT/POT/blob/master/test/test_gpu.py
Output:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_162490/3474066936.py in <module>
----> 1 test_gpu_sinkhorn_lpl1()
/tmp/ipykernel_162490/3323216572.py in test_gpu_sinkhorn_lpl1()
18
19 G = ot.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg)
---> 20 G1 = ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg)
21
22 np.testing.assert_allclose(G1, G, rtol=1e-10)
~/.cache/pypoetry/virtualenvs/venvproject-py3.7/lib/python3.7/site-packages/ot/gpu/da.py in sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta, numItermax, numInnerItermax, stopInnerThr, verbose, log, to_numpy)
121 classes = npp.unique(labels_a2)
122 for c in classes:
--> 123 idxc, = utils.to_gpu(npp.where(labels_a2 == c))
124 indices_labels.append(idxc)
125
~/.cache/pypoetry/virtualenvs/venvproject-py3.7/lib/python3.7/site-packages/ot/gpu/utils.py in to_gpu(*args)
91 return (cp.asarray(x) for x in args)
92 else:
---> 93 return cp.asarray(args[0])
94
95
~/.cache/pypoetry/virtualenvs/venvproject-py3.7/lib/python3.7/site-packages/cupy/_creation/from_data.py in asarray(a, dtype, order)
64
65 """
---> 66 return _core.array(a, dtype, False, order)
67
68
cupy/_core/core.pyx in cupy._core.core.array()
cupy/_core/core.pyx in cupy._core.core.array()
TypeError: Expected list, got tuple
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Ubuntu 20.10
- Python version: 3.7.12
- How was POT installed (source,
pip
,conda
): poetry (uses pip) - Build command you used (if compiling from source):
- Only for GPU related bugs:
- CUDA version: 11.1
- GPU models and configuration: Nvidia RTX 3090
- Any other relevant information: POT 0.7.0, Cupy 9.4.0 (cupy-111)
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-5.8.0-63-generic-x86_64-with-debian-bullseye-sid
Python 3.7.12 (default, Sep 15 2021, 10:55:12)
[GCC 10.3.0]
NumPy 1.21.1
SciPy 1.6.1
POT 0.7.0