Skip to content

Domain adaptation Classes #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Aug 29, 2017
Merged

Conversation

Slasnista
Copy link
Contributor

  • first proposal of DA class structure
  • BaseEstimator: OTDA wrapper (does not work as a stand-alone but implements the methods common to any OTDA algorithm)
  • SinkhornTransport: implements Sinkhorn algorithm for OTDA
  • try doc strings compliant with numpy requirements

@rflamary
Copy link
Collaborator

Thank you @Slasnista , I will have a look this week end or monday.

Maybe creating a compliant BaseEstimator Class as @agramfort proposed is a nice way to avaoid sklearn dependency.

@rflamary
Copy link
Collaborator

Hello, I left a few comments in response to @ngayraud.

@Slasnista I think this is a good start but we need to agree on a typical use case. below how I see the class working:

tr=EMDTransport(out_of_sample_map='ferradans',weight_est='uniform')

tf.fit(Xs,Xt) # learn G matrix and store it, also store data

Xs_trans=tr.transform(Xs) # barycentric mapping
Xs_trans=tr.transform() # barycentric mapping with stored data
Xs2_trans=tr.transform(Xs2) # barycentric mapping + Ferradans since out of sample

@Slasnista
Copy link
Contributor Author

@rflamary thanks for the comments. I took them into account and I will update my pr.

For the use cases I agree with you on most of them. Yet even it's a bit redundant, it looks weird to not pass data to the transform method.

So if we opt for this solution, we need to write a detailed example to explain the differences between the different version of transform. I will look at your examples to update them accordingly to the classes structure.

@rflamary
Copy link
Collaborator

@Slasnista OK, as a second though I agree with you it's weird with no params. Let's not do it.

Still I think it is very important to handle automatically the known Xs and new our of samples Xs2.

Rémi

@Slasnista
Copy link
Contributor Author

changes

  • as @agramfort suggested, I removed dependency on sklearn's BaseEstimator
  • removed also dependency on sklearn's pairwise_distance as @ngayraud proposed
  • added in __init__a parameter density_estimation which takes a function as input. (default: uniform)
  • stored Xs and Xt used for coupling estimation as attributes.
  • transform and inverse_transform modified: they check if their respective input has already been seen: if yes applies barycentric mapping, if no they will apply out of samples mapping (not implemented yet)

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice, just a few more details and we have a workable class.

Think about adding somme test to test_da.py

ot/da.py Outdated
##############################################################################
# adapted from scikit-learn

import warnings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel free to put it at the top of the file

ot/da.py Outdated
# ``if type(self).__module__.startswith('sklearn.')``.


def distribution_estimation_uniform(X):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that there exist the function ot.utils.unif that does exactly the same thing

maybe a simple

from .utils import unif
distribution_estimation_uniform=unif

would do the trick

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a no it does not wrok but you still can use unif at the return line

ot/da.py Outdated

# store arrays of samples
self.Xs = Xs
self.Xt = Xt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove sinkhorn from the BAseEstimator and put it in the SinkhornTransport please.

Keep all teh rest so that we can do super as still have all the storing and density estimation.

ot/da.py Outdated
@@ -1114,7 +1255,10 @@ class SinkhornTransport(BaseTransport):

def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
tol=10e-9, verbose=False, log=False, mapping="barycentric",
metric="sqeuclidean", distribution="uniform"):
metric="sqeuclidean",
distribution_estimation=distribution_estimation_uniform,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok good. It is much simpler and allows the user to move around freely with the distribution estimation.

ot/da.py Outdated
self.method = "sinkhorn"
self.out_of_sample_map = out_of_sample_map

def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that in every fit and transform, all of the parameters are initialized to None. However, I think that at least Xs (and Xt for the inverse transform) should be required.

@Slasnista
Copy link
Contributor Author

Hi,

  • I added a test function for SinkhornTransport()
  • I slightly changed SinkhornTransport() class. There are still some changes needed but we are converging.

Here some questions:

  • could we move the test folder in the ot foler ?
  • why are there so many things in the __init__ of ot folder ?
  • could we change the names of parameters and variables in da functions ? (I fear I, G, M... might be tricky for debugging in few months ?)
  • do you have a way to check the rendering of doc string ? I'd like to make new class docstrings compliant with your guidelines now, so that we won't have to spend extra sessions documenting code ;)

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK thank you @Slasnista , it's indeed starting to converge and i see that test are not failing anymore !

before merging I would like to see at least the EMDTransport() implemented. Ican take care of the classes whth linear and kernel mapping estimation if you want.

I added a little comment and I will answer your questions:

  • Not in thus PR we cannot. This can be discussed in the future but this PR focus on a new and improved class for ot.
  • there are a lot of things imported because as for numpy I want all the base ot functions (emd,emd2,sinkhorn,sinkhorn2, dist) to be easily abailable with no need for importing submodule.
  • variable name change will not be done in this PR also because it is a full breaker for legacy classes ans all examples. It is open for discussion but note that ALL the ot function in the toolbox use this denomination and consistency is IMHO very important. Finally G/gamma/T are standard names for an OT matrix in the ot community and M/C are stanadr names for cost matrices.
  • on spyder editor, you can perform a ctrl-i to open the documentation of the current function (just click on its definition before). the documentation can be seen in raw form of with rendering.

test/test_da.py Outdated
np.random.seed(42)


def test_sinkhorn_transport():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe test_sinkhorn_transport_class would be a better name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for the reply, it's good for me. I will modify the name of the test function

@ngayraud
Copy link
Contributor

ngayraud commented Aug 1, 2017

Here is a bit of code you can add for a simple EMD transport. I hope I have it right.
Also from what I remember, if we set log = true, the methods return the log as well. So we need an "If log statement (see code below)

class EMDTransport(BaseTransport):
    """Domain Adapatation OT method based on Earth Mover's Distance
    Parameters
    ----------
    mode : string, optional (default="unsupervised")
        The DA mode. If "unsupervised" no target labels are taken into account
        to modify the cost matrix. If "semisupervised" the target labels
        are taken into account to set coefficients of the pairwise distance
        matrix to 0 for row and columns indices that correspond to source and
        target samples which share the same labels.
    mapping : string, optional (default="barycentric")
        The kind of mapping to apply to transport samples from a domain into
        another one.
        if "barycentric" only the samples used to estimate the coupling can
        be transported from a domain to another one.
    metric : string, optional (default="sqeuclidean")
        The ground metric for the Wasserstein problem
    distribution : string, optional (default="uniform")
        The kind of distribution estimation to employ
    verbose : int, optional (default=0)
        Controls the verbosity of the optimization algorithm
    log : int, optional (default=0)
        Controls the logs of the optimization algorithm
    Attributes
    ----------
    References
    ----------
    .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
           "Optimal Transport for Domain Adaptation," in IEEE Transactions
           on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
    """

    def __init__(self, mode="unsupervised", verbose=False, 
                 log=False, metric="sqeuclidean",
                 distribution_estimation=distribution_estimation_uniform,
                 out_of_sample_map='ferradans'):

        self.mode = mode
        self.verbose = verbose
        self.log = log
        self.metric = metric
        self.distribution_estimation = distribution_estimation
        self.out_of_sample_map = out_of_sample_map

    def fit(self, Xs, ys=None, Xt=None, yt=None):
        """Build a coupling matrix from source and target sets of samples
        (Xs, ys) and (Xt, yt)
        Parameters
        ----------
        Xs : array-like of shape = [n_source_samples, n_features]
            The training input samples.
        ys : array-like, shape = [n_source_samples]
            The class labels
        Xt : array-like of shape = [n_target_samples, n_features]
            The training input samples.
        yt : array-like, shape = [n_labeled_target_samples]
            The class labels
        Returns
        -------
        self : object
            Returns self.
        """

        self = super(EMDTransport, self).fit(Xs, ys, Xt, yt)

        # coupling estimation
        if self.log:
                    self.gamma_, self.log_ = emd(
                        a=self.mu_s, b=self.mu_t, M=self.Cost, 
                        verbose=self.verbose, log=self.log)
        else:
                    self.gamma_= emd(
                        a=self.mu_s, b=self.mu_t, M=self.Cost, 
                        verbose=self.verbose, log=self.log)

@Slasnista
Copy link
Contributor Author

@ngayraud thank you for the code ! I adapt it and put it inside the pr ;)

just for if we can solve it like it

# coupling estimation
returned_ = emd(
    a=self.mu_s, b=self.mu_t, M=self.Cost, 
    verbose=self.verbose, log=self.log)
if self.log:
     self.coupling_, self.log = returned_
else:
    self.coupling_ = returned_

That does not change much, but we do not have the coupling estimation in a loop any more.

Btw, do all the other methods return log in addition to couplings ?

@rflamary
Copy link
Collaborator

rflamary commented Aug 1, 2017 via email

@Slasnista
Copy link
Contributor Author

@rflamary thx, the test gave me this error ;) I removed the log and verbose input to emd for the moment.

@ngayraud
Copy link
Contributor

ngayraud commented Aug 1, 2017

I think that all the sinkhorn based transports return a log! I was unsure about the emd one though.

@Slasnista
Copy link
Contributor Author

Hi,
I did some update to the PR:

  • add a test for fit_transform method and correction of a bug for this method (did not return self)
  • add 2 new classes with class regularization SinkhornLpl1Transport() and SinkhornL1l2Transport()
  • added few in lines in .fit() method of BaseTransport() class to support semi supervised domain adaptation. It implies that the labeled sources (resp. target) samples are in the first rows (resp. columns) + added a test to check that when semi supervised is selected but no labels (neither ys, yt, nor both) unsupervised DA is performed instead of semi-supervised

What do you think about these new modifications ?

@rflamary
Copy link
Collaborator

rflamary commented Aug 4, 2017

Hello @Slasnista,

Those are nice new features.

I just have a few questions/todos:

  • Why is the method necessary? I think the fact that the ys and yt are set to None is enough to know if semi supervised should be used. No need to have two signal that can contradict...

  • Also what is that about the supervised sample being the first ones? Couldn't we use ys/yt to define the supervised ones (0..N) and usnupervised (-1) ?

  • Semi supervised as in the PAMI paper does not work with O cost when the class is the same. It uses infinite (or very large) cost when the label is different. This is a versy different approach based on DO-NOT-LINK constratints and still provide OT inside classes which is no done when the loss in 0.

  • Finally I think the last piece before merging is coding a proper out of sample as done here:
    https://github.com/Slasnista/POT/blob/0b005906f9d78adbf4d52d2ea9610eb3fde96a7c/ot/da.py#L712
    Note that it's just a few lines but one should perform it on minibach in order to limit the memory requirement and hence it should Have a loop.

  • I know I am never happy but the interpolation (linear and kernel regression methods are also missing).

  • Once all the classes are done properly, feel free to add a warning to OTDA class that tell about the future deprecation and the new classes. I will add both of you to the contributor list BTW.

This will be a very nice merge thank you again for all the work .

@Slasnista
Copy link
Contributor Author

Slasnista commented Aug 4, 2017

@rflamary:

  1. ok for me, this way it will not be too redundant.
  2. ok, the way it's written works too independently of the position of the labeled samples so nothing to change
  3. ok, but how do you set this infinite value ? 10**2, 10**3, something based on the max value of the cost matrix ? Do you have a predefined rule for doing this ?

I address the following points in a a second time ;)

@rflamary
Copy link
Collaborator

rflamary commented Aug 4, 2017

@Slasnista nice to see we agree. Point 3 is the tough one ;).

For sinkhorn and group lasso regularization we set it to infinity (np.inf) and under constraint qualification conditions (there exist a solution that is not infinite), the algorithm converges to a solution (with zerso mass on the do-not-link links).

For emd solver (and conjugate gradient that uses emd) I don't think the solver is going to appreciate infinity so I would use as you propose the max value of the cost matrix plus something maybe epsilon(max(Cost)). I think it should be a parameter anyway with the max/inf taken if set to None.

Take your time for the remaining I will be in vacation next week and probably won't fine time to do a proper final review anyway.

@Slasnista
Copy link
Contributor Author

I completed all the modifications we talked about.

I think we should grasp this occasion to think about the algorithms' outputs. In my opinion:

  1. functions which compute couplings and or mappings should always return logs_.
  2. functions which compute couplings and mappings should always return coupling_, mapping_, bias_, logs_. I would prefer return a zero vector instead of having an output of varying size depending of the bias.

Point 2. is not a burning issue but 1. is one.

Right now, instantiating an object with logs=True results in an error. I propose to remove logs from __init__ and to let all .fit methods create an attribute .logs_ which will contain the logs. What do you think about it ?

@rflamary
Copy link
Collaborator

Hello @Slasnista,

Thank you for the new commits that handle all the things we discussed. I will do a proper code review but in the meantime you need to update the examples in the examples folder that use the old classes:

  • plot_OTDA_classes.py
  • plot_OTDA_color_images.py
  • plot_OTDA_mapping.py
  • plot_OTDA_mapping_color_images.py

We don't want the examples to be deprecated ;) Once this is done I will take care of updating the notebooks which is not done automatically yet.

About your questions, there should be a log parameter in init of the classes (note that it's log and not logs in all the toolbox as a parameter, log_ should be ok for storing the log ) and not the fit. Depending on this parameter the fit method should update the log_ parameter in the class. This parameter should exist all the time and be a an empty dictionary when log is set to False (maybe it should be set in the init of the Base class ?).

I understand that the varying size return in the algorithm is surprising but it leads to a very simple use of the functions while still allowing expert user to get more information if needed. Changing the behavior of the algorithm now will completely break the toolbox and again is not the object of this PR.
The reason for this is that computing the log can be time consuming since more things have to be returned (and the user should be able to choose of not to compute it).

@ncourty
Copy link
Collaborator

ncourty commented Aug 25, 2017

Hi all, entering the discussion after some holiday break :)
@rflamary, regarding the notebooks, since we decided to go full sphinx-gallery for the example, it may be the right tempo to consider re-writing them in the sphinx syntax ?

@agramfort
Copy link
Collaborator

agramfort commented Aug 25, 2017 via email

@Slasnista
Copy link
Contributor Author

Hi @rflamary,

  • there si no error anymore when log=True
  • I created a .log_ that is either filled or instantiated as an empty dictionary depending on whether log is set to True
  • for methods without log parameter (emd, lpl1) I removed log from parameters in the __init__
  • I added some test to check that log_ is filled when log=True.

@rflamary
Copy link
Collaborator

Hello everyone,

@ncourty since you are back from holiday and full of energy maybe you could take care of the test file formating ;) ? As POT maintainer you should be able to push commits on the PR.

@Slasnista thank you, every thing looks great, I will do a proper code review shortly but I think we will be able to do a merge shortly.

@Slasnista
Copy link
Contributor Author

Done with the examples according to the new OTDA objects !

@rflamary you're welcome, I now look forward to seeing it merged ;-)

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, just a few changes and a go from travis and I think we can merge and do a new POT version.

@Slasnista don't forget to add yourself and @ngayraud to the contributors of the toolbox in the readme file.

@@ -0,0 +1,103 @@
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I think this short module is a bit overkill, maybe move the class and function to the utils.py module with the origin URL.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I do that

ot/da.py Outdated
@@ -10,21 +10,28 @@
# License: MIT License

import numpy as np
import warnings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it still necessary? I think everything is in deprecated now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check whether it's used in the script otherwise I remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually I need into the BaseEstimator class.

I will move the BaseEstimator classe into utils.py. That would be more relevant

@@ -706,12 +763,19 @@ def normalizeM(self, norm):
self.M = np.log(1 + np.log(1 + self.M))


@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the decorator! Removed in 0.5 is short but probably a good thing since it will disappear from the examples shortly

@@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You chose to move the da examples in a sub-folder do you think it is more clear or is it to be similar to sk-learn? did you remove the old ones, every trace of the old class should disappear from the examples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I think sub-categories would clearer for examples and later on for the doc generation
  • I did not remove the old ones, shoud I move them in the new sub-folder ? @agramfort : do you have some advice to deprecate examples / documention jointly with some code ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shoud remove them from the examples since it's what people look at when trying out the toolbox.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I do that so

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes remove deprecated code from all examples and documentation

ot/da.py Outdated
if p.name != 'self' and p.kind != p.VAR_KEYWORD]
for p in parameters:
if p.kind == p.VAR_POSITIONAL:
raise RuntimeError("scikit-learn estimators should always "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove scikit-learn and say POT ;)

ot/da.py Outdated
# adapted from sklearn

class BaseEstimator(object):
"""Base class for all estimators in scikit-learn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add POt reference and a link to the original source in sk-learn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, i remove all the scikit refs like this one

ot/da.py Outdated


class BaseTransport(BaseEstimator):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add documentation and list of necessary parameters in init for the class to run its functions
such as metric

Copy link
Contributor Author

@Slasnista Slasnista Aug 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not do that for BaseTransport since it has no __init__ and no real parameters but I will check the others to not forget parameters

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what i'm saying is that you code here functions fit and transform that require some parameters/properties to be set before calling ()hence that should be defined in init. you should express here which properties should be set in order for the function to work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this to allows new contributors to propose new classes with no headaches ;)

ot/da.py Outdated

class BaseTransport(BaseEstimator):

def fit(self, Xs=None, ys=None, Xt=None, yt=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we give all None values to the input parameters, we should handle those in case the user just call the function with nothing.

When a necessary parameter is missing we should print a warning and not do anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good remark, I will handle the possible exceptions

ot/da.py Outdated

return self

def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SAme here with the necessary parameters, maybe we should coce a test function with an error message that lists the necessary parameters i'm thinking as

def check_params(**kwargs):
    OK=True
    for param in kwargs:
        if kwargs[kwargs] is None:
            OK=False

    if not OK:
        #Print the list of Necessary params in kwargs

    return OK


# the function can be called inside the classes as
if check_params(Xs=Xs,Xt=Xt):
    do the stuff

What is nice is that we can give different keywords params depending on the class (semi supervised or else).

This function or at leats a clener version should be in utils.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry but this is not what you did in the last commit, you just added a condition, oi think the implementation above would be more future proof since it handles the warning/error message automatically, no need for else.

ot/da.py Outdated
The transport source samples.
"""

if np.array_equal(self.Xs, Xs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.Xs is Xs will test if the array is the same object and avoid the multiple tests and computation in np_array_equal.

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK seems good to me.

Let's sleep on it and proceed with the merge tomorrow if everyone is OK with it.

Thank you to all of you and especially @Slasnista who did a wonderful job on these brand new classes.

After the merge we will wait for new features from @ncourty and release 0.4 on pipy.

@@ -61,62 +50,72 @@ def mat2im(X, shape):
idx1 = np.random.randint(X1.shape[0], size=(nb,))
idx2 = np.random.randint(X2.shape[0], size=(nb,))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use RandomState

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean a seed ?

import ot

# number of source and target points to generate
ns = 150
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_samples_source = 150
n_samples_target = 150

it makes your comment above unnecessary and the code more readable

Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)

# Cost matrix
M = ot.dist(Xs, Xt)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's doing euclidian distance than naming / API is not explicit enough IMO

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dist is basically a wrapper around scipy cdist where you can choose the metric.

Also it shouldn't be computed here since it's not used anywhere it seems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I did compute it to plot it along with source and target points


n = 100 # nb samples in source and target datasets
theta = 2 * np.pi / 20
nz = 0.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nz ? could you use more explicit names?

max_iter=20, verbose=True)

ot_mapping_linear.fit(
Xs=Xs, Xt=Xt)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it fits on one line

@Slasnista
Copy link
Contributor Author

Final modifications according to @agramfort comments pushed

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry last change and we can merge on my side I promise ;)

import ot

np.random.seed(42)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think @agramfort prefer to use RandomSate as explained in
https://stackoverflow.com/questions/22994423/difference-between-np-random-seed-and-np-random-randomstate

This allow you to use controlable random generator without impacting the global one as is done in random.seed.

Note that we plan on doing proper dataset functions that accept a random state generator as parameter to allow full control (See #21 ).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries ;)

@rflamary rflamary merged commit a2ec6e5 into PythonOT:master Aug 29, 2017
@rflamary
Copy link
Collaborator

OK it is merged, thank you to @ngayraud @Slasnista and @agramfort for the work.

@Slasnista Slasnista mentioned this pull request Aug 30, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants