# The Wasserstein Distance

The Wasserstein distance measures the discrepancy between two distributions.  
For simplicity, we consider discrete distributions on
\([ \delta_1, \delta_2, \ldots, \delta_n ]\).  

Given a ground metric, for instance the \(L^2\) norm  
\[
c(x,y) = \lVert x - y \rVert_2,
\]  
we can construct a distance matrix  
\[
C_{i,j} = c(\delta^a_i, \delta^b_j).
\]  

Then the \(p\)-Wasserstein distance between \(a\) and \(b\) is defined as  

\[
W_p(a,b) = \left( \min_{P \in U(a,b)} \langle C^p, P \rangle \right)^{1/p}
= \left( \min_{P \in U(a,b)} \sum_{i,j} C_{i,j}^p P_{i,j} \right)^{1/p},
\]  

where  
\[
U(a,b) = \Bigl\{ P \;\big|\; \sum_{i,j} P_{i,j} = 1, \;
\sum_j P_{i,j} = a_i, \;
\sum_i P_{i,j} = b_j \Bigr\}
\]  

is the set of joint distributions over  
\([ \delta^a_1, \delta^a_2, \ldots, \delta^a_n ] \_]()


In [2]:
import random
import numpy as np
import cvxpy as cp
import holoviews as hv
from holoviews import opts
from bokeh.layouts import gridplot, row, column
from bokeh.plotting import figure, output_file, show
from bokeh.io import output_notebook, export_png
hv.extension('bokeh')
output_notebook()

In [3]:
###       discretize a mixed normal distribution and gamma distribution         ###
n_bins = 50
def mixedGaussian(mu1, mu2, sigma1, sigma2, n):
    bernoulli = np.random.binomial(n = 1, p = 0.5, size = n)
    gaussian1 = np.random.normal(mu1, sigma1, n)
    gaussian2 = np.random.normal(mu2, sigma2, n)
    return (gaussian1**bernoulli)*(gaussian2**(1-bernoulli))
dist_a = mixedGaussian(mu1 = 1, mu2 = 10, sigma1 = 2, sigma2 = 1.5, n = 100000)
p_a, edges_a = np.histogram(dist_a, bins=n_bins)
pa = figure(title='dist_a: mixed gaussian distribution, 0.5*N(1, 2) + 0.5*N(10, 1.5)', background_fill_color="#fafafa", tools = "save", height=300)
p_a = p_a/100000
pa.quad(top=p_a, bottom=0, left=edges_a[:-1], right=edges_a[1:], fill_color="navy", line_color="white", alpha=0.5)
dist_b = np.random.gamma(7, scale = 1, size = 100000)
p_b, edges_b = np.histogram(dist_b, bins=n_bins)
p_b = p_b/100000
pb = figure(title='dist_b: gamma distribution, Gamma(7, 1)', background_fill_color="#fafafa", y_range = pa.y_range, height=300)
pb.quad(top=p_b, bottom=0, left=edges_b[:-1], right=edges_b[1:], fill_color="navy", line_color="white", alpha=0.5)
show(row(pa, pb)) #export_png(row(pa, pb), filename="sinkhorn428_p1.png")

In [4]:
###       linear programming to solve the Wasserstein distance       ###
edges_a = edges_a[:-1]
edges_b = edges_b[:-1]
# the distance matrix
C = (edges_a.reshape((n_bins,1)) - edges_b.reshape((1,n_bins)))**2
# Create two scalar optimization variables.
P0 = cp.Variable((n_bins, n_bins), nonneg=True)

# Create two constraints.
constraints = [cp.sum(P0, axis = 1) == p_a, cp.sum(P0, axis = 0) == p_b]

# Form objective.
obj = cp.Minimize(cp.trace(C.T@P0))

# Form and solve problem.
prob = cp.Problem(obj, constraints)
prob.solve()  # Returns the optimal value.
print("status:", prob.status)
print("optimal value", prob.value)
print("optimal var", P0.value)

status: optimal
optimal value 8.69396924147564
optimal var [[9.99987109e-06 2.99383390e-10 7.86848396e-11 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [2.91259785e-11 1.69241240e-11 1.41361052e-11 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [2.69670795e-11 1.51121235e-11 1.30763667e-11 ... 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 ...
 [1.16380119e-11 0.00000000e+00 0.00000000e+00 ... 1.89479052e-11
  1.97005691e-11 6.33652648e-11]
 [9.41458447e-12 0.00000000e+00 0.00000000e+00 ... 2.29515787e-11
  2.51905237e-11 2.00693480e-10]
 [1.24193889e-11 0.00000000e+00 0.00000000e+00 ... 2.76663786e-11
  3.13489647e-11 1.99997055e-05]]


# Entropic Regularization

Define the relative entropy between \(P\) and \(a \otimes b\) as  

\[
\mathrm{KL}(P \,\|\, a \otimes b)
= \sum_{i,j} P_{i,j} \log \frac{P_{i,j}}{a_i \times b_j}
= \sum_{i,j} P_{i,j} \log P_{i,j} - \sum_i a_i \log a_i - \sum_j b_j \log b_j,
\]  

where \(a \otimes b\) denotes the joint distribution when the marginal distributions \(a\) and \(b\) are independent.  

By the above derivation, the regularization term \(\mathrm{KL}(P \,\|\, a \otimes b)\) is equivalent to  

\[
-H(P) = \sum_{i,j} P_{i,j} \log P_{i,j} - P_{i,j},
\]  

since the difference is a constant irrelevant to \(P\).  

Then the entropic penalized Wasserstein distance is defined as  

\[
W_{p}^{\varepsilon,p} = \min_{P \in U(a,b)} \langle C^p, P \rangle + \varepsilon \, \mathrm{KL}(P \,\|\, a \otimes b),
\]  

or equivalently,  

\[
W_{p}^{\varepsilon,p} = \min_{P \in U(a,b)} \langle C^p, P \rangle - \varepsilon H(P).
\]  

---

### Properties of the Entropic Penalized Wasserstein Distance

- Since \(\mathrm{KL}(P \,\|\, a \otimes b)\) is strongly convex, a unique minimizer exists in the above optimization problem.  
  Note that the original Wasserstein distance in the Kantorovich formulation may have several minimizers.  

- With entropy regularization, the solution \(P\) is less sparse, in the sense that fewer entries \(P_{i,j}\) are zero.  
  In contrast, in the original Wasserstein distance, the solution \(P\) of the linear program lies on the boundary of \(U(a,b)\),  
  meaning most entries of \(P\) are zero.  

- When \(\varepsilon \to \infty\), the solution \(P \to a \otimes b\).  
  When \(\varepsilon \to 0\), the solution \(P \to P_{\text{OT}}\) (the optimal transport plan).  

---

# Sinkhorn Algorithm

We now consider the Wasserstein-2 distance (\(p=2\)).  
Let us relabel \(C^2\) as \(C\).  

The Sinkhorn algorithm uses the dual formulation of the constrained convex optimization,  
which reduces the unknowns from \(P\) (\(n^2\) variables) to the dual variables \(f, g\) (\(2n\) variables).  

Define the Lagrangian:  

\[
L(P,f,g) = \langle C, P \rangle - \varepsilon H(P) - \langle f, P \mathbf{1} - a \rangle - \langle g, P^\top \mathbf{1} - b \rangle.
\]  

The first-order condition is  

\[
\frac{\partial L(P,f,g)}{\partial P_{i,j}} = C_{i,j} + \varepsilon \log P_{i,j} - f_i - g_j = 0,
\]  

which leads to the solution  

\[
P = \mathrm{diag}\!\left(e^{f/\varepsilon}\right) \cdot e^{-C/\varepsilon} \cdot \mathrm{diag}\!\left(e^{g/\varepsilon}\right).
\]  

Therefore, the solution must be of the form  

\[
P = \mathrm{diag}(u) \, K \, \mathrm{diag}(v),
\quad \text{with } K = e^{-C/\varepsilon}.
\]  

The marginal constraints require  

\[
\mathrm{diag}(u) \, K \, \mathrm{diag}(v) \mathbf{1} = a,
\quad \text{and} \quad
\mathrm{diag}(v) \, K^\top \, \mathrm{diag}(u) \mathbf{1} = b.
\]  

---

### Algorithm (Sinkhorn Iteration)

**Input:** \(C, a, b, \varepsilon\)  

1. **Initialization:**  
   \[
   u = v = \mathbf{1}, \quad K = e^{-C/\varepsilon}
   \]  

2. **Main loop:**  
   While \(P\) changes:  
   \[
   u^{(i+1)} = \frac{a}{K v^{(i)}},
   \quad v^{(i+1)} = \frac{b}{K^\top u^{(i+1)}}
   \]  

3. **Return:**  
   \[
   P = \mathrm{diag}(u) K \mathrm{diag}(v),
   \quad \hat{W}_p^p = \mathrm{trace}(C^\top P).
   \]  


In [None]:
element_max = np.vectorize(max)
def sinkhorn(C, a, b, epsilon, precision):
    a = a.reshape((C.shape[0], 1))
    b = b.reshape((C.shape[1], 1))
    K = np.exp(-C/epsilon)

    # initialization
    u = np.ones((C.shape[0], 1))
    v = np.ones((C.shape[1], 1))
    P = np.diag(u.flatten()) @ K @ np.diag(v.flatten())
    p_norm = np.trace(P.T @ P)

    while True:
        u = a/element_max((K @ v), 1e-300) # avoid divided by zero
        v = b/element_max((K.T @ u), 1e-300)
        P = np.diag(u.flatten()) @ K @ np.diag(v.flatten())
        if abs((np.trace(P.T @ P) - p_norm)/p_norm) < precision:
            break
        p_norm = np.trace(P.T @ P)
    return P, np.trace(C.T @ P)

###       apply sinkhorn algorithm to epsilon = 0.1, 1, 10, 100       ###
P1, W_1 = sinkhorn(C, p_a, p_b, epsilon = 1, precision = 1e-30)

###       visualize the join distribution with different epsilon       ###
bounds=(edges_b.min(), edges_a.min(), edges_b.max(), edges_a.max())   # Coordinate system: (left, bottom, right, top)
img_0 = hv.Image(np.flip(P0.value, axis=0), bounds=bounds).relabel("epsilon = 0" + ', W =' + str(round(prob.value, 2))).opts(colorbar=False, cmap = 'BuPu', color_levels = int(1e4), width=300, xlabel='dist_b', ylabel='dist_a').redim.range(z = (0, np.max(P0.value)))
# img_01 = hv.Image(np.flip(P01, axis=0), bounds=bounds).relabel((', ').join(["epsilon = 0.1",  'W_hat =' + str(round(W_01, 2))])).opts(colorbar=False, cmap = 'BuPu', color_levels = int(1e4), width=300, xlabel='dist_b', ylabel='dist_a').redim.range(z = (0, np.max(P0.value)))

P_independent = p_a.reshape((n_bins, 1))*p_b.reshape((1, n_bins))
img_infty = hv.Image(np.flip(P_independent, axis = 0), bounds=bounds).relabel((', ').join(["epsilon = infty",  'W_hat =' + str(round(np.trace(C.T @ P_independent), 2))])).opts(colorbar=True, cmap = 'BuPu', color_levels = int(1e4), width=350, xlabel='dist_b', ylabel='dist_a').redim.range(z = (0, np.max(P0.value)))

layout = hv.Layout([img_0, # img_01, img_1, img_10, img_100,
                    img_infty]).cols(2)
hv.save(layout, filename="sinkhorn428_p2.png")

In [1]:
! uv pip install selenium

[2mUsing Python 3.12.11 environment at: /usr[0m
[2mAudited [1m1 package[0m [2min 970ms[0m[0m


In [7]:
!uv pip install geckodriver-autoinstaller

[2mUsing Python 3.12.11 environment at: /usr[0m
[2K[2mResolved [1m1 package[0m [2min 179ms[0m[0m
[2K[2mPrepared [1m1 package[0m [2min 24ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 9ms[0m[0m
 [32m+[39m [1mgeckodriver-autoinstaller[0m[2m==0.1.0[0m


In [None]:
from selenium import webdriver
import geckodriver_autoinstaller


geckodriver_autoinstaller.install()
driver = webdriver.Firefox()