In [5]:
import numpy as np
from collections import Counter
import torch
import torch.nn.functional as F

In [6]:
def pairwise_euclidean_distance(x, y):
    cost = torch.sum(x ** 2, axis=1, keepdim=True) + torch.sum(y ** 2, dim=1) - 2 * torch.matmul(x, y.t())
    return cost

In [7]:
init_b_dist = None
init_a_dist = None
a_dist = None
b_dist = None
sinkhorn_alpha = 3.0
OT_max_iter = 5000
stopThr=.5e-2
epsilon = 1e-16

In [8]:
def forward(x, y):
  # Sinkhorn's algorithm
  M = pairwise_euclidean_distance(x, y) # 欧几里得距离
  device = M.device

  if init_a_dist is None:
      a = (torch.ones(M.shape[0]) / M.shape[0]).unsqueeze(1).to(device)
  else:
      a = F.softmax(a_dist, dim=0).to(device)

  if init_b_dist is None:
      b = (torch.ones(M.shape[1]) / M.shape[1]).unsqueeze(1).to(device)
  else:
      b = F.softmax(b_dist, dim=0).to(device)

  u = (torch.ones_like(a) / a.size()[0]).to(device) # Kx1

  K = torch.exp(-M * sinkhorn_alpha) # 指数
  err = 1
  cpt = 0
  while err > stopThr and cpt < OT_max_iter:
      v = torch.div(b, torch.matmul(K.t(), u) + epsilon) # torch.div 数组的’点除’运算
      u = torch.div(a, torch.matmul(K, v) + epsilon)
      cpt += 1
      if cpt % 50 == 1:
          bb = torch.mul(v, torch.matmul(K.t(), u)) # 矩阵点乘运算
          err = torch.norm(torch.sum(torch.abs(bb - b), dim=0), p=float('inf'))

  transp = u * (K * v.T)

  loss_ETP = torch.sum(transp * M)

  return loss_ETP, transp

In [9]:
x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float32)
y = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float32)

In [10]:
loss,transp =  forward(x, y)

print(loss)
print(transp)

tensor(4.0268e-10)
tensor([[3.3333e-01, 1.2584e-11, 6.7683e-43],
        [1.2584e-11, 3.3333e-01, 1.2584e-11],
        [6.7683e-43, 1.2584e-11, 3.3333e-01]])


array([1. , 0.4, 0.4, 0.2, 0.2])