## Image Embeddings

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np

from os import listdir, makedirs, path
from random import shuffle

from sklearn.cluster import KMeans, HDBSCAN
from sklearn.metrics import silhouette_score
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from PIL import Image as PImage, ImageOps as PImageOps

from parameters.arquigrafia import IMAGES_PATH
from parameters.embeddings import EMBEDDINGS_PATH

from models.embedding_models import Clip, EfficientNet, ResNet, Vit

makedirs(EMBEDDINGS_PATH, exist_ok=True)

### Run embeddings

In [None]:
input_files = sorted([f for f in listdir(IMAGES_PATH) if f.endswith("jpg")])

for idx, io_file in list(enumerate(input_files))[:]:
  input_file_path = path.join(IMAGES_PATH, io_file)
  output_file_path = path.join(EMBEDDINGS_PATH, io_file.replace(".jpg", ".json"))

  if path.isfile(output_file_path):
    continue

  if idx % 100 == 0:
    print(idx, IMAGES_PATH, io_file)

  image = PImageOps.exif_transpose(PImage.open(input_file_path).convert("RGB"))

  raw_embs = {}

  raw_embs["clip"] = Clip.get_embedding(image).tolist()
  raw_embs["efficient"] = EfficientNet.get_embedding(image).tolist()
  raw_embs["resnet"] = ResNet.get_embedding(image).tolist()
  raw_embs["vit"] = Vit.get_embedding(image).tolist()

  image_embs = {"raw": raw_embs}

  with open(output_file_path, "w", encoding="utf-8") as of:
    json.dump(image_embs, of, sort_keys=True, separators=(',',':'), ensure_ascii=False)

### Reduce Dims

In [None]:
input_files = sorted([f for f in listdir(IMAGES_PATH) if f.endswith("jpg")])

raw_embs = {}

for idx, io_file in enumerate(input_files):
  output_file_path = path.join(EMBEDDINGS_PATH, io_file.replace(".jpg", ".json"))

  if not path.isfile(output_file_path):
    continue

  if idx % 100 == 0:
    print(idx, IMAGES_PATH, io_file)

  with open(output_file_path, "r", encoding="utf8") as f:
    image_embs = json.load(f)

  for k,e in image_embs["raw"].items():
    if k not in raw_embs:
      raw_embs[k] = []
    raw_embs[k].append(e)

In [None]:
def pca_kmeans(emb_raw, n_clusters=8, n_components=4):
  mStandard = StandardScaler()
  mPCA = PCA(n_components=n_components)
  mCluster = KMeans(n_clusters=n_clusters)

  emb_std = mStandard.fit_transform(emb_raw)
  emb_pca = mPCA.fit_transform(emb_std)
  emb_clusters = mCluster.fit_predict(emb_std)

  return emb_std, emb_pca, emb_clusters, mCluster.inertia_

In [None]:
def plot_pca_clusters(pcas, clusters, title=""):
  for i in range(3):
    for j in range(i+1,3):
      plt.scatter(pcas[:,i], pcas[:,j], c=clusters, marker='o', linestyle='', alpha=0.5)
      plt.title(title)
      plt.show()

  # 3D
  fig = plt.figure(figsize=(8, 8))
  ax = fig.add_subplot(projection='3d')
  ax.scatter(pcas[:,0], pcas[:,1], pcas[:,2], c=clusters, marker='o', linestyle='', alpha=0.25)
  ax.set_title(title)
  plt.show()

In [None]:
def plot_elbows(embs, title, min_clusters=2, max_clusters=10):
  wss = []
  sil = []
  ncs = range(min_clusters, max_clusters+1)
  for nc in ncs:
    _std, _, _clusters, _wss = pca_kmeans(embs, n_clusters=nc)
    _sil = silhouette_score(_std, _clusters)
    wss.append(_wss)
    sil.append(_sil)

  plt.plot(ncs, wss)
  plt.title(f"{title} - Within-Cluster Sum of Squares")
  plt.show()

  plt.plot(ncs, sil)
  plt.title(f"{title} - Silhouette Score")
  plt.show()

In [None]:
for m in ["clip", "efficient", "resnet", "vit"]:
  plot_elbows(raw_embs[m], m)

In [None]:
for m in ["clip", "efficient", "resnet", "vit"]:
  _, _pca, _clusters, _ = pca_kmeans(raw_embs[m])
  plot_pca_clusters(_pca, _clusters, title=m)

In [None]:
clip_std, clip_pca, clip_clusters, clip_wss = pca_kmeans(raw_embs["clip"])

In [None]:
efficient_std, efficient_pca, efficient_clusters, efficient_wss = pca_kmeans(raw_embs["efficient"])

In [None]:
resnet_std, resnet_pca, resnet_clusters, resnet_wss = pca_kmeans(raw_embs["resnet"])

In [None]:
vit_std, vit_pca, vit_clusters, vit_wss = pca_kmeans(raw_embs["vit"])

In [None]:
input_files = sorted([f for f in listdir(IMAGES_PATH) if f.endswith("jpg")])

In [None]:
m_clusters = clip_clusters

for c in np.unique(m_clusters):
  this_cluster = np.where(m_clusters == c)[0]
  shuffle(this_cluster)
  fig, axes = plt.subplots(nrows=8, ncols=8)
  fig.set_size_inches(10, 10)
  fig.set_dpi(72)

  fig.suptitle(f"Cluster {c}")
  for ciidx, ax in enumerate(axes.flat):
    iidx = this_cluster[ciidx]
    img = PImage.open(path.join(IMAGES_PATH, input_files[iidx])).convert("RGB")
    img = PImageOps.exif_transpose(img).resize((128,128))
    ax.imshow(img)
    ax.axis("off")

  plt.tight_layout()
  plt.show()

In [None]:
# Kmeans
# Efficient: BAD
# Clip: 4 bad, 6 ok, 8 clusters good
# ResNet: 4 ok, 6 bad, 8 bad
# Vit: 4 bad, 6 ok, 8 good

In [None]:
# TODO: PCA
# TODO: write output

  # with open(output_file_path, "w", encoding="utf-8") as of:
  #   json.dump(image_embs, of, sort_keys=True, separators=(',',':'), ensure_ascii=False)

## Test Embeddings

In [None]:
imgs = [
  PImage.open(IMAGES_PATH+"/100.jpg"),
  PImage.open(IMAGES_PATH+"/101.jpg"),
  PImage.open(IMAGES_PATH+"/1000.jpg"),
  PImage.open(IMAGES_PATH+"/1001.jpg"),
  PImage.open(IMAGES_PATH+"/1010.jpg"),
  PImage.open(IMAGES_PATH+"/1011.jpg"),
]

img = imgs[0]

for i in imgs:
  display(i.resize((128,128)))

In [None]:
emb = ResNet.get_embedding(imgs)
print(emb.shape)

emb_diff = (emb - emb[1]).pow(2).sum(dim=1).pow(0.5)
emb_diff.argsort()

In [None]:
emb = ResNet.get_embedding(img)
print(emb.shape)

In [None]:
emb = EfficientNet.get_embedding(imgs)
print(emb.shape)

emb_diff = (emb - emb[1]).pow(2).sum(dim=1).pow(0.5)
emb_diff.argsort()

In [None]:
emb = EfficientNet.get_embedding(img)
print(emb.shape)

In [None]:
emb = Vit.get_embedding(imgs)
print(emb.shape)

emb_diff = (emb - emb[1]).pow(2).sum(dim=1).pow(0.5)
emb_diff.argsort()

In [None]:
emb = Vit.get_embedding(img)
print(emb.shape)

In [None]:
emb = Clip.get_embedding(imgs)
print(emb.shape)

emb_diff = (emb - emb[1]).pow(2).sum(dim=1).pow(0.5)
emb_diff.argsort()

In [None]:
emb = Clip.get_embedding(img)
print(emb.shape)