In [4]:
!pip install --no-index /kaggle/input/imc2024-packages-lightglue-rerun-kornia/* --no-deps
!mkdir -p /root/.cache/torch/hub/checkpoints
!cp /kaggle/input/aliked/pytorch/aliked-n16/1/* /root/.cache/torch/hub/checkpoints/
!cp /kaggle/input/lightglue/pytorch/aliked/1/* /root/.cache/torch/hub/checkpoints/
!cp /kaggle/input/lightglue/pytorch/aliked/1/aliked_lightglue.pth /root/.cache/torch/hub/checkpoints/aliked_lightglue_v0-1_arxiv-pth

Processing /kaggle/input/imc2024-packages-lightglue-rerun-kornia/kornia-0.7.2-py2.py3-none-any.whl
Processing /kaggle/input/imc2024-packages-lightglue-rerun-kornia/kornia_moons-0.2.9-py3-none-any.whl
Processing /kaggle/input/imc2024-packages-lightglue-rerun-kornia/kornia_rs-0.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Processing /kaggle/input/imc2024-packages-lightglue-rerun-kornia/lightglue-0.0-py3-none-any.whl
Processing /kaggle/input/imc2024-packages-lightglue-rerun-kornia/pycolmap-0.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Processing /kaggle/input/imc2024-packages-lightglue-rerun-kornia/rerun_sdk-0.15.0a2-cp38-abi3-manylinux_2_31_x86_64.whl
Installing collected packages: rerun-sdk, pycolmap, lightglue, kornia-rs, kornia-moons, kornia
  Attempting uninstall: kornia-rs
    Found existing installation: kornia_rs 0.1.8
    Uninstalling kornia_rs-0.1.8:
      Successfully uninstalled kornia_rs-0.1.8
  Attempting uninstall: kornia
    Found exist

In [5]:
import matplotlib.pyplot as plt

import os
from tqdm import tqdm
from pathlib import Path
from time import time, sleep
from fastprogress import progress_bar
import gc
import numpy as np
import h5py
from IPython.display import clear_output
from collections import defaultdict
from copy import deepcopy
from typing import Any
import itertools
import pandas as pd

import cv2
import torch
from torch import Tensor as T
import torch.nn.functional as F
import kornia as K
import kornia.feature as KF
from PIL import Image
from transformers import AutoImageProcessor, AutoModel

from lightglue import match_pair
from lightglue import LightGlue, ALIKED
from lightglue.utils import load_image, rbd

import pycolmap

import sys
sys.path.append("/kaggle/input/colmap-db-import")

from database import *
from h5_to_db import *

  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


In [26]:
def arr_to_str(input_arr):
    return ';'.join([str(x) for x in input_arr.reshape(-1)])

def load_torch_image(input_path, device=torch.device('cpu')):
    input_img = K.io.load_image(input_path, K.io.ImageLoadType.RGB32, device=device)[None, ...]

    return input_img


device = K.utils.get_cuda_device_if_available(0)
print(device)

cuda:0


In [27]:
def get_global_descriptor(input_paths, device=torch.device('cpu')):
    processor = AutoImageProcessor.from_pretrained('/kaggle/input/dinov2/pytorch/base/1')
    model = AutoModel.from_pretrained('/kaggle/input/dinov2/pytorch/base/1')
    model = model.eval()
    model = model.to(device)

    global_desc_dinov2 = []

    for idx, img_path in tqdm(enumerate(input_paths), total=len(input_paths)):
        key = os.path.splitext(os.path.basename(img_path))[0]
        torch_img = load_torch_image(img_path)

        with torch.inference_mode():
            inputs = processor(images=torch_img, return_tensors="pt", do_rescale=False).to(device)
            outputs = model(**inputs)
            dino_norm = F.normalize(outputs.last_hidden_state[:, 1:].max(dim=1)[0], dim=1, p=2)

        global_desc_dinov2.append(dino_norm.detach().cpu())

    global_desc_dinov2 = torch.cat(global_desc_dinov2, dim=0)

    return global_desc_dinov2

def get_img_pairs(input_paths):
    index_pairs = []

    for i in range(len(input_paths)):
        for j in range(i+1, len(input_paths)):
            index_pairs.append((i, j))

    return index_pairs

def get_img_pairs_shortlist(input_paths, sim_th=0.3, min_pairs=20, exhausive_if_less=20, device=torch.device('cpu')):
    num_imgs = len(input_paths)
    if num_imgs <= exhausive_if_less:
        return get_img_pairs(input_paths)

    descs = get_global_descriptor(input_paths, device=device)
    descs_np = torch.cdist(descs, descs, p=2).detach().cpu().numpy()

    mask = descs_np <= sim_th
    total = 0

    matching_list = []
    already_there_set = []
    num_arange = np.arange(num_imgs)

    for idx in range(num_imgs - 1):
        mask_idx = mask[idx]
        to_match = num_arange[mask_idx]

        if len(to_match) < min_pairs:
            to_match = np.argsort(descs_np[idx])[:min_pairs]

        for i in to_match:
            if idx == i:
                continue
            if descs_np[idx, i] < 1000:
                matching_list.append(tuple(sorted((idx, i.item()))))
                total += 1

    matching_list = sorted(list(set(matching_list)))

    return matching_list

In [28]:
def detect_aliked(input_paths, feature_path='.featureout', num_features=2048, resize_to=1024, device=torch.device('cpu')):
    data_type = torch.float32
    extractor = ALIKED(max_num_keypoints=num_features, detection_threshold=0.01, resize=resize_to)
    extractor = extractor.eval()
    extractor = extractor.to(device, data_type)

    if not os.path.isdir(feature_path):
        os.makedirs(feature_path)

    with h5py.File(f'{feature_path}/keypoints.h5', mode='w') as feature_keypoints, h5py.File(f'{feature_path}/descriptors.h5', mode='w') as feature_descriptors:
        for img_path in tqdm(input_paths):
            img_filename = img_path.split('/')[-1]
            key = img_filename

            with torch.inference_mode():
                input_img = load_torch_image(img_path, device=device).to(data_type)
                img_feature = extractor.extract(input_img)
                img_kpts = img_feature['keypoints'].reshape(-1, 2).detach().cpu().numpy()
                img_descs = img_feature['descriptors'].reshape(len(img_kpts), -1).detach().cpu().numpy()

                feature_keypoints[key] = img_kpts
                feature_descriptors[key] = img_descs

    return

def match_lightglue(input_paths, index_pairs, feature_path='.featureout', device=torch.device('cpu'), min_matches=15, verbose=False):
    matcher = KF.LightGlueMatcher("aliked", {"width_confidence":-1, "depth_confidence":-1, "mp":True if 'cuda' in str(device) else False})
    matcher = matcher.eval()
    matcher = matcher.to(device)

    with h5py.File(f'{feature_path}/keypoints.h5', mode='r') as feature_keypoints, h5py.File(f'{feature_path}/descriptors.h5', mode='r') as feature_descriptors, h5py.File(f'{feature_path}/matches.h5', mode='w') as feature_matches:
        for pair_index in tqdm(index_pairs):
            idx_src, idx_dst = pair_index
            filename_src, filename_dst = input_paths[idx_src], input_paths[idx_dst]
            key_src, key_dst = filename_src.split('/')[-1], filename_dst.split('/')[-1]

            keypoints_src = torch.from_numpy(feature_keypoints[key_src][...]).to(device)
            keypoints_dst = torch.from_numpy(feature_keypoints[key_dst][...]).to(device)

            descriptors_src = torch.from_numpy(feature_descriptors[key_src][...]).to(device)
            descriptors_dst = torch.from_numpy(feature_descriptors[key_dst][...]).to(device)

            with torch.inference_mode():
                dists, idxs = matcher(descriptors_src, descriptors_dst, KF.laf_from_center_scale_ori(keypoints_src[None]), KF.laf_from_center_scale_ori(keypoints_dst[None]))

            if len(idxs) == 0:
                continue

            num_matches = len(idxs)

            if verbose:
                print(f'{key_src}-{key_dst}: {num_matches} matches')

            group = feature_matches.require_group(key_src)

            if num_matches >= min_matches:
                group.create_dataset(key_dst, data=idxs.detach().cpu().numpy().reshape(-1, 2))

    return

def import_into_colmap(img_dir, feature_dir='.featureout', db_path='colmap.db'):
    db = COLMAPDatabase.connect(db_path)
    db.create_tables()

    is_single_cam = False
    filename_to_id = add_keypoints(db, feature_dir, img_dir, '', 'simple-pinhole', is_single_cam)
    add_matches(db, feature_dir, filename_to_id)

    db.commit()

    return

In [29]:
input_src = '/kaggle/input/image-matching-challenge-2024/'
data_dict = {}

with open(f'{input_src}/sample_submission.csv', 'r') as submission_file:
    for idx, val in enumerate(submission_file):
        if (idx == 0):
            print(val)

        if val and idx > 0:
            img_path, dataset, scene, _, _ = val.strip().split(',')

            if dataset not in data_dict:
                data_dict[dataset] = {}

            if scene not in data_dict[dataset]:
                data_dict[dataset][scene] = []

            data_dict[dataset][scene].append(img_path)

for dataset in data_dict:
    for scene in data_dict[dataset]:
        print(f'{dataset} / {scene} -> {len(data_dict[dataset][scene])} images')

output_results = {}
timings = {"shortlisting":[], "feature_detection":[], "feature_matching":[], "RANSAC":[], "Reconstruction":[]}

image_path,dataset,scene,rotation_matrix,translation_vector

church / church -> 41 images


In [30]:
def create_submission(output_results, data_dict):
    with open(f'submission.csv', 'w') as result_file:
        result_file.write('image_path,dataset,scene,rotation_matrix,translation_vector\n')

        for dataset in data_dict:
            if dataset in output_results:
                result = output_results[dataset]
            else:
                result = {}

            for scene in data_dict[dataset]:
                if scene in result:
                    scene_result = result[scene]
                else:
                    scene_result = {"R":{}, "t":{}}

                for image in data_dict[dataset][scene]:
                    if image in scene_result:
                        # print(image)
                        R = scene_result[image]['R'].reshape(-1)
                        T = scene_result[image]['t'].reshape(-1)
                    else:
                        R = np.eye(3).reshape(-1)
                        T = np.zeros((3))

                    result_file.write(f'{image},{dataset},{scene},{arr_to_str(R)},{arr_to_str(T)}\n')

In [31]:
gc.collect()

datasets = []

for dataset in data_dict:
    datasets.append(dataset)

for dataset in data_dict:
    print(dataset)

    if dataset not in output_results:
        output_results[dataset] = {}

    for scene in data_dict[dataset]:
        print(scene)

        img_dir = os.path.join(input_src, '/'.join(data_dict[dataset][scene][0].split('/')[:-1]))

        try:
            output_results[dataset][scene] = {}

            input_paths = [os.path.join(input_src, x) for x in data_dict[dataset][scene]]
            feature_path = f'featureout/{dataset}/{scene}'
            os.makedirs(feature_path, exist_ok=True)

            t = time()
            index_pairs = get_img_pairs_shortlist(input_paths, device=device)
            t = time() - t
            timings['shortlisting'].append(t)

            print(f'{len(index_pairs)}, pairs to match, {t:.4f} sec')
            gc.collect()

            t = time()
            detect_aliked(input_paths, feature_path, device=device)
            t = time() - t
            timings['feature_detection'].append(t)

            print(f'Features detected in {t:.4f} sec')
            gc.collect()

            t = time()
            match_lightglue(input_paths, index_pairs, feature_path, device=device)
            t = time() - t
            timings['feature_matching'].append(t)

            print(f'Features matched in {t:.4f} sec')
            gc.collect()

            db_path = f'{feature_path}/colmap.db'
            if os.path.isfile(db_path):
                os.remove(db_path)

            sleep(1)

            import_into_colmap(img_dir, feature_path, db_path)

            t = time()
            pycolmap.match_exhaustive(db_path)
            t = time() - t
            timings['RANSAC'].append(t)

            print(f'RANSAC in {t:.4f} sec')

            t = time()
            recon_options = pycolmap.IncrementalPipelineOptions()
            recon_options.min_model_size = 3
            recon_options.max_num_models = 2

            output_path = f'{feature_path}/colmap_rec_aliked'
            os.makedirs(output_path, exist_ok=True)

            recon_results = pycolmap.incremental_mapping(database_path=db_path, image_path=img_dir, output_path=output_path, options=recon_options)

            sleep(1)

            print(recon_results)
            clear_output(wait=False)
            t = time() - t
            timings['Reconstruction'].append(t)

            print(f'Reconstruction done in {t:.4f} sec')

            imgs_registered = 0
            best_idx = None

            if isinstance(recon_results, dict):
                for idx, rec in recon_results.items():
                    try:
                        if len(rec.images) > imgs_registered:
                            imgs_registered = len(rec.images)
                            best_idx = idx
                    except:
                        continue

            if best_idx is not None:
                for key, img in recon_results[best_idx].images.items():
                    key_img = f'test/{scene}/images/{img.name}'

                    output_results[dataset][scene][key_img] = {}
                    output_results[dataset][scene][key_img]["R"] = deepcopy(img.cam_from_world.rotation.matrix())
                    output_results[dataset][scene][key_img]["t"] = deepcopy(np.array(img.cam_from_world.translation))

            create_submission(output_results, data_dict)
            gc.collect()
        except Exception as e:
            print(e)
            pass

Reconstruction done in 84.5267 sec
