In [88]:
import torch


def sinkhorn(scores, eps=0.1, n_iter=3):
    # 将代价矩阵通过 softmax 转换为初始概率分布
    scores = torch.tensor(scores)
    n, m = scores.shape
    Q = torch.softmax(-scores / eps, dim=1)  # 使用 softmax 归一化
    r = torch.ones(n)
    c = torch.ones(m) * (n / m)

    for _ in range(n_iter):
        # 行归一化
        u = c / Q.sum(dim=0)
        Q *= u.unsqueeze(0)
        # 列归一化
        v = r / Q.sum(dim=1)
        Q *= v.unsqueeze(1)
    return Q.numpy()


x = torch.randn(5, 5)
y = sinkhorn(x, eps=0.05, n_iter=20)
y.sum(0), y.sum(1)

  scores = torch.tensor(scores)


(array([0.97665113, 0.9992993 , 1.0000001 , 0.99942327, 1.0246263 ],
       dtype=float32),
 array([0.99999994, 1.        , 1.        , 1.        , 1.        ],
       dtype=float32))

In [93]:
def sinkhorn_log(scores, eps=0.1, n_iter=3):
    """
    Sinkhorn algorithm in the log domain.
    Inputs:
      scores: cost matrix (n x m)
      eps: temperature parameter
      n_iter: number of normalization iterations
    Returns:
      Q: normalized matrix as torch.Tensor
    """
    scores = torch.tensor(scores)  # ensure tensor conversion
    n, m = scores.shape

    # Initialize log domain matrix L = log Q = -scores / eps
    L = -scores / eps

    # Set log scaling factors (r are ones, c are scaled to match the desired sum along columns)
    log_r = torch.zeros(n, device=L.device)  # since r=1 => log(1)=0
    log_c = torch.log(torch.ones(m, device=L.device) * (n / m))

    for _ in range(n_iter):
        # Normalize columns: compute log sum over rows
        logsum_cols = torch.logsumexp(L, dim=0)  # shape: (m,)
        log_u = log_c - logsum_cols  # adjustment for columns
        # Broadcast addition over columns
        L = L + log_u.unsqueeze(0)

        # Normalize rows: compute log sum over columns
        logsum_rows = torch.logsumexp(L, dim=1)  # shape: (n,)
        log_v = log_r - logsum_rows  # adjustment for rows
        # Broadcast addition over rows
        L = L + log_v.unsqueeze(1)

    # Convert back from log-domain
    Q = torch.exp(L)
    return Q


# Example usage:
x = torch.randn(5, 5)
y = sinkhorn_log(x, eps=0.05, n_iter=20)
print("Column sums:", y.sum(dim=0))
print("Row sums:", y.sum(dim=1))

Column sums: tensor([0.6664, 0.4999, 0.6622, 0.4999, 0.6716])
Row sums: tensor([1.0000, 1.0000, 1.0000])


  scores = torch.tensor(scores)  # ensure tensor conversion
