-
Notifications
You must be signed in to change notification settings - Fork 528
Description
Describe the bug
when you initialize and run ot.da.SinkhornLpl1Transport() with log=True, and then the fit() function, I get an error that says "ValueError: too many values to unpack (expected 2)".
To Reproduce
Steps to reproduce the behavior:
I ran the following two lines of code:
1.ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e02,reg_cl=1e-2,log=True,verbose=True)
2. ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
Error Message:
ValueError Traceback (most recent call last)
/tmp/ipykernel_176879/2895611854.py in
1 # Sinkhorn Transport with Group lasso regularization
2 ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e02,reg_cl=1e-2,log=True,verbose=True)
----> 3 ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
4 transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
5 pd.DataFrame(transp_Xs_lpl1).head()
~/.local/lib/python3.8/site-packages/ot/da.py in fit(self, Xs, ys, Xt, yt)
1748 # deal with the value of log
1749 if self.log:
-> 1750 self.coupling_, self.log_ = returned_
1751 else:
1752 self.coupling_ = returned_
ValueError: too many values to unpack (expected 2)
Code sample
ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e02,reg_cl=1e-2,log=True,verbose=True)
ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
Expected behavior
The fit() function should return an object containing both "coupling" and "log" details.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Ubuntu 20
- Python version: 3.8
- How was POT installed (source,
pip
,conda
): pip
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-5.4.0-132-generic-x86_64-with-glibc2.29
Python 3.8.10 (default, Jun 22 2022, 20:18:18)
[GCC 9.4.0]
NumPy 1.21.4
SciPy 1.8.0
POT 0.8.1.0