In [None]:
import os
import torch
import torch.nn.functional as F
import numpy as np
import json
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from google.colab import drive
!pip install torchmetrics                 # NEEDED IN THE FIRST EXECUTION

# --- CONFIGURATIONS: SET PATHS, DIVICE, MODEL AND HYPERPARAMETERS ---

PATH_DRIVE = "/content/drive"
PATH_FILE = "/content/drive/MyDrive/SPair-71k.tar.gz"
PATH_DB = "/content/SPair-71k"
PATH_TEST = "/content/SPair-71k/PairAnnotation/test"
IMAGE_FOLDER_NAME = "/content/SPair-71k/JPEGImages"
ALL_PAIRS_PATH = "/content/SPair-71k/Layout/small/test.txt"
PTH_PATH = "/content/drive/MyDrive/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth"

PATH_RES = "/content/Results"
os.makedirs(PATH_RES, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"          # MODEL IN [dinov2, dinov3, sam]
MODEL_VERSION = "dinov2"

if MODEL_VERSION == "sam":
      !pip install git+https://github.com/facebookresearch/segment-anything.git             # FOR SAM
      from segment_anything import sam_model_registry, SamPredictor

      PTH_PATH = "/content/drive/MyDrive/sam_vit_b_01ec64.pth"
      predictor = None

IMAGE_SIZE = 224
PATCH_SIZE = 14 if MODEL_VERSION == "dinov2" else 16
(H_PATCH, W_PATCH) = (IMAGE_SIZE // PATCH_SIZE, IMAGE_SIZE // PATCH_SIZE)
ALPHA = [0.05, 0.1, 0.2]
CATEGORIES = ["aeroplane"] # , "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "dog", "horse",
              # "motorbike", "person", "pottedplant", "sheep", "train", "tvmonitor"]

# --- PREPROCESSING IMAGES: RESIZE, CONVERT INTO TENSOR AND NORMALIZE ACCORDING TO IMAGENET. ---

preprocess = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# --- HELPER FUNCTIONS. ---

# --- LOAD AND PREPROCESS IMAGE. EXTRACT PATCH-WISE FEATURES FROM DINO MODEL. ---
# --- RESHAPE AND NORMALIZE FEATURE MAPS. ---
# --- FINALLY, RETURN THE EXTRACTED FEATURE MAP AND THE ORIGINAL DIMENSIONS. ---

def get_descriptors(img_path):
    img = Image.open(img_path).convert("RGB")
    (w, h) = img.size
    input_tensor = preprocess(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():

        if MODEL_VERSION == "dinov2":
            feats = model.get_intermediate_layers(input_tensor, n=1)[0]  # [1,N,D]
            feats = feats.reshape(1, H_PATCH, W_PATCH, feats.shape[2])   # PATCH-WISE

        elif MODEL_VERSION == "dinov3":
            x = model.forward_features(input_tensor)["x_norm_patchtokens"]
            feats = x.reshape(1, H_PATCH, W_PATCH, x.shape[-1])  # [1,H,W,D]

        elif MODEL_VERSION == "sam":
            predictor.set_image(np.array(img))            # CONVERT INTO ARRAY
            feats = predictor.get_image_embedding()   # [1, D, H', W']
            feats = feats.permute(0, 2, 3, 1)        # [1, H', W', D]

    feats = F.normalize(feats, dim=-1)
    return (feats, w, h)

# --- LOAD IMAGES AND VISUALIZE SOURCE KEYPOINTS (RED), TARGET KEYPOINTS (GREEN) AND TARGET PREDICTED KEYPOINTS (BLUE). ---

def visualize_keypoints(src_path, trg_path, src_kps, pred_kps, trg_kps):
    src_img = cv2.imread(src_path)[:,:,::-1] # BGR -> RGB
    trg_img = cv2.imread(trg_path)[:,:,::-1]

    (fig, axes) = plt.subplots(1,2, figsize=(12,6))
    axes[0].imshow(src_img)
    axes[0].scatter(src_kps[:,0], src_kps[:,1], c="r", s=40, label="src_kps")          # SOURCE
    axes[0].set_title("Source Image")

    axes[1].imshow(trg_img)
    axes[1].scatter(pred_kps[:,0], pred_kps[:,1], c="b", s=40, label="pred_kps")
    axes[1].scatter(trg_kps[:,0], trg_kps[:,1], c="g", s=40, marker="X", label="gt_kps")       # CORRECT AND PREDICTED
    axes[1].set_title("Target Image")

    plt.legend()
    plt.show()
    return

# --- SHOW TWO IMAGES (SOURCE AND TARGET) WITH KEYPOINTS. ---
# --- READS THE LIST OF TEST PAIRS AND FILTER FOR THE CURRENT CATEGORY. ---
# --- FOR EACH PAIR: LOAD SOURCE AND TARGET IMAGES, EXTRACT DESCRIPTORS AND CORRECT KEYPOINTS AND RESCALE COORDINATES ---
# --- FOR EACH ALPHA: USE COSINE SIMILARITY IN ORDER TO FIND THE CORRESPONDING POINT IN THE TARGET IMAGE. ---
# --- FINALLY, RETURN THE CORRECT GENERATED KEYPOINTS (USING THEIR DISTANCE FROM THE ORIGINAL ONE) RATIO. ---
# --- OPTIONALLY: VISUALIZE RESULTS. ---

def run_evaluation(category, visualize=False):
    total_correct = {alpha: 0 for alpha in ALPHA}
    total_points = 0
    category_filenames = []
    file = open(ALL_PAIRS_PATH, "r")

    for line in file:
        parts = line.strip().split(':')

        if parts[1] == category:
            filename = line.strip() + ".json"
            path = os.path.join(PATH_TEST, filename)        # FILTER
            if os.path.exists(path):
                category_filenames.append(path)

    file.close()
    pbar = tqdm(category_filenames, desc="Evaluating " + category)

    for filename in pbar:
        total_correct_image = {alpha: 0 for alpha in ALPHA}
        total_points_image = 0

        file = open(filename, "r")
        annotation = json.load(file)
        file.close()

        (src_name, trg_name) = (annotation["src_imname"], annotation["trg_imname"])
        src_path = os.path.join(IMAGE_FOLDER_NAME, category, src_name)
        trg_path = os.path.join(IMAGE_FOLDER_NAME, category, trg_name)

        (feat_src, sw, sh) = get_descriptors(src_path)
        (feat_trg, tw, th) = get_descriptors(trg_path)         # DESCRIPTORS

        src_kps = np.array(annotation["src_kps"])
        trg_kps = np.array(annotation["trg_kps"])                   # KEYPOINTS AND THRESHOLD
        trg_bbox = np.array(annotation["trg_bndbox"])
        max_dim = max(trg_bbox[2]-trg_bbox[0], trg_bbox[3]-trg_bbox[1])

        pred_kps = []
        (_, Hf, Wf, D) = feat_trg.shape
        trg_flat = feat_trg[0].reshape(Hf*Wf, D)
        scale_x = Wf / sw
        scale_y = Hf / sh
        sx = np.clip((src_kps[:,0] * scale_x).astype(int), 0, Wf-1)          # RESCALING TO FEATURE MAP
        sy = np.clip((src_kps[:,1] * scale_y).astype(int), 0, Hf-1)

        for i in range(len(src_kps)):
            src_desc = feat_src[0, sy[i], sx[i]]

            # COSINE SIMILARITY

            sim = torch.matmul(trg_flat, src_desc)
            best_idx = sim.argmax()

            px = best_idx % Wf
            py = best_idx // Wf

            pred_x = (px + 0.5) * (tw / Wf)
            pred_y = (py + 0.5) * (th / Hf)
            pred_kps.append([pred_x, pred_y])

            # EUCLIDEAN DISTANCE TO VERIFY

            dist = np.linalg.norm([pred_x - trg_kps[i,0], pred_y - trg_kps[i,1]])
            total_points += 1
            total_points_image += 1

            for alpha in ALPHA:
                if dist <= alpha * max_dim:
                    total_correct[alpha] += 1
                    total_correct_image[alpha] += 1

        # UPDATE PROGRES BAR FOR THE CURRENT IMAGE

        str_bar = {}
        print()
        print("Results for the couple (", src_name, ",", trg_name, ")")
        print()

        for alpha in ALPHA:
            str_bar[alpha] = round(100 * total_correct_image[alpha] / total_points_image, 2)
            print("PCK for alpha=", alpha, " --> ", str_bar[alpha] ,"%")

        mean_pck_image = sum(str_bar.values()) / len(str_bar)
        print("MEAN PCK: ", str(round(mean_pck_image,2)), "%")            # MEAN FOR IMAGE
        str_bar["MEAN"] = mean_pck_image

        path = PATH_RES + "/pck_results_for_" + MODEL_VERSION + "_" + src_name + "_" + trg_name + ".json"
        file = open(path, "w")
        json.dump(str_bar, file, indent=4)                  # SAVE FOR IMAGE
        file.close()

        # IMAGE VISUALIZATION

        if visualize:
            visualize_keypoints(src_path, trg_path, src_kps, np.array(pred_kps), trg_kps)

    # PCK
    print()
    print("Results for", category)
    print()

    for alpha in ALPHA:
        total_correct[alpha] = round(100 * total_correct[alpha] / total_points, 2)
        print("PCK for alpha=", alpha, " --> ", total_correct[alpha] ,"%")

    return total_correct

# ----------------------------------------------------------------------------------------------------------------- #

# --- MOUNT DRIVE AND EXTRACT DATASET. ---

drive.mount(PATH_DRIVE, force_remount=True)
!tar -xzf {PATH_FILE}

# --- LOAD MODEL BASED ON CHOICE. ---

print()
print("Loading ", MODEL_VERSION, " model...")

if MODEL_VERSION == "dinov2":
    model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg", pretrained = True).to(DEVICE)

elif MODEL_VERSION == "dinov3":
    model = torch.hub.load("facebookresearch/dinov3", "dinov3_vitb16", weights = PTH_PATH).to(DEVICE)

elif MODEL_VERSION == "sam":
    model = sam_model_registry["vit_b"](checkpoint=PTH_PATH).to(DEVICE)
    predictor = SamPredictor(model)

model.eval()

# --- LOOP AMONG CLASSES. ---

for cat in CATEGORIES:
    print("\n")
    print("Evaluating category: ", cat)
    results = run_evaluation(cat, visualize=False)
    path = PATH_RES + "/pck_results_for_" + MODEL_VERSION + "_" + cat + ".json"

    mean_pck = sum(results.values()) / len(results)
    print("="*40)
    print("MEAN PCK for ", cat, ": ", str(round(mean_pck,2)), "%")            # MEAN
    results["MEAN"] = mean_pck

    file = open(path, "w")
    json.dump(results, file, indent=4)                  # SAVE MEAN
    file.close()