In [5]:
import pygmtools as pygm
import torch
pygm.set_backend('pytorch')

In [6]:
arr = torch.randn(2000, 2000)
arr

tensor([[ 1.3087, -1.7306,  2.2575,  ..., -1.2932, -0.3575, -0.5444],
        [-0.6863, -0.9531, -0.0267,  ...,  0.4607,  1.1105, -0.1052],
        [ 0.7123,  0.2137, -0.3272,  ...,  0.1451,  0.2820,  0.4747],
        ...,
        [ 0.9561,  1.3059, -0.4667,  ..., -0.7492,  0.1205,  0.7719],
        [-0.1278,  0.9065, -0.1012,  ...,  0.0249,  0.2878, -0.3659],
        [ 0.1662,  2.3891,  0.5505,  ..., -0.3514, -1.0603, -0.4309]])

In [10]:
%timeit y=pygm.hungarian(arr)

111 ms ± 2.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [9]:
%timeit z=pygm.sinkhorn(arr)

52.7 ms ± 1.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
z=pygm.sinkhorn(arr)
print(f'row sum: {z.sum(1)}')
print(f'col sum: {z.sum(0)}')

row sum: tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000])
col sum: tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000])


In [17]:
def sinkhorn(
    cost_matrix: torch.Tensor,
    num_iterations: int = 20,
    temperature: float = 0.1,
    eps: float = 1e-8,
) -> torch.Tensor:
    """
    实现 Sinkhorn 算法求解最优传输问题

    参数:
        cost_matrix (torch.Tensor): 形状为 (batch_size, n, m) 的代价矩阵
        num_iterations (int): 迭代次数
        temperature (float): 温度参数，控制软化程度
        eps (float): 数值稳定性的小常数

    返回:
        torch.Tensor: 形状为 (batch_size, n, m) 的软匹配矩阵
    """
    # 确保输入是3维张量
    if cost_matrix.dim() == 2:
        cost_matrix = cost_matrix.unsqueeze(0)

    batch_size, n, m = cost_matrix.shape

    # 初始化 log-空间的矩阵
    log_P = -cost_matrix / temperature

    # 初始化行和列的 scaling vectors
    log_u = torch.zeros(batch_size, n, 1, device=cost_matrix.device)
    log_v = torch.zeros(batch_size, 1, m, device=cost_matrix.device)

    # Sinkhorn 迭代
    for _ in range(num_iterations):
        # 更新 u (行归一化)
        log_u = -torch.logsumexp(log_P + log_v, dim=2, keepdim=True)

        # 更新 v (列归一化)
        log_v = -torch.logsumexp(log_P + log_u, dim=1, keepdim=True)

    # 计算最终的传输矩阵
    P = torch.exp(log_P + log_u + log_v)

    # 处理数值不稳定性
    P = torch.clamp(P, eps, 1.0)

    # 最终的行列归一化
    P = P / P.sum(dim=2, keepdim=True)
    P = P / P.sum(dim=1, keepdim=True)

    return P.squeeze(0) if batch_size == 1 else P


In [68]:
cost_matrix = torch.randint(0, 10, (5, 5)).float()

In [92]:
# 应用 Sinkhorn 算法
result = sinkhorn(cost_matrix, num_iterations=100, temperature=1e-5)

# 验证结果
print("Row sums:", result.sum(1))  # 应该接近 1
print("Column sums:", result.sum(0))  # 应该接近 1
print("Result:", torch.sum(cost_matrix*result))  # 乘以原始代价矩阵，得到最终的传输矩阵

Row sums: tensor([1.0043, 0.9938, 0.9934, 1.0045, 1.0039])
Column sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
Result: tensor(4.9784)


In [41]:
import torch.nn.functional as F
a=(torch.arange(1, 901).reshape(30,30)>450).float()
# print(a)
%timeit F.normalize(a, p=1, dim=1)

10.9 µs ± 255 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [42]:
%timeit a / (a.sum(dim=1, keepdim=True) + 1e-8)

6.19 µs ± 168 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [47]:
torch.isclose(F.normalize(a, p=1, dim=1), a / (a.sum(dim=1, keepdim=True) + 1e-8)).all()

tensor(True)

In [44]:
def sparse_normalize(binary_matrix: torch.Tensor) -> torch.Tensor:
    """
    使用稀疏矩阵格式进行归一化
    """
    indices = binary_matrix.nonzero().t()
    values = binary_matrix[indices[0], indices[1]]
    sparse = torch.sparse_coo_tensor(indices, values, binary_matrix.size())
    row_sums = torch.sparse.sum(sparse, dim=1).to_dense()
    normalized = torch.sparse_coo_tensor(
        indices, values / row_sums[indices[0]], binary_matrix.size()
    )
    return normalized.to_dense()

In [45]:
%timeit sparse_normalize(a)

94.6 µs ± 8.38 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
def fast_normalize(binary_matrix: torch.Tensor, dim: int = 1) -> torch.Tensor:
    """
    使用where快速归一化二值矩阵
    """
    row_sums = binary_matrix.sum(dim=dim, keepdim=True)
    return torch.where(
        row_sums > 0, binary_matrix / row_sums, torch.zeros_like(binary_matrix)
    )
%timeit fast_normalize(a)

9.26 µs ± 28.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
