# Calculation of the discrete OT maps for colored MNIST

## 1. Imports

In [1]:
import sys

sys.path.append("..")

import numpy as np

%matplotlib inline 

import torch
import gc


# from src.resnet_generator import ResnetGenerator

from src.tools import load_dataset


from tqdm.notebook import tqdm
from IPython.display import clear_output


# This needed to use dataloaders for some datasets
from PIL import PngImagePlugin

LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

In [2]:
gc.collect()
torch.cuda.empty_cache()

## 2. Pairwise distance calculation

In [3]:
from ot.bregman import sinkhorn
from ot.lp import emd
import warnings

warnings.simplefilter("always")

DATASET1, DATASET1_PATH = "MNIST-colored_2", "/home/zyz/data/MNIST"
DATASET2, DATASET2_PATH = "MNIST-colored_3", "/home/zyz/data/MNIST"

IMG_SIZE = 32
BATCH_SIZE = 100
N = 1000

In [5]:
X_sampler, X_test_sampler = load_dataset(
    DATASET1,
    DATASET1_PATH,
    img_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False,
    device="cpu",
)
Y_sampler, Y_test_sampler = load_dataset(
    DATASET2,
    DATASET2_PATH,
    img_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False,
    device="cpu",
)

torch.cuda.empty_cache()
gc.collect()
clear_output()

X = X_test_sampler.loader.dataset[:N][0]
Y = Y_test_sampler.loader.dataset[:N][0]

In [9]:
M = np.zeros((N, N))

for i in tqdm(range(N)):
    M[i] = ((X[i][None, :] - Y) ** 2).sum(dim=(1, 2, 3))

a = np.ones(N) / N
b = np.ones(N) / N

M = np.array(M, dtype=np.float128)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i in tqdm(range(N)):


  0%|          | 0/1000 [00:00<?, ?it/s]

## 3. Discrete OT calculation

### Discrete OT mapping calculation ($\epsilon = 0$)

In [10]:
mapping = emd(a, b, M)

In [13]:
fname = "../discrete_transport_mapping/eps_0"
with open(fname, "wb") as f:
    np.save(f, mapping)

### Discrete entropic OT mapping calculation

### $\epsilon$ = 1

In [14]:
epsilon = 1

reg = epsilon
scale = 1 / (3 * IMG_SIZE * IMG_SIZE)
distance = M * scale
reg_normed = reg * scale
mapping = sinkhorn(
    a, b, distance, reg=reg_normed, warn=True, verbose=True, numItermax=100000
)

It.  |Err         
-------------------
    0|5.892064e-02|
   10|4.190472e-02|
   20|3.258747e-02|
   30|2.552248e-02|
   40|2.135256e-02|
   50|1.850824e-02|
   60|1.639575e-02|
   70|1.412680e-02|
   80|1.208283e-02|
   90|1.048092e-02|
  100|8.960642e-03|
  110|8.257241e-03|
  120|7.484365e-03|
  130|6.335158e-03|
  140|5.731197e-03|
  150|5.189093e-03|
  160|4.625923e-03|
  170|4.134005e-03|
  180|3.734240e-03|
  190|3.318411e-03|
It.  |Err         
-------------------
  200|2.939215e-03|
  210|2.642794e-03|
  220|2.402936e-03|
  230|2.212529e-03|
  240|2.020100e-03|
  250|1.803921e-03|
  260|1.651152e-03|
  270|1.526713e-03|
  280|1.421299e-03|
  290|1.330766e-03|
  300|1.241093e-03|
  310|1.144747e-03|
  320|1.055067e-03|
  330|9.538251e-04|
  340|8.607615e-04|
  350|7.864946e-04|
  360|7.214967e-04|
  370|6.626201e-04|
  380|6.111753e-04|
  390|5.668389e-04|
It.  |Err         
-------------------
  400|5.282641e-04|
  410|4.939959e-04|
  420|4.629665e-04|
  430|4.344442e-04|
  4



In [15]:
fname = "../discrete_transport_mapping/eps_1"
with open(fname, "wb") as f:
    np.save(f, mapping)

### $\epsilon = 10$

In [16]:
epsilon = 10

reg = epsilon
scale = 1 / (3 * IMG_SIZE * IMG_SIZE)
distance = M * scale
reg_normed = reg * scale
mapping = sinkhorn(
    a, b, distance, reg=reg_normed, warn=True, verbose=True, numItermax=100000
)

It.  |Err         
-------------------
    0|5.465605e-02|
   10|8.424768e-03|
   20|2.520885e-03|
   30|9.174680e-04|
   40|3.741655e-04|
   50|1.732377e-04|
   60|9.404735e-05|
   70|5.942719e-05|
   80|4.186366e-05|
   90|3.164121e-05|
  100|2.511126e-05|
  110|2.068348e-05|
  120|1.754765e-05|
  130|1.524408e-05|
  140|1.349517e-05|
  150|1.212693e-05|
  160|1.102743e-05|
  170|1.012298e-05|
  180|9.364013e-06|
  190|8.716412e-06|
It.  |Err         
-------------------
  200|8.156113e-06|
  210|7.665729e-06|
  220|7.232389e-06|
  230|6.846341e-06|
  240|6.500035e-06|
  250|6.187513e-06|
  260|5.903989e-06|
  270|5.645565e-06|
  280|5.409022e-06|
  290|5.191675e-06|
  300|4.991259e-06|
  310|4.805853e-06|
  320|4.633812e-06|
  330|4.473721e-06|
  340|4.324356e-06|
  350|4.184649e-06|
  360|4.053671e-06|
  370|3.930606e-06|
  380|3.814733e-06|
  390|3.705418e-06|
It.  |Err         
-------------------
  400|3.602097e-06|
  410|3.504271e-06|
  420|3.411493e-06|
  430|3.323364e-06|
  4

In [17]:
fname = "../discrete_transport_mapping/eps_10"
with open(fname, "wb") as f:
    np.save(f, mapping)