In [None]:
!wget https://github.com/sorki/python-mnist/blob/master/bin/mnist_get_data.sh
!mnist_get_data.sh

In [None]:
import os
import pickle

import numpy as np
from mnist import MNIST
from scipy.spatial.distance import cdist

from methods import APDAGD, PrimalDualAAM, PrimalDualAAMLS, Sinkhorn
from problems import EntropyRegularizedOTProblem

In [None]:
# https://pypi.org/project/python-mnist
mndata = MNIST('./data/')
images, labels = mndata.load_training()

In [None]:
n = len(images[0])
m = int(np.sqrt(n))

def mnist(eps, p, q):
    p, q = np.float64(images[p]), np.float64(images[q])
    p, q = p / sum(p), q / sum(q)
    
    p = (1-eps/8)*p + eps/(8*n)
    q = (1-eps/8)*q + eps/(8*n)
    
    return p, q

def cartesian_product(*arrays):
    la = len(arrays)
    dtype = np.result_type(*arrays)
    arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
    for i, a in enumerate(np.ix_(*arrays)):
        arr[...,i] = a
    return arr.reshape(-1, la)

C = np.arange(m)
C = cartesian_product(C, C)
C = cdist(C, C)
C /= np.max(C)
C.max()

In [None]:
#experiments were done for
p_list = [34860, 31226,   239, 37372, 17390]
q_list = [45815, 35817, 43981, 54698, 49947]

In [None]:
x_array = np.linspace(1/2e-2, 1/4e-4, 6)
epslist = 1/x_array
epslist

In [None]:
p, q = mnist(epslist[0], p_list[0], q_list[0])
p[:5], q[:5]

In [None]:
C[:5]

In [None]:
os.makedirs("reports/entropy_regularized", exist_ok=True)

# Sinkhorn

In [None]:
for k in range(1):
    for eps in epslist[:1]:
        epsp = eps / 8
        p, q = mnist(epsp, p_list[k], q_list[k])
        gamma = eps / 4 / np.log(n)
        entr_reg = EntropyRegularizedOTProblem(gamma, n, C, p, q)
        lamu = np.zeros(2 * n)
        x, history = Sinkhorn(epsp / 2, log=True).fit(entr_reg, lamu)
        with open(f"reports/entropy_regularized/sinkhorn_{eps}.pkl", "wb") as f:
            pickle.dump(history, f)

# APDAGD

In [None]:
for k in range(1):
    for eps in epslist[:1]:
        p, q = mnist(0, p_list[k], q_list[k])
        gamma = eps / 3 / np.log(n)
        entr_reg = EntropyRegularizedOTProblem(gamma, n, C, p, q)
        x, history = APDAGD(eps / 6, log=True).fit(entr_reg)
        with open(f"reports/entropy_regularized/apdagd_{eps}.pkl", "wb") as f:
            pickle.dump(history, f)

# PDAAM with line-search

In [None]:
for k in range(1):
    for eps in epslist[:1]:
        epsp = eps / 8
        p, q = mnist(epsp, p_list[k], q_list[k])
        gamma = eps / 3 / np.log(n)
        entr_reg = EntropyRegularizedOTProblem(gamma, n, C, p, q)
        x, history = PrimalDualAAMLS(eps / 6, log=True).fit(entr_reg)
        with open(f"reports/entropy_regularized/pdaam-ls_{eps}.pkl", "wb") as f:
            pickle.dump(history, f)

# PDAAM

In [None]:
for k in range(1):
    for eps in epslist[:1]:
        epsp = eps / 8
        p, q = mnist(epsp, p_list[k], q_list[k])
        gamma = eps / 3 / np.log(n)
        entr_reg = EntropyRegularizedOTProblem(gamma, n, C, p, q)
        x, history = PrimalDualAAM(eps / 6, log=True).fit(entr_reg)
        with open(f"reports/entropy_regularized/pdaam_{eps}.pkl", "wb") as f:
            pickle.dump(history, f)