In [None]:
import cv2
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
from background_removal_exp import background_remover_w2
from descriptors import preprocess_image, extract_descriptor, extract_descriptors
from image_split_v2 import split_images

# Paths
IMG_FOLDER = "../Data/Week3/qst2_w3/"
IMG_FOLDER_GT = "../Data/Week3/BBDD/"
DESC_GT_PATH = "results/descriptors_gt.pkl"
RESULTS_PATH = "results/qst2_top10_results.pkl"

K = 10  # number of best results to retrieve

# --- 1. Load or compute GT descriptors ---
if os.path.exists(DESC_GT_PATH):
    print(f"✅ Loading GT descriptors from {DESC_GT_PATH}")
    with open(DESC_GT_PATH, "rb") as f:
        data = pickle.load(f)
        desc_gt = data["desc_gt"]
        gt_names = data["gt_names"]
else:
    print("🧠 Computing GT descriptors...")
    desc_gt, gt_names = extract_descriptors(IMG_FOLDER_GT, preprocess=False)
    os.makedirs(os.path.dirname(DESC_GT_PATH), exist_ok=True)
    with open(DESC_GT_PATH, "wb") as f:
        pickle.dump({"desc_gt": desc_gt, "gt_names": gt_names}, f)
    print(f"💾 Saved GT descriptors to {DESC_GT_PATH}")


def crop_to_mask_rectangle(image, mask):
    """Crop the image to the rectangular bounding box of the mask (removes black areas)."""
    mask = (mask > 0).astype(np.uint8)
    coords = cv2.findNonZero(mask)
    if coords is None:
        return image  # fallback if mask empty
    x, y, w, h = cv2.boundingRect(coords)
    return image[y:y+h, x:x+w]


# --- 2. Process query images ---
image_names = sorted([f for f in os.listdir(IMG_FOLDER) if f.endswith('.jpg')])
results = []  # list of lists: one (or two) sublists per query image

for img_idx, img_name in enumerate(image_names):
    print(f"\nProcessing {img_name} ...")
    img_path = os.path.join(IMG_FOLDER, img_name)
    img = cv2.imread(img_path)

    if img is None:
        print(f"⚠️ Skipping {img_name}: could not read image.")
        continue

    _,splitted = split_images(img)

    artwork_results = []  # will hold one or two sublists for this image

    # --- Two artworks ---
    if isinstance(splitted, tuple):
        left_artwork, right_artwork = splitted
        for side, art in zip(["Left", "Right"], [left_artwork, right_artwork]):
            art = preprocess_image(art)
            _, mask, output, _ = background_remover_w2.remove_background_morphological_gradient(art)
            cropped = crop_to_mask_rectangle(art, mask)
            desc = extract_descriptor(cropped)

            # Cosine similarity to GT descriptors
            sims = cosine_similarity([desc], desc_gt)[0]
            top_k = np.argsort(-sims)[:K].tolist()
            artwork_results.append(top_k)

            # Optional visualization
            plt.imshow(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
            plt.title(f"{side} Artwork: {img_name}")
            plt.axis("off")
            plt.show()

    # --- One artwork ---
    else:
        img = preprocess_image(splitted)
        _, mask, output, _ = background_remover_w2.remove_background_morphological_gradient(img)
        cropped = crop_to_mask_rectangle(img, mask)
        desc = extract_descriptor(cropped)

        sims = cosine_similarity([desc], desc_gt)[0]
        top_k = np.argsort(-sims)[:K].tolist()
        artwork_results.append(top_k)

        plt.imshow(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
        plt.title(f"{img_name}")
        plt.axis("off")
        plt.show()

    results.append(artwork_results)


# --- 3. Save top-10 results ---
os.makedirs(os.path.dirname(RESULTS_PATH), exist_ok=True)
with open(RESULTS_PATH, "wb") as f:
    pickle.dump(results, f)

print(f"\n✅ Saved top-{K} retrieval results to {RESULTS_PATH}")


✅ Loading GT descriptors from results/descriptors_gt.pkl


FileNotFoundError: [WinError 3] El sistema no puede encontrar la ruta especificada: '../Data/Week3/qst2_w3/non_augmented/'

In [7]:
print(results)

[[[262, 156, 48, 19, 286, 197, 88, 6, 210, 18]], [[108, 236, 36, 116, 49, 1, 280, 221, 147, 84]], [[151, 56, 284, 266, 209, 145, 202, 114, 178, 194]], [[156, 242, 161, 255, 16, 176, 132, 103, 39, 184]], [[241, 255, 184, 221, 243, 207, 194, 1, 269, 263]], [[47, 36, 232, 104, 251, 236, 35, 270, 49, 277]], [[284, 156, 97, 4, 22, 217, 17, 181, 51, 57]], [[193, 177, 217, 44, 253, 164, 97, 4, 22, 182]], [[109, 210, 32, 276, 281, 255, 144, 171, 16, 250], [221, 198, 165, 6, 280, 108, 25, 106, 267, 1]], [[58, 229, 141, 38, 67, 2, 178, 167, 192, 243]], [[109, 274, 250, 210, 16, 186, 18, 276, 161, 107]], [[53, 109, 193, 197, 70, 7, 182, 250, 254, 4]], [[55, 254, 62, 179, 148, 219, 78, 26, 121, 9]], [[280, 236, 164, 1, 250, 221, 100, 108, 36, 201]], [[259, 196, 215, 262, 116, 268, 183, 258, 235, 255]], [[124, 53, 72, 82, 129, 224, 197, 79, 3, 19]], [[144, 276, 281, 210, 255, 109, 167, 16, 164, 181]], [[81, 222, 93, 3, 251, 226, 30, 247, 14, 107]], [[176, 184, 130, 235, 132, 39, 23, 156, 94, 262]],