# Variables

In [None]:
# local path to image library
library_path = ''

# local image folder that stores downloaded images
image_folder = ''

# Config module search path

In [None]:
import sys
from pathlib import Path

# Get my_package directory path from Notebook
parent_dir = str(Path().resolve().parents[1])

# Add to sys.path
if parent_dir not in sys.path:
  sys.path.insert(0, parent_dir)

from pprint import pprint
pprint(sys.path)

# Imports

In [None]:
import os
import sys
from pathlib import Path

import numpy as np
import polars as pl

# to support chinese
from matplotlib import rcParams
rcParams['font.family'] = 'sans-serif'
rcParams['font.sans-serif'] = ['Heiti TC']

# Build image index with CLIP ViT-L/14

model: https://huggingface.co/openai/clip-vit-large-patch14

Image index is constructed with four columns:
- Category: Name of the category to which the images belong
- Filepath: image file path relative to the `index.pq` file, which **_MUST BE_** placed under the root folder
- Cluster: cluster id if index is built with knn, the smaller the cluster id, the larger the number of images in the cluster.
  - Knn is performed within a single category.
- Embedding: image embedding encoded with ViT-L/14

Image index is built with Polars.

In [None]:
# preload ViT model

from imageretrieval.retrieval import ImageEmbeddings

image_encoder = ImageEmbeddings()

In [None]:
def build_index(root_path, persist=False):
  df = index_images(root_path)
  if persist:
    d = df.with_columns(df['filepath'].apply(lambda x: os.path.relpath(x, root_path)))
    d.write_parquet(os.path.join(root_path, 'index.pq'))
  return df


def load_index(root_path):
  df = pl.read_parquet(os.path.join(root_path, 'index.pq'))
  df = df.with_columns(df['filepath'].apply(lambda x: os.path.join(root_path, x)))
  return df


def clustering(embeddings, k=2) -> list[list[int]]:
  from sklearn.cluster import KMeans
  from sklearn.metrics.pairwise import cosine_similarity

  # Calculate cosine similarity between embeddings
  similarity_matrix = cosine_similarity(embeddings)

  # Perform K-Means clustering on the similarity matrix
  kmeans = KMeans(
    n_clusters=k,
    random_state=0,
    n_init=10
  ).fit(similarity_matrix)

  # Group the images based on their cluster assignments
  groups: list[list[int]] = [[] for _ in range(k)]
  for i, idx in enumerate(kmeans.labels_):
    groups[idx].append(i)

  return sorted(groups, key=lambda x: len(x), reverse=True)


def index_images(folder_path, seen=[], df=None, recursive=True, k=1):
  categories = []
  filepath = []
  cluster = []
  embedding = []

  # Only treat leaf directories as catetories
  leaf_dirs = [p for p in Path(folder_path).rglob('**') if not any(s != p for s in p.glob('**'))]
  for path in leaf_dirs:
    print(f'Processing {path}...')

    category = os.path.relpath(path, folder_path).replace('/', '.')
    if category == '.':
      category = os.path.basename(folder_path)
    if category in seen:
      base = category
      suf = 1
      while True:
        candidate = '-'.join([base, str(suf)])
        if candidate not in seen:
          category = candidate
          break
        suf += 1
    seen.append(category)

    image_files = [
      *path.glob('**/*.png'),
      *path.glob('**/*.jpg'),
      *path.glob('**/*.jpeg'),
    ]
    embeddings = image_encoder.encode(image_files)
    print('Shape:', embeddings.shape)

    groups = clustering(embeddings, k) if k > 1 else [range(len(embeddings))]
    for idx, group in enumerate(groups):
      for i in group:
        categories.append(category)
        filepath.append(str(image_files[i]))
        cluster.append(idx)
        embedding.append(embeddings[i])

  frame = pl.DataFrame({
    'category': categories,
    'filepath': filepath,
    'cluster': cluster,
    'embedding': embedding,
  })

  return frame if df is None else df.extend(frame)


def image_plot(title, similarities, images, count=8, w=50):
  from matplotlib import pyplot as plt
  from textwrap import wrap
  from PIL import Image

  col = 4
  row = (count + col - 1) // col
  fig = plt.figure(figsize=(16, row * 4), dpi=300)
  for i in range(min(count, len(images))):
    plt.subplot(row, col, i + 1)
    plt.imshow(Image.open(images[i]).convert('RGB'))
    plt.title(f'{os.path.basename(images[i])}\n{similarities[i]}', size=12)
  fig.suptitle('\n'.join(wrap(title, w)), size=16)


def retrieval(sentences: list[str], df):
  """
  Text-based image retrieval
  """

  text_embeddings = text_encoder.encode(sentences)

  def similarity(expr=True, top=8, visualize=True):
    sub = df.filter(expr)
    image_embeddings = np.array(list(sub['embedding']))

    scores = np.dot(text_embeddings, image_embeddings.T)
    indices = np.argsort(scores, axis=1)[:, ::-1]
#     print(f'Scores: {scores}')
#     print(f'Indices: {indices}')

    seen = []
    selected = []
    files = list(sub['filepath'])
    categories = list(sub['category'])
    for idx, text in enumerate(sentences):
      for i in indices[idx]:
        if i not in seen:
          selected.append((categories[i], files[i]))
          break
      if visualize:
        images = [files[i] for i in indices[idx]]
        similarities = [scores[idx][i] for i in indices[idx]]
        image_plot(text, similarities, images, count=top)
    return selected

  return similarity

## Load image index from library

`index.pq` **_MUST_** exist in the `library_path`

In [None]:
if not os.path.exists(library_path):
  raise RuntimeError(f'Path not found: {library_path}')

df = load_index(library_path)

print(f'Root directory: {library_path}')
print(f'Total images: {df.height}')
print(f'Categories: {list(df["category"].unique())}')
print(df)

## Add image from local directory

In [None]:
if not os.path.exists(image_folder):
  raise RuntimeError(f'Path not found: {image_folder}')

seen = list(df['category'].unique()) if df is not None else []
newdf = index_images(image_folder, seen=seen, k=2)
print(f'Total images: {newdf.height}')
print(f'Categories: {list(newdf["category"].unique())}')
print(newdf)

df = newdf if df is None else df.extend(newdf)
print(df)

# Text-based image retrieval

https://huggingface.co/M-CLIP/XLM-Roberta-Large-Vit-L-14

In [None]:
# preload multilingual model

from imageretrieval.retrieval import TextEmbeddings

text_encoder = TextEmbeddings()

In [None]:
sentences = [
  '这145平米的样板间采用10多种色彩，创造了和谐、高级的空间',
  '客餐厅采用无吊顶设计，左右清晰分区',
  '墙面、顶面和门框采用黑色金属线条，营造简约现代感',
  '餐椅和抱枕的脏橘色贯穿客餐厅，避免了视觉上的割裂感',
  '选择雅琪诺悦动风华系列的窗帘，营造出温馨的氛围',
  '主卧和次卧都采用了不同的装饰元素，呈现出不同的风格，相应地丰富了整个空间'
]
similarity = retrieval(sentences, df)

In [None]:
similarity(pl.col('cluster')==0, top=8, visualize=True)

# Misc

## Build and save image index

In [None]:
# fill in your local path
root_path = ''

if not os.path.exists(root_path):
  raise RuntimeError(f'Path not found: {root_path}')

build_index(root_path, persist=True)

## List all available fonts

In [None]:
from matplotlib.font_manager import FontManager
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)

from pprint import pprint
pprint(mat_fonts)

In [None]:
d1 = pl.DataFrame({
  'a': ['str1', 'str2'],
  'b': ['str3', 'str4'],
})
d2 = d1.with_columns(d1['a'].apply(lambda x: os.path.join('aa', x)))
d1.extend(d2)