# Implementation of the Sinkhorn-Knopp algorithm
## Introduction
[Sinkhorn-Knopp](https://www.cerfacs.fr/algor/reports/2006/TR_PA_06_42.pdf)的算法基于一个发现，$d_M^\lambda(\mathbf{r}, \mathbf{c})$ 的最优解 $\mathbf{P}_\lambda^*$ 一定满足如下形式：

$$
(\mathbf{P}_\lambda^*)_{ij} = \alpha_i \beta_j e^{-\lambda M_{ij}}
$$

其中 $\alpha_i, \beta_j$ 是待求解的常数，从而保证 $\mathbf{P}_\lambda^*$ 横向求和是 $\mathbf{r}$，纵向求和是 $\mathbf{c}$。因此我们采用一种迭代的思想去求解，即

**Input:** $\mathbf{M}, \mathbf{r}, \mathbf{c}, \lambda$

**Initialize:** $(\mathbf{P}_\lambda)_{ij} = e^{-\lambda M_{ij}}$

**Loop:**

1. Scale每一行，从而横向求和是 $\mathbf{r}$。
2. Scale每一列，从而纵向求和是 $\mathbf{c}$。

**Until convergence.**

## Implementation

In [4]:
import numpy as np
def sinkhorn_knopp(cost_matrix, source, target, lam, eps=1e-3):
    n, m = cost_matrix.shape
    # Initialize the transport matrix P
    P = np.exp(-lam * cost_matrix)
    # Normalize the matrix
    P /= P.sum()

    u = np.zeros(n)

    # Source corresponds to rows, 
    # target corresponds to colums
    source = source.reshape(-1, 1)
    target = target.reshape(1, -1)

    while np.max(np.abs(u - P.sum(1))) > eps:
        u = P.sum(1)
        row_ratio = source / P.sum(1, keepdims=True)
        P *= row_ratio
        col_ratio = target / P.sum(0, keepdims=True)
        P *= col_ratio
    return P, np.sum(P * cost_matrix)

## Application of Sinkhorn-Knopp
It’s Party Time! 假设实验室要举办一场聚会, 一共有n=8个人, 有m=5种小吃。每人的食量由`source_dist`所示，小吃的份数由`target_dist`所示，每个人对应小吃的厌恶指数由`cost_matrix`所示。+2表示最厌恶，-2表示最喜欢。

In [5]:
source_dist = np.array([3, 3, 3, 4, 2, 2, 2, 1])
target_dist = np.array([4, 2, 6, 4, 4])

# Cost proportional to distance between bucket idxs
cost_matrix = np.array([
    [2, 2, 1, 0, 0],
    [0, -2, -2, -2, 2],
    [1, 2, 2, 2, -1],
    [2, 1, 0, 1, -1],
    [0.5, 2, 2, 1, 0],
    [0, 1, 1, 1, -1],
    [-2, 2, 2, 1, 1],
    [2, 1, 2, 1, -1]
])

transport_matrix, min_cost = sinkhorn_knopp(
    cost_matrix,
    source_dist,
    target_dist,
    lam=0.01
)
print(transport_matrix)
print(min_cost) 

[[0.59472088 0.29811948 0.90063855 0.60525176 0.60126933]
 [0.59641114 0.3050063  0.91227554 0.60697194 0.57933508]
 [0.60249062 0.29900917 0.89433811 0.59503748 0.60912461]
 [0.78978408 0.3998789  1.20806044 0.79577136 0.80650522]
 [0.40327862 0.19914433 0.59564181 0.40028628 0.40164896]
 [0.40247353 0.19974299 0.5974324  0.39749471 0.40285637]
 [0.41220723 0.19852765 0.59379731 0.39904673 0.39642108]
 [0.19863391 0.10057117 0.29781584 0.20013973 0.20283935]]
11.365116552366768


## 衡量分布间的距离
此外, Optimal Transport还可以被用来作为衡量两个分布之间的距离的工具. 当然单纯衡量分布之间距离的方式有很多, 例如KL-Divergence, TV-Divergence. 但是Optimal Transport的好处在于, 我们可以通过代价矩阵cost matrix来引入先验知识。