In [None]:
import cv2
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
from background_removal_exp import background_remover_w2 as background_remover
from descriptors import preprocess_image, extract_descriptor, extract_descriptors
from image_split import split_images

# Paths
IMG_FOLDER = "../Data/Week3/qsd2_w3/"
IMG_FOLDER_GT = "../Data/Week3/BBDD/"
GT_CORRESPS_PATH = "../Data/Week3/qsd2_w3/gt_corresps.pkl"
DESC_GT_PATH = "results/descriptors_gt.pkl"

# --- 1. Load or compute GT descriptors ---
if os.path.exists(DESC_GT_PATH):
    print(f"✅ Loading GT descriptors from {DESC_GT_PATH}")
    with open(DESC_GT_PATH, "rb") as f:
        data = pickle.load(f)
        desc_gt = data["desc_gt"]
        gt_names = data["gt_names"]
else:
    print("🧠 Computing GT descriptors...")
    desc_gt, gt_names = extract_descriptors(IMG_FOLDER_GT, preprocess=False)
    
    os.makedirs(os.path.dirname(DESC_GT_PATH), exist_ok=True)
    with open(DESC_GT_PATH, "wb") as f:
        pickle.dump({"desc_gt": desc_gt, "gt_names": gt_names}, f)
    print(f"💾 Saved GT descriptors to {DESC_GT_PATH}")

# --- 2. Load GT correspondences ---
with open(GT_CORRESPS_PATH, "rb") as f:
    gt_corresps = pickle.load(f)


def crop_to_mask_rectangle(image, mask):
    """Crop the image to the rectangular bounding box of the mask (removes black areas)."""
    # Ensure mask is binary (0 or 255)
    mask = (mask > 0).astype(np.uint8)

    # Find nonzero points (foreground)
    coords = cv2.findNonZero(mask)
    if coords is None:
        return image  # fallback if mask is empty

    # Get bounding rectangle
    x, y, w, h = cv2.boundingRect(coords)

    # Crop the original image to the bounding box
    cropped = image[y:y+h, x:x+w]

    return cropped


# --- 3. Process all images in the folder ---
image_names = sorted([f for f in os.listdir(IMG_FOLDER) if f.endswith('.jpg')])
desc_query = []

for img_idx, img_name in enumerate(image_names):
    print(f"Processing {img_name} ...")
    img_path = os.path.join(IMG_FOLDER, img_name)
    img = cv2.imread(img_path)

    if img is None:
        print(f"⚠️ Skipping {img_name}: could not read image.")
        continue

    # Split possible multiple artworks
    splitted = split_images(img)

    if isinstance(splitted, tuple):  # two artworks detected
        left_artwork, right_artwork = splitted

        left_artwork = preprocess_image(left_artwork)
        right_artwork = preprocess_image(right_artwork)

        iml, left_mask, left_output, _ = background_remover.remove_background_morphological_gradient(left_artwork)
        imr, right_mask, right_output, _ = background_remover.remove_background_morphological_gradient(right_artwork)

        # Crop each artwork to its mask bounding box (no black borders)
        left_cropped = crop_to_mask_rectangle(left_artwork, left_mask)
        right_cropped = crop_to_mask_rectangle(right_artwork, right_mask)

        # Extract descriptors
        desc_left = extract_descriptor(left_cropped)
        desc_right = extract_descriptor(right_cropped)

        desc_query.append([desc_left, desc_right])
        
        """plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(cv2.cvtColor(left_cropped, cv2.COLOR_BGR2RGB))
        plt.title(f"Left Artwork: {img_name}")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(cv2.cvtColor(right_cropped, cv2.COLOR_BGR2RGB))
        plt.title(f"Right Artwork: {img_name}")
        plt.axis('off')
        
        plt.show()"""

    else:  # single artwork
        img = preprocess_image(splitted)
        im, mask, output, _ = background_remover.remove_background_morphological_gradient(img)

        # Crop to mask bounding box (remove black)
        cropped = crop_to_mask_rectangle(img, mask)

        # Extract descriptor
        desc = extract_descriptor(cropped)
        desc_query.append([desc])  # keep structure consistent
        
        """plt.imshow(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
        plt.title(f"{img_name}")
        plt.axis('off')
        plt.show()"""

# --- 4. Compute mAP@1 and mAP@5 ---
def compute_map_at_k(desc_query, desc_gt, gt_corresps, k=5):
    """Compute mean Average Precision at K, supporting multiple artworks per query image."""
    aps = []

    for i, descs in enumerate(desc_query):
        query_gt = gt_corresps[i]  # list of ground-truth GT indices for this query
        if not isinstance(query_gt, list):
            query_gt = [query_gt]

        for desc in descs:
            sims = cosine_similarity([desc], desc_gt)[0]
            ranked_indices = np.argsort(-sims)[:k]  # descending order

            num_relevant = len(query_gt)
            num_correct = 0
            precision_at_i = []

            for rank, idx in enumerate(ranked_indices, start=1):
                if idx in query_gt:
                    num_correct += 1
                    precision_at_i.append(num_correct / rank)

            ap = np.sum(precision_at_i) / num_relevant if num_relevant > 0 else 0
            aps.append(ap)

    return np.mean(aps) if aps else 0.0


map1 = compute_map_at_k(desc_query, desc_gt, gt_corresps, k=1)
map5 = compute_map_at_k(desc_query, desc_gt, gt_corresps, k=5)

print(f"\n✅ mAP@1 = {map1:.4f}")
print(f"✅ mAP@5 = {map5:.4f}")


✅ Loading GT descriptors from results/descriptors_gt.pkl
Processing 00000.jpg ...
Processing 00001.jpg ...
Processing 00002.jpg ...
Processing 00003.jpg ...
Processing 00004.jpg ...
Processing 00005.jpg ...
Processing 00006.jpg ...
Processing 00007.jpg ...
Processing 00008.jpg ...
Processing 00009.jpg ...
Processing 00010.jpg ...
Processing 00011.jpg ...
Processing 00012.jpg ...
Processing 00013.jpg ...
Processing 00014.jpg ...
Processing 00015.jpg ...
Processing 00016.jpg ...
Processing 00017.jpg ...
Processing 00018.jpg ...
Processing 00019.jpg ...
Processing 00020.jpg ...
Processing 00021.jpg ...
Processing 00022.jpg ...
Processing 00023.jpg ...
Processing 00024.jpg ...
Processing 00025.jpg ...
Processing 00026.jpg ...
Processing 00027.jpg ...
Processing 00028.jpg ...
Processing 00029.jpg ...

✅ mAP@1 = 0.2838
✅ mAP@5 = 0.3218
