Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ The contributors to this library are
* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)

This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):

Expand Down
302 changes: 300 additions & 2 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Multi-lib backend for POT

The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
or Jax, POT code should work nonetheless.
Jax, or Cupy, POT code should work nonetheless.
To achieve that, POT provides backend classes which implements functions in their respective backend
imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.

Expand Down Expand Up @@ -44,6 +44,14 @@
jax = False
jax_type = float

try:
import cupy as cp
import cupyx
cp_type = cp.ndarray
except ImportError:
cp = False
cp_type = float

str_type_error = "All array should be from the same type/backend. Current types are : {}"


Expand All @@ -57,6 +65,9 @@ def get_backend_list():
if jax:
lst.append(JaxBackend())

if cp:
lst.append(CupyBackend())

return lst


Expand All @@ -78,6 +89,8 @@ def get_backend(*args):
return TorchBackend()
elif isinstance(args[0], jax_type):
return JaxBackend()
elif isinstance(args[0], cp_type):
return CupyBackend()
else:
raise ValueError("Unknown type of non implemented backend.")

Expand All @@ -94,7 +107,8 @@ def to_numpy(*args):
class Backend():
"""
Backend abstract class.
Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`
Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`,
:py:class:`CupyBackend`

- The `__name__` class attribute refers to the name of the backend.
- The `__type__` class attribute refers to the data structure used by the backend.
Expand Down Expand Up @@ -1500,3 +1514,287 @@ def assert_same_dtype_device(self, a, b):

assert a_dtype == b_dtype, "Dtype discrepancy"
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"


class CupyBackend(Backend): # pragma: no cover
"""
CuPy implementation of the backend

- `__name__` is "cupy"
- `__type__` is cp.ndarray
"""

__name__ = 'cupy'
__type__ = cp_type
__type_list__ = None

rng_ = None

def __init__(self):
self.rng_ = cp.random.RandomState()

self.__type_list__ = [
cp.array(1, dtype=cp.float32),
cp.array(1, dtype=cp.float64)
]

def to_numpy(self, a):
return cp.asnumpy(a)

def from_numpy(self, a, type_as=None):
if type_as is None:
return cp.asarray(a)
else:
with cp.cuda.Device(type_as.device):
return cp.asarray(a, dtype=type_as.dtype)

def set_gradients(self, val, inputs, grads):
# No gradients for cupy
return val

def zeros(self, shape, type_as=None):
if isinstance(shape, (list, tuple)):
shape = tuple(int(i) for i in shape)
if type_as is None:
return cp.zeros(shape)
else:
with cp.cuda.Device(type_as.device):
return cp.zeros(shape, dtype=type_as.dtype)

def ones(self, shape, type_as=None):
if isinstance(shape, (list, tuple)):
shape = tuple(int(i) for i in shape)
if type_as is None:
return cp.ones(shape)
else:
with cp.cuda.Device(type_as.device):
return cp.ones(shape, dtype=type_as.dtype)

def arange(self, stop, start=0, step=1, type_as=None):
return cp.arange(start, stop, step)

def full(self, shape, fill_value, type_as=None):
if isinstance(shape, (list, tuple)):
shape = tuple(int(i) for i in shape)
if type_as is None:
return cp.full(shape, fill_value)
else:
with cp.cuda.Device(type_as.device):
return cp.full(shape, fill_value, dtype=type_as.dtype)

def eye(self, N, M=None, type_as=None):
if type_as is None:
return cp.eye(N, M)
else:
with cp.cuda.Device(type_as.device):
return cp.eye(N, M, dtype=type_as.dtype)

def sum(self, a, axis=None, keepdims=False):
return cp.sum(a, axis, keepdims=keepdims)

def cumsum(self, a, axis=None):
return cp.cumsum(a, axis)

def max(self, a, axis=None, keepdims=False):
return cp.max(a, axis, keepdims=keepdims)

def min(self, a, axis=None, keepdims=False):
return cp.min(a, axis, keepdims=keepdims)

def maximum(self, a, b):
return cp.maximum(a, b)

def minimum(self, a, b):
return cp.minimum(a, b)

def abs(self, a):
return cp.abs(a)

def exp(self, a):
return cp.exp(a)

def log(self, a):
return cp.log(a)

def sqrt(self, a):
return cp.sqrt(a)

def power(self, a, exponents):
return cp.power(a, exponents)

def dot(self, a, b):
return cp.dot(a, b)

def norm(self, a):
return cp.sqrt(cp.sum(cp.square(a)))

def any(self, a):
return cp.any(a)

def isnan(self, a):
return cp.isnan(a)

def isinf(self, a):
return cp.isinf(a)

def einsum(self, subscripts, *operands):
return cp.einsum(subscripts, *operands)

def sort(self, a, axis=-1):
return cp.sort(a, axis)

def argsort(self, a, axis=-1):
return cp.argsort(a, axis)

def searchsorted(self, a, v, side='left'):
if a.ndim == 1:
return cp.searchsorted(a, v, side)
else:
# this is a not very efficient way to make numpy
# searchsorted work on 2d arrays
ret = cp.empty(v.shape, dtype=int)
for i in range(a.shape[0]):
ret[i, :] = cp.searchsorted(a[i, :], v[i, :], side)
return ret

def flip(self, a, axis=None):
return cp.flip(a, axis)

def outer(self, a, b):
return cp.outer(a, b)

def clip(self, a, a_min, a_max):
return cp.clip(a, a_min, a_max)

def repeat(self, a, repeats, axis=None):
return cp.repeat(a, repeats, axis)

def take_along_axis(self, arr, indices, axis):
return cp.take_along_axis(arr, indices, axis)

def concatenate(self, arrays, axis=0):
return cp.concatenate(arrays, axis)

def zero_pad(self, a, pad_width):
return cp.pad(a, pad_width)

def argmax(self, a, axis=None):
return cp.argmax(a, axis=axis)

def mean(self, a, axis=None):
return cp.mean(a, axis=axis)

def std(self, a, axis=None):
return cp.std(a, axis=axis)

def linspace(self, start, stop, num):
return cp.linspace(start, stop, num)

def meshgrid(self, a, b):
return cp.meshgrid(a, b)

def diag(self, a, k=0):
return cp.diag(a, k)

def unique(self, a):
return cp.unique(a)

def logsumexp(self, a, axis=None):
# Taken from
# https://github.com/scipy/scipy/blob/v1.7.1/scipy/special/_logsumexp.py#L7-L127
a_max = cp.amax(a, axis=axis, keepdims=True)

if a_max.ndim > 0:
a_max[~cp.isfinite(a_max)] = 0
elif not cp.isfinite(a_max):
a_max = 0

tmp = cp.exp(a - a_max)
s = cp.sum(tmp, axis=axis)
out = cp.log(s)
a_max = cp.squeeze(a_max, axis=axis)
out += a_max
return out

def stack(self, arrays, axis=0):
return cp.stack(arrays, axis)

def reshape(self, a, shape):
return cp.reshape(a, shape)

def seed(self, seed=None):
if seed is not None:
self.rng_.seed(seed)

def rand(self, *size, type_as=None):
if type_as is None:
return self.rng_.rand(*size)
else:
with cp.cuda.Device(type_as.device):
return self.rng_.rand(*size, dtype=type_as.dtype)

def randn(self, *size, type_as=None):
if type_as is None:
return self.rng_.randn(*size)
else:
with cp.cuda.Device(type_as.device):
return self.rng_.randn(*size, dtype=type_as.dtype)

def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
data = self.from_numpy(data)
rows = self.from_numpy(rows)
cols = self.from_numpy(cols)
if type_as is None:
return cupyx.scipy.sparse.coo_matrix(
(data, (rows, cols)), shape=shape
)
else:
with cp.cuda.Device(type_as.device):
return cupyx.scipy.sparse.coo_matrix(
(data, (rows, cols)), shape=shape, dtype=type_as.dtype
)

def issparse(self, a):
return cupyx.scipy.sparse.issparse(a)

def tocsr(self, a):
if self.issparse(a):
return a.tocsr()
else:
return cupyx.scipy.sparse.csr_matrix(a)

def eliminate_zeros(self, a, threshold=0.):
if threshold > 0:
if self.issparse(a):
a.data[self.abs(a.data) <= threshold] = 0
else:
a[self.abs(a) <= threshold] = 0
if self.issparse(a):
a.eliminate_zeros()
return a

def todense(self, a):
if self.issparse(a):
return a.toarray()
else:
return a

def where(self, condition, x, y):
return cp.where(condition, x, y)

def copy(self, a):
return a.copy()

def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return cp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)

def dtype_device(self, a):
return a.dtype, a.device

def assert_same_dtype_device(self, a, b):
a_dtype, a_device = self.dtype_device(a)
b_dtype, b_device = self.dtype_device(b)

# cupy has implicit type conversion so
# we automatically validate the test for type
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
Loading