Skip to content

issue with using other metric choices with emd2_1d #669

@mrunalimanj

Description

@mrunalimanj

Describe the bug

emd2_1d errors when not using the sped-up distribution metrics, e.g. cosine, yule,

To Reproduce

Steps to reproduce the behavior:
Simple test case adapted from the 1d example code:


import numpy as np
import matplotlib.pylab as pl
import ot
import ot.plot
from ot.datasets import make_1D_gauss as gauss

##############################################################################
# Generate data
# -------------


#%% parameters

n = 100  # nb bins

# bin positions
x = np.arange(n, dtype=np.float64)

# Gaussian distributions
a = gauss(n, m=20, s=5)  # m= mean, s= std
b = gauss(n, m=60, s=10)

# use fast 1D solver
G0 = ot.emd_1d(x, x, a, b, metric="cosine")
 54 G0 = ot.emd_1d(x, x, a, b, metric="cosine")
     55 
     56 # Equivalent to

~/miniconda3/envs/ms-gen/lib/python3.8/site-packages/ot/lp/solver_1d.py in emd_1d(x_a, x_b, a, b, metric, p, dense, log, check_marginals)
    257     perm_b = nx.argsort(x_b_1d)
    258 
--> 259     G_sorted, indices, cost = emd_1d_sorted(
    260         nx.to_numpy(a[perm_a]).astype(np.float64),
    261         nx.to_numpy(b[perm_b]).astype(np.float64),

ot/lp/emd_wrap.pyx in ot.lp.emd_wrap.emd_1d_sorted()

AttributeError: 'float' object has no attribute 'reshape'

Expected behavior

Should return a value, but instead errors (can't tell if math is yet correct)

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Linux
  • Python version: 3.8.18
  • How was POT installed (source, pip, conda): pip
  • Build command you used (if compiling from source): pip install POT

Output of the following code snippet:

import platform; print(platform.platform()) 
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__): 
import ot; print("POT", ot.__version__)

Linux-5.15.0-117-generic-x86_64-with-glibc2.10
Python 3.8.18 | packaged by conda-forge | (default, Oct 10 2023, 15:44:36)
[GCC 12.3.0]
NumPy 1.24.4
SciPy 1.10.1
POT 0.9.4

Additional context

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions