-
Notifications
You must be signed in to change notification settings - Fork 528
Domain adaptation Classes #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Slasnista
commented
Jul 28, 2017
- first proposal of DA class structure
- BaseEstimator: OTDA wrapper (does not work as a stand-alone but implements the methods common to any OTDA algorithm)
- SinkhornTransport: implements Sinkhorn algorithm for OTDA
- try doc strings compliant with numpy requirements
Thank you @Slasnista , I will have a look this week end or monday. Maybe creating a compliant BaseEstimator Class as @agramfort proposed is a nice way to avaoid sklearn dependency. |
Hello, I left a few comments in response to @ngayraud. @Slasnista I think this is a good start but we need to agree on a typical use case. below how I see the class working: tr=EMDTransport(out_of_sample_map='ferradans',weight_est='uniform')
tf.fit(Xs,Xt) # learn G matrix and store it, also store data
Xs_trans=tr.transform(Xs) # barycentric mapping
Xs_trans=tr.transform() # barycentric mapping with stored data
Xs2_trans=tr.transform(Xs2) # barycentric mapping + Ferradans since out of sample |
@rflamary thanks for the comments. I took them into account and I will update my pr. For the use cases I agree with you on most of them. Yet even it's a bit redundant, it looks weird to not pass data to the transform method. So if we opt for this solution, we need to write a detailed example to explain the differences between the different version of transform. I will look at your examples to update them accordingly to the classes structure. |
@Slasnista OK, as a second though I agree with you it's weird with no params. Let's not do it. Still I think it is very important to handle automatically the known Xs and new our of samples Xs2. Rémi |
changes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice, just a few more details and we have a workable class.
Think about adding somme test to test_da.py
ot/da.py
Outdated
############################################################################## | ||
# adapted from scikit-learn | ||
|
||
import warnings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feel free to put it at the top of the file
ot/da.py
Outdated
# ``if type(self).__module__.startswith('sklearn.')``. | ||
|
||
|
||
def distribution_estimation_uniform(X): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that there exist the function ot.utils.unif that does exactly the same thing
maybe a simple
from .utils import unif
distribution_estimation_uniform=unif
would do the trick
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a no it does not wrok but you still can use unif at the return line
ot/da.py
Outdated
|
||
# store arrays of samples | ||
self.Xs = Xs | ||
self.Xt = Xt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove sinkhorn from the BAseEstimator and put it in the SinkhornTransport please.
Keep all teh rest so that we can do super as still have all the storing and density estimation.
ot/da.py
Outdated
@@ -1114,7 +1255,10 @@ class SinkhornTransport(BaseTransport): | |||
|
|||
def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000, | |||
tol=10e-9, verbose=False, log=False, mapping="barycentric", | |||
metric="sqeuclidean", distribution="uniform"): | |||
metric="sqeuclidean", | |||
distribution_estimation=distribution_estimation_uniform, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok good. It is much simpler and allows the user to move around freely with the distribution estimation.
ot/da.py
Outdated
self.method = "sinkhorn" | ||
self.out_of_sample_map = out_of_sample_map | ||
|
||
def fit(self, Xs=None, ys=None, Xt=None, yt=None): | ||
"""Build a coupling matrix from source and target sets of samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that in every fit and transform, all of the parameters are initialized to None. However, I think that at least Xs (and Xt for the inverse transform) should be required.
Hi,
Here some questions:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK thank you @Slasnista , it's indeed starting to converge and i see that test are not failing anymore !
before merging I would like to see at least the EMDTransport() implemented. Ican take care of the classes whth linear and kernel mapping estimation if you want.
I added a little comment and I will answer your questions:
- Not in thus PR we cannot. This can be discussed in the future but this PR focus on a new and improved class for ot.
- there are a lot of things imported because as for numpy I want all the base ot functions (emd,emd2,sinkhorn,sinkhorn2, dist) to be easily abailable with no need for importing submodule.
- variable name change will not be done in this PR also because it is a full breaker for legacy classes ans all examples. It is open for discussion but note that ALL the ot function in the toolbox use this denomination and consistency is IMHO very important. Finally G/gamma/T are standard names for an OT matrix in the ot community and M/C are stanadr names for cost matrices.
- on spyder editor, you can perform a ctrl-i to open the documentation of the current function (just click on its definition before). the documentation can be seen in raw form of with rendering.
test/test_da.py
Outdated
np.random.seed(42) | ||
|
||
|
||
def test_sinkhorn_transport(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe test_sinkhorn_transport_class would be a better name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you for the reply, it's good for me. I will modify the name of the test function
Here is a bit of code you can add for a simple EMD transport. I hope I have it right.
|
@ngayraud thank you for the code ! I adapt it and put it inside the pr ;) just for if we can solve it like it
That does not change much, but we do not have the coupling estimation in a loop any more. Btw, do all the other methods return log in addition to couplings ? |
Emd do not accept log or verbose. It's juste à frapper for c code.
Le 1 août 2017 13:01, "Stanislas Chambon" <notifications@github.com> a
écrit :
… @ngayraud <https://github.com/ngayraud> thank you for the code ! I adapt
it and put it inside the pr ;)
just for if we can solve it like it
# coupling estimation
returned_ = emd(
a=self.mu_s, b=self.mu_t, M=self.Cost,
verbose=self.verbose, log=self.log)
if self.log:
self.coupling_, self.log = returned_
else:
self.coupling_ = returned_
That does not change much, but we do not have the coupling estimation in a
loop any more.
Btw, do all the other methods return log in addition to couplings ?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#22 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/ABUpecXzZWLWL2W5mIs-swJCzwG1Py68ks5sTwVvgaJpZM4Omn_M>
.
|
@rflamary thx, the test gave me this error ;) I removed the log and verbose input to emd for the moment. |
I think that all the sinkhorn based transports return a log! I was unsure about the emd one though. |
Hi,
What do you think about these new modifications ? |
Hello @Slasnista, Those are nice new features. I just have a few questions/todos:
This will be a very nice merge thank you again for all the work . |
I address the following points in a a second time ;) |
@Slasnista nice to see we agree. Point 3 is the tough one ;). For sinkhorn and group lasso regularization we set it to infinity (np.inf) and under constraint qualification conditions (there exist a solution that is not infinite), the algorithm converges to a solution (with zerso mass on the do-not-link links). For emd solver (and conjugate gradient that uses emd) I don't think the solver is going to appreciate infinity so I would use as you propose the max value of the cost matrix plus something maybe epsilon(max(Cost)). I think it should be a parameter anyway with the max/inf taken if set to None. Take your time for the remaining I will be in vacation next week and probably won't fine time to do a proper final review anyway. |
I completed all the modifications we talked about. I think we should grasp this occasion to think about the algorithms' outputs. In my opinion:
Point 2. is not a burning issue but 1. is one. Right now, instantiating an object with |
Hello @Slasnista, Thank you for the new commits that handle all the things we discussed. I will do a proper code review but in the meantime you need to update the examples in the examples folder that use the old classes:
We don't want the examples to be deprecated ;) Once this is done I will take care of updating the notebooks which is not done automatically yet. About your questions, there should be a log parameter in init of the classes (note that it's log and not logs in all the toolbox as a parameter, log_ should be ok for storing the log ) and not the fit. Depending on this parameter the fit method should update the log_ parameter in the class. This parameter should exist all the time and be a an empty dictionary when log is set to False (maybe it should be set in the init of the Base class ?). I understand that the varying size return in the algorithm is surprising but it leads to a very simple use of the functions while still allowing expert user to get more information if needed. Changing the behavior of the algorithm now will completely break the toolbox and again is not the object of this PR. |
Hi all, entering the discussion after some holiday break :) |
" in the sphinx syntax" you mean in plain python :)
|
Hi @rflamary,
|
Hello everyone, @ncourty since you are back from holiday and full of energy maybe you could take care of the test file formating ;) ? As POT maintainer you should be able to push commits on the PR. @Slasnista thank you, every thing looks great, I will do a proper code review shortly but I think we will be able to do a merge shortly. |
Done with the examples according to the new OTDA objects ! @rflamary you're welcome, I now look forward to seeing it merged ;-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, just a few changes and a go from travis and I think we can merge and do a new POT version.
@Slasnista don't forget to add yourself and @ngayraud to the contributors of the toolbox in the readme file.
ot/deprecation.py
Outdated
@@ -0,0 +1,103 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I think this short module is a bit overkill, maybe move the class and function to the utils.py module with the origin URL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I do that
ot/da.py
Outdated
@@ -10,21 +10,28 @@ | |||
# License: MIT License | |||
|
|||
import numpy as np | |||
import warnings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it still necessary? I think everything is in deprecated now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll check whether it's used in the script otherwise I remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh actually I need into the BaseEstimator
class.
I will move the BaseEstimator
classe into utils.py
. That would be more relevant
@@ -706,12 +763,19 @@ def normalizeM(self, norm): | |||
self.M = np.log(1 + np.log(1 + self.M)) | |||
|
|||
|
|||
@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the decorator! Removed in 0.5 is short but probably a good thing since it will disappear from the examples shortly
@@ -0,0 +1,142 @@ | |||
# -*- coding: utf-8 -*- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You chose to move the da examples in a sub-folder do you think it is more clear or is it to be similar to sk-learn? did you remove the old ones, every trace of the old class should disappear from the examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I think sub-categories would clearer for examples and later on for the doc generation
- I did not remove the old ones, shoud I move them in the new sub-folder ? @agramfort : do you have some advice to deprecate examples / documention jointly with some code ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we shoud remove them from the examples since it's what people look at when trying out the toolbox.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I do that so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes remove deprecated code from all examples and documentation
ot/da.py
Outdated
if p.name != 'self' and p.kind != p.VAR_KEYWORD] | ||
for p in parameters: | ||
if p.kind == p.VAR_POSITIONAL: | ||
raise RuntimeError("scikit-learn estimators should always " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove scikit-learn and say POT ;)
ot/da.py
Outdated
# adapted from sklearn | ||
|
||
class BaseEstimator(object): | ||
"""Base class for all estimators in scikit-learn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add POt reference and a link to the original source in sk-learn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes, i remove all the scikit refs like this one
ot/da.py
Outdated
|
||
|
||
class BaseTransport(BaseEstimator): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add documentation and list of necessary parameters in init for the class to run its functions
such as metric
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not do that for BaseTransport since it has no __init__
and no real parameters but I will check the others to not forget parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what i'm saying is that you code here functions fit and transform that require some parameters/properties to be set before calling ()hence that should be defined in init. you should express here which properties should be set in order for the function to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this to allows new contributors to propose new classes with no headaches ;)
ot/da.py
Outdated
|
||
class BaseTransport(BaseEstimator): | ||
|
||
def fit(self, Xs=None, ys=None, Xt=None, yt=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we give all None values to the input parameters, we should handle those in case the user just call the function with nothing.
When a necessary parameter is missing we should print a warning and not do anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good remark, I will handle the possible exceptions
ot/da.py
Outdated
|
||
return self | ||
|
||
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SAme here with the necessary parameters, maybe we should coce a test function with an error message that lists the necessary parameters i'm thinking as
def check_params(**kwargs):
OK=True
for param in kwargs:
if kwargs[kwargs] is None:
OK=False
if not OK:
#Print the list of Necessary params in kwargs
return OK
# the function can be called inside the classes as
if check_params(Xs=Xs,Xt=Xt):
do the stuff
What is nice is that we can give different keywords params depending on the class (semi supervised or else).
This function or at leats a clener version should be in utils.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry but this is not what you did in the last commit, you just added a condition, oi think the implementation above would be more future proof since it handles the warning/error message automatically, no need for else.
ot/da.py
Outdated
The transport source samples. | ||
""" | ||
|
||
if np.array_equal(self.Xs, Xs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.Xs is Xs
will test if the array is the same object and avoid the multiple tests and computation in np_array_equal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK seems good to me.
Let's sleep on it and proceed with the merge tomorrow if everyone is OK with it.
Thank you to all of you and especially @Slasnista who did a wonderful job on these brand new classes.
After the merge we will wait for new features from @ncourty and release 0.4 on pipy.
@@ -61,62 +50,72 @@ def mat2im(X, shape): | |||
idx1 = np.random.randint(X1.shape[0], size=(nb,)) | |||
idx2 = np.random.randint(X2.shape[0], size=(nb,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use RandomState
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean a seed ?
examples/da/plot_otda_d2.py
Outdated
import ot | ||
|
||
# number of source and target points to generate | ||
ns = 150 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n_samples_source = 150
n_samples_target = 150
it makes your comment above unnecessary and the code more readable
examples/da/plot_otda_d2.py
Outdated
Xt, yt = ot.datasets.get_data_classif('3gauss2', nt) | ||
|
||
# Cost matrix | ||
M = ot.dist(Xs, Xt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's doing euclidian distance than naming / API is not explicit enough IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dist is basically a wrapper around scipy cdist where you can choose the metric.
Also it shouldn't be computed here since it's not used anywhere it seems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I did compute it to plot it along with source and target points
examples/da/plot_otda_mapping.py
Outdated
|
||
n = 100 # nb samples in source and target datasets | ||
theta = 2 * np.pi / 20 | ||
nz = 0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nz ? could you use more explicit names?
examples/da/plot_otda_mapping.py
Outdated
max_iter=20, verbose=True) | ||
|
||
ot_mapping_linear.fit( | ||
Xs=Xs, Xt=Xt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it fits on one line
Final modifications according to @agramfort comments pushed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry last change and we can merge on my side I promise ;)
examples/da/plot_otda_classes.py
Outdated
import ot | ||
|
||
np.random.seed(42) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I think @agramfort prefer to use RandomSate as explained in
https://stackoverflow.com/questions/22994423/difference-between-np-random-seed-and-np-random-randomstate
This allow you to use controlable random generator without impacting the global one as is done in random.seed.
Note that we plan on doing proper dataset functions that accept a random state generator as parameter to allow full control (See #21 ).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no worries ;)
OK it is merged, thank you to @ngayraud @Slasnista and @agramfort for the work. |