New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Jcpot : Multi source DA with target shift #137
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still a few things to handle the PR is looking good.
README.md
Outdated
@@ -29,6 +29,7 @@ It provides the following solvers: | |||
* Non regularized free support Wasserstein barycenters [20]. | |||
* Unbalanced OT with KL relaxation distance and barycenter [10, 25]. | |||
* Screening Sinkhorn Algorithm for OT [26]. | |||
* JCPOT algorithm for multi-source target shift [27]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multi source domain adaptation with target shift
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ot/bregman.py
Outdated
|
||
# build the cost matrix and the Gibbs kernel | ||
M = dist(Xs[d], Xt, metric=metric) | ||
M = M / np.median(M) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normalization by median should be optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed median and adjusted parameters in the tests accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good. just a few more comments and we are done.
Thanks again @ievred
@@ -549,3 +547,57 @@ def test_linear_mapping_class(): | |||
Cst = np.cov(Xst.T) | |||
|
|||
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) | |||
|
|||
|
|||
def test_jcpot_transport_class(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you also write a test for the jcpot function in addition to the class?
It's nice to chgeck both since the interface is available for both. you can also use it on another test dataset with obvious solutions (repeating samples and known proportions for instance).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
class JCPOTTransport(BaseTransport): | ||
|
||
"""Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add to the documentation what kind of mzpping is used? barycentric it seems but you could also so label prop by keeping the target position and providong non binary one hot encoding no? This couls be a parameter to give to the method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added label prop for base class and jcpot + tests for all otda methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice!
ot/bregman.py
Outdated
The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. | ||
|
||
The algorithm used for solving the problem is the Iterative Bregman projections algorithm | ||
with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
target distribution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
corrected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very nice.
Just a few more comments related mostly to the new transfer_labels functions
|
||
bary = bary / np.sum(bary) | ||
|
||
if log: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum, do we really need to return the gammas by default?
The name of the function is barycenter which is the weights h in this case.
I would put the gammas in the log or at least second position.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as far as I understand we have to return gamma as in documentation of the BaseTransport of da.py it is said that: "fit method should estimate a coupling matrix and store it in a coupling_ attribute"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ot/da.py
Outdated
# compute transported samples | ||
transp_ys = np.dot(D1, transp) | ||
|
||
return np.argmax(transp_ys, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why return argmax? I woul return the smooth label estimations and let the user do the argmax if he really wants one label.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, changed documentation to "soft labels"
ot/da.py
Outdated
# compute transported samples | ||
transp_ys = np.dot(D1, transp.T) | ||
|
||
return np.argmax(transp_ys, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, changed documentation to "soft labels"
ot/da.py
Outdated
transp[~ np.isfinite(transp)] = 0 | ||
|
||
# compute transported labels | ||
transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same argmax comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, changed documentation to "soft labels"
Added jcpot class in the da.py, jcpot_barycenter with the optimization routine, an example and a test