In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import glob
import shutil
import os
from tqdm import tqdm

# Loading Data In
Idea: fetching from drive every time is too slow. Instead, store all data in a zip in drive, fetch the whole zip and save all data in colab runtime. Now fetches go to colab SSD which is fast.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

drive_path = '/content/drive/MyDrive/Amortized Optimal Transport/Data/images.zip'
local_path = '/content/dataset'

if not os.path.exists(local_path):
  print("Getting data zip from google drive...")
  shutil.copy(drive_path, '/content/data.zip')
  print("Unzipping data locally...")
  !unzip -q /content/data.zip -d {local_path}
  print("Dataset successfully loaded in /content/dataset")
else:
  print("Dataset zip already loaded and unzipped locally")


Mounted at /content/drive
Getting data zip from google drive...
Unzipping data locally...
Dataset successfully loaded in /content/dataset


# Extracting color palette from input image

In [None]:
!pip install kornia

import cv2
import kornia

Collecting kornia
  Downloading kornia-0.8.2-py2.py3-none-any.whl.metadata (18 kB)
Collecting kornia_rs>=0.1.9 (from kornia)
  Downloading kornia_rs-0.1.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading kornia-0.8.2-py2.py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m57.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading kornia_rs-0.1.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m121.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kornia_rs, kornia
Successfully installed kornia-0.8.2 kornia_rs-0.1.10


In [None]:
#@title fast_kmeans(x, k, max_iters=100, tol=1e-3)
def fast_kmeans(x, k, max_iters=100, tol=1e-3):
    """
    GPU kmeans implementation.
    kmeans_pytorch has an issue where if a center has no members you get 0/0 = NaN center_shift
    Consequently had to make my own kmeans

    Args:
        x (Tensor): data of shape (N, D)
        k (int): number of clusters
    Returns:
        labels (Tensor): (N,) cluster assignments
        centers (Tensor): (k, D) cluster centers
    """
    # pick k random points to start in
    N, D = x.shape
    indices = torch.randperm(N, device=x.device)[:k]
    centers = x[indices].clone()

    labels = torch.zeros(N, dtype=torch.long, device=x.device)

    for i in range(max_iters):
        centers_old = centers.clone()

        #calculate distances
        dists = torch.cdist(x, centers)
        labels = torch.argmin(dists, dim=1)

        new_centers = torch.zeros(k, D, device=x.device)
        counts = torch.zeros(k, device=x.device)

        #sum coordinates of points in each cluster
        new_centers.index_add_(0, labels, x)

        # count points in each cluster
        ones = torch.ones(N, device=x.device)
        counts.index_add_(0, labels, ones)

        # Replace 0s with 1s to avoid 0/0 division
        # empty clusters will be discarded later
        counts_safe = torch.clamp(counts, min=1).unsqueeze(1)
        candidates = new_centers / counts_safe

        # If count > 0, use candidate. If count == 0, keep centers_old.
        valid_mask = (counts > 0).unsqueeze(1)
        centers = torch.where(valid_mask, candidates, centers_old)

        shift = torch.norm(centers - centers_old)
        if shift < tol:
            break

    return labels, centers

In [None]:
#@title get_palette(image_path, k=128, testing=False)
def get_palette(image_path, k=128, testing=False):
  '''
  Runs k-means clustering on the Lab-space colors of the input picture.
  RGB distance is not useful, Delta E distance in Lab space correctly models how differenently humans view colors.

  Want everything to run on GPU so we use kornia for rgb to lab conversion. kmeans_pytorch for kmeans

  Parameters:
    image_path: string path to image in local colab env. Should be /content/dataset/XXXXX.jpg
    k: number of clusters to make in k-means
    testing: if True, will imshow the image get_palette is called on

  Returns:
    Centroids: (k, 3) Lab coordinates of all k cluster means
    Weights: (k,) probability distribution weighing each cluster proportional to the number of elements in it
    Pixel labels: (h*w,) cluster membership per pixel
  '''
  device = 'cuda'

  image = cv2.imread(image_path) # a numpy array

  # Important note: bgr2rgb was used here rather than bgr2lab as opencv's
  #                 lab doesn't give me the right scaling that I need between
  #                 the L, a, b values. Has to do because they use 8 bit "quantization,"
  #                 but I didn't look too deep into what that meant
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

  image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1) # get to (3, height, width)
  image_tensor = image_tensor.unsqueeze(0).to(device) # (1, 3, height, width). kornia expects (batch size, num channels, height, width)
  image_tensor = image_tensor.float() / 255.0

  lab_image_tensor = kornia.color.rgb_to_lab(image_tensor) #bgr_to_lab does not exist

  if testing:
    print(type(image_rgb))
    print(image_rgb.shape)
    plt.imshow(image_rgb)
    plt.axis("off")
    plt.show()

    print(type(lab_image_tensor))
    print(lab_image_tensor.shape)
    # plot from cpu
    lab_vis = lab_image_tensor[0].permute(1,2,0).cpu().detach().numpy()
    plt.imshow(lab_vis)
    plt.axis("off")
    plt.show()

  pixels = lab_image_tensor[0].view(3, -1).permute(1, 0).contiguous() #-1 here implicitly combines height and width

  #kmeans only cares about color distance so we view lab_image_tensor to just be ((h*w), 3)
  cluster_ids_x, cluster_centers = fast_kmeans(x=pixels, k=k)

  counts = torch.bincount(cluster_ids_x, minlength=k)
  weights = counts / counts.sum()

  return cluster_centers.cpu(), weights.cpu(), cluster_ids_x.cpu()

In [None]:
# iterate through dataset, run get_palette on each image, save all kmeans palette outputs locally

def save_palettes_locally(k=128):
  palette_dir = '/content/palettes/'
  data_dir = '/content/dataset/'
  if not os.path.exists(palette_dir):
    os.makedirs(palette_dir)

  filepaths = sorted(os.listdir(data_dir))

  centroid_tensor = torch.empty(len(filepaths), k, 3)
  weights_tensor = torch.empty(len(filepaths), k)
  memberships_list = []
  for i, filepath in tqdm(enumerate(filepaths)):
    centroids, weights, memberships = get_palette(data_dir + filepath, k=k, testing=False)
    centroid_tensor[i] = centroids
    weights_tensor[i] = weights
    memberships_list.append(memberships)

  torch.save({
      'centroids': centroid_tensor,
      'weights': weights_tensor,
      'memberships': memberships_list,
      'filenames': filepaths
  }, os.path.join(palette_dir, "palette_bank.pt"))

def save_palettes_to_drive():
  palette_dir = '/content/palettes/'
  drive_dir = '/content/drive/MyDrive/Amortized Optimal Transport/Data'

  print("Zipping Palettes...")
  archive_path = shutil.make_archive(base_name=f'/content/palettes', format='zip', root_dir=palette_dir)
  if not os.path.exists(drive_dir):
    os.makedirs(drive_dir)

  print("Copying palettes to drive...")
  try:
    shutil.copy(archive_path, drive_dir)
    print("Done")
  except Exception as e:
    print(f"error: {e}")


In [None]:
save_palettes_locally(k=128)
save_palettes_to_drive()

2000it [14:24,  2.31it/s]


Zipping Palettes...
Copying palettes to drive...
Done
