## Import modules

In [1]:
import os
import torch
import pandas as pd
from PIL import Image
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from random import random, randrange
from torchvision import transforms,models
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
from copy import deepcopy
import matplotlib.image as mpimg

from color_analysis import ColorAnalysis

  from .autonotebook import tqdm as notebook_tqdm


## Import fashion clip model

In [2]:
model_name = "patrickjohncyh/fashion-clip"
image_processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name)

## Creating images lists

In [3]:
CLEAN_IMAGES_CSV_PATH = os.path.join("../../data","DAM")
reference_images = [os.path.join(CLEAN_IMAGES_CSV_PATH, img)for img in os.listdir(CLEAN_IMAGES_CSV_PATH)]

TEST_IMAGES_PATH = os.path.join("../../data","test_image_headmind")
img_list = [os.path.join(TEST_IMAGES_PATH,i) for i in os.listdir(TEST_IMAGES_PATH)]

## Creating Answers Directory

In [4]:
csv_all = pd.read_csv("results17-32.csv",header=None, sep=", ")
csv_arthur = pd.read_csv("answer65-80.csv", header=None, sep=", ")
csv_all = pd.concat([pd.read_csv("answers1-16.csv", header=None, sep= ", "),csv_all, pd.read_csv("results33-48.csv", header=None, sep=", "), pd.read_csv("results49-64.csv", header=None, sep=", "), csv_arthur])
csv_all.head()
csv_all.columns = ['0','1']

  csv_all = pd.read_csv("results17-32.csv",header=None, sep=", ")
  csv_arthur = pd.read_csv("answer65-80.csv", header=None, sep=", ")
  csv_all = pd.concat([pd.read_csv("answers1-16.csv", header=None, sep= ", "),csv_all, pd.read_csv("results33-48.csv", header=None, sep=", "), pd.read_csv("results49-64.csv", header=None, sep=", "), csv_arthur])
  csv_all = pd.concat([pd.read_csv("answers1-16.csv", header=None, sep= ", "),csv_all, pd.read_csv("results33-48.csv", header=None, sep=", "), pd.read_csv("results49-64.csv", header=None, sep=", "), csv_arthur])
  csv_all = pd.concat([pd.read_csv("answers1-16.csv", header=None, sep= ", "),csv_all, pd.read_csv("results33-48.csv", header=None, sep=", "), pd.read_csv("results49-64.csv", header=None, sep=", "), csv_arthur])


In [5]:
answers = {}
for _, row in csv_all.iterrows():
    if row['1'] in answers.keys():
        answers[row['1']].append(row['0'])
    else:
        answers[row['1']] = [row['0']]

## Creating embeddings for the referrence images

In [6]:
reference_features = []
for img in tqdm(reference_images):
    image = Image.open(img).convert("RGB")
    inputs = image_processor(images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        embedding = model.get_image_features(**inputs)
        embedding = embedding / embedding.norm(p=2, dim=-1)
        reference_features.append(embedding)

100%|██████████| 2766/2766 [03:04<00:00, 14.98it/s]


In [7]:
client_features = []
for img in tqdm(img_list):
    image = Image.open(img).convert("RGB")
    inputs = image_processor(images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        embedding = model.get_image_features(**inputs)
        embedding = embedding / embedding.norm(p=2, dim=-1)
        client_features.append(embedding)

100%|██████████| 80/80 [00:14<00:00,  5.62it/s]


In [8]:
for i in range(len(reference_features)):
    reference_features[i] = reference_features[i].flatten()

for i in range(len(client_features)):
    client_features[i] = client_features[i].flatten()

In [9]:
client_embedding_tensor = torch.stack(client_features).squeeze(1)
embedding_tensor = torch.stack(reference_features).squeeze(1)

cosine_similarities = torch.mm(client_embedding_tensor, embedding_tensor.t())

# closest_indices = torch.argsort(cosine_similarities, dim=1,descending=True)[:,:11]
# closest_indices

### Color code

In [10]:
####################################################
# -- HERE: Incorporate the ColorAnalysis “buff” ----
####################################################

# 1) Load the precomputed color info for reference images from CSV
#    The CSV is assumed to be created via save_color_info_to_csv(...).
#
#    Columns:
#    [filename, L1, a1, b1, L2, a2, b2, L3, a3, b3]
#
color_database_csv = "./color_database.csv"  # Adjust if needed
df_ref = pd.read_csv(color_database_csv)

In [11]:
# Transform CSV rows into a dictionary:
#   reference_color_dict["some_image.jpg"] = np.array([[L1,a1,b1],[L2,a2,b2],[L3,a3,b3]], dtype=np.float32)
reference_color_dict = {}
for _, row in df_ref.iterrows():
    filename = row["filename"]
    colors_flat = [
        [row["L1"], row["a1"], row["b1"]],
        [row["L2"], row["a2"], row["b2"]],
        [row["L3"], row["a3"], row["b3"]]
    ]
    reference_color_dict[filename] = np.array(colors_flat, dtype=np.float32)

# Map the full path of reference_images to their Lab color arrays
reference_colors = {}
for ref_path in reference_images:
    ref_filename = os.path.basename(ref_path)  # just the file name
    if ref_filename in reference_color_dict:
        reference_colors[ref_path] = reference_color_dict[ref_filename]
    else:
        reference_colors[ref_path] = None

In [12]:
# Read client color from csv (if exists) or extract them
color_test_csv = "color_test.csv"
client_colors = {}  # Reset or define anew
color_analyzer = ColorAnalysis()

if os.path.exists(color_test_csv):
    print("Loading client color data from color_test.csv ...")
    df_test = pd.read_csv(color_test_csv)
    for _, row in df_test.iterrows():
        filename = row["filename"]
        color_sets = []
        for i in range(1, 4):  # 3 colors
            color_sets.append([row[f"L{i}"], row[f"a{i}"], row[f"b{i}"]])
        # Convert to numpy array (Lab values)
        img_lab = np.array(color_sets, dtype=np.float32)

        # Reconstruct full path if needed, or store by filename only
        img_path = os.path.join(TEST_IMAGES_PATH, filename)
        client_colors[img_path] = img_lab
else:
    print("No color_test.csv found, computing colors on the fly...")
    client_colors = {}
    for img_path in tqdm(img_list, desc="Extracting Colors for Client Images"):
        colors_lab = color_analyzer.extract_object_colors(img_path, num_colors=3)
        client_colors[img_path] = colors_lab

Loading client color data from color_test.csv ...


In [13]:
if not os.path.exists(color_test_csv):
    # AFTER you've computed client_colors, save them to a CSV named "color_test.csv"

    # Prepare rows for CSV: [filename, L1, a1, b1, L2, a2, b2, L3, a3, b3]
    rows = []
    for img_path, colors_lab in client_colors.items():
        if colors_lab is not None:
            flattened = colors_lab.flatten().tolist()
            filename = os.path.basename(img_path)
            rows.append([filename] + flattened)

    # Define columns
    columns = ["filename"]
    for i in range(1, 4):  # num_colors = 3
        columns += [f"L{i}", f"a{i}", f"b{i}"]

    # Save CSV
    df_client = pd.DataFrame(rows, columns=columns)
    df_client.to_csv("color_test.csv", index=False)
    print("Client color info saved to color_test.csv")

In [17]:
### Fine-tuning the color parameters (search to find best ones)
import itertools

best_threshold = None
best_boost = None
best_accuracy = 0.0

# Some example ranges to iterate over (customize as you wish)
threshold_range = [80]
boost_range = [0.03, 0.02] 

for color_threshold, color_boost in itertools.product(threshold_range, boost_range):
    # Make a fresh copy of the original cosines
    final_sims = cosine_similarities.clone()

    num_clients = final_sims.shape[0]
    num_refs = final_sims.shape[1]

    for i in range(num_clients):
        client_img_path = img_list[i]
        client_color_set = client_colors[client_img_path]
        if client_color_set is None:
            continue  # no color info, skip

        for j in range(num_refs):
            ref_img_path = reference_images[j]
            ref_color_set = reference_colors[ref_img_path]
            if ref_color_set is None:
                continue

            # Compute Lab distance
            dist = color_analyzer.color_distance(client_color_set, ref_color_set)
            if dist < color_threshold:
                ratio = dist / color_threshold   # 0..1
                bonus = color_boost * (1.0 - ratio)
                #bonus = color_boost
                final_sims[i, j] += bonus

    # Re-rank
    closest_indices = torch.argsort(final_sims, dim=1, descending=True)[:, :11]

    # Build guesses
    guesses = {}
    for i, img in enumerate(img_list):
        file_name = os.path.basename(img)
        img_id = file_name.split(".")[0]
        guesses[img_id] = [reference_images[idx] for idx in closest_indices[i].tolist()]

    # Evaluate
    accuracy = 0
    nb_guess = 0
    for answer_key in answers.keys():
        found = 0
        if answer_key in guesses.keys():
            nb_guess += 1
            for value_guess in guesses[answer_key]:
                for value_answ in answers[answer_key]:
                    if value_answ in value_guess and not found:
                        accuracy += 1
                        found = 1

    final_acc = accuracy / nb_guess * 100 if nb_guess > 0 else 0.0

    # Check if it's the best so far
    if final_acc > best_accuracy:
        best_accuracy = final_acc
        best_threshold = color_threshold
        best_boost = color_boost

print(f"Best threshold: {best_threshold}, Best boost: {best_boost}, Accuracy: {best_accuracy:.2f}%")


Best threshold: 80, Best boost: 0.03, Accuracy: 93.51%


In [24]:
# 3) Define your thresholds and boost factor
color_threshold = 80.0  # distance in Lab; typical from your snippet
color_boost     = 0.03  # how much to “push up” similarity if color is close

In [25]:
# 4) Combine color info with the existing CLIP similarity
#    We'll make a copy of the original cosines so we can store the final
#    “color-boosted” similarity in a new tensor.
final_sims = cosine_similarities.clone()  # shape [num_clients, num_refs]

num_clients = final_sims.shape[0]
num_refs    = final_sims.shape[1]

for i in range(num_clients):
    client_img_path = img_list[i]
    client_color_set = client_colors[client_img_path]
    if client_color_set is None:
        # No color info, skip
        continue

    for j in range(num_refs):
        ref_img_path = reference_images[j]
        ref_color_set = reference_colors[ref_img_path]
        if ref_color_set is None:
            continue

        # Compute Lab distance between sets
        dist = color_analyzer.color_distance(client_color_set, ref_color_set)
        # If below threshold, linearly scale the boost
        if dist < color_threshold:
            ratio = dist / color_threshold   # 0..1
            # Add a partial color_boost, bigger if dist is small
            bonus = color_boost * (1.0 - ratio)  
            final_sims[i, j] = final_sims[i, j] + bonus

## Code back to original

In [26]:
##############################################
# Now run the "closest indices" selection
##############################################
closest_indices = torch.argsort(final_sims, dim=1, descending=True)[:, :11]

In [27]:
guesses = {}

for i, img in enumerate(img_list):
    file_name = img.split("\\")[-1]
    img_id = file_name.split(".")[0]
    guesses[img_id] = [reference_images[ind] for ind in closest_indices[i].tolist()]

In [None]:
accuracy = 0
nb_guess = 0
for answer_key in answers.keys():
    found = 0
    if answer_key in guesses.keys():
        nb_guess += 1
        for value_guess in guesses[answer_key]:
            for value_answ in answers[answer_key]:
                if value_answ in value_guess and not found:
                    accuracy += 1
                    found = 1

accuracy/nb_guess*100

In [None]:
# Initialize accuracy counts
accuracy_top1 = 0
accuracy_top5 = 0
accuracy_top10 = 0
nb_guess = 0

for answer_key in answers.keys():
    if answer_key in guesses.keys():
        nb_guess += 1
        found_top1 = False
        found_top5 = False
        found_top10 = False

        for value_answ in answers[answer_key]:
            if any(value_answ in value_guess for value_guess in guesses[answer_key][:1]):
                accuracy_top1 += 1
                found_top1 = True
            if any(value_answ in value_guess for value_guess in guesses[answer_key][:5]):
                accuracy_top5 += 1
                found_top5 = True
            if any(value_answ in value_guess for value_guess in guesses[answer_key][:10]):
                accuracy_top10 += 1
                found_top10 = True

# Compute percentage accuracy
accuracy_top1 = accuracy_top1 / nb_guess * 100 if nb_guess > 0 else 0
accuracy_top5 = accuracy_top5 / nb_guess * 100 if nb_guess > 0 else 0
accuracy_top10 = accuracy_top10 / nb_guess * 100 if nb_guess > 0 else 0
print(accuracy_top1, accuracy_top5, accuracy_top10)

### Displaying predictions

In [29]:
img_path = "../../data/test_image_headmind/IMG_6893.jpg"

In [None]:
for i,img_path in enumerate(img_list):

    if i < 10:
        continue
    if i > 12:
        break
    file_name = img_path.split("\\")[-1]
    img_id = file_name.split(".")[0]
    guess_list = []
    for g in guesses[img_id]:
        guess_list.append(g)

    fig, axs = plt.subplots(2,6,figsize=(20,5))
    fig
    axs[0,0].imshow(np.swapaxes(mpimg.imread(img_path),0,1))
    for j in range(1,12):
        axs[int(j/6),int(j%6)].imshow(mpimg.imread(guess_list[j-1]))