Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] small tentative for EMD 1D in torch #218

Merged
merged 8 commits into from Jan 3, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_tests.yml
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [3.5, 3.6, 3.7, 3.8]
python-version: [3.6, 3.7, 3.8]

steps:
- uses: actions/checkout@v1
Expand Down
99 changes: 95 additions & 4 deletions ot/torch/lp.py
Expand Up @@ -5,7 +5,9 @@
import numpy as np
import torch
from torch.autograd import Function
from .. import emd
from ot import emd
from torch.nn.functional import pad
from ot.torch.utils import quantile_function


# Author: Remi Flamary <remi.flamary@unice.fr>
Expand All @@ -20,7 +22,6 @@ class OptimalTransportLossFunction(Function):
@staticmethod
# bias is an optional argument
def forward(ctx, a, b, M, num_iter_max=100000):

# convert to numpy
a2 = a.detach().cpu().numpy().astype(np.float64)
b2 = b.detach().cpu().numpy().astype(np.float64)
Expand All @@ -42,7 +43,6 @@ def forward(ctx, a, b, M, num_iter_max=100000):

@staticmethod
def backward(ctx, grad_output):

grad_a, grad_b, grad_M = ctx.saved_tensors

print(grad_a)
Expand All @@ -56,7 +56,6 @@ def ot_loss(a, b, M, num_iter_max=100000):


def ot_solve(a, b, M, num_iter_max=100000, log=False):

a2 = a.detach().cpu().numpy().astype(np.float64)
b2 = b.detach().cpu().numpy().astype(np.float64)
M2 = M.detach().cpu().numpy().astype(np.float64)
Expand All @@ -79,3 +78,95 @@ def ot_solve(a, b, M, num_iter_max=100000, log=False):
G = emd(a2, b2, M2, log=False, numItermax=num_iter_max)

return torch.from_numpy(G).type_as(M)


def ot_loss_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
r"""
Computes the 1 dimensional OT loss [2] between two (batched) empirical distributions
..math:
ot_{loss} &= \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq

It is formally the p-Wasserstein distance raised to the power p.
We do so in a vectorized way by first building the individual quantile functions then integrating them.
This has a theoretically higher complexity than the core OT implementation but behaves better with PyTorch

Parameters
----------
u_values: torch.Tensor (n, ...)
locations of the first empirical distribution
v_values: torch.Tensor (m, ...)
locations of the second empirical distribution
u_weights: torch.Tensor (n, ...), optional
weights of the first empirical distribution, if None then uniform weights are used
v_weights: torch.Tensor (m, ...), optional
weights of the second empirical distribution, if None then uniform weights are used
p: int, optional
order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1
require_sort: bool, optional
Are the locations sorted along the last dimension already
ncourty marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
cost: torch.Tensor (...,)
the batched EMD

Examples
--------
Simple example:
>>> import ot
>>> import torch
>>> np.random.seed(0)
>>> n_source = 7
>>> n_target = 100
>>> a = torch.tensor(ot.utils.unif(n_source), requires_grad=True)
>>> b = torch.tensor(ot.utils.unif(n_target))
>>> X_source = torch.tensor(np.random.randn(n_source,), requires_grad=True)
>>> Y_target = torch.tensor(np.random.randn(n_target,))
>>> loss = ot.torch.lp.ot_loss_1d(X_source, Y_target, a, b)
>>> torch.autograd.grad(loss, X_source)[0]
tensor([0.1429, 0.1429, 0.1429, 0.1229, 0.1429, 0.1429, 0.1429],
dtype=torch.float64)

References
----------
.. [2] Cuturi, M. (2013). [Sinkhorn distances: Lightspeed computation of optimal transport](https://arxiv.org/pdf/1306.0895.pdf). In Advances in Neural Information Processing Systems (pp. 2292-2300).

"""
assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
ncourty marked this conversation as resolved.
Show resolved Hide resolved
n = u_values.shape[0]
m = v_values.shape[0]

device = u_values.device
dtype = u_values.dtype

if u_weights is None:
u_weights = torch.full_like(u_values, 1 / n, dtype=dtype, device=device)
elif u_weights.ndim != u_values.ndim:
ncourty marked this conversation as resolved.
Show resolved Hide resolved
u_weights = torch.repeat_interleave(u_weights.unsqueeze(-1), u_values.shape[-1], -1)

if v_weights is None:
v_weights = torch.full_like(v_values, 1 / m, dtype=dtype, device=device)
elif v_weights.ndim != v_values.ndim:
v_weights = torch.repeat_interleave(v_weights.unsqueeze(-1), v_values.shape[-1], -1)

if require_sort:
u_values, u_sorter = torch.sort(u_values, 0)
v_values, v_sorter = torch.sort(v_values, 0)

u_weights = torch.gather(u_weights, 0, u_sorter)
v_weights = torch.gather(v_weights, 0, v_sorter)

u_cumweights = torch.cumsum(u_weights, 0)
v_cumweights = torch.cumsum(v_weights, 0)

qs, _ = torch.sort(torch.cat((u_cumweights, v_cumweights), 0), 0)
u_quantiles = quantile_function(qs, u_cumweights, u_values)
v_quantiles = quantile_function(qs, v_cumweights, v_values)

qs = pad(qs, (qs.ndim - 1) * (0, 0) + (1, 0))
delta = qs[1:, ...] - qs[:-1, ...]
diff_quantiles = torch.abs(u_quantiles - v_quantiles)

if p == 1:
return torch.sum(delta * torch.abs(diff_quantiles), dim=0)
return torch.sum(delta * torch.pow(diff_quantiles, p), dim=0)
28 changes: 27 additions & 1 deletion ot/torch/utils.py
Expand Up @@ -58,7 +58,7 @@ def dist(x1, x2, metric="sqeuclidean"):
if x2 is None:
x2 = x1
if metric == "sqeuclidean":
return torch.cdist(x1, x2, p=2)**2
return torch.cdist(x1, x2, p=2) ** 2
elif metric == "euclidean":
p = 2
elif metric == "cityblock":
Expand Down Expand Up @@ -89,3 +89,29 @@ def proj_simplex(v, z=1):
return w[:, 0]
else:
return w


def quantile_function(qs, cws, xs):
# type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
r""" Computes the quantile function of an empirical distribution

Parameters
----------
qs: torch.tensor (n,)
Quantiles at which the quantile function is evaluated
cws: torch.tensor (m, ...)
cumulative weights of the 1D empirical distribution, if batched, must be similar to xs
xs: torch.tensor (n, ...)
locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions

Returns
-------
q: torch.tensor (..., n)
The quantiles of the distribution
"""
n = xs.shape[0]

cws = cws.T.contiguous()
qs = qs.T.contiguous()
idx = torch.searchsorted(cws, qs).T
return torch.gather(xs, 0, idx.clip(0, n - 1))
111 changes: 94 additions & 17 deletions test/test_torch.py
Expand Up @@ -12,7 +12,6 @@

import ot.torch
import torch
nogo = False

lst_types = [torch.float32, torch.float64]

Expand All @@ -22,12 +21,10 @@


except BaseException:
nogo = True
pytest.skip("Missing pytorch", allow_module_level=True)


@pytest.mark.skipif(nogo, reason="Missing pytorch")
def test_dist():

n = 200

lst_metrics = ['sqeuclidean', 'euclidean', 'cityblock', 0, 0.5, 1, 2, 5]
Expand All @@ -39,16 +36,13 @@ def test_dist():
y = torch.randn(n, 2, dtype=dtype, device=device)

for metric in lst_metrics:

M = ot.torch.dist(x, y, metric)

assert M.shape[0] == n
assert M.shape[1] == n


@pytest.mark.skipif(nogo, reason="Missing pytorch")
def test_ot_loss():

n = 10

lst_metrics = ['sqeuclidean', 'euclidean', 'cityblock', 0, 0.5, 1, 2, 5]
Expand All @@ -63,21 +57,17 @@ def test_ot_loss():
b = ot.torch.unif(n, dtype=dtype, device=device)

for metric in lst_metrics:

M = ot.torch.dist(x, y, metric)
loss = ot.torch.ot_loss(a, b, M)

assert float(loss) >= 0


@pytest.mark.skipif(nogo, reason="Missing pytorch")
def test_proj_simplex():

n = 10

for dtype in lst_types:
for device in lst_devices:

x = torch.randn(n, dtype=dtype, device=device)

xp = ot.torch.proj_simplex(x)
Expand All @@ -93,9 +83,7 @@ def test_proj_simplex():
assert torch.allclose(xp.sum(0), torch.ones(3, dtype=dtype, device=device))


@pytest.mark.skipif(nogo, reason="Missing pytorch")
def test_ot_loss_grad():

n = 10

lst_metrics = ['sqeuclidean', 'euclidean', 'cityblock', 0, 0.5, 1, 2, 5]
Expand All @@ -104,7 +92,6 @@ def test_ot_loss_grad():
for device in lst_devices:

for metric in lst_metrics:

x = torch.randn(n, 2, dtype=dtype, device=device, requires_grad=True)
y = torch.randn(n, 2, dtype=dtype, device=device, requires_grad=True)

Expand All @@ -124,9 +111,7 @@ def test_ot_loss_grad():
assert float(loss) >= 0


@pytest.mark.skipif(nogo, reason="Missing pytorch")
def test_ot_solve():

n = 10

lst_metrics = ['sqeuclidean', 'euclidean', 'cityblock', 0, 0.5, 1, 2, 5]
Expand All @@ -141,9 +126,101 @@ def test_ot_solve():
b = ot.torch.unif(n, dtype=dtype, device=device)

for metric in lst_metrics:

M = ot.torch.dist(x, y, metric)
G = ot.torch.ot_solve(a, b, M)

np.testing.assert_allclose(ot.unif(n), G.sum(1).cpu().numpy())
np.testing.assert_allclose(ot.unif(n), G.sum(0).cpu().numpy()) # cf convergence sinkhorn


@pytest.mark.parametrize("random_weights", [True, False])
@pytest.mark.parametrize("batch_size", [0, 2, 10])
def test_ot_loss_1d(random_weights, batch_size):
torch.random.manual_seed(42)
n = 300
m = 200
k = 5
ps = [1, 2, 3]

for dtype in lst_types:
for device in lst_devices:
if batch_size:
x = torch.randn(n, batch_size, k, dtype=dtype, device=device)
y = torch.randn(m, batch_size, k, dtype=dtype, device=device)
else:
x = torch.randn(n, k, dtype=dtype, device=device)
y = torch.randn(m, k, dtype=dtype, device=device)
if random_weights:
if batch_size:
a = torch.rand(n, batch_size, dtype=dtype, device=device)
b = torch.rand(m, batch_size, dtype=dtype, device=device)
else:
a = torch.rand(n, dtype=dtype, device=device)
b = torch.rand(m, dtype=dtype, device=device)
a = a / torch.sum(a, 0, keepdim=True)
b = b / torch.sum(b, 0, keepdim=True)
np_a = a.cpu().numpy()
np_b = b.cpu().numpy()
else:
a = b = np_a = np_b = None

for p in ps:
same_dist_cost = ot.torch.lp.ot_loss_1d(x, x, a, a, p)
assert np.allclose(same_dist_cost.cpu().numpy(), 0., atol=1e-5)
torch_cost = ot.torch.lp.ot_loss_1d(x, y, a, b, p)

if batch_size:
cpu_cost = np.zeros((batch_size, k))
else:
cpu_cost = np.zeros(k)

for i in range(k):
if batch_size:
for batch_num in range(batch_size):
cpu_cost[batch_num, i] = ot.lp.emd2_1d(x[:, batch_num, i].cpu().numpy(),
y[:, batch_num, i].cpu().numpy(),
np_a if np_a is None else np_a[:, batch_num],
np_b if np_b is None else np_b[:, batch_num],
"minkowski", p=p)
else:
cpu_cost[i] = ot.lp.emd2_1d(x[:, i].cpu().numpy(), y[:, i].cpu().numpy(), np_a, np_b,
"minkowski", p=p)

np.testing.assert_allclose(cpu_cost, torch_cost.cpu().numpy(), atol=1e-5)


def test_ot_loss_1d_grad():
torch.random.manual_seed(42)
n = 10
m = 5
k = 5
ps = [1, 2, 3]

for dtype in lst_types:
for device in lst_devices:
x = torch.randn(n, k, dtype=dtype, device=device, requires_grad=True)
y = torch.randn(m, k, dtype=dtype, device=device, requires_grad=True)

a = torch.rand(n, dtype=dtype, device=device, requires_grad=True)
b = torch.rand(m, dtype=dtype, device=device, requires_grad=True)

for p in ps:
torch.autograd.gradcheck(lambda *inp: ot.torch.lp.ot_loss_1d(*inp, p=p), (x, y, a, b), eps=1e-3,
atol=1e-2, raise_exception=True)

res_equal = ot.torch.lp.ot_loss_1d(x, x, a, a, p=p).sum()
print(torch.autograd.grad(res_equal, (x, a)))


@pytest.mark.filterwarnings("error")
def test_quantile():
torch.random.manual_seed(42)
dims = (100, 5, 3)
cws = torch.rand(*dims)
cws = cws / cws.sum(0, keepdim=True)
cws = torch.cumsum(cws, 0)
qs, _ = torch.sort(torch.rand(*dims), dim=0)
xs = torch.randn(*dims)
res = ot.torch.utils.quantile_function(qs, cws, xs)
assert np.all(res.cpu().numpy() <= xs.max(0, keepdim=True)[0].cpu().numpy())
assert np.all(res.cpu().numpy() >= xs.min(0, keepdim=True)[0].cpu().numpy())