This notebook will generate csv file containing simmilarity of colors between all final submissions and their potential source of inspiration. THe simmilarity will be computed as a weighetd mean of wasserstein distances of hue and saturation histograms with repsecitve weigth 0.66 and 0.33

### Load images

In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import PIL
import os
import random
import pandas as pd
import re
import time

from scipy.stats import wasserstein_distance
from typing import Literal


def imshow(img):
    img = img.astype(np.uint8)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    display(PIL.Image.fromarray(img).convert("RGB"))


def imshow_on_axis(img, ax, title):
    img = img.astype(np.uint8)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    ax.imshow(img)
    ax.set_title(title)
    ax.axis("off")


def load_images_from_path(path, number):
    images = []
    pattern = rf"^{number}(?!\d)"

    for filename in os.listdir(path):
        if re.match(pattern, filename):
            img = cv2.imread(os.path.join(path, filename))
            if img is not None:
                images.append((img, filename))

    return images


def load_images(number):
    final_images = []

    final_path = f"data/final_submissions/{number}/"
    web_path = "data/web/"
    ai_path = "data/ai/"

    if not os.path.exists(final_path):
        print(f"The final submissions path '{final_path}' does not exist.")
        return None, None, None

    for filename in os.listdir(final_path):
        img = cv2.imread(os.path.join(final_path, filename))
        if img is not None:
            final_images.append((img, f"{number}_{filename}"))
    web_images = load_images_from_path(web_path, number)
    ai_images = load_images_from_path(ai_path, number)

    if not final_images:
        print(f"No images found in '{final_path}'. Please check the contents.")
        return None, None, None

    if not web_images and not ai_images:
        print(
            f"The number '{number}' does not correspond to any valid images in 'web' or 'ai' folders."
        )
        return None, None, None

    if not web_images:
        print(f"No web images found with prefix '{number}' in '{web_path}'.")
        return None, None, None
    if not ai_images:
        print(f"No AI images found with prefix '{number}' in '{ai_path}'.")
        return None, None, None

    return final_images, web_images, ai_images

### Color histograms

In [2]:
def get_hue_histogram(image, num_buckets):
    image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    hist = cv2.calcHist([image_hsv], [0], None, [num_buckets], [0, 180])

    hist = hist / np.sum(hist)

    return hist


def get_saturation_histogram(image, num_buckets):
    image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    hist = cv2.calcHist(
        [image_hsv], [1], None, [num_buckets], [0, 256]
    )  # Saturation ranges from 0 to 255

    # Normalize the histogram
    hist = hist / np.sum(hist)

    return hist


def get_hue_saturation_histogram(image, num_buckets):
    return (
        get_hue_histogram(image, num_buckets),
        get_saturation_histogram(image, num_buckets),
    )


def hue_to_rgb(hue):
    hsv_color = np.uint8([[[hue, 255, 255]]])
    rgb_color = cv2.cvtColor(hsv_color, cv2.COLOR_HSV2RGB)[0][0]
    return tuple(rgb_color / 255.0)


def plot_hue_saturation_histogram(
    hist_tuple, axes, row, col, bucket_size, plt_title="Hue and Saturation Histogram"
):
    hue_hist, saturation_hist = hist_tuple
    axes[row, col].clear()  # Clear the axis for fresh plotting
    axes[row, col].set_title(plt_title)
    axes[row, col].set_xlabel("Bins")
    axes[row, col].set_ylabel("Frequency")

    # Define bin widths
    hue_bin_width = 180 / bucket_size
    saturation_bin_width = 256 / bucket_size

    # Plot Hue histogram
    for i in range(bucket_size):
        avg_hue = i * hue_bin_width + hue_bin_width / 2
        color = hue_to_rgb(avg_hue)
        axes[row, col].bar(
            i - 0.2,
            hue_hist[i][0],
            color=color,
            width=0.4,
            label="Hue" if i == 0 else "",
        )

    # Plot Saturation histogram
    for i in range(bucket_size):
        avg_saturation = i * saturation_bin_width + saturation_bin_width / 2
        color = (
            avg_saturation / 255,
            avg_saturation / 255,
            avg_saturation / 255,
            0.5,
        )  # grayscale with transparency
        axes[row, col].bar(
            i + 0.2,
            saturation_hist[i][0],
            color=color,
            width=0.4,
            label="Saturation" if i == 0 else "",
        )

    # Set axis limits and grid
    axes[row, col].set_xlim([0, bucket_size - 1])
    axes[row, col].grid(axis="y", linestyle="--")

### Metrics for comparing hists

In [3]:
def get_wasserstein_distance(hist1, hist2):
    hist1_flat = hist1.flatten()
    hist2_flat = hist2.flatten()
    return wasserstein_distance(
        u_values=range(len(hist1_flat)),  # positions in the histogram
        v_values=range(len(hist2_flat)),
        u_weights=hist1_flat,  # weights for each position
        v_weights=hist2_flat,
    )


def get_norm_wasserstein_distance(hist1, hist2, bin_size=64):
    distance = get_wasserstein_distance(hist1, hist2)
    return distance / bin_size


def get_norm_hue_sat_wassertstein_distance(hist1, hist2, bin_size=64):
    distance_hue, distance_saturation = get_norm_wasserstein_distance(
        hist1[0], hist2[0], bin_size
    ), get_norm_wasserstein_distance(hist1[1], hist2[1], bin_size)

    return 1 - (
        0.66 * distance_hue + 0.33 * distance_saturation
    )  # Change into gain function

### Creating csv file

In [4]:
def add_colors_to_csv(group_id, bucket_size, results):
    final_images, web_images, ai_images = load_images(group_id)

    final_hists = [
        get_hue_saturation_histogram(image, bucket_size) for image, _ in final_images
    ]
    web_hists = [
        get_hue_saturation_histogram(image, bucket_size) for image, _ in web_images
    ]
    ai_hists = [
        get_hue_saturation_histogram(image, bucket_size) for image, _ in ai_images
    ]

    for i in range((len(final_hists))):

        dist_web = [
            [
                get_norm_hue_sat_wassertstein_distance(
                    final_hists[i], web_hists[j], bucket_size
                ),
                web_hists[j],
                web_images[j],
            ]
            for j in range(len(web_hists))
        ]
        dist_ai = [
            [
                get_norm_hue_sat_wassertstein_distance(
                    final_hists[i], ai_hists[j], bucket_size
                ),
                ai_hists[j],
                ai_images[j],
            ]
            for j in range(len(ai_hists))
        ]

        final_image = final_images[i]

        for j in range(len(dist_web)):

            similarity, hist, img_src = dist_web[j]
            img = img_src[0]
            src = img_src[1]

            row = pd.DataFrame(
                [
                    {
                        "Final_Submission": final_images[i][1],
                        "Inspiration": src,
                        "Similarity": similarity,
                    }
                ]
            )
            results = pd.concat([results, row], ignore_index=True)

        for j in range(len(dist_ai)):

            similarity, hist, img_src = dist_ai[j]
            img = img_src[0]
            src = img_src[1]

            row = pd.DataFrame(
                [
                    {
                        "Final_Submission": final_images[i][1],
                        "Inspiration": src,
                        "Similarity": similarity,
                    }
                ]
            )
            results = pd.concat([results, row], ignore_index=True)
        
    return results

The group 26 has no web images

In [5]:
columns = ["Final_Submission", "Inspiration", "Similarity"]
results_df = pd.DataFrame(columns=columns)

for i in range (1, 26):
    print(f"Calculating group {i}/27", end="\r", flush=True)
    results_df = add_colors_to_csv(i, 64, results_df)
results_df = add_colors_to_csv(27, 64, results_df)

results_df.to_csv("color_similarity.csv", index=False)
results_df

Calculating group 2/27

  results = pd.concat([results, row], ignore_index=True)


Calculating group 25/27

Unnamed: 0,Final_Submission,Inspiration,Similarity
0,1_1.png,1A_1.jpg,0.622241
1,1_1.png,1A_10.jpg,0.886676
2,1_1.png,1A_2.jpg,0.800207
3,1_1.png,1A_3.jpg,0.775083
4,1_1.png,1A_4.jpg,0.815200
...,...,...,...
8832,27_9.jpg,27_16.jpg,0.893497
8833,27_9.jpg,27_17.jpg,0.759103
8834,27_9.jpg,27_18.jpg,0.832084
8835,27_9.jpg,27_19.jpg,0.894955
