You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As my datapoints are empirical distributions I want to use the Wasserstein distance as a loss function over a batch of shape (n_batch, n_points, dimension). Standard way to make functions that take a batch as an input is torch.vmap, yet I get the error described below.
To Reproduce
def wasserstein2_loss(X, Y):
n, m = X.shape[0], Y.shape[0]
a = torch.ones(n) / n
b = torch.ones(m) / m
M = ot.dist(X, Y, metric="sqeuclidean")
return ot.emd2(a, b, M) ** 0.5
wasserstein2_loss_batched = torch.vmap(wasserstein2_loss)
W2 = wasserstein2_loss_batched(X, Y) # should be an array of shape `n_batch`
Error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 W2 = wasserstein2_loss_batched(X, Y)
File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:434, in vmap.<locals>.wrapped(*args, **kwargs)
430 return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
431 args_spec, out_dims, randomness, **kwargs)
433 # If chunk_size is not specified.
--> 434 return _flat_vmap(
435 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
436 )
File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:39, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
36 @functools.wraps(f)
37 def fn(*args, **kwargs):
38 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39 return f(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:619, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
617 try:
618 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 619 batched_outputs = func(*batched_inputs, **kwargs)
620 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
621 finally:
Cell In[4], line 13, in wasserstein2_loss(X, Y)
11 b = torch.ones(m) / m
12 M = ot.dist(X, Y, metric="sqeuclidean")
---> 13 return wasserstein_distance(a, b, M) ** 0.5
File /usr/local/lib/python3.10/dist-packages/ot/lp/__init__.py:488, in emd2(a, b, M, processes, numItermax, log, return_matrix, center_dual, numThreads, check_marginals)
485 nx = get_backend(M0, a0, b0)
487 # convert to numpy
--> 488 M, a, b = nx.to_numpy(M, a, b)
490 a = np.asarray(a, dtype=np.float64)
491 b = np.asarray(b, dtype=np.float64)
File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in Backend.to_numpy(self, *arrays)
205 return self._to_numpy(arrays[0])
206 else:
--> 207 return [self._to_numpy(array) for array in arrays]
File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in <listcomp>(.0)
205 return self._to_numpy(arrays[0])
206 else:
--> 207 return [self._to_numpy(array) for array in arrays]
File /usr/local/lib/python3.10/dist-packages/ot/backend.py:1763, in TorchBackend._to_numpy(self, a)
1761 if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
1762 return np.array(a)
-> 1763 return a.cpu().detach().numpy()
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Expected behavior
Make POT distance functions batchable via torch.vmap, seems Sinkhorn distance code has this problem too.
The text was updated successfully, but these errors were encountered:
The exact ot.emd2 solver uses a compiled C++ solver so everything needs to be done on CPU and converted to numpy which is why it cannot be used with vmap that require only pytorch operation.
We might be able to make sinkhorn compatile in the future but emd2 cannot (it is highly non vectorizable also so even if this was possible there would be no gain from batching).
Describe the bug
As my datapoints are empirical distributions I want to use the Wasserstein distance as a loss function over a batch of shape
(n_batch, n_points, dimension)
. Standard way to make functions that take a batch as an input istorch.vmap
, yet I get the error described below.To Reproduce
Error
Expected behavior
Make POT distance functions batchable via
torch.vmap
, seems Sinkhorn distance code has this problem too.The text was updated successfully, but these errors were encountered: