Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Jcpot : Multi source DA with target shift #137

Merged
merged 29 commits into from Apr 15, 2020
Merged

Conversation

ievred
Copy link
Contributor

@ievred ievred commented Mar 31, 2020

Added jcpot class in the da.py, jcpot_barycenter with the optimization routine, an example and a test

@rflamary rflamary changed the title Jcpot [WIP] Jcpot Mar 31, 2020
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still a few things to handle the PR is looking good.

README.md Outdated
@@ -29,6 +29,7 @@ It provides the following solvers:
* Non regularized free support Wasserstein barycenters [20].
* Unbalanced OT with KL relaxation distance and barycenter [10, 25].
* Screening Sinkhorn Algorithm for OT [26].
* JCPOT algorithm for multi-source target shift [27].
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multi source domain adaptation with target shift

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

ot/bregman.py Outdated

# build the cost matrix and the Gibbs kernel
M = dist(Xs[d], Xt, metric=metric)
M = M / np.median(M)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalization by median should be optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed median and adjusted parameters in the tests accordingly

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good. just a few more comments and we are done.

Thanks again @ievred

@@ -549,3 +547,57 @@ def test_linear_mapping_class():
Cst = np.cov(Xst.T)

np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)


def test_jcpot_transport_class():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also write a test for the jcpot function in addition to the class?

It's nice to chgeck both since the interface is available for both. you can also use it on another test dataset with obvious solutions (repeating samples and known proportions for instance).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

class JCPOTTransport(BaseTransport):

"""Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add to the documentation what kind of mzpping is used? barycentric it seems but you could also so label prop by keeping the target position and providong non binary one hot encoding no? This couls be a parameter to give to the method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added label prop for base class and jcpot + tests for all otda methods

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice!

ot/bregman.py Outdated
The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.

The algorithm used for solving the problem is the Iterative Bregman projections algorithm
with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

target distribution

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

corrected

@rflamary rflamary changed the title [WIP] Jcpot [WIP] Jcpot : Multi source DA with target shift Apr 10, 2020
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice.

Just a few more comments related mostly to the new transfer_labels functions


bary = bary / np.sum(bary)

if log:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, do we really need to return the gammas by default?

The name of the function is barycenter which is the weights h in this case.
I would put the gammas in the log or at least second position.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as far as I understand we have to return gamma as in documentation of the BaseTransport of da.py it is said that: "fit method should estimate a coupling matrix and store it in a coupling_ attribute"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

ot/da.py Outdated
# compute transported samples
transp_ys = np.dot(D1, transp)

return np.argmax(transp_ys, axis=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why return argmax? I woul return the smooth label estimations and let the user do the argmax if he really wants one label.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, changed documentation to "soft labels"

ot/da.py Outdated
# compute transported samples
transp_ys = np.dot(D1, transp.T)

return np.argmax(transp_ys, axis=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, changed documentation to "soft labels"

ot/da.py Outdated
transp[~ np.isfinite(transp)] = 0

# compute transported labels
transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same argmax comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, changed documentation to "soft labels"

@rflamary rflamary changed the title [WIP] Jcpot : Multi source DA with target shift [MRG] Jcpot : Multi source DA with target shift Apr 15, 2020
@rflamary rflamary merged commit adc5570 into PythonOT:master Apr 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants