Skip to content

Domain adaptation Classes #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Aug 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
553a456
remove linewidth error message
Slasnista Jul 28, 2017
ca9c9d6
first proposal for OT wrappers
Slasnista Jul 28, 2017
84adadd
small modifs according to NG proposals
Slasnista Jul 28, 2017
cd3397f
integrate AG comments
Slasnista Jul 28, 2017
bd7c7d2
own BaseEstimator class written + rflamary comments addressed
Slasnista Jul 31, 2017
122b5bf
update SinkhornTransport class + added test for class
Slasnista Aug 1, 2017
d9be6c2
added EMDTransport Class from NG's code + added dedicated test
Slasnista Aug 1, 2017
70be034
added test for fit_transform + correction of fit_transform bug (missi…
Slasnista Aug 4, 2017
64880e7
added new class SinkhornLpl1Transport() + dedicated test
Slasnista Aug 4, 2017
727077a
added new class SinkhornL1l2Transport() + dedicated test
Slasnista Aug 4, 2017
0b00590
semi supervised mode supported
Slasnista Aug 4, 2017
d793f1f
correction of semi supervised mode
Slasnista Aug 4, 2017
778f4f7
reformat doc strings + remove useless log / verbose parameters for emd
Slasnista Aug 4, 2017
738bfb1
out of samples by Ferradans supported for transform and inverse_trans…
Slasnista Aug 4, 2017
8a21429
added new class MappingTransport to support linear and kernel mapping…
Slasnista Aug 4, 2017
8149e05
make doc strings compliant with numpy / modif according to AG review
Slasnista Aug 23, 2017
791a4a6
out of samples transform and inverse transform by batch
Slasnista Aug 23, 2017
326d163
test functions for MappingTransport Class
Slasnista Aug 23, 2017
0930223
added deprecation warning on old classes
Slasnista Aug 23, 2017
2d4d0b4
solving log issues to avoid errors and adding further tests
Slasnista Aug 25, 2017
74ca2d7
refactoring examples according to new DA classes
Slasnista Aug 25, 2017
f80693b
small corrections for examples
Slasnista Aug 25, 2017
892d7ce
set properly path of data
Slasnista Aug 25, 2017
55840f6
move no da objects into utils.py
Slasnista Aug 28, 2017
a8fa91b
handling input arguments in fit, transform... methods + remove old ex…
Slasnista Aug 28, 2017
c5d7c40
check input parameters with helper functions
Slasnista Aug 28, 2017
7d3fc95
update readme
Slasnista Aug 28, 2017
a29e22d
addressed AG comments + adding random seed
Slasnista Aug 29, 2017
65de6fc
pass on examples | introduced RandomState
Slasnista Aug 29, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ The contributors to this library are:
* [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/)
* [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation)
* [Léo Gautheron](https://github.com/aje) (GPU implementation)
* [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1)
* [Stanislas Chambon](https://slasnista.github.io/)

This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):

Expand Down
150 changes: 150 additions & 0 deletions examples/da/plot_otda_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You chose to move the da examples in a sub-folder do you think it is more clear or is it to be similar to sk-learn? did you remove the old ones, every trace of the old class should disappear from the examples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I think sub-categories would clearer for examples and later on for the doc generation
  • I did not remove the old ones, shoud I move them in the new sub-folder ? @agramfort : do you have some advice to deprecate examples / documention jointly with some code ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shoud remove them from the examples since it's what people look at when trying out the toolbox.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I do that so

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes remove deprecated code from all examples and documentation

"""
========================
OT for domain adaptation
========================

This example introduces a domain adaptation in a 2D setting and the 4 OTDA
approaches currently supported in POT.

"""

# Authors: Remi Flamary <remi.flamary@unice.fr>
# Stanislas Chambon <stan.chambon@gmail.com>
#
# License: MIT License

import matplotlib.pylab as pl
import ot


##############################################################################
# generate data
##############################################################################

n_source_samples = 150
n_target_samples = 150

Xs, ys = ot.datasets.get_data_classif('3gauss', n_source_samples)
Xt, yt = ot.datasets.get_data_classif('3gauss2', n_target_samples)


##############################################################################
# Instantiate the different transport algorithms and fit them
##############################################################################

# EMD Transport
ot_emd = ot.da.EMDTransport()
ot_emd.fit(Xs=Xs, Xt=Xt)

# Sinkhorn Transport
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)

# Sinkhorn Transport with Group lasso regularization
ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)

# Sinkhorn Transport with Group lasso regularization l1l2
ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20,
verbose=True)
ot_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt)

# transport source samples onto target samples
transp_Xs_emd = ot_emd.transform(Xs=Xs)
transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)


##############################################################################
# Fig 1 : plots source and target samples
##############################################################################

pl.figure(1, figsize=(10, 5))
pl.subplot(1, 2, 1)
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title('Source samples')

pl.subplot(1, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title('Target samples')
pl.tight_layout()


##############################################################################
# Fig 2 : plot optimal couplings and transported samples
##############################################################################

param_img = {'interpolation': 'nearest', 'cmap': 'spectral'}

pl.figure(2, figsize=(15, 8))
pl.subplot(2, 4, 1)
pl.imshow(ot_emd.coupling_, **param_img)
pl.xticks([])
pl.yticks([])
pl.title('Optimal coupling\nEMDTransport')

pl.subplot(2, 4, 2)
pl.imshow(ot_sinkhorn.coupling_, **param_img)
pl.xticks([])
pl.yticks([])
pl.title('Optimal coupling\nSinkhornTransport')

pl.subplot(2, 4, 3)
pl.imshow(ot_lpl1.coupling_, **param_img)
pl.xticks([])
pl.yticks([])
pl.title('Optimal coupling\nSinkhornLpl1Transport')

pl.subplot(2, 4, 4)
pl.imshow(ot_l1l2.coupling_, **param_img)
pl.xticks([])
pl.yticks([])
pl.title('Optimal coupling\nSinkhornL1l2Transport')

pl.subplot(2, 4, 5)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
label='Target samples', alpha=0.3)
pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
marker='+', label='Transp samples', s=30)
pl.xticks([])
pl.yticks([])
pl.title('Transported samples\nEmdTransport')
pl.legend(loc="lower left")

pl.subplot(2, 4, 6)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
label='Target samples', alpha=0.3)
pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
marker='+', label='Transp samples', s=30)
pl.xticks([])
pl.yticks([])
pl.title('Transported samples\nSinkhornTransport')

pl.subplot(2, 4, 7)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
label='Target samples', alpha=0.3)
pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
marker='+', label='Transp samples', s=30)
pl.xticks([])
pl.yticks([])
pl.title('Transported samples\nSinkhornLpl1Transport')

pl.subplot(2, 4, 8)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
label='Target samples', alpha=0.3)
pl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys,
marker='+', label='Transp samples', s=30)
pl.xticks([])
pl.yticks([])
pl.title('Transported samples\nSinkhornL1l2Transport')
pl.tight_layout()

pl.show()
165 changes: 165 additions & 0 deletions examples/da/plot_otda_color_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# -*- coding: utf-8 -*-
"""
========================================================
OT for domain adaptation with image color adaptation [6]
========================================================

This example presents a way of transferring colors between two image
with Optimal Transport as introduced in [6]

[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport.
SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
"""

# Authors: Remi Flamary <remi.flamary@unice.fr>
# Stanislas Chambon <stan.chambon@gmail.com>
#
# License: MIT License

import numpy as np
from scipy import ndimage
import matplotlib.pylab as pl
import ot


r = np.random.RandomState(42)


def im2mat(I):
"""Converts and image to matrix (one pixel per line)"""
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))


def mat2im(X, shape):
"""Converts back a matrix to an image"""
return X.reshape(shape)


def minmax(I):
return np.clip(I, 0, 1)


##############################################################################
# generate data
##############################################################################

# Loading images
I1 = ndimage.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
I2 = ndimage.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256

X1 = im2mat(I1)
X2 = im2mat(I2)

# training samples
nb = 1000
idx1 = r.randint(X1.shape[0], size=(nb,))
idx2 = r.randint(X2.shape[0], size=(nb,))

Xs = X1[idx1, :]
Xt = X2[idx2, :]


##############################################################################
# Instantiate the different transport algorithms and fit them
##############################################################################

# EMDTransport
ot_emd = ot.da.EMDTransport()
ot_emd.fit(Xs=Xs, Xt=Xt)

# SinkhornTransport
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)

# prediction between images (using out of sample prediction as in [6])
transp_Xs_emd = ot_emd.transform(Xs=X1)
transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)

transp_Xs_sinkhorn = ot_emd.transform(Xs=X1)
transp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2)

I1t = minmax(mat2im(transp_Xs_emd, I1.shape))
I2t = minmax(mat2im(transp_Xt_emd, I2.shape))

I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))


##############################################################################
# plot original image
##############################################################################

pl.figure(1, figsize=(6.4, 3))

pl.subplot(1, 2, 1)
pl.imshow(I1)
pl.axis('off')
pl.title('Image 1')

pl.subplot(1, 2, 2)
pl.imshow(I2)
pl.axis('off')
pl.title('Image 2')


##############################################################################
# scatter plot of colors
##############################################################################

pl.figure(2, figsize=(6.4, 3))

pl.subplot(1, 2, 1)
pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
pl.axis([0, 1, 0, 1])
pl.xlabel('Red')
pl.ylabel('Blue')
pl.title('Image 1')

pl.subplot(1, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
pl.axis([0, 1, 0, 1])
pl.xlabel('Red')
pl.ylabel('Blue')
pl.title('Image 2')
pl.tight_layout()


##############################################################################
# plot new images
##############################################################################

pl.figure(3, figsize=(8, 4))

pl.subplot(2, 3, 1)
pl.imshow(I1)
pl.axis('off')
pl.title('Image 1')

pl.subplot(2, 3, 2)
pl.imshow(I1t)
pl.axis('off')
pl.title('Image 1 Adapt')

pl.subplot(2, 3, 3)
pl.imshow(I1te)
pl.axis('off')
pl.title('Image 1 Adapt (reg)')

pl.subplot(2, 3, 4)
pl.imshow(I2)
pl.axis('off')
pl.title('Image 2')

pl.subplot(2, 3, 5)
pl.imshow(I2t)
pl.axis('off')
pl.title('Image 2 Adapt')

pl.subplot(2, 3, 6)
pl.imshow(I2te)
pl.axis('off')
pl.title('Image 2 Adapt (reg)')
pl.tight_layout()

pl.show()
Loading