Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] New ot.gpu with cupy #67

Merged
merged 7 commits into from Sep 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 3 additions & 7 deletions README.md
Expand Up @@ -14,7 +14,7 @@ This open source Python library provide several solvers for optimization problem
It provides the following solvers:

* OT Network Flow solver for the linear program/ Earth Movers Distance [1].
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] and greedy SInkhorn [22] with optional GPU implementation (requires cudamat).
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] and greedy SInkhorn [22] with optional GPU implementation (requires cupy).
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
* Bregman projections for Wasserstein barycenter [3], convolutional barycenter [21] and unmixing [4].
Expand Down Expand Up @@ -83,12 +83,8 @@ Some sub-modules require additional dependences which are discussed below
```
pip install pymanopt autograd
```
* **ot.gpu** (GPU accelerated OT) depends on cudamat that have to be installed with:
```
git clone https://github.com/cudamat/cudamat.git
cd cudamat
python setup.py install --user # for user install (no root)
```
* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html).


obviously you need CUDA installed and a compatible GPU.

Expand Down
6 changes: 6 additions & 0 deletions docs/source/all.rst
Expand Up @@ -48,6 +48,12 @@ ot.da

.. automodule:: ot.da
:members:

ot.gpu
--------

.. automodule:: ot.gpu
:members:

ot.dr
--------
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Expand Up @@ -31,7 +31,7 @@ class Mock(MagicMock):
@classmethod
def __getattr__(cls, name):
return MagicMock()
MOCK_MODULES = ['ot.lp.emd_wrap','autograd','pymanopt','cudamat','autograd.numpy','pymanopt.manifolds','pymanopt.solvers']
MOCK_MODULES = ['ot.lp.emd_wrap','autograd','pymanopt','cupy','autograd.numpy','pymanopt.manifolds','pymanopt.solvers']
# 'autograd.numpy','pymanopt.manifolds','pymanopt.solvers',
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
# !!!!
Expand Down
26 changes: 25 additions & 1 deletion ot/gpu/__init__.py
@@ -1,12 +1,36 @@
# -*- coding: utf-8 -*-
"""

This module provides GPU implementation for several OT solvers and utility
functions. The GPU backend in handled by `cupy
<https://cupy.chainer.org/>`_.

By default, the functions in this module accept and return numpy arrays
in order to proide drop-in replacement for the other POT function but
the transfer between CPU en GPU comes with a significant overhead.

In order to get the best erformances, we recommend to given only cupy
arrays to the functions and desactivate the conversion to numpy of the
result of the function with parameter ``to_numpy=False``.




"""

from . import bregman
from . import da
from .bregman import sinkhorn
from .da import sinkhorn_lpl1_mm

from . import utils
from .utils import dist, to_gpu, to_np


# Author: Remi Flamary <remi.flamary@unice.fr>
# Leo Gautheron <https://github.com/aje>
#
# License: MIT License

__all__ = ["bregman", "da", "sinkhorn"]
__all__ = ["utils", "dist", "sinkhorn",
"sinkhorn_lpl1_mm", 'bregman', 'da', 'to_gpu', 'to_np']
148 changes: 92 additions & 56 deletions ot/gpu/bregman.py
Expand Up @@ -8,14 +8,18 @@
#
# License: MIT License

import numpy as np
import cudamat
import cupy as np # np used for matrix computation
import cupy as cp # cp used for cupy specific operations
from . import utils


def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
log=False, returnAsGPU=False):
r"""
Solve the entropic regularization optimal transport problem on GPU
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
verbose=False, log=False, to_numpy=True, **kwargs):
"""
Solve the entropic regularization optimal transport on GPU

If the input matrix are in numpy format, they will be uploaded to the
GPU first which can incur significant time overhead.

The function solves the following optimization problem:

Expand All @@ -40,9 +44,10 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
----------
a : np.ndarray (ns,)
samples weights in the source domain
b : np.ndarray (nt,)
samples in the target domain
M_GPU : cudamat.CUDAMatrix (ns,nt)
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
M : np.ndarray (ns,nt)
loss matrix
reg : float
Regularization term >0
Expand All @@ -54,8 +59,9 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
Print information along iterations
log : bool, optional
record log if True
returnAsGPU : bool, optional
return the OT matrix as a cudamat.CUDAMatrix
to_numpy : boolean, optional (default True)
If true convert back the GPU array result to numpy format.


Returns
-------
Expand Down Expand Up @@ -88,60 +94,78 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
ot.optim.cg : General regularized OT

"""

a = cp.asarray(a)
b = cp.asarray(b)
M = cp.asarray(M)

if len(a) == 0:
a = np.ones((M.shape[0],)) / M.shape[0]
if len(b) == 0:
b = np.ones((M.shape[1],)) / M.shape[1]

# init data
Nini = len(a)
Nfin = len(b)

if len(b.shape) > 1:
nbb = b.shape[1]
else:
nbb = 0

if log:
log = {'err': []}

# we assume that no distances are null except those of the diagonal of
# distances
u = (np.ones(Nini) / Nini).reshape((Nini, 1))
u_GPU = cudamat.CUDAMatrix(u)
a_GPU = cudamat.CUDAMatrix(a.reshape((Nini, 1)))
ones_GPU = cudamat.empty(u_GPU.shape).assign(1)
v = (np.ones(Nfin) / Nfin).reshape((Nfin, 1))
v_GPU = cudamat.CUDAMatrix(v)
b_GPU = cudamat.CUDAMatrix(b.reshape((Nfin, 1)))

M_GPU.divide(-reg)
if nbb:
u = np.ones((Nini, nbb)) / Nini
v = np.ones((Nfin, nbb)) / Nfin
else:
u = np.ones(Nini) / Nini
v = np.ones(Nfin) / Nfin

K_GPU = cudamat.exp(M_GPU)
# print(reg)

ones_GPU.divide(a_GPU, target=a_GPU)
Kp_GPU = cudamat.empty(K_GPU.shape)
K_GPU.mult_by_col(a_GPU, target=Kp_GPU)
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)

tmp_GPU = cudamat.empty(K_GPU.shape)
# print(np.min(K))
tmp2 = np.empty(b.shape, dtype=M.dtype)

Kp = (1 / a).reshape(-1, 1) * K
cpt = 0
err = 1
while (err > stopThr and cpt < numItermax):
uprev_GPU = u_GPU.copy()
vprev_GPU = v_GPU.copy()
uprev = u
vprev = v

KtransposeU_GPU = K_GPU.transpose().dot(u_GPU)
b_GPU.divide(KtransposeU_GPU, target=v_GPU)
ones_GPU.divide(Kp_GPU.dot(v_GPU), target=u_GPU)
KtransposeU = np.dot(K.T, u)
v = np.divide(b, KtransposeU)
u = 1. / np.dot(Kp, v)

if (np.any(KtransposeU_GPU.asarray() == 0) or
not u_GPU.allfinite() or not v_GPU.allfinite()):
if (np.any(KtransposeU == 0) or
np.any(np.isnan(u)) or np.any(np.isnan(v)) or
np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)
u_GPU = uprev_GPU.copy()
v_GPU = vprev_GPU.copy()
u = uprev
v = vprev
break
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
K_GPU.mult_by_col(u_GPU, target=tmp_GPU)
tmp_GPU.mult_by_row(v_GPU.transpose(), target=tmp_GPU)

bcopy_GPU = b_GPU.copy().transpose()
bcopy_GPU.add_sums(tmp_GPU, axis=0, beta=-1)
err = bcopy_GPU.euclid_norm()**2
if nbb:
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
np.sum((v - vprev)**2) / np.sum((v)**2)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
tmp2 = np.sum(u[:, None] * K * v[None, :], 0)
#tmp2=np.einsum('i,ij,j->j', u, K, v)
err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
if log:
log['err'].append(err)

Expand All @@ -150,20 +174,32 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
cpt += 1
if log:
log['u'] = u_GPU.asarray()
log['v'] = v_GPU.asarray()

K_GPU.mult_by_col(u_GPU, target=K_GPU)
K_GPU.mult_by_row(v_GPU.transpose(), target=K_GPU)

if returnAsGPU:
res = K_GPU
else:
res = K_GPU.asarray()

cpt = cpt + 1
if log:
return res, log
else:
return res
log['u'] = u
log['v'] = v

if nbb: # return only loss
#res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory)
res = np.empty(nbb)
for i in range(nbb):
res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i])
if to_numpy:
res = utils.to_np(res)
if log:
return res, log
else:
return res

else: # return OT matrix
res = u.reshape((-1, 1)) * K * v.reshape((1, -1))
if to_numpy:
res = utils.to_np(res)
if log:
return res, log
else:
return res


# define sinkhorn as sinkhorn_knopp
sinkhorn = sinkhorn_knopp