Here we run sinkhorn on the palettes stored in drive to create our dataset for the neural network. This is basically labeling.

In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import glob
import shutil
import os
from tqdm import tqdm

In [None]:
# download palettes onto colab locally

from google.colab import drive
drive.mount('/content/drive')

palette_path = '/content/drive/MyDrive/Amortized Optimal Transport/Data/palettes.zip'
local_path = '/content/palettes'

if not os.path.exists(local_path):
  print("Copying palettes zip...")
  shutil.copy(palette_path, '/content/palettes.zip')
  print("Unzipping palettes zip...")
  !unzip -q /content/palettes.zip -d {local_path}
  print("Done")
else:
  print("Palettes already stored locally")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Palettes already stored locally


# Naive Sinkhorn Algorithm Implementation

First custom implementation of sinkhorn was computed directly. Led to a lot of underflow issues (understandably) so I worked in log space in my final impl.

In [None]:
def sinkhorn_naive_torch(a, b, cost_matrix, epsilon, max_iters=1000):
  '''
  Returns the optimal transport matrix of the sinkhorn algorithm.
  Implemented naively without regard for underflow or other optimiziations.

  Parameters:
    a: left/source probability vector of shape (n,)
    b: right/target probability vector of shape (m,)
    cost_matrix: cost matrix in euclidian space of shape (n,m)
    epsilon: smoothing parameter for the gibbs kernel

  Returns:
    (n,m) matrix depicting the optimal transport matrix from a to b
  '''
  gibbs_kernel = torch.exp(-cost_matrix/epsilon)
  v = torch.ones_like(b)
  past_v = v

  for i in range(max_iters):
    u = a / (gibbs_kernel @ v)
    v = b / (gibbs_kernel.T @ u)
    if i % 10 == 0 and torch.allclose(v, past_v):
      break
    past_v = v

  return u[:, None] * gibbs_kernel * v[None, :]

# Final Sinkhorn Implementation
Note this is not identical to what is used in colab "4. Training/Testing NN"
That implementation allows for warm starting sinkhorn with the Neural Net guess.

In [None]:
def sinkhorn(a, b, cost_matrix, epsilon, max_iters):
  '''
  My sinkhorn implementation.
  Uses log probabilities to avoid underflow.

  Parameters:
    a: (n,) left/source probability vector
    b: (m,) right/target probability vector
    cost_matrix: (n,m) cost matrix in euclidean space
    epsilon: smoothing parameter for the gibbs kernel

  Returns:
    (n,m) matrix depicting the optimal transport matrix from a to b
  '''
  M = cost_matrix # just renaming
  log_v = torch.zeros_like(b)
  past_log_v = log_v
  log_a = torch.log(a)
  log_b = torch.log(b)
  log_eps = torch.log(torch.tensor(epsilon))

  for i in range(max_iters):
    log_u = log_a - torch.logsumexp(-M/epsilon + log_v, 1)
    log_v = log_b - torch.logsumexp(-M.T/epsilon + log_u, 1)

    if torch.allclose(log_v, past_log_v, atol=1e-4, rtol=0):
      break

    past_log_v = log_v

  return torch.exp(log_u[:, None] - M / epsilon + log_v[None, :])

In [None]:
#@title sinkhorn_batch(a, b, cost_matrix, epsilon, max_iters)
def sinkhorn_batch(a, b, cost_matrix, epsilon=.1, max_iters=1000):
  '''
  My sinkhorn implementation (handles batches)
  Uses log probabilities to avoid underflow.

  Parameters:
    a: (batch, n) left/source probability vector
    b: (batch, m) right/target probability vector
    cost_matrix: (batch, n, m) cost matrix in euclidean space
    epsilon: smoothing parameter for the gibbs kernel

  Returns:
    (batch, n, m) matrix depicting the optimal transport matrix from a to b
  '''
  M = cost_matrix # just renaming
  logK = -M/epsilon
  logK_T = -M.permute(0,2,1)/epsilon

  log_v = torch.zeros_like(b)
  past_log_v = log_v
  log_a = torch.log(a)
  log_b = torch.log(b)
  log_eps = torch.log(torch.tensor(epsilon))

  for i in range(max_iters):
    log_u = log_a - torch.logsumexp(logK + log_v.unsqueeze(1), 2)
    log_v = log_b - torch.logsumexp(logK_T + log_u.unsqueeze(1), 2)

    if torch.allclose(log_v, past_log_v, atol=1e-4, rtol=0):
      break

    past_log_v = log_v

  return torch.exp(log_u[:, :, None] - M / epsilon + log_v[:, None, :])

# Creating NN Training Data

In [None]:
# load tensors from local storage

palettes_dict = torch.load('/content/palettes/palette_bank.pt', map_location='cpu')

centroids_tensor = palettes_dict['centroids'].cuda()
weights_tensor = palettes_dict['weights'].cuda()
# memberships_list = palettes_dict['memberships'] # unneeded here
filenames_list = palettes_dict['filenames']

print(centroids_tensor.shape)
print(centroids_tensor)
print(weights_tensor.shape)
print(weights_tensor)
print(filenames_list[0:10])
print(filenames_list[-1:-11:-1])

torch.Size([2000, 128, 3])
tensor([[[ 2.4221e+01, -1.3307e+01,  3.1794e+01],
         [ 5.8568e+01, -2.7090e+01, -6.0873e-02],
         [ 5.8039e+01, -1.7412e+01,  5.5339e+01],
         ...,
         [ 5.1207e+01, -1.9094e+01,  5.4930e+01],
         [ 5.9679e+01, -1.8525e+01,  6.1443e+01],
         [ 5.7916e+01, -2.2764e+00,  1.7729e+01]],

        [[ 1.5576e+01, -7.8221e+00,  5.9351e+00],
         [ 9.6874e+01,  9.7466e-01,  4.6768e+00],
         [ 9.9080e+01, -1.8537e+00,  7.0578e+00],
         ...,
         [ 9.3113e+01,  2.9814e+00,  9.6789e+00],
         [ 3.4186e+01, -7.3259e+00,  5.9161e+00],
         [ 9.7258e+01,  1.3090e+00,  3.7129e+00]],

        [[ 4.7509e+01,  2.4322e+01,  1.6912e+01],
         [ 4.7459e+01,  3.1948e+01,  2.2976e+01],
         [ 2.6430e+01,  7.6339e+00, -4.6345e+00],
         ...,
         [ 7.6952e+01,  1.2032e+01, -2.4740e+00],
         [ 4.1159e+01,  2.5760e+01,  1.9359e+01],
         [ 1.7052e+01, -7.7285e+00, -3.4121e+00]],

        ...,

        [[ 

In [None]:
#@title get_transport_matrices(centroids, weights, total_samples=100_000, batch_size = 32768, epsilon=.1, max_iters=10000)
# should run with high RAM on
def get_transport_matrices(centroids, weights, total_samples=100_000, batch_size = 32768, epsilon=.1, max_iters=10000):
  '''
  Generates the P transport matrices using batched sinkhorn
  Parameters:
    centroids: shape (batch, k, 3)
    weights: shape (batch, k)
    total_samples: number of samples to generate
    batch_size: size of each batch to send to gpu
    epsilon and max_iters: params for sinkhorn
  Returns:
    X_src_indices: (total_samples,) Indices used from centroids and weights
    X_tgt_indices: (total_samples,) Indices used from centroids and weights
    y_transport_matrices: (total_samples, k, k)
  '''
  num_batches = (total_samples + batch_size - 1) // batch_size

  X_src_indices = []
  X_tgt_indices = []
  y_transport_matrices = []

  for i in tqdm(range(num_batches)):
    curr_batch_size = min(batch_size, total_samples - i * batch_size)
    source_indices = torch.randint(low=0, high=centroids.shape[0], size=(curr_batch_size,))
    target_indices = torch.randint(low=0, high=centroids.shape[0], size=(curr_batch_size,))

    centroids_batch_a = centroids[source_indices]
    centroids_batch_b = centroids[target_indices]
    weights_batch_a = weights[source_indices]
    weights_batch_b = weights[target_indices]

    batch_cost_matrix = torch.cdist(centroids_batch_a, centroids_batch_b) # shape (batch, k, k)

    # get transport matrix batch
    with torch.no_grad():
      batch_P = sinkhorn_batch(weights_batch_a, weights_batch_b, batch_cost_matrix, epsilon=epsilon, max_iters=max_iters)
    X_src_indices.append(source_indices.cpu())
    X_tgt_indices.append(target_indices.cpu())
    y_transport_matrices.append(batch_P.cpu())

  X_src_indices = torch.cat(X_src_indices) # cast along dim=0
  X_tgt_indices = torch.cat(X_tgt_indices)
  y_transport_matrices = torch.cat(y_transport_matrices)

  return X_src_indices, X_tgt_indices, y_transport_matrices

In [None]:
def save_sinkhorn_results_locally():
  # to save space we will save the indices of the source/target centroids/weights
  output_dir = '/content/sinkhorn_inputs_matrices/'
  os.makedirs(output_dir, exist_ok=True)

  X_src_ind, X_tgt_ind, y_trans_mats = get_transport_matrices(centroids_tensor, weights_tensor)
  torch.save({
    'X_src_indices': X_src_ind,
    'X_tgt_indices': X_tgt_ind,
    'y_transport_matrices': y_trans_mats
  }, os.path.join(output_dir, 'sinkhorn_bank.pt'))


def save_sinkhorn_results_to_drive():
  drive_dir = '/content/drive/MyDrive/Amortized Optimal Transport/Data'

  print("Zipping Sinkhorn results...")
  archive_path = shutil.make_archive(base_name=f'/content/sinkhorn_bank', format='zip', root_dir='/content/sinkhorn_inputs_matrices/')
  print("Copying sinkhorn results to drive...")
  try:
    shutil.copy(archive_path, drive_dir)
    print("Done")
  except Exception as e:
    print(f"error: {e}")

In [None]:
save_sinkhorn_results_locally()

100%|██████████| 4/4 [1:50:28<00:00, 1657.03s/it]


In [None]:
save_sinkhorn_results_to_drive()

Zipping Sinkhorn results...
Copying sinkhorn results to drive...
Done
