In [None]:
import cv2
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
import background_remover_w2 as background_remover
from noise_filter import preprocess_image
from image_split import split_images
from descriptors import compute_descriptor, create_extractor
from keypoints import detect_keypoints

In [None]:
# Paths
IMG_FOLDER = "../Data/Week4/qsd1_w4/"
IMG_FOLDER_GT = "../Data/BBDD/"
GT_CORRESPS_PATH = "../Data/Week4/qsd1_w4/gt_corresps.pkl"
DESC_GT_PATH = "results/descriptors_gt.pkl"

In [None]:
METHOD = "SIFT"

EXTRACTOR = create_extractor(METHOD)

L2NORM = False

def _l2norm(x):
    x = np.asarray(x, dtype=np.float32)
    n = np.linalg.norm(x) + 1e-12
    return x / n

In [None]:
# --- 1. Load or compute GT descriptors ---

def build_gt_descriptors(gt_folder, extractor):
    names = sorted([f for f in os.listdir(gt_folder) if f.lower().endswith(('.jpg'))])
    descs = []
    for name in names:
        print("Processing image ",name)
        img = cv2.imread(os.path.join(gt_folder, name))
        # 1) (Optional) convert BGR->RGB if your descriptors expect RGB
        # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # 2) preprocess (not necessary in gt folder because images doesn't have noise)
        # img_p = preprocess_image(img)

        # 3) background remover + crop
        #im, mask, _, _ = background_remover.remove_background_morphological_gradient(img)
        #img_c = background_remover.crop_to_mask_rectangle(im, mask)

        # 4) descriptor (L2 normalize)
        #kps = detect_keypoints(img_c, METHOD)
        d = compute_descriptor(img, None, extractor)
        if L2NORM:
            d = _l2norm(d)
        descs.append(d)
    return descs, names


if os.path.exists(DESC_GT_PATH):
    with open(DESC_GT_PATH, "rb") as f:
        data = pickle.load(f)
        desc_gt = data["desc_gt"]
        gt_names = data["gt_names"]
else:
    desc_gt, gt_names = build_gt_descriptors(IMG_FOLDER_GT, EXTRACTOR)
    os.makedirs(os.path.dirname(DESC_GT_PATH), exist_ok=True)
    with open(DESC_GT_PATH, "wb") as f:
        pickle.dump({"desc_gt": desc_gt, "gt_names": gt_names}, f)

In [None]:
# --- 2. Load GT correspondences ---
with open(GT_CORRESPS_PATH, "rb") as f:
    gt_corresps = pickle.load(f)

In [None]:
# --- 3. Process all images in the folder ---

image_names = sorted([f for f in os.listdir(IMG_FOLDER) if f.endswith('.jpg')])
desc_query = []

for img_idx, img_name in enumerate(image_names):
    print(f"Processing {img_name} ...")
    img_path = os.path.join(IMG_FOLDER, img_name)
    img = cv2.imread(img_path)

    if img is None:
        print(f"⚠️ Skipping {img_name}: could not read image.")
        continue

    # Split possible multiple artworks
    splitted = split_images(img)

    if splitted[0] is True:
        splitted = splitted[1]  # two artworks detected
        left_artwork, right_artwork = splitted

        left_artwork = preprocess_image(left_artwork)
        right_artwork = preprocess_image(right_artwork)

        iml, left_mask, left_output, _ = background_remover.remove_background_morphological_gradient(left_artwork)
        imr, right_mask, right_output, _ = background_remover.remove_background_morphological_gradient(right_artwork)

        # Crop each artwork to its mask bounding box (no black borders)
        left_cropped = background_remover.crop_to_mask_rectangle(left_artwork, left_mask)
        right_cropped = background_remover.crop_to_mask_rectangle(right_artwork, right_mask)

        # Extract descriptors
        #kps_left = detect_keypoints(left_cropped, METHOD)
        #kps_right = detect_keypoints(right_cropped, METHOD)
        
        desc_left  = compute_descriptor(left_cropped, None, EXTRACTOR)
        desc_right = compute_descriptor(right_cropped, None, EXTRACTOR)
        
        if L2NORM:
            desc_left=_l2norm(desc_left)
            desc_right=_l2norm(desc_right)

        desc_query.append([desc_left, desc_right])
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(cv2.cvtColor(left_cropped, cv2.COLOR_BGR2RGB))
        plt.title(f"Left Artwork: {img_name}")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(cv2.cvtColor(right_cropped, cv2.COLOR_BGR2RGB))
        plt.title(f"Right Artwork: {img_name}")
        plt.axis('off')
        
        plt.show()

    else:  # single artwork
        splitted = splitted[1]  # single artwork
        img = preprocess_image(splitted)
        im, mask, output, _ = background_remover.remove_background_morphological_gradient(img)

        # Crop to mask bounding box (remove black)
        cropped = background_remover.crop_to_mask_rectangle(img, mask)

        # Extract descriptor
        #kps = detect_keypoints(cropped, METHOD)
        desc = compute_descriptor(cropped, None, EXTRACTOR)
        print(desc)
        if L2NORM:
            desc = _l2norm(desc)
            
        desc_query.append([desc])  # keep structure consistent
        
        plt.imshow(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
        plt.title(f"{img_name}")
        plt.axis('off')
        plt.show()

In [None]:
# --- 4. Compute mAP@1 and mAP@5 ---
def compute_map_at_k(desc_query, desc_gt, gt_corresps, k=5):
    query_aps = []

    for i, descs in enumerate(desc_query):
        # GT for this query
        q_gt = gt_corresps[i]
        if not isinstance(q_gt, list):
            q_gt = [q_gt]

        # Lengths
        n_crops = len(descs)
        n_gts = len(q_gt)

        # Make lengths match
        if n_crops > n_gts:
            # repeat last GT until same length
            q_gt = q_gt + [q_gt[-1]] * (n_crops - n_gts)
        elif n_crops < n_gts:
            # repeat last descriptor until same length
            descs = descs + [descs[-1]] * (n_gts - n_crops)

        crop_aps = []

        # Compute the AP for each image (descriptors taken with their gt labels)
        for desc, gt in zip(descs, q_gt):
            sims = cosine_similarity([desc], desc_gt)[0]
            ranked_indices = np.argsort(-sims)[:k]

            # AP for THIS (crop, gt) pair
            num_correct = 0
            precs = []

            for rank, idx in enumerate(ranked_indices, start=1):
                if idx == gt:   # exact match for this specific GT
                    num_correct += 1
                    precs.append(num_correct / rank)

            ap = np.sum(precs)
            crop_aps.append(ap)

        # average over pairs of this query
        query_aps.append(float(np.mean(crop_aps)) if crop_aps else 0.0)

    # final mAP
    return float(np.mean(query_aps)) if query_aps else 0.0
map1 = compute_map_at_k(desc_query, desc_gt, gt_corresps, k=1)
map5 = compute_map_at_k(desc_query, desc_gt, gt_corresps, k=5)

print(f"\n✅ mAP@1 = {map1:.4f}")
print(f"✅ mAP@5 = {map5:.4f}")
