In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
from tqdm import tqdm
from time import time, sleep
import gc
import numpy as np
import h5py
import dataclasses
from copy import deepcopy
import pandas as pd

import cv2
import torch
import torch.nn.functional as F
import kornia as K
import kornia.feature as KF

import torch
from tqdm import tqdm
from pathlib import Path
import numpy as np

from hloc import (
    extract_features,
    match_features,
    reconstruction,
    visualization,
    pairs_from_exhaustive,
)
from hloc.visualization import plot_images, read_image
from hloc.utils import viz_3d
from transformers import AutoImageProcessor, AutoModel
import pycolmap

import metric
import time
import json
from boq_inferface import boq_sort_topk
from boq_inferface import get_trained_boq


device = K.utils.get_cuda_device_if_available(0)
print(f'{device=}')

In [None]:
device = 'cuda:0'
boq_model = get_trained_boq(backbone_name="dinov2", output_dim=12288)
boq_model.to(device)
boq_model.eval()

In [None]:
def load_torch_image(fname, device=torch.device('cpu')):
    img = K.io.load_image(fname, K.io.ImageLoadType.RGB32, device=device)[None, ...]
    return img


# Must Use efficientnet global descriptor to get matching shortlists.
def get_global_desc(fnames, device = torch.device('cpu')):
    processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
    model = AutoModel.from_pretrained('facebook/dinov2-base')
    model = model.eval()
    model = model.to(device)
    global_descs_dinov2 = []
    for i, img_fname_full in tqdm(enumerate(fnames),total= len(fnames)):
        key = os.path.splitext(os.path.basename(img_fname_full))[0]
        timg = load_torch_image(img_fname_full)
        with torch.inference_mode():
            inputs = processor(images=timg, return_tensors="pt", do_rescale=False).to(device)
            outputs = model(**inputs)
            dino_mac = F.normalize(outputs.last_hidden_state[:,1:].max(dim=1)[0], dim=1, p=2)
        global_descs_dinov2.append(dino_mac.detach().cpu())
    global_descs_dinov2 = torch.cat(global_descs_dinov2, dim=0)
    return global_descs_dinov2



@dataclasses.dataclass
class Prediction:
    image_id: str | None  # A unique identifier for the row -- unused otherwise. Used only on the hidden test set.
    dataset: str
    filename: str
    cluster_index: int | None = None
    rotation: np.ndarray | None = None
    translation: np.ndarray | None = None

# Set is_train=True to run the notebook on the training data.
# Set is_train=False if submitting an entry to the competition (test data is hidden, and different from what you see on the "test" folder).
is_train = True
data_dir = 'data/image-matching-challenge-2025'
workdir = 'result/'
os.makedirs(workdir, exist_ok=True)
workdir = Path(workdir)

if is_train:
    sample_submission_csv = os.path.join(data_dir, 'train_labels.csv')
else:
    sample_submission_csv = os.path.join(data_dir, 'sample_submission.csv')

samples = {}
competition_data = pd.read_csv(sample_submission_csv)
for _, row in competition_data.iterrows():
    # Note: For the test data, the "scene" column has no meaning, and the rotation_matrix and translation_vector columns are random.
    if row.dataset not in samples:
        samples[row.dataset] = []
    samples[row.dataset].append(
        Prediction(
            image_id=None if is_train else row.image_id,
            dataset=row.dataset,
            filename=row.image
        )
    )

for dataset in samples:
    print(f'Dataset "{dataset}" -> num_images={len(samples[dataset])}')

max_images = None  # Used For debugging only. Set to None to disable.
datasets_to_process = None  # Not the best convention, but None means all datasets.

if is_train:
    # max_images = 5

    # Note: When running on the training dataset, the notebook will hit the time limit and die. Use this filter to run on a few specific datasets.
    datasets_to_process = [
    	# New data.
    	'amy_gardens',
    	'ETs',
    	'fbk_vineyard',
    	'stairs',
    	# Data from IMC 2023 and 2024.
    	'imc2024_dioscuri_baalshamin',
    	'imc2023_theather_imc2024_church',
    	'imc2023_heritage',
    	'imc2023_haiper',
    	'imc2024_lizard_pond',
    	# Crowdsourced PhotoTourism data.
    	'pt_stpeters_stpauls',
    	'pt_brandenburg_british_buckingham',
    	'pt_piazzasanmarco_grandplace',
    	'pt_sacrecoeur_trevi_tajmahal',
    ]

for dataset, predictions in samples.items():
    if datasets_to_process and dataset not in datasets_to_process:
        print(f'Skipping "{dataset}"')
        continue
    
    images_dir = os.path.join(data_dir, 'train' if is_train else 'test', dataset)
    if not os.path.exists(images_dir):
        print(f'Images dir "{images_dir}" does not exist. Skipping "{dataset}"')
        continue
    
    images_dir = Path(images_dir)

    print(f'Images dir: {images_dir}')

    image_names = [p.filename for p in predictions]
    if max_images is not None:
        image_names = image_names[:max_images]

    print(f'\nProcessing dataset "{dataset}": {len(image_names)} images')

    filename_to_index = {p.filename: idx for idx, p in enumerate(predictions)}
    
    !rm -rf $workdir/$dataset

    sfm_pairs = workdir / dataset / "pairs-sfm.txt"
    loc_pairs = workdir / dataset / "pairs-loc.txt"
    sfm_dir = workdir / dataset / "sfm"
    features = workdir / dataset / "features.h5"
    matches = workdir / dataset / "matches.h5"
    topk_save_path = workdir / dataset / "topk.json"
    topk_vis_path = workdir / dataset / "show_topk"
    topk = 20
    vis = False

    os.makedirs(topk_vis_path, exist_ok=True)
    if len(image_names) > 20 and len(image_names) < 60:
        feature_conf = extract_features.confs["disk"]
        matcher_conf = match_features.confs["disk+lightglue"]
        # disk lightglue topk
        start = time.time()
        extract_features.main(feature_conf, images_dir, image_list=image_names, feature_path=features)
        pairs_from_exhaustive.main(sfm_pairs, image_list=image_names)
        match_features.main(matcher_conf, sfm_pairs, features=features, matches=matches)

        from utils import  parser_h5s, get_topk_candidates
        res = parser_h5s(features, matches, sfm_pairs, min_match_score=0.2)
        topks = get_topk_candidates(res, images_dir, topk_vis_path, k=topk, min_matches=15, vis=vis)
        with open(topk_save_path, "w", encoding="utf-8") as f:
            json.dump(topks, f, ensure_ascii=False, indent=4)
    
        end = time.time()
        print(f'\033[91m {dataset} with num {len(image_names)}, Disk + light_glue + topk cost time: { end - start} seconds \033[0m')

    
    elif len(image_names) >= 60:
        start = time.time()
        topks = boq_sort_topk(images_dir, image_names, boq_model, device, k=topk, vis=vis, vis_save_dir=topk_vis_path)
        ##boq global topk
        with open(topk_save_path, "w", encoding="utf-8") as f:
            json.dump(topks, f, ensure_ascii=False, indent=4)
    
        end = time.time()
        print(f'\033[93m {dataset} with num {len(image_names)}, BOQ TOPK time: { end - start} seconds \033[0m')

    else:
        # do not top k, just do masr3d

        print('\033[94m skip this dataset, cause number of imgs lower than 20')
        pass
        
    
    

    
    gc.collect()

