In [None]:
import numpy as np
import torch
from torch.autograd import grad

use_cuda = torch.cuda.is_available()
dtype    = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

from geomloss import SamplesLoss
# from pykeops.torch.cluster import grid_cluster, cluster_ranges_centroids

np.random.seed(0)
# torch.manual_seed(0)
num_points = 10000
num_dims = 2
# A_i = np.random.rand(num_points, 1) # this imposes some random categorical distribution drawn from ...
# A_i = A_i / np.sum(A_i)
# A_i = A_i.ravel() ######IMP causing bug otherwise
A_i = np.ones((num_points,)) / num_points # this imposes a uniform distribution over sampled points and may be used to see empirically how the loss behaves when we discretise a continuous distribution by approximating it with a categorical distribution.
X_i = np.random.rand(num_points, num_dims) - np.ones((num_points, num_dims))/2
# B_j = np.random.rand(num_points, 1)
# B_j = B_j / np.sum(B_j)
# B_j = B_j.ravel()
B_j = np.ones((num_points,)) / num_points
Y_j = np.random.rand(num_points, num_dims) - np.ones((num_points, num_dims))/2

print(A_i.shape, X_i.shape, B_j.shape, Y_j.shape)
print(A_i, '\n', X_i)
print(B_j, '\n', Y_j)

#plot variations across num_points, num_dims, p = 1 or 2, blur, diameter, scaling (for speed vs accuracy), also for discretisation of continuous distribution (e.g. sampling leads to categorical distribution)

# epsilon schedule goes from diameter**p to blur**p linearly on a logarithmic scale,
# with stepsize = p*log(scaling).

# for low no. of points(= 10)
p = 2. # seems more robust for p = 2 than 1
diameter = 1. # seems to be somewhat significant if selected too low (when smaller than blur?) e.g. 0.1: value of loss even goes in the wrong direction i.e. away from the correct value as more outer iterations elapse
blur = 0.6561 # seems too low a value (was 0.0 to at least 4 decimal places) will lead to NaN loss; need more iterations if this value is large
scaling = 0.7 # Seems robust to too low a value e.g. 0.007
# Seems we don't need multiple outer iterations if we have enough inner ones (for low no. of points?)
# Seems that at a coarse scale the loss is usually underestimated
# When approximating 2 of the same uniform distributions over 2-D space with 2 sampled categorical distributions, we know there true Wasserstein distance is 0, but the sampled one for 10 points = 0.03361376, 20 points = 0.098184645, 100 p = 0.003122977, 1000p = 0.0007166295, 10000p = 7.326214e-05
print("Epsilon schedule for p = 2:", [diameter**p] +
      [np.exp(e) for e in np.arange(p*np.log(diameter), p*np.log(blur), p*np.log(scaling))]
     + [blur**p])

p = 1.
print("Epsilon schedule for p = 1:", [diameter**p] +
      [np.exp(e) for e in np.arange(p*np.log(diameter), p*np.log(blur), p*np.log(scaling))]
     + [blur**p])


In [None]:
import matplotlib.pyplot as plt


scaling, Nits = .5, 9
cluster_scale = .1 if not use_cuda else .05
i = 3
blur = scaling**i
diameter = 1.

if blur > cluster_scale:
    print('Calculating Sinkhorn divergences over coarse clusters. blur, cluster_scale =', blur, cluster_scale)
else:
    print('Calculating Sinkhorn divergences over actual points. blur, cluster_scale =', blur, cluster_scale)

# Create a copy of the data...
A_i_torch = torch.from_numpy(A_i).type(dtype)
X_i_torch = torch.from_numpy(X_i).contiguous().type(dtype)
B_j_torch = torch.from_numpy(B_j).type(dtype)
Y_j_torch = torch.from_numpy(Y_j).contiguous().type(dtype)
a_i, x_i = A_i_torch.clone(), X_i_torch.clone()
b_j, y_j = B_j_torch.clone(), Y_j_torch.clone()

# And require grad:
a_i.requires_grad = True
x_i.requires_grad = True
b_j.requires_grad = True

# Compute the loss + gradients:
Loss_p1 = SamplesLoss("sinkhorn", p=1, blur=blur, diameter=diameter, cluster_scale = cluster_scale,
                        scaling=scaling, backend="multiscale", verbose=True)
loss_p1 = Loss_p1(a_i, x_i, b_j, y_j)
Loss_p2 = SamplesLoss("sinkhorn", p=2, blur=blur, diameter=diameter, cluster_scale = cluster_scale,
                        scaling=scaling, backend="multiscale", verbose=True)
loss_p2 = Loss_p2(a_i, x_i, b_j, y_j)


print("Loss_p1 =", Loss_p1, "Loss_p2 =", Loss_p2)
print("loss_p1 =", loss_p1.detach().numpy(), "loss_p2 =", loss_p2.detach().numpy())



plt.figure(figsize=((6, 4.5)))

size_scale = 1000
ax = plt.scatter(X_i[:, 0], X_i[:, 1], s=size_scale * A_i, c='blue')
ax = plt.scatter(Y_j[:, 0], Y_j[:, 1], s=size_scale * B_j, c='red')


plt.tight_layout()
plt.show()


i = 4
blur = scaling**i

if blur > cluster_scale:
    print('Calculating Sinkhorn divergences over coarse clusters. blur, cluster_scale =', blur, cluster_scale)
else:
    print('Calculating Sinkhorn divergences over actual points. blur, cluster_scale =', blur, cluster_scale)
Loss_p1 = SamplesLoss("sinkhorn", p=1, blur=blur, diameter=diameter, cluster_scale = cluster_scale,
                        scaling=scaling, backend="multiscale", verbose=True)
loss_p1 = Loss_p1(a_i, x_i, b_j, y_j)
Loss_p2 = SamplesLoss("sinkhorn", p=2, blur=blur, diameter=diameter, cluster_scale = cluster_scale,
                        scaling=scaling, backend="multiscale", verbose=True)
loss_p2 = Loss_p2(a_i, x_i, b_j, y_j)

print("loss_p1 =", loss_p1.detach().numpy(), "loss_p2 =", loss_p2.detach().numpy())

In [None]:
# For p = 2
plt.figure(figsize=( (12, ((Nits-1)//3 + 1) * 4)))

from pykeops.torch.cluster import grid_cluster, cluster_ranges_centroids
for i in range(Nits):
    blur = scaling**i
    Loss = SamplesLoss("sinkhorn", p=2, blur=blur, diameter=diameter, cluster_scale = cluster_scale,
                        scaling=scaling, backend="multiscale")#, verbose=True)

    # Create a copy of the data...
    A_i_torch = torch.from_numpy(A_i).type(dtype)
    X_i_torch = torch.from_numpy(X_i).contiguous().type(dtype)
    B_j_torch = torch.from_numpy(B_j).type(dtype)
    Y_j_torch = torch.from_numpy(Y_j).contiguous().type(dtype)
    a_i, x_i = A_i_torch.clone(), X_i_torch.clone()
    b_j, y_j = B_j_torch.clone(), Y_j_torch.clone()


    # And require grad:
    a_i.requires_grad = True
    x_i.requires_grad = True
    b_j.requires_grad = True

    # Compute the loss + gradients:
    Loss_xy = Loss(a_i, x_i, b_j, y_j)
    [F_i, G_j, dx_i] = grad( Loss_xy, [a_i, b_j, x_i] )

#     print("F_i.shape, dx_i.shape", F_i.shape, dx_i.shape)
    print("Iteration:", i, "Loss_xy", Loss_xy.detach().numpy())
    # The generalized "Brenier map" is (minus) the gradient of the Sinkhorn loss
    # with respect to the Wasserstein metric:
    BrenierMap = - dx_i / (a_i.view(-1, 1) + 1e-7)

    # Compute the coarse measures for display ----------------------------------

    x_lab = grid_cluster(x_i, cluster_scale)
    _, x_c, a_c = cluster_ranges_centroids(x_i, x_lab, weights=a_i)
#     print("Clustered array size:", x_c.size())

    y_lab = grid_cluster(y_j, cluster_scale)
    _, y_c, b_c = cluster_ranges_centroids(y_j, y_lab, weights=b_j)


    # Fancy display: -----------------------------------------------------------

    ax = plt.subplot(((Nits-1)//3 + 1) , 3, i+1)
#     ax.scatter( [10], [10] )  # shameless hack to prevent a slight change of axis...

    #added by me
    size_scale = 200
    ax.scatter(X_i[:, 0], X_i[:, 1], s=size_scale * A_i, c='blue')
    ax.scatter(Y_j[:, 0], Y_j[:, 1], s=size_scale * B_j, c='red')
    if blur > cluster_scale:
        x_c_ = x_c.detach().cpu().numpy()
        a_c_ = x_c.detach().cpu().numpy()
        y_c_ = y_c.detach().cpu().numpy()
        b_c_ = x_c.detach().cpu().numpy()
        ax.scatter(x_c_[:, 0], x_c_[:, 1], s=size_scale * a_c_, c='purple')
        ax.scatter(y_c_[:, 0], y_c_[:, 1], s=size_scale * b_c_, c='yellow')

    v_ = BrenierMap.detach().cpu().numpy()
    x_ = x_i.detach().cpu().numpy()
    ax.quiver( x_[:,0], x_[:,1], v_[:,0], v_[:,1], 
                scale = 1, scale_units="xy",# angles='xy',
                color="#5CBF3A", zorder= 3, width= 0.02/3 ) #/ len(x_)
#     display_potential(ax, G_j, "#E2C5C5")
#     display_potential(ax, F_i, "#C8DFF9")


#     if blur > cluster_scale:
#         display_samples(ax, y_j, b_j, [(.55,.55,.95, .2)])
#         display_samples(ax, x_i, a_i, [(.95,.55,.55, .2)], v = BrenierMap)
#         display_samples(ax, y_c, b_c, [(.55,.55,.95)])
#         display_samples(ax, x_c, a_c, [(.95,.55,.55)])

#     else:
#     display_samples(ax, y_j, b_j, [(.55,.55,.95)])
#     display_samples(ax, x_i, a_i, [(.95,.55,.55)])#, v = BrenierMap)


    ax.set_title("iteration {}, blur = {:.3f}".format(i+1, blur))

#     ax.set_xticks([0, 1]) ; ax.set_yticks([0, 1])
    ax.axis([-1, 1, -1, 1]) ; ax.set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.show()


In [None]:
# For p = 1
plt.figure(figsize=( (12, ((Nits-1)//3 + 1) * 4)))

from pykeops.torch.cluster import grid_cluster, cluster_ranges_centroids
for i in range(Nits):
    blur = scaling**i
    Loss = SamplesLoss("sinkhorn", p=1, blur=blur, diameter=1., cluster_scale = cluster_scale,
                        scaling=scaling, backend="multiscale")

    # Create a copy of the data...
    A_i_torch = torch.from_numpy(A_i).type(dtype)
    X_i_torch = torch.from_numpy(X_i).contiguous().type(dtype)
    B_j_torch = torch.from_numpy(B_j).type(dtype)
    Y_j_torch = torch.from_numpy(Y_j).contiguous().type(dtype)
    a_i, x_i = A_i_torch.clone(), X_i_torch.clone()
    b_j, y_j = B_j_torch.clone(), Y_j_torch.clone()


    # And require grad:
    a_i.requires_grad = True
    x_i.requires_grad = True
    b_j.requires_grad = True

    # Compute the loss + gradients:
    Loss_xy = Loss(a_i, x_i, b_j, y_j)
    [F_i, G_j, dx_i] = grad( Loss_xy, [a_i, b_j, x_i] )

#     print("F_i.shape, dx_i.shape", F_i.shape, dx_i.shape)
    print("Iteration:", i, "Loss_xy", Loss_xy.detach().numpy())
    # The generalized "Brenier map" is (minus) the gradient of the Sinkhorn loss
    # with respect to the Wasserstein metric:
    BrenierMap = - dx_i / (a_i.view(-1,1) + 1e-7)

    # Compute the coarse measures for display ----------------------------------

    x_lab = grid_cluster(x_i, cluster_scale)
    _, x_c, a_c = cluster_ranges_centroids(x_i, x_lab, weights=a_i)
    print("Clustered array size:", x_c.size())

    y_lab = grid_cluster(y_j, cluster_scale)
    _, y_c, b_c = cluster_ranges_centroids(y_j, y_lab, weights=b_j)


    # Fancy display: -----------------------------------------------------------

    ax = plt.subplot(((Nits-1)//3 + 1) , 3, i+1)
#     ax.scatter( [10], [10] )  # shameless hack to prevent a slight change of axis...

    #added by me
    size_scale = 200
    ax.scatter(X_i[:, 0], X_i[:, 1], s=size_scale * A_i, c='blue')
    ax.scatter(Y_j[:, 0], Y_j[:, 1], s=size_scale * B_j, c='red')
    if blur > cluster_scale:
        x_c_ = x_c.detach().cpu().numpy()
        a_c_ = x_c.detach().cpu().numpy()
        y_c_ = y_c.detach().cpu().numpy()
        b_c_ = x_c.detach().cpu().numpy()
        ax.scatter(x_c_[:, 0], x_c_[:, 1], s=size_scale * a_c_, c='purple')
        ax.scatter(y_c_[:, 0], y_c_[:, 1], s=size_scale * b_c_, c='yellow')

    v_ = BrenierMap.detach().cpu().numpy()
    x_ = x_i.detach().cpu().numpy()
    ax.quiver( x_[:,0], x_[:,1], v_[:,0], v_[:,1], 
                scale = 1, scale_units="xy",# angles='xy',
                color="#5CBF3A", zorder= 3, width= 0.02/3 ) #/ len(x_)
#     display_potential(ax, G_j, "#E2C5C5")
#     display_potential(ax, F_i, "#C8DFF9")


#     if blur > cluster_scale:
#         display_samples(ax, y_j, b_j, [(.55,.55,.95, .2)])
#         display_samples(ax, x_i, a_i, [(.95,.55,.55, .2)], v = BrenierMap)
#         display_samples(ax, y_c, b_c, [(.55,.55,.95)])
#         display_samples(ax, x_c, a_c, [(.95,.55,.55)])

#     else:
#     display_samples(ax, y_j, b_j, [(.55,.55,.95)])
#     display_samples(ax, x_i, a_i, [(.95,.55,.55)])#, v = BrenierMap)


    ax.set_title("iteration {}, blur = {:.3f}".format(i+1, blur))

#     ax.set_xticks([0, 1]) ; ax.set_yticks([0, 1])
    ax.axis([-1, 1, -1, 1]) ; ax.set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.show()


In [None]:
# Copied code from above to see empirical sample complexities of sampling and approximating.
for num_points in [10, 20, 100, 1000]:
    all_losses = []
    for seed_ in range(1000):
        np.random.seed(seed_)
        num_dims = 2
        A_i = np.ones((num_points,)) / num_points
        X_i = np.random.rand(num_points, num_dims) - np.ones((num_points, num_dims))/2
        B_j = np.ones((num_points,)) / num_points
        Y_j = np.random.rand(num_points, num_dims) - np.ones((num_points, num_dims))/2

        for i in range(Nits):
            blur = scaling**i
            Loss = SamplesLoss("sinkhorn", p=2, blur=blur, diameter=diameter, cluster_scale = cluster_scale,
                                scaling=scaling, backend="multiscale")#, verbose=True)

            # Create a copy of the data...
            A_i_torch = torch.from_numpy(A_i).type(dtype)
            X_i_torch = torch.from_numpy(X_i).contiguous().type(dtype)
            B_j_torch = torch.from_numpy(B_j).type(dtype)
            Y_j_torch = torch.from_numpy(Y_j).contiguous().type(dtype)
            a_i, x_i = A_i_torch.clone(), X_i_torch.clone()
            b_j, y_j = B_j_torch.clone(), Y_j_torch.clone()


            # Compute the loss + gradients:
            Loss_xy = Loss(a_i, x_i, b_j, y_j)
#             [F_i, G_j, dx_i] = grad( Loss_xy, [a_i, b_j, x_i] )

            # The generalized "Brenier map" is (minus) the gradient of the Sinkhorn loss
            # with respect to the Wasserstein metric:
#             BrenierMap = - dx_i / (a_i.view(-1, 1) + 1e-7)

            # Compute the coarse measures for display ----------------------------------

#             x_lab = grid_cluster(x_i, cluster_scale)
#             _, x_c, a_c = cluster_ranges_centroids(x_i, x_lab, weights=a_i)
#         #     print("Clustered array size:", x_c.size())

#             y_lab = grid_cluster(y_j, cluster_scale)
#             _, y_c, b_c = cluster_ranges_centroids(y_j, y_lab, weights=b_j)

        Loss_xy_ = Loss_xy.detach().numpy()
#         print("Seed:", seed_, "num_points:", num_points, "Iteration:", i, "Loss_xy", Loss_xy_)
        all_losses.append(Loss_xy_)
#     plt.xscale('log')
    plt.hist(all_losses, bins=100, range=(0.0, 0.2))#, log=True)
    plt.title("num_points: " + str(num_points))
    plt.show()
    print("Mean:", np.mean(all_losses), "Median:", np.median(all_losses))
# For 10, 20, 100, 1000 points over 100 seeds:
# Mean: 0.038914967 Median: 0.037568074
# Mean: 0.025193183 Median: 0.021923613
# Mean: 0.005761823 Median: 0.0054472107
# Mean: 0.00069120317 Median: 0.00065572176
# Over 1000 seeds:
# Mean: 0.04238657 Median: 0.03862535
# Mean: 0.023622321 Median: 0.021923613
# Mean: 0.00586633 Median: 0.005333068
# Mean: 0.00069400994 Median: 0.0006544423