-
Notifications
You must be signed in to change notification settings - Fork 529
Description
Describe the bug
Attempting to use the EMD (specifically ot.emd2
) to compute the loss returns a pickling error. This is likely because the loss computation could not be pickled properly (on MacOS atleast).
To Reproduce
Steps to reproduce the behavior:
- Run code snippet below
Stack trace
Traceback (most recent call last):
File "examples/test_emd2.py", line 25, in <module>
emd_loss = ot.emd2(source, targets, cost)
File "/Users/ayushkarnawat/Documents/dev/python_workspace/POT/ot/lp/__init__.py", line 429, in emd2
res = parmap(f, [b[:, i] for i in range(nb)], processes)
File "/Users/ayushkarnawat/Documents/dev/python_workspace/POT/ot/utils.py", line 249, in parmap
p.start()
File "/Users/ayushkarnawat/miniconda3/envs/pot/lib/python3.8/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/Users/ayushkarnawat/miniconda3/envs/pot/lib/python3.8/multiprocessing/context.py", line 224, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "/Users/ayushkarnawat/miniconda3/envs/pot/lib/python3.8/multiprocessing/context.py", line 283, in _Popen
return Popen(process_obj)
File "/Users/ayushkarnawat/miniconda3/envs/pot/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/Users/ayushkarnawat/miniconda3/envs/pot/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "/Users/ayushkarnawat/miniconda3/envs/pot/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/Users/ayushkarnawat/miniconda3/envs/pot/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'emd2.<locals>.f'
Screenshots
N.A.
Code sample
import numpy as np
import ot
from ot.datasets import make_1D_gauss as gauss
num_bins = 100
num_targets = 50
# Bin positions
x = np.arange(num_bins, dtype=np.float64)
means = np.linspace(20, 90, num_targets)
# Generate gaussian distributions
source = gauss(num_bins, m=20, s=5) # m= mean, s= std
targets = np.zeros((num_bins, num_targets))
for i, m in enumerate(means):
targets[:, i] = gauss(num_bins, m=m, s=5)
# Compute cost matrix and normalize
cost = ot.dist(x.reshape((num_bins, 1)), x.reshape((num_bins, 1)), metric="euclidean")
cost /= cost.max()
# Compute EMD loss
emd_loss = ot.emd2(source, targets, cost)
Expected behavior
The loss between each source and target is computed properly. For the example provided, this should be,
[0.0, 0.014426924576813046, 0.0288559930665439, 0.04328574159548381, 0.057715687365961325, 0.07214568550413558, 0.08657569638056341, 0.1010057100978733, 0.1154357243964259, 0.12986573880413424, 0.1442957532306665, 0.15872576766018065, 0.17315578209012886, 0.1875857965201352, 0.20201581095014873, 0.21644582538016313, 0.23087583981017767, 0.245305854240192, 0.2597358686702063, 0.2741658831002206, 0.2885958975302355, 0.30302591196024975, 0.31745592639026415, 0.3318859408202786, 0.346315955250293, 0.36074596968030737, 0.3751759841103219, 0.38960599854033634, 0.40403601297035013, 0.4184660274003606, 0.43289604183033936, 0.44732605626007876, 0.4617560706881306, 0.47618608510525495, 0.4906160994573777, 0.5050461134543907, 0.5194761256704825, 0.5339061296913891, 0.5483360991299938, 0.5627659348379075, 0.5771952970338474, 0.5916231256118063, 0.6060464162679144, 0.6204574579249451, 0.6348383926216666, 0.6491520836171457, 0.6633296042796002, 0.6772576353391107, 0.6907726671326471, 0.7036700128578843]
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): MacOS
- Python version: 3.8.2
- How was POT installed (source,
pip
,conda
): source - Build command you used (if compiling from source):
python setup.py develop
- Only for GPU related bugs:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
macOS-10.13.6-x86_64-i386-64bit
Python 3.8.2 (default, May 6 2020, 02:49:43)
[Clang 4.0.1 (tags/RELEASE_401/final)]
NumPy 1.17.4
SciPy 1.4.1
POT 0.7.0