# Sinkhorn Algorithm for Optimal Transport

The idea of optimal transport is to find a transportation plan $\pi$ from two probability
measures $\mu$ and $\nu$ defined on the spaces $X$ and $Y$ respectively.
This transportation plan should be the most economical one.
To express that is math, we define the cost function $c$
(in discrete case it is just a matrix). This way the cost of transporting the mass
from a point $x \in X$ to $y \in Y$ has the cost $c(x, y)$.

There are several formulations of the Optimal Transport problem.

The **Kantorovich formulation** is stated as
$$
\min_{\pi \in \Pi(\mu, \nu)} \int_{X \times Y} c(x, y) d \pi (x, y),
$$
subject to the marginal constraints for all measurable sets $A \subset X$ and $B \subset Y$
$$
\pi (A \times Y) = \mu(A)
\qquad \text{and} \qquad
\pi (X \times B) = \nu(B).
$$

Computationally, it may be hard to work with this Kantorovich formulation
as it underlines a strict bounds for the problem to solve.
We may need to introduce so called _entropic regularization_ of the _Kantorovich formulation_:

$$
\min_{\pi \in \Pi(\mu, \nu)} \int_{X \times Y} c(x, y) d \pi (x, y) + \varepsilon KL(\pi | \mu \otimes \nu),
$$

where $KL(\pi)$ is the Kullback-Leibler divergence defined as
$$
KL(\pi) = \int_{X \times Y} \left( \pi \ln \left( \pi - 1 \right) \right) d\mu \otimes \nu.
$$

This regularized problem has a unique solution $\pi^*$ such that
$$
\pi^* (x, y) = u(x) \; K(x, y) \; v(y),
$$
where $K(x, y) = \exp \left\{ - c(x, y) / \varepsilon \right\} $.

For the discrete case we have
$$
\pi^* = \text{diag}(u) \; K \; \text{diag}(v)
$$
where matrix $K$ is
$$
K_{ij} = \exp \left\{ \frac{- C_{ij}}{\varepsilon} \right\}.
$$

For such vectors $u$ and $v$ we have the marginal constrains to be held
$$
\pi^* \mathbb{1} = \mu \qquad \text{and} \qquad \pi^{*T} \mathbb{1} = \nu. 
$$

These constraints lead to the iterative updates, which define the Sinkhorn Algorithm
of iteratively updating the functions $u$ and $v$ until the convergence
$$
u := \frac{\mu}{K v}, \qquad v := \frac{\nu}{K^T u}.
$$

In [2]:
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt