Skip to content

[MRG] Wasserstein convolutional barycenter #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)

[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning

[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.
Binary file added data/duck.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/heart.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/redcross.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/tooth.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
92 changes: 92 additions & 0 deletions examples/plot_convolutional_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

#%%
# -*- coding: utf-8 -*-
"""
============================================
Convolutional Wasserstein Barycenter example
============================================

This example is designed to illustrate how the Convolutional Wasserstein Barycenter
function of POT works.
"""

# Author: Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License


import numpy as np
import pylab as pl
import ot

##############################################################################
# Data preparation
# ----------------
#
# The four distributions are constructed from 4 simple images


f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]

A = []
f1 = f1 / np.sum(f1)
f2 = f2 / np.sum(f2)
f3 = f3 / np.sum(f3)
f4 = f4 / np.sum(f4)
A.append(f1)
A.append(f2)
A.append(f3)
A.append(f4)
A = np.array(A)

nb_images = 5

# those are the four corners coordinates that will be interpolated by bilinear
# interpolation
v1 = np.array((1, 0, 0, 0))
v2 = np.array((0, 1, 0, 0))
v3 = np.array((0, 0, 1, 0))
v4 = np.array((0, 0, 0, 1))


##############################################################################
# Barycenter computation and visualization
# ----------------------------------------
#

pl.figure(figsize=(10, 10))
pl.title('Convolutional Wasserstein Barycenters in POT')
cm = 'Blues'
# regularization parameter
reg = 0.004
for i in range(nb_images):
for j in range(nb_images):
pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
tx = float(i) / (nb_images - 1)
ty = float(j) / (nb_images - 1)

# weights are constructed by bilinear interpolation
tmp1 = (1 - tx) * v1 + tx * v2
tmp2 = (1 - tx) * v3 + tx * v4
weights = (1 - ty) * tmp1 + ty * tmp2

if i == 0 and j == 0:
pl.imshow(f1, cmap=cm)
pl.axis('off')
elif i == 0 and j == (nb_images - 1):
pl.imshow(f3, cmap=cm)
pl.axis('off')
elif i == (nb_images - 1) and j == 0:
pl.imshow(f2, cmap=cm)
pl.axis('off')
elif i == (nb_images - 1) and j == (nb_images - 1):
pl.imshow(f4, cmap=cm)
pl.axis('off')
else:
# call to barycenter computation
pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
pl.axis('off')
pl.show()
110 changes: 110 additions & 0 deletions ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,116 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)


def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images.

The function solves the following optimization problem:

.. math::
\mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)

where :

- :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
- reg is the regularization strength scalar value

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_

Parameters
----------
A : np.ndarray (n,w,h)
n distributions (2D images) of size w x h
reg : float
Regularization term >0
weights : np.ndarray (n,)
Weights of each image on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True


Returns
-------
a : (w,h) ndarray
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters


References
----------

.. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
ACM Transactions on Graphics (TOG), 34(4), 66


"""

if weights is None:
weights = np.ones(A.shape[0]) / A.shape[0]
else:
assert(len(weights) == A.shape[0])

if log:
log = {'err': []}

b = np.zeros_like(A[0, :, :])
U = np.ones_like(A)
KV = np.ones_like(A)

cpt = 0
err = 1

# build the convolution operator
t = np.linspace(0, 1, A.shape[1])
[Y, X] = np.meshgrid(t, t)
xi1 = np.exp(-(X - Y)**2 / reg)

def K(x):
return np.dot(np.dot(xi1, x), xi1)

while (err > stopThr and cpt < numItermax):

bold = b
cpt = cpt + 1

b = np.zeros_like(A[0, :, :])
for r in range(A.shape[0]):
KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
b = np.exp(b)
for r in range(A.shape[0]):
U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])

if cpt % 10 == 1:
err = np.sum(np.abs(bold - b))
# log and verbose print
if log:
log['err'].append(err)

if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))

if log:
log['niter'] = cpt
log['U'] = U
return b, log
else:
return b


def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
stopThr=1e-3, verbose=False, log=False):
"""
Expand Down
24 changes: 24 additions & 0 deletions test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,30 @@ def test_bary():
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)


def test_wassersteinbary():

size = 100 # size of a square image
a1 = np.random.randn(size, size)
a1 += a1.min()
a1 = a1 / np.sum(a1)
a2 = np.random.randn(size, size)
a2 += a2.min()
a2 = a2 / np.sum(a2)
# creating matrix A containing all distributions
A = np.zeros((2, 100, 100))
A[0, :, :] = a1
A[1, :, :] = a2

# wasserstein
reg = 1e-3
bary_wass = ot.bregman.convolutional_barycenter2d(A, reg)

np.testing.assert_allclose(1, np.sum(bary_wass))

# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)


def test_unmix():

n_bins = 50 # nb bins
Expand Down
8 changes: 4 additions & 4 deletions test/test_stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_stochastic_sag():
# test sag
n = 15
reg = 1
numItermax = 300000
numItermax = 30000
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_stochastic_asgd():
# test asgd
n = 15
reg = 1
numItermax = 300000
numItermax = 100000
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_sag_asgd_sinkhorn():
# test all algorithms
n = 15
reg = 1
nb_iter = 300000
nb_iter = 100000
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_dual_sgd_sinkhorn():
# test all dual algorithms
n = 10
reg = 1
nb_iter = 150000
nb_iter = 15000
batch_size = 10
rng = np.random.RandomState(0)

Expand Down