-
Notifications
You must be signed in to change notification settings - Fork 528
Closed
Labels
Description
Describe the bug
Tests in test/test_gpu.py do not all pass (test_gpu_sinkhorn).
There is a discrepancy between the matrix computed by ot.bregman and the one computed by ot.gpu.bregman.
To Reproduce
Steps to reproduce the behavior:
- Run test/test_gpu.py
=================================== FAILURES ===================================
______________________________ test_gpu_sinkhorn _______________________________
@pytest.mark.skipif(nogpu, reason="No GPU available")
def test_gpu_sinkhorn():
rng = np.random.RandomState(0)
for n_samples in [50, 100, 500, 1000]:
a = rng.rand(n_samples // 4, 100)
b = rng.rand(n_samples, 100)
wa = ot.unif(n_samples // 4)
wb = ot.unif(n_samples)
wb2 = np.random.rand(n_samples, 20)
wb2 /= wb2.sum(0, keepdims=True)
M = ot.dist(a.copy(), b.copy())
M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False)
reg = 1
G = ot.sinkhorn(wa, wb, M, reg)
G1 = ot.gpu.sinkhorn(wa, wb, M, reg)
> np.testing.assert_allclose(G1, G, rtol=1e-10)
E AssertionError:
E Not equal to tolerance rtol=1e-10, atol=0
E
E Mismatched elements: 600 / 600 (100%)
E Max absolute difference: 1.37138433e-07
E Max relative difference: 2.00548806e-05
E x: array([[5.888717e-04, 8.415569e-05, 8.951892e-05, 2.190684e-03,
E 1.977490e-05, 9.029307e-03, 2.036359e-04, 2.300168e-03,
E 2.933039e-04, 1.124758e-04, 3.394681e-04, 1.416449e-03,...
E y: array([[5.888707e-04, 8.415610e-05, 8.951908e-05, 2.190691e-03,
E 1.977479e-05, 9.029328e-03, 2.036355e-04, 2.300179e-03,
E 2.933086e-04, 1.124764e-04, 3.394675e-04, 1.416444e-03,...
test/test_gpu.py:73: AssertionError
Screenshots
Code sample
from test import test_gpu
test_gpu.test_gpu_sinkhorn()
Expected behavior
GPU unit tests should pass.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Linux
- Python version: 3.8
- How was POT installed (source,
pip
,conda
): source - Build command you used (if compiling from source): make buildext ; make install
- Only for GPU related bugs:
- CUDA version: 10.2
- GPU models and configuration: Happens on NVIDIA V100 and NVIDIA A100
- Any other relevant information: CuPY version is 9.0.0
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-4.18.0-147.48.1.el8_1.x86_64-x86_64-with-redhat-8.1-Ootpa
Python 3.7.11 (default, Jul 27 2021, 14:32:16)
[GCC 7.5.0]
NumPy 1.21.2
SciPy 1.7.1
POT 0.8.0dev