-
Notifications
You must be signed in to change notification settings - Fork 528
Domain adaptation Classes #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
553a456
remove linewidth error message
Slasnista ca9c9d6
first proposal for OT wrappers
Slasnista 84adadd
small modifs according to NG proposals
Slasnista cd3397f
integrate AG comments
Slasnista bd7c7d2
own BaseEstimator class written + rflamary comments addressed
Slasnista 122b5bf
update SinkhornTransport class + added test for class
Slasnista d9be6c2
added EMDTransport Class from NG's code + added dedicated test
Slasnista 70be034
added test for fit_transform + correction of fit_transform bug (missi…
Slasnista 64880e7
added new class SinkhornLpl1Transport() + dedicated test
Slasnista 727077a
added new class SinkhornL1l2Transport() + dedicated test
Slasnista 0b00590
semi supervised mode supported
Slasnista d793f1f
correction of semi supervised mode
Slasnista 778f4f7
reformat doc strings + remove useless log / verbose parameters for emd
Slasnista 738bfb1
out of samples by Ferradans supported for transform and inverse_trans…
Slasnista 8a21429
added new class MappingTransport to support linear and kernel mapping…
Slasnista 8149e05
make doc strings compliant with numpy / modif according to AG review
Slasnista 791a4a6
out of samples transform and inverse transform by batch
Slasnista 326d163
test functions for MappingTransport Class
Slasnista 0930223
added deprecation warning on old classes
Slasnista 2d4d0b4
solving log issues to avoid errors and adding further tests
Slasnista 74ca2d7
refactoring examples according to new DA classes
Slasnista f80693b
small corrections for examples
Slasnista 892d7ce
set properly path of data
Slasnista 55840f6
move no da objects into utils.py
Slasnista a8fa91b
handling input arguments in fit, transform... methods + remove old ex…
Slasnista c5d7c40
check input parameters with helper functions
Slasnista 7d3fc95
update readme
Slasnista a29e22d
addressed AG comments + adding random seed
Slasnista 65de6fc
pass on examples | introduced RandomState
Slasnista File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
======================== | ||
OT for domain adaptation | ||
======================== | ||
|
||
This example introduces a domain adaptation in a 2D setting and the 4 OTDA | ||
approaches currently supported in POT. | ||
|
||
""" | ||
|
||
# Authors: Remi Flamary <remi.flamary@unice.fr> | ||
# Stanislas Chambon <stan.chambon@gmail.com> | ||
# | ||
# License: MIT License | ||
|
||
import matplotlib.pylab as pl | ||
import ot | ||
|
||
|
||
############################################################################## | ||
# generate data | ||
############################################################################## | ||
|
||
n_source_samples = 150 | ||
n_target_samples = 150 | ||
|
||
Xs, ys = ot.datasets.get_data_classif('3gauss', n_source_samples) | ||
Xt, yt = ot.datasets.get_data_classif('3gauss2', n_target_samples) | ||
|
||
|
||
############################################################################## | ||
# Instantiate the different transport algorithms and fit them | ||
############################################################################## | ||
|
||
# EMD Transport | ||
ot_emd = ot.da.EMDTransport() | ||
ot_emd.fit(Xs=Xs, Xt=Xt) | ||
|
||
# Sinkhorn Transport | ||
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1) | ||
ot_sinkhorn.fit(Xs=Xs, Xt=Xt) | ||
|
||
# Sinkhorn Transport with Group lasso regularization | ||
ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0) | ||
ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt) | ||
|
||
# Sinkhorn Transport with Group lasso regularization l1l2 | ||
ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20, | ||
verbose=True) | ||
ot_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt) | ||
|
||
# transport source samples onto target samples | ||
transp_Xs_emd = ot_emd.transform(Xs=Xs) | ||
transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs) | ||
transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs) | ||
transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs) | ||
|
||
|
||
############################################################################## | ||
# Fig 1 : plots source and target samples | ||
############################################################################## | ||
|
||
pl.figure(1, figsize=(10, 5)) | ||
pl.subplot(1, 2, 1) | ||
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') | ||
pl.xticks([]) | ||
pl.yticks([]) | ||
pl.legend(loc=0) | ||
pl.title('Source samples') | ||
|
||
pl.subplot(1, 2, 2) | ||
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') | ||
pl.xticks([]) | ||
pl.yticks([]) | ||
pl.legend(loc=0) | ||
pl.title('Target samples') | ||
pl.tight_layout() | ||
|
||
|
||
############################################################################## | ||
# Fig 2 : plot optimal couplings and transported samples | ||
############################################################################## | ||
|
||
param_img = {'interpolation': 'nearest', 'cmap': 'spectral'} | ||
|
||
pl.figure(2, figsize=(15, 8)) | ||
pl.subplot(2, 4, 1) | ||
pl.imshow(ot_emd.coupling_, **param_img) | ||
pl.xticks([]) | ||
pl.yticks([]) | ||
pl.title('Optimal coupling\nEMDTransport') | ||
|
||
pl.subplot(2, 4, 2) | ||
pl.imshow(ot_sinkhorn.coupling_, **param_img) | ||
pl.xticks([]) | ||
pl.yticks([]) | ||
pl.title('Optimal coupling\nSinkhornTransport') | ||
|
||
pl.subplot(2, 4, 3) | ||
pl.imshow(ot_lpl1.coupling_, **param_img) | ||
pl.xticks([]) | ||
pl.yticks([]) | ||
pl.title('Optimal coupling\nSinkhornLpl1Transport') | ||
|
||
pl.subplot(2, 4, 4) | ||
pl.imshow(ot_l1l2.coupling_, **param_img) | ||
pl.xticks([]) | ||
pl.yticks([]) | ||
pl.title('Optimal coupling\nSinkhornL1l2Transport') | ||
|
||
pl.subplot(2, 4, 5) | ||
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', | ||
label='Target samples', alpha=0.3) | ||
pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, | ||
marker='+', label='Transp samples', s=30) | ||
pl.xticks([]) | ||
pl.yticks([]) | ||
pl.title('Transported samples\nEmdTransport') | ||
pl.legend(loc="lower left") | ||
|
||
pl.subplot(2, 4, 6) | ||
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', | ||
label='Target samples', alpha=0.3) | ||
pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, | ||
marker='+', label='Transp samples', s=30) | ||
pl.xticks([]) | ||
pl.yticks([]) | ||
pl.title('Transported samples\nSinkhornTransport') | ||
|
||
pl.subplot(2, 4, 7) | ||
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', | ||
label='Target samples', alpha=0.3) | ||
pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys, | ||
marker='+', label='Transp samples', s=30) | ||
pl.xticks([]) | ||
pl.yticks([]) | ||
pl.title('Transported samples\nSinkhornLpl1Transport') | ||
|
||
pl.subplot(2, 4, 8) | ||
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', | ||
label='Target samples', alpha=0.3) | ||
pl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys, | ||
marker='+', label='Transp samples', s=30) | ||
pl.xticks([]) | ||
pl.yticks([]) | ||
pl.title('Transported samples\nSinkhornL1l2Transport') | ||
pl.tight_layout() | ||
|
||
pl.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
======================================================== | ||
OT for domain adaptation with image color adaptation [6] | ||
======================================================== | ||
|
||
This example presents a way of transferring colors between two image | ||
with Optimal Transport as introduced in [6] | ||
|
||
[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). | ||
Regularized discrete optimal transport. | ||
SIAM Journal on Imaging Sciences, 7(3), 1853-1882. | ||
""" | ||
|
||
# Authors: Remi Flamary <remi.flamary@unice.fr> | ||
# Stanislas Chambon <stan.chambon@gmail.com> | ||
# | ||
# License: MIT License | ||
|
||
import numpy as np | ||
from scipy import ndimage | ||
import matplotlib.pylab as pl | ||
import ot | ||
|
||
|
||
r = np.random.RandomState(42) | ||
|
||
|
||
def im2mat(I): | ||
"""Converts and image to matrix (one pixel per line)""" | ||
return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) | ||
|
||
|
||
def mat2im(X, shape): | ||
"""Converts back a matrix to an image""" | ||
return X.reshape(shape) | ||
|
||
|
||
def minmax(I): | ||
return np.clip(I, 0, 1) | ||
|
||
|
||
############################################################################## | ||
# generate data | ||
############################################################################## | ||
|
||
# Loading images | ||
I1 = ndimage.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 | ||
I2 = ndimage.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 | ||
|
||
X1 = im2mat(I1) | ||
X2 = im2mat(I2) | ||
|
||
# training samples | ||
nb = 1000 | ||
idx1 = r.randint(X1.shape[0], size=(nb,)) | ||
idx2 = r.randint(X2.shape[0], size=(nb,)) | ||
|
||
Xs = X1[idx1, :] | ||
Xt = X2[idx2, :] | ||
|
||
|
||
############################################################################## | ||
# Instantiate the different transport algorithms and fit them | ||
############################################################################## | ||
|
||
# EMDTransport | ||
ot_emd = ot.da.EMDTransport() | ||
ot_emd.fit(Xs=Xs, Xt=Xt) | ||
|
||
# SinkhornTransport | ||
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1) | ||
ot_sinkhorn.fit(Xs=Xs, Xt=Xt) | ||
|
||
# prediction between images (using out of sample prediction as in [6]) | ||
transp_Xs_emd = ot_emd.transform(Xs=X1) | ||
transp_Xt_emd = ot_emd.inverse_transform(Xt=X2) | ||
|
||
transp_Xs_sinkhorn = ot_emd.transform(Xs=X1) | ||
transp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2) | ||
|
||
I1t = minmax(mat2im(transp_Xs_emd, I1.shape)) | ||
I2t = minmax(mat2im(transp_Xt_emd, I2.shape)) | ||
|
||
I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape)) | ||
I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape)) | ||
|
||
|
||
############################################################################## | ||
# plot original image | ||
############################################################################## | ||
|
||
pl.figure(1, figsize=(6.4, 3)) | ||
|
||
pl.subplot(1, 2, 1) | ||
pl.imshow(I1) | ||
pl.axis('off') | ||
pl.title('Image 1') | ||
|
||
pl.subplot(1, 2, 2) | ||
pl.imshow(I2) | ||
pl.axis('off') | ||
pl.title('Image 2') | ||
|
||
|
||
############################################################################## | ||
# scatter plot of colors | ||
############################################################################## | ||
|
||
pl.figure(2, figsize=(6.4, 3)) | ||
|
||
pl.subplot(1, 2, 1) | ||
pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs) | ||
pl.axis([0, 1, 0, 1]) | ||
pl.xlabel('Red') | ||
pl.ylabel('Blue') | ||
pl.title('Image 1') | ||
|
||
pl.subplot(1, 2, 2) | ||
pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt) | ||
pl.axis([0, 1, 0, 1]) | ||
pl.xlabel('Red') | ||
pl.ylabel('Blue') | ||
pl.title('Image 2') | ||
pl.tight_layout() | ||
|
||
|
||
############################################################################## | ||
# plot new images | ||
############################################################################## | ||
|
||
pl.figure(3, figsize=(8, 4)) | ||
|
||
pl.subplot(2, 3, 1) | ||
pl.imshow(I1) | ||
pl.axis('off') | ||
pl.title('Image 1') | ||
|
||
pl.subplot(2, 3, 2) | ||
pl.imshow(I1t) | ||
pl.axis('off') | ||
pl.title('Image 1 Adapt') | ||
|
||
pl.subplot(2, 3, 3) | ||
pl.imshow(I1te) | ||
pl.axis('off') | ||
pl.title('Image 1 Adapt (reg)') | ||
|
||
pl.subplot(2, 3, 4) | ||
pl.imshow(I2) | ||
pl.axis('off') | ||
pl.title('Image 2') | ||
|
||
pl.subplot(2, 3, 5) | ||
pl.imshow(I2t) | ||
pl.axis('off') | ||
pl.title('Image 2 Adapt') | ||
|
||
pl.subplot(2, 3, 6) | ||
pl.imshow(I2te) | ||
pl.axis('off') | ||
pl.title('Image 2 Adapt (reg)') | ||
pl.tight_layout() | ||
|
||
pl.show() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You chose to move the da examples in a sub-folder do you think it is more clear or is it to be similar to sk-learn? did you remove the old ones, every trace of the old class should disappear from the examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we shoud remove them from the examples since it's what people look at when trying out the toolbox.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I do that so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes remove deprecated code from all examples and documentation