In [None]:
import os
import torch
import torch.nn.functional as F
import numpy as np
import json
import matplotlib.pyplot as plt
import cv2
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from google.colab import drive
!pip install torchmetrics                # ONLY FOR FIRST EXECUTION

# --- CONFIGURATIONS: SET PATHS, DEVICE, MODEL AND HYPERPARAMETERS ---

PATH_DRIVE = "/content/drive"
PATH_EXPORT = "/content/drive/MyDrive/best_model.pth"
PATH_FILE = "/content/drive/MyDrive/SPair-71k.tar.gz"

PATH_TEST = "/content/SPair-71k/PairAnnotation/test"
PATH_VAL = "/content/SPair-71k/PairAnnotation/val"
PATH_TRAIN = "/content/SPair-71k/PairAnnotation/trn"

ALL_TEST_PATH = "/content/SPair-71k/Layout/small/test.txt"
ALL_TRAIN_PATH = "/content/SPair-71k/Layout/small/trn.txt"
ALL_VAL_PATH = "/content/SPair-71k/Layout/small/val.txt"

IMAGE_FOLDER_NAME = "/content/SPair-71k/JPEGImages"
PTH_PATH = "/content/drive/MyDrive/dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth"

PATH_RES = "/content/Results"
os.makedirs(PATH_RES, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_VERSION = "dinov2"         # "dinov2", "dinov3", "sam"
TUNING = False                     # TRAINING OR INFERENCE

USE_WIN = True                # SOFTMAX OR ARGMAX
WINDOW_SOFTMAX = 7
TAU_SOFTMAX = 0.01       # POINT P3 (SIZE: 3,5,7; TAU: 0.01, 0.05, 0.07)

IMAGE_SIZE = 1024 if MODEL_VERSION == "sam" else 224
PATCH_SIZE = 14 if MODEL_VERSION == "dinov2" else 16
H_PATCH, W_PATCH = IMAGE_SIZE // PATCH_SIZE, IMAGE_SIZE // PATCH_SIZE
ALPHA = [0.05, 0.1, 0.2]
CATEGORIES = [ "aeroplane" ] # , "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "dog", "horse",
                #  "motorbike", "person", "pottedplant", "sheep", "train", "tvmonitor" ]

# --- IF SAM MODEL ---

if MODEL_VERSION == "sam":
    !pip install git+https://github.com/facebookresearch/segment-anything.git
    from segment_anything import sam_model_registry, SamPredictor                 # DOWNLOAD

    PTH_PATH = "/content/drive/MyDrive/sam_vit_b_01ec64.pth"
    predictor = None

# --- DATASET CLASS ---
# --- COLLECT ALL JSON FILE BASED ON CURRENT_CATEGORY OR MODE ---
# --- WHEN REQUIRED, OPEN THE NEXT FILE AND TAKE IN A DICTIONARY ALL YOU NEED ---

class SPair71kDataset(Dataset):
    def __init__(self, pair_path, source_path, category_filter=None):
        self.pair_files = []
        self.image_path = source_path
        file = open(pair_path, "r")

        for line in file:
            (pair_id, category) = line.strip().split(".json")[0].split(":")
            if TUNING or category == category_filter:
                self.pair_files.append((line, category))

        file.close()
        return

    def __len__(self):
        return len(self.pair_files)

    def __getitem__(self, idx):
        (json_file, category) = self.pair_files[idx]
        file_name = json_file.strip() + ".json"
        json_path = os.path.join(self.image_path, file_name)

        file = open(json_path, "r")
        annotation = json.load(file)
        file.close()

        src_path = os.path.join(IMAGE_FOLDER_NAME, category, annotation["src_imname"])      # TAKE ALL INFO
        trg_path = os.path.join(IMAGE_FOLDER_NAME, category, annotation["trg_imname"])
        ids = [int(el) for el in annotation["kps_ids"]]

        return {
            "src_path": src_path,
            "trg_path": trg_path,
            "src_kps": np.array(annotation["src_kps"]),
            "trg_kps": np.array(annotation["trg_kps"]),
            "kps_ids": np.array(ids),
            "trg_bndbox": np.array(annotation["trg_bndbox"]),        # DICT of BATCHES (SIZE=1)
            "file_name": file_name,
        }

# --- IMAGE PREPROCESSING ---
# --- RESIZE IMAGE TO STANDARD MODEL DIMENSIONS, CONVERT IT INTO A TENSOR AND NORMALIZE IT ---

preprocess = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# --- HELPER FUNCTIONS ---

# --- UNFREEZE ONLY THE LAST num_last_blocks LAYERS AND THE FINAL LAYER NORM ---
# --- FREEZE ALL LAYERS AND THEN FREE THE LASTS (IF ACCESSIBLE) ---

def setup_light_finetuning(model, num_last_blocks):
    N_Params = 0
    N_Free_Params = 0

    if MODEL_VERSION == "sam":
        blocks_to_unfreeze = model.image_encoder.blocks[-num_last_blocks:]
        
        if hasattr(model.image_encoder, "post_norm"):       # NOT SURE THE FINAL NORM IS ACCESSIBLE
            norm = model.image_encoder.post_norm
        else:
            norm = None

    else:
        blocks_to_unfreeze = model.blocks[-num_last_blocks:]
        norm = model.norm

    for param in model.parameters():                # FREEZE ALL
        N_Params += param.numel()
        param.requires_grad = False

    for block in blocks_to_unfreeze:

        for param in block.parameters():
            N_Free_Params += param.numel()           # UNFREEZE
            param.requires_grad = True

    if norm:
        for param in norm.parameters():          # UNFREEZE NORM
            N_Free_Params += param.numel()
            param.requires_grad = True

    # NUMBERS

    print("Total parameters:", N_Params)
    print("Total trainable:", N_Free_Params)
    print("Percentage trainable:", round(100 * N_Free_Params / N_Params, 2), "%")
    return model

# --- LOAD AND PREPROCESS IMAGE. EXTRACT PATCH-WISE FEATURES FROM MODEL. ---
# --- RESHAPE AND NORMALIZE FEATURE MAPS. ---
# --- FINALLY, RETURN THE EXTRACTED FEATURE MAP AND THE ORIGINAL DIMENSIONS. --

def get_descriptors(img_path, grad):
    img = Image.open(img_path).convert("RGB")
    (w, h) = img.size
    input_tensor = preprocess(img).unsqueeze(0).to(DEVICE)

    with torch.set_grad_enabled(grad):             # ENABLE WEIGTH UPDATING OR NOT

        if MODEL_VERSION == "dinov2":
            feats = model.get_intermediate_layers(input_tensor, n=1)[0]
            feats = feats.reshape(1, H_PATCH, W_PATCH, feats.shape[2])                # FOR DINOV2

        elif MODEL_VERSION == "dinov3":
            x = model.forward_features(input_tensor)["x_norm_patchtokens"]            # FOR DINOV3
            feats = x.reshape(1, H_PATCH, W_PATCH, x.shape[-1])

        elif MODEL_VERSION == "sam":

            if TUNING:
                feats = model.image_encoder(input_tensor)[:, 1:, :]            # FOR TRAINING SAM
            else:
                predictor.set_image(np.array(img))
                feats = predictor.get_image_embedding()                 # FOR INFERENCE WITH SAM
                feats = feats.permute(0, 2, 3, 1)

    feats = F.normalize(feats, dim=-1)
    return (feats, w, h)

# --- MATCH KEYPOINTS. ---
# --- EXTRACT FEATURE SIZE, CREATE PROBABILITY GRID FOR SOFTMAX ---
# --- FOR EACH SRC_KPS RESCALE IT, USE COSINE SIMILARITY METRIC AND COMPUTE THE PREDICTION ---
# --- WITH SIMPLE ARGMAX (STEP 1), SIMPLE SOFTMAX (STEP 2) OR WINDOW SOFTMAX (STEP 3) ---
# --- IN THE LAST CASE, COMPUTE THE RESULT WITH DIFFERENT WINDOW SIZE AND TAU ---

def match_keypoints(batch, grad):
    src_kps = batch["src_kps"][0].numpy()

    (feat_src, sw, sh) = get_descriptors(batch["src_path"][0], grad)          # GET DESCRIPTORS
    (feat_trg, tw, th) = get_descriptors(batch["trg_path"][0] , grad)

    (_, Hf, Wf, D) = feat_trg.shape
    trg_flat = feat_trg[0].reshape(Hf * Wf, D)
    pred_kps = []

    for i in range(src_kps.shape[0]):
        sx = int(src_kps[i, 0] * Wf / sw)
        sy = int(src_kps[i, 1] * Hf / sh)
        sx = torch.clamp(torch.tensor(sx, device=DEVICE), 0, Wf - 1)
        sy = torch.clamp(torch.tensor(sy, device=DEVICE), 0, Hf - 1)
        src_desc = feat_src[0, sy, sx, :]

        sim = torch.matmul(trg_flat, src_desc)            # COSINE SIMILARITY
 
        if TUNING:                                                                     # NORMAL SOFTMAX
            coords = torch.stack(torch.meshgrid(torch.arange(Wf, device=DEVICE),
                                                    torch.arange(Hf, device=DEVICE), indexing='ij'), dim=-1).reshape(-1,2)
            prob = F.softmax(sim / TAU_SOFTMAX, dim=0)
            pred_xy = (coords.float() * prob[:, None]).sum(dim=0)

        elif USE_WIN:
            sim_map = sim.reshape(Hf, Wf)

            pred_xy = window_soft_argmax(sim_map)              # WINDOW SOFTMAX
            (px, py) = (pred_xy[0], pred_xy[1])

        else:
            best_idx = sim.argmax()
            pred_xy = torch.tensor([best_idx % Wf, best_idx // Wf], device=DEVICE)           # NORMAL INFERENCE

        pred_x = (pred_xy[0] + 0.5) * (tw / Wf)
        pred_y = (pred_xy[1] + 0.5) * (th / Hf)
        pred_kps.append(torch.stack([pred_x, pred_y]))
    
    return torch.stack(pred_kps)

# --- USE SOFTARGMAX TO PREDICT THE FINAL VALUE- ---  

def window_soft_argmax(similarity_map):
    (H, W) = similarity_map.shape

    best_idx = similarity_map.argmax()          # NORMAL SIMLARITY
    y_peak = best_idx // W
    x_peak = best_idx % W

    half = WINDOW_SOFTMAX // 2
    y0 = max(y_peak - half, 0)
    y1 = min(y_peak + half + 1, H)              # WINDOW'S DIMENSIONS
    x0 = max(x_peak - half, 0)
    x1 = min(x_peak + half + 1, W)

    window = similarity_map[y0:y1, x0:x1]           # SOFTMAX WINDOW

    (ys, xs) = torch.meshgrid(torch.arange(y0, y1, device=similarity_map.device),
                            torch.arange(x0, x1, device=similarity_map.device),
                            indexing="ij")
    coords = torch.stack([xs.flatten(), ys.flatten()], dim=1).float()

    prob = F.softmax(window.flatten() / TAU_SOFTMAX, dim=0)        # NEW SOFTMAX
    pred_patch = (coords * prob[:, None]).sum(dim=0)         # FINAL PREDICTION

    return pred_patch


# --- VISUALIZE RESULTS AND COMPARE CORRECT AND PREDICTED KEYPOINTS ON TARGET IMAGE. ---

def visualize_keypoints(src_path, trg_path, src_kps, pred_kps, trg_kps):
    src_img = cv2.imread(src_path)[:, :, ::-1]
    trg_img = cv2.imread(trg_path)[:, :, ::-1]

    (_, axes) = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(src_img)
    axes[0].scatter(src_kps[:,0], src_kps[:,1], c="r", s=40, label="src_kps")
    axes[0].set_title("Source Image")

    axes[1].imshow(trg_img)
    axes[1].scatter(pred_kps[:,0], pred_kps[:,1], c="b", s=40, label="pred_kps")
    axes[1].scatter(trg_kps[:,0], trg_kps[:,1], c="g", s=40, marker="X", label="gt_kps")
    axes[1].set_title("Target Image")

    plt.legend()
    plt.show()

# --- COMPUTE FINAL LOSS ON EVALUATION IMAGES ---
# --- TAKE ALL EVALUATION IMAGES AND COMPUTE THE LOSS IN ORDER TO VERIFY THE UPDATE ---

def compute_evaluation_loss():
    criterion = torch.nn.SmoothL1Loss()
    total_loss = 0.0
    total_samples = 0

    dataset = SPair71kDataset(ALL_VAL_PATH, PATH_VAL)
    loader = DataLoader(dataset, batch_size=1, shuffle=False)                # ANALIZE ALL VAL IMAGES
    bar = tqdm(loader, desc="Computing evaluation loss")

    for batch in bar:
        trg_kps = torch.tensor(batch["trg_kps"][0], device=DEVICE, dtype=torch.float32)
        pred_kps = match_keypoints(batch, grad=False)                  # NO GRADIENTS

        loss = criterion(pred_kps, trg_kps)
        total_loss += loss.item()
        total_samples += 1                                      # COMPUTE LOSS

    mean_loss = total_loss / total_samples
    print("Mean evaluation loss: ", str(mean_loss))               # MEAN VALUE

    return mean_loss

# --- TRAINING FUNCTION. ---
# --- DEFINE IMAGES, LOSS AND METRIC, THAN UNFREEZE LAST LAYERS ---
# --- FOR EACH EPOCH, ANALYZE ALL TRAINING IMAGES AND UPLOAD WEIGHTS COMPUTING THE LOSS. ---
# --- THEN EVALUATE THE MODEL ON THE EVALUATION IMAGES AND SAVE THE MODEL IN CASE IT'S THE BEST ONE. --- 

def Train(model, number_of_Epochs, learning_rate, layers):
    global_loss = float('inf')
    criterion = torch.nn.SmoothL1Loss()

    dataset = SPair71kDataset(ALL_TRAIN_PATH, PATH_TRAIN, None)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)           # INITIALIZATION
    params = []

    print()
    print("Trying with", layers, "free layers")
    model = setup_light_finetuning(model, layers)

    for p in model.parameters():
        if p.requires_grad:                  # TRAINABLE PARAMETERS
            params.append(p)

    optimizer = torch.optim.Adam(params, lr=learning_rate)

    for epoch in range(number_of_Epochs):
        model.train()
        print()
        print("Starting epoch", epoch)
        pbar = tqdm(dataloader, desc="Training")

        for batch in pbar:
            trg_kps = torch.tensor(batch["trg_kps"][0], device=DEVICE, dtype=torch.float32)
      
            pred_kps = match_keypoints(batch, grad=True)     # PREDICTION WITH BASIC SOFTMAX
            loss = criterion(pred_kps, trg_kps)           # LOSS

            optimizer.zero_grad()           # UPDATING
            loss.backward()
            optimizer.step()

            # UPDATING BAR AND GLOBAL LOSS

            pbar.set_postfix(loss=loss.item())

        # EVALUATION

        model.eval()
        loss = compute_evaluation_loss()

        if loss < global_loss:
            path = os.path.join(PATH_EXPORT, "best_model.pth")          # UPDATE ON DRIVE
            torch.save(model.state_dict(), path)
            global_loss = loss

    print("Training finished")
    return

# --- EVALUATION FUNCTION ---
# --- CONSIDER ONLY ONE CATEGORY. ---
# --- FOR EACH PAIR: LOAD SOURCE AND TARGET IMAGES ---
# --- EXTRACT DESCRIPTORS AND RESCALE COORDINATES ---
# --- USE COSINE SIMILARITY IN ORDER TO FIND THE CORRESPONDING POINT IN THE TARGET IMAGE. ---
# --- FINALLY, RETURN THE CORRECT GENERATED KEYPOINTS (USING THEIR DISTANCE FROM THE ORIGINAL ONE) RATIO. ---
# --- OPTIONALLY: VISUALIZE RESULTS. ---

def run_evaluation(category, result_path, visualize=False):
    results_keypoints = {}
    total_correct = {alpha: 0 for alpha in ALPHA}
    total_points = 0

    dataset_cat = SPair71kDataset(ALL_TEST_PATH, PATH_TEST, cat)
    loader = DataLoader(dataset_cat, batch_size=1, shuffle=False)         # EVALUATE THIS CATEGORY OF IMAGES
    pbar = tqdm(loader, desc="Evaluating " + category)

    for batch in pbar:
        trg_kps = batch["trg_kps"][0].numpy()
        kps_ids = batch["kps_ids"][0].numpy()
        pred_kps = match_keypoints(batch, False).numpy()

        max_dim = max(batch["trg_bndbox"][0][2]-batch["trg_bndbox"][0][0], batch["trg_bndbox"][0][3]-batch["trg_bndbox"][0][1])
        total_correct_image = {alpha: 0 for alpha in ALPHA}
        total_points_image = 0                                      # FOR THE CURRENT IMAGE

        for (i,key) in enumerate(kps_ids):
            if str(key) not in results_keypoints:             # IF HAS NOT BEEN CONSIDERED YET, ADD IT
                results_keypoints[str(key)] = []

            dist = np.linalg.norm(pred_kps[i] - trg_kps[i])            # DISTANCE METRIC
            total_points += 1
            total_points_image += 1

            for alpha in ALPHA:
                results_keypoints[str(key)].append((alpha, bool(dist <= alpha * max_dim)))

                if dist <= alpha * max_dim:                     # PREDICTION IS CORRECT?
                    total_correct[alpha] += 1
                    total_correct_image[alpha] += 1

        # PRINT PER IMAGE

        print()

        if USE_WIN:
            print("Results for file " + batch["file_name"][0] + " using softmax")
            print("Window size:", WINDOW_SOFTMAX)
            print("Tau:", TAU_SOFTMAX)
        else:
            print("Results for file " + batch["file_name"][0] + " using argmax")

        
        print()

        for alpha in ALPHA:
            total_correct_image[alpha] = round(100 * total_correct_image[alpha] / total_points_image, 2)      # PCKS PER IMAGE
            print("PCK@" + str(alpha) + ": " + str(total_correct_image[alpha]) + "%")

        mean_pck_image = round(sum(total_correct_image.values()) / len(ALPHA), 2)
        print("MEAN PCK:", mean_pck_image, "%")
        total_correct_image["MEAN"] = mean_pck_image              # MEAN

        # SAVE JSON

        path = os.path.join(result_path, "pck_percentages_for_" + batch["file_name"][0])
        file = open(path, "w")
        json.dump(total_correct_image, file, indent=4)
        file.close()

        # VISUALIZE

        if visualize:
            visualize_keypoints(batch["src_path"][0], batch["trg_path"][0], batch["src_kps"][0], pred_kps, trg_kps)

    # KEYPOINTS

    path = os.path.join(result_path, "keypoints_percentages_for_" + category + ".json")
    file = open(path, "w")
    json.dump(results_keypoints, file, indent=4)               # SAVE KEYPOINTS ANALYZIS
    file.close()

    # CATEGORY PCK

    print()
    print("="*40)
    print("Category:", category)                                   # FOR THE CURRENT CATEGORY
    print()

    for alpha in ALPHA:
        total_correct[alpha] = round(100 * total_correct[alpha] / total_points, 2)         # COMPUTE ALPHAS PCKS
        print("PCK@" + str(alpha) + ": " + str(total_correct[alpha]), "%")

    mean_pck = round(sum(total_correct.values()) / len(ALPHA), 2)             # COMPUTE MEAN
    total_correct["MEAN"] = mean_pck
    print("MEAN PCK: " + str(mean_pck) + "%")

    path = os.path.join(result_path, "pck_mean_percentages_for_" + category + ".json")
    file = open(path, "w")
    json.dump(total_correct, file, indent=4)              # SAVE
    file.close()
    return

# --- MOUNT DRIVE AND EXTRACT DATASET ---

drive.mount(PATH_DRIVE, force_remount=True)
!tar -xzf {PATH_FILE}

# --- LOAD MODEL ---

print()
print("Loading ", MODEL_VERSION, " model...")

if MODEL_VERSION == "dinov2":
    model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg", pretrained=True).to(DEVICE)
elif MODEL_VERSION == "dinov3":
    model = torch.hub.load("facebookresearch/dinov3", "dinov3_vitb16", weights=PTH_PATH).to(DEVICE)
elif MODEL_VERSION == "sam":
    model = sam_model_registry["vit_b"](checkpoint=PTH_PATH).to(DEVICE)
    predictor = SamPredictor(model)

# --- TRAINING IF ENABLED ---

if TUNING:
    Train(model, number_of_Epochs=5, learning_rate=1e-5, layers=3)

    best_model = torch.load(PATH_EXPORT, map_location=DEVICE)              # LOADING BEST MODEL
    model.load_state_dict(best_model)

model.eval()

# --- EVALUATION PER CATEGORY ---

for cat in CATEGORIES:
    print()
    print("Evaluating category: ", cat)

    path = os.path.join(PATH_RES, MODEL_VERSION, cat)                    # NEW DIRECTORY
    os.makedirs(path, exist_ok=True)
    run_evaluation(cat, path, visualize=False)
