Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Self distances with ot.bregman.empirical_sinkhorn2 higher than expected #667

Open
Jabbath opened this issue Jul 31, 2024 · 0 comments
Open

Comments

@Jabbath
Copy link

Jabbath commented Jul 31, 2024

I am computing matrices of W_2 distances with ot.bregman.empirical_sinkhorn2 between point clouds centered along a curve. I expect that the distance from a point cloud to itself should be zero or close to zero. However, this is not the case and the self-distances are in fact higher than to some neighboring point clouds. This seems like an unexpected behavior and I am wondering if there is an underlying issue causing it. I've created a code snippet below which highlights the issue. Any input would be greatly appreciated.

To Reproduce

import ot
import numpy as np
import matplotlib.pyplot as plt

T = 40 # Number of point clouds
N = 25 # Number of points in each cloud
max_x = 4*np.pi # Cosine is evaluate on [0, max_x]
variance = 0.4 # The variance of the normal at each point cloud
epsilon = 0.5

def generate_cos(T=100, N=50, max_x=2*np.pi, variance=1):
    """
    Generates a dataset which follows a cosine wave. Point clouds are 2-D gaussians which
    are centered at a point on the cosine wave.
    :param T: Number of point clouds to generate
    :param N: Number of samples to take at each timepoint
    :param max_x: The right end of the cosine wave
    :variance: The variance of the gaussian distributions
    :return: The data matrix of shape (T, N, 2)
    """
    # Form the matrix [[x, cos(x)], ...]
    span = np.linspace(0, max_x, T, endpoint=True)
    y_vals = np.cos(span)
    means = np.zeros((T, 2))
    means[:, 0] = span
    means[:, 1] = y_vals

    # Sample a normal dist centered at each point
    x = np.zeros((T, N, 2))
    for i, mean in enumerate(means):
        dist = np.random.multivariate_normal(mean, np.eye(2)*variance, N)
        x[i, :, :] = dist

    return x

x = generate_cos(T=T, N=N, max_x=max_x, variance=variance)

dists = np.zeros(shape=(T, T))

for i in range(T):
    for j in range(i, T):
        d = np.sqrt(ot.bregman.empirical_sinkhorn2(x[i], x[j], epsilon, 
                                                   a=ot.unif(x[i].shape[0]), 
                                                   b=ot.unif(x[j].shape[0])))
        dists[i, j] = d

dists = dists + dists.T

plt.figure(figsize=(10, 10))
plt.matshow(dists[10:20, 10:20], fignum=1)
plt.title('$W_2$ distance matrix')
plt.colorbar()
plt.show()

Screenshots

dist_mat

Expected Behavior

I expect the diagonal to be close to zero.

Environment (please complete the following information):

  • OS: Linux
  • Python version: 3.11.5
  • How was POT installed (source, pip, conda): pip
  • POT version: 0.9.4

Output of code snippet:

Linux-4.18.0-513.5.1.el8_9.x86_64-x86_64-with-glibc2.28
Python 3.11.5 (main, Sep 22 2023, 15:34:29) [GCC 8.5.0 20210514 (Red Hat 8.5.0-20)]
NumPy 1.26.0
SciPy 1.11.3
[KeOps] Warning : Cuda libraries were not detected on the system or could not be loaded ; using cpu only mode
POT 0.9.4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant