In [None]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from scipy.stats import multivariate_normal
from scipy.linalg import sqrtm

def wasserstein(m1, m2, S1, S2):
    dm = m2 - m1
    S1_half = sqrtm(S1)
    dS = S1 + S2 - 2 * sqrtm(S1_half.dot(S2).dot(S1_half))
    return np.sqrt(dm.dot(dm) + np.trace(dS))

def plot(p, xlim=[-5, 5], ylim=[-5, 5], n=200):
    X = np.linspace(*xlim, n)
    Y = np.linspace(*ylim, n)
    X, Y = np.meshgrid(X, Y)
    pos = np.dstack((X, Y))
    Z = p.pdf(pos)
    plt.contour(X, Y, Z)
    
m1 = np.array([0, 2])
S1 = np.array([[100, 0], [0, 10]])

m2 = np.array([0, -2])
S2 = np.array([[1, .3], [.3, 1]])

print("Wasserstein distance:", wasserstein(m1, m2, S1, S2))

plt.figure()
p1 = multivariate_normal(m1, S1)
p2 = multivariate_normal(m2, S2)
plot(p1)
plot(p2)

In [None]:
import cvxpy as cp

def opt_wasserstein(m1, m2, S1, S2):
    C = cp.Variable((4, 4), PSD=True)
    constraints = [C[:2, :2] == S1, C[2:, 2:] == S2]
    cost = cp.sum_squares(m2 - m1) + cp.trace(S1 + S2 - 2 * C[:2, 2:])
    prob = cp.Problem(cp.Minimize(cost), constraints)
    prob.solve()
    return np.sqrt(prob.value)
print("Wasserstein distance:", opt_wasserstein(m1, m2, S1, S2))