-
Notifications
You must be signed in to change notification settings - Fork 528
Closed
Labels
Description
Describe the bug
emd2_1d
errors when not using the sped-up distribution metrics, e.g. cosine
, yule
,
To Reproduce
Steps to reproduce the behavior:
Simple test case adapted from the 1d example code:
import numpy as np
import matplotlib.pylab as pl
import ot
import ot.plot
from ot.datasets import make_1D_gauss as gauss
##############################################################################
# Generate data
# -------------
#%% parameters
n = 100 # nb bins
# bin positions
x = np.arange(n, dtype=np.float64)
# Gaussian distributions
a = gauss(n, m=20, s=5) # m= mean, s= std
b = gauss(n, m=60, s=10)
# use fast 1D solver
G0 = ot.emd_1d(x, x, a, b, metric="cosine")
54 G0 = ot.emd_1d(x, x, a, b, metric="cosine")
55
56 # Equivalent to
~/miniconda3/envs/ms-gen/lib/python3.8/site-packages/ot/lp/solver_1d.py in emd_1d(x_a, x_b, a, b, metric, p, dense, log, check_marginals)
257 perm_b = nx.argsort(x_b_1d)
258
--> 259 G_sorted, indices, cost = emd_1d_sorted(
260 nx.to_numpy(a[perm_a]).astype(np.float64),
261 nx.to_numpy(b[perm_b]).astype(np.float64),
ot/lp/emd_wrap.pyx in ot.lp.emd_wrap.emd_1d_sorted()
AttributeError: 'float' object has no attribute 'reshape'
Expected behavior
Should return a value, but instead errors (can't tell if math is yet correct)
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Linux
- Python version: 3.8.18
- How was POT installed (source,
pip
,conda
): pip - Build command you used (if compiling from source): pip install POT
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__):
import ot; print("POT", ot.__version__)
Linux-5.15.0-117-generic-x86_64-with-glibc2.10
Python 3.8.18 | packaged by conda-forge | (default, Oct 10 2023, 15:44:36)
[GCC 12.3.0]
NumPy 1.24.4
SciPy 1.10.1
POT 0.9.4