In [None]:
!wget https://github.com/sorki/python-mnist/blob/master/bin/mnist_get_data.sh
!mnist_get_data.sh

In [1]:
import os
import pickle

import numpy as np
from mnist import MNIST
from scipy.spatial.distance import cdist

from methods import APDAGD, CLVR, PrimalDualAAM, PrimalDualAAMLS, Sinkhorn
from problems import EuclideanRegularizedOTProblem

In [2]:
# https://pypi.org/project/python-mnist
mndata = MNIST('./data/')
images, labels = mndata.load_training()

In [3]:
n = len(images[0])
m = int(np.sqrt(n))

def mnist(eps, p, q):
    p, q = np.float64(images[p]), np.float64(images[q])
    p, q = p / sum(p), q / sum(q)
    
    p = (1-eps/8)*p + eps/(8*n)
    q = (1-eps/8)*q + eps/(8*n)
    
    return p, q

def cartesian_product(*arrays):
    la = len(arrays)
    dtype = np.result_type(*arrays)
    arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
    for i, a in enumerate(np.ix_(*arrays)):
        arr[...,i] = a
    return arr.reshape(-1, la)

C = np.arange(m)
C = cartesian_product(C, C)
C = cdist(C, C)
C /= np.max(C)
C.max()

1.0

In [4]:
#experiments were done for
p_list = [34860, 31226,   239, 37372, 17390]
q_list = [45815, 35817, 43981, 54698, 49947]

In [5]:
x_array = np.linspace(1/2e-2, 1/4e-4, 6)
epslist = 1/x_array

In [10]:
chosen_epslist = 1/x_array[[0, 1, 4]]
chosen_epslist

array([0.02      , 0.00185185, 0.00049751])

In [6]:
os.makedirs("reports/euclidean_regularized", exist_ok=True)

# Sinkhorn

In [None]:
lambda_ = np.zeros(n, dtype=np.float64)
mu_ = np.zeros(n, dtype=np.float64)
lamu = np.concatenate((lambda_, mu_), axis=0)

In [None]:
for k in range(1):
    for eps in epslist[:1]:
        epsp = eps / 4
        p, q = mnist(epsp, p_list[k], q_list[k])
        gamma = eps / 2
        eucl_reg = EuclideanRegularizedOTProblem(gamma, n, C, p, q)
        x, history = Sinkhorn(epsp / 4, log=True).fit(eucl_reg, lamu)
        with open(f"reports/euclidean_regularized/sinkhorn_{eps}.pkl", "wb") as f:
            pickle.dump(history, f)

# APDAGD

In [None]:
for k in range(1):
    for eps in epslist[:1]:
        p, q = mnist(eps, p_list[k], q_list[k])
        gamma = eps / 3
        eucl_reg = EuclideanRegularizedOTProblem(gamma, n, C, p, q)
        x, history = APDAGD(eps / 3, log=True).fit(eucl_reg)
        with open(f"reports/euclidean_regularized/apdagd_{eps}.pkl", "wb") as f:
            pickle.dump(history, f)

# PDAAM with line-search

In [None]:
for k in range(1):
    for eps in epslist[:1]:
        p, q = mnist(eps, p_list[k], q_list[k])
        gamma = eps / 3
        eucl_reg = EuclideanRegularizedOTProblem(gamma, n, C, p, q)
        x, history = PrimalDualAAMLS(eps / 3, log=True).fit(eucl_reg)
        with open(f"reports/euclidean_regularized/pdaam-ls_{eps}.pkl", "wb") as f:
            pickle.dump(history, f)

# PDAAM

In [None]:
epslist[4:5]

In [None]:
for k in range(1):
    for eps in epslist[:1]:
        p, q = mnist(eps, p_list[k], q_list[k])
        gamma = eps / 3
        eucl_reg = EuclideanRegularizedOTProblem(gamma, n, C, p, q)
        x, history = PrimalDualAAM(eps / 3, log=True).fit(eucl_reg)
        with open(f"reports/euclidean_regularized/pdaam_{eps}.pkl", "wb") as f:
            pickle.dump(history, f)

# CLVR

In [7]:
alpha = 2

In [11]:
for k in range(1):
    for eps in chosen_epslist[1:]:#epslist[:1]:
        p, q = mnist(eps, p_list[k], q_list[k])
        gamma = eps / 3
        eucl_reg = EuclideanRegularizedOTProblem(gamma, n, C, p, q)
        x, history = CLVR(alpha, eps / 3, log=True).fit(eucl_reg, max_iter=10000)
        with open(f"reports/euclidean_regularized/clvr_{eps}.pkl", "wb") as f:
            pickle.dump(history, f)

–––––––––––––––––––––––––––––
CLVR configuration:
alpha = 2
epsilon = 0.0006172839506172839
–––––––––––––––––––––––––––––

– Outer iteration 0: 0.014236 > 0.000617 or 0.194347 > 0.000617
– Outer iteration 200: 0.150786 > 0.000617 or 39.833843 > 0.000617
– Outer iteration 400: 0.118084 > 0.000617 or 0.000390 > 0.000617
– Outer iteration 600: 0.118727 > 0.000617 or 1.048332 > 0.000617
– Outer iteration 800: 0.117519 > 0.000617 or 182.753281 > 0.000617
– Outer iteration 1000: 0.109845 > 0.000617 or 115.716605 > 0.000617
– Outer iteration 1200: 0.102621 > 0.000617 or 18.826074 > 0.000617
– Outer iteration 1400: 0.099510 > 0.000617 or 183.171381 > 0.000617
– Outer iteration 1600: 0.092266 > 0.000617 or 362.739469 > 0.000617
– Outer iteration 1800: 0.085902 > 0.000617 or 187.805591 > 0.000617
– Outer iteration 2000: 0.083911 > 0.000617 or 240.449721 > 0.000617
– Outer iteration 2200: 0.081314 > 0.000617 or 360.344792 > 0.000617
– Outer iteration 2400: 0.078430 > 0.000617 or 274.806492 > 0.00