In [1]:
import numpy as np
import ot
import torch
import torch.nn.functional as F

In [2]:
import torch

class SinkhornDistance(torch.nn.Module):
    r"""
        Given two empirical measures each with :math:`P_1` locations
        :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
        outputs an approximation of the regularized OT cost for point clouds.
        Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
        'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
        'mean': the sum of the output will be divided by the number of
        elements in the output, 'sum': the output will be summed. Default: 'none'
        Shape:
            - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
            - Output: :math:`(N)` or :math:`()`, depending on `reduction`
    """

    def __init__(self, eps=1e-3, max_iter=100, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction

    def forward(self, mu, nu, C):
        u = torch.ones_like(mu)
        v = torch.ones_like(nu)

        # Sinkhorn iterations
        for i in range(self.max_iter):
            v = self.eps * \
                (torch.log(
                    nu + 1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            u = self.eps * \
                (torch.log(
                    mu + 1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u

        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(
            self.M(C, U, V)).detach()
        # Sinkhorn distance
        cost = torch.sum(
            pi * C, dim=(-2, -1))
        return cost, pi

    def M(self, C, u, v):
        '''
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / epsilon$"
        '''
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

def distributed_sinkhorn(out, sinkhorn_iterations=100, epsilon=0.05):
    L = torch.exp(out / epsilon).t() # K x B
    B = L.shape[1]
    K = L.shape[0]

    # make the matrix sums to 1
    sum_L = torch.sum(L)
    L /= sum_L

    for _ in range(sinkhorn_iterations):
        L /= torch.sum(L, dim=1, keepdim=True)
        L /= K

        L /= torch.sum(L, dim=0, keepdim=True)
        L /= B

    L *= B
    L = L.t()
    return L

In [5]:
first_histogram = np.array([1.0, 1.0, 1.0])
second_histogram = np.array([1.0, 1.0, 1.0])
cost = np.array([[4.0, 1, 3], [2, 0, 5], [3, 2, 2]])

print("Args M, u, v:")
print(cost)
print(first_histogram)
print(second_histogram)

pi_ot = ot.emd(first_histogram, second_histogram, cost)
cost_ot = ot.emd2(first_histogram, second_histogram, cost)

print("--------------emd: pi, cost, sum------------")
print(pi_ot)
print(cost_ot)
print((cost*pi_ot).sum())


pi_sinkhorn = ot.sinkhorn(first_histogram, second_histogram, cost, reg = 1)
cost_sinkhorn = ot.sinkhorn2(first_histogram, second_histogram, cost, reg = 1)
print("--------------sinkhorn: pi, cost, sum------------")
print(pi_sinkhorn)
print(cost_sinkhorn)
print((cost*pi_sinkhorn).sum())

pi = distributed_sinkhorn(-torch.tensor(cost))
print("--------------dsinkhorn: pi, sum------------")
print(pi)
print((torch.tensor(cost)*pi).sum())

first_histogram_pt = torch.tensor(first_histogram)
second_histogram_pt = torch.tensor(second_histogram)
cost_pt = torch.tensor(cost)
# cost_pt.requires_grad=True

first_histogram_pt = first_histogram_pt.expand(2, -1)
second_histogram_pt = second_histogram_pt.expand(2, -1)
cost_pt = cost_pt.expand(2, -1, -1)

solver = SinkhornDistance()
cost_ot, pi = solver(first_histogram_pt, second_histogram_pt, cost_pt)
print("--------------sinkhornd: pi, cost------------")
print(pi)
print(cost_ot)

# first_histogram = np.array([[1.0, 2.5, 1.0], [1.0, 1, 1.0]])
# second_histogram = np.array([[1.0, 1.5, 2.0], [1.0, 1, 0.0]])
# cost = np.array([[[4.0, 1, 3], [2, 0, 5], [3, 2, 2]],[[2, 1, 3], [2, 2, 0], [3, 2, 2]]])
# first_histogram_pt = torch.tensor(first_histogram)
# second_histogram_pt = torch.tensor(second_histogram)
# cost_pt = torch.tensor(cost)

# solver = SinkhornDistance()
# cost_ot, pi = solver(first_histogram_pt, second_histogram_pt, cost_pt)
# print(pi)
# print(cost_ot)

Args M, u, v:
[[4. 1. 3.]
 [2. 0. 5.]
 [3. 2. 2.]]
[1. 1. 1.]
[1. 1. 1.]
--------------emd: pi, cost, sum------------
[[0. 1. 0.]
 [1. 0. 0.]
 [0. 0. 1.]]
5.0
5.0
--------------sinkhorn: pi, cost, sum------------
[[0.18340519 0.44219386 0.37440095]
 [0.51965476 0.46091571 0.01942953]
 [0.29694005 0.09689043 0.60616952]]
5.732414695260431
5.732414695260431
--------------dsinkhorn: pi, sum------------
tensor([[1.9952e-07, 9.8983e-01, 1.0174e-02],
        [9.8988e-01, 1.0122e-02, 9.1099e-31],
        [1.9611e-05, 4.1333e-16, 9.9998e-01]], dtype=torch.float64)
tensor(5.0001, dtype=torch.float64)
--------------sinkhornd: pi, cost------------
tensor([[[0.0000, 0.9898, 0.0102],
         [0.9899, 0.0101, 0.0000],
         [0.0000, 0.0000, 1.0000]],

        [[0.0000, 0.9898, 0.0102],
         [0.9899, 0.0101, 0.0000],
         [0.0000, 0.0000, 1.0000]]], dtype=torch.float64)
tensor([5.0001, 5.0001], dtype=torch.float64)


In [65]:
consist_cnt = 0
total_cnt = 0
costa = 0
costb = 0
for _ in range(1000):
    protos = torch.randn(4, 768)
    hiddens = torch.randn(3, 768)
    sim = torch.einsum("sd,pd->sp", F.normalize(hiddens), F.normalize(protos))
    cost = 1 - sim
    # cost = torch.randint(0, 5, (4, 3)).float()
    # print(cost)

    res_algo = distributed_sinkhorn(-cost)
    # print(res_algo)

    n_samples, n_proto = cost.size()
    sample_constraint = torch.ones(n_samples, dtype=torch.float)
    proto_constraint = torch.ones(n_proto, dtype=torch.float) * n_samples / n_proto
    res_pot = ot.sinkhorn(sample_constraint, proto_constraint, M=cost / cost.max(), reg=0.05, warn=False)
    # print(res_pot)

    consist_mask = torch.argmax(res_algo, dim=-1) == torch.argmax(res_pot, dim=-1)
    consist_cnt += torch.sum(consist_mask).item()
    total_cnt += res_algo.size(0)
    costa += (res_algo * cost).sum().item()
    costb += (res_pot * cost).sum().item()

print("Ratio = {}/{}={}".format(consist_cnt, total_cnt, consist_cnt / total_cnt))
print(costa)
print(costb)

Ratio = 2993/3000=0.9976666666666667
2966.9846572875977
2968.5397703647614
