## Baseline submission

A notebook to generate a valid submission. Implements three local feature/matcher methods: LoFTR, DISK, and KeyNetAffNetHardNet.

Remember to enable a GPU accelerator and disable internet access, then press "submit" on the right pane.

In [None]:
!pip install /kaggle/input/imc2023-dependencies-data/wheels/einops-0.6.1-py3-none-any.whl

In [None]:
# General utilities
import os
from tqdm import tqdm
from time import time
from fastprogress import progress_bar
import gc
import numpy as np
import h5py
from IPython.display import clear_output
from collections import defaultdict
from copy import deepcopy
import pandas as pd
from dataclasses import dataclass

# CV/ML
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import kornia as K
import kornia.feature as KF
from PIL import Image
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torchvision import transforms
# 3D reconstruction
import pycolmap

In [None]:
device=torch.device('cuda')

In [None]:
import sys
sys.path.append("../input/super-glue-pretrained-network")
from models.matching import Matching
from models.utils import (compute_pose_error, compute_epipolar_error,
                          estimate_pose, make_matching_plot,
                          error_colormap, AverageTimer, pose_auc, read_image,
                          rotate_intrinsics, rotate_pose_inplane,
                          scale_intrinsics)

In [None]:
resize = [-1, ]
resize_float = True

config = {
    "superpoint": {
        "nms_radius": 8,
        "keypoint_threshold": 0.005,
        "max_keypoints": 2048
    },
    "superglue": {
        "weights": "outdoor",
        "sinkhorn_iterations": 50,
        "match_threshold": 0.1,
    }
}
super_glue_matcher = Matching(config).eval().to(device)

In [None]:
#timm.list_models(pretrained=True)

In [None]:
print('Kornia version', K.__version__)
print('Pycolmap version', pycolmap.__version__)

LOCAL_FEATURE = 'LoFTR' # 'KeyNetAffNetHardNet', 'LoFTR'

# Can be LoFTR, KeyNetAffNetHardNet, or DISK

In [None]:
DEBUG_FLAG = False
if len(os.listdir("/kaggle/input/image-matching-challenge-2023/test")) < 2:
    DEBUG_FLAG = True

In [None]:
def arr_to_str(a):
    return ';'.join([str(x) for x in a.reshape(-1)])


def load_torch_image(fname, device=torch.device('cpu')):
    img = K.image_to_tensor(cv2.imread(fname), False).float() / 255.
    img = K.color.bgr_to_rgb(img.to(device))
    return img

In [None]:
# Code to manipulate a colmap database.
# Forked from https://github.com/colmap/colmap/blob/dev/scripts/python/database.py

# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#     * Redistributions of source code must retain the above copyright
#       notice, this list of conditions and the following disclaimer.
#
#     * Redistributions in binary form must reproduce the above copyright
#       notice, this list of conditions and the following disclaimer in the
#       documentation and/or other materials provided with the distribution.
#
#     * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
#       its contributors may be used to endorse or promote products derived
#       from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)

# This script is based on an original implementation by True Price.

import sys
import sqlite3
import numpy as np


IS_PYTHON3 = sys.version_info[0] >= 3

MAX_IMAGE_ID = 2**31 - 1

CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras (
    camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
    model INTEGER NOT NULL,
    width INTEGER NOT NULL,
    height INTEGER NOT NULL,
    params BLOB,
    prior_focal_length INTEGER NOT NULL)"""

CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors (
    image_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)"""

CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images (
    image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
    name TEXT NOT NULL UNIQUE,
    camera_id INTEGER NOT NULL,
    prior_qw REAL,
    prior_qx REAL,
    prior_qy REAL,
    prior_qz REAL,
    prior_tx REAL,
    prior_ty REAL,
    prior_tz REAL,
    CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < {}),
    FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))
""".format(MAX_IMAGE_ID)

CREATE_TWO_VIEW_GEOMETRIES_TABLE = """
CREATE TABLE IF NOT EXISTS two_view_geometries (
    pair_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    config INTEGER NOT NULL,
    F BLOB,
    E BLOB,
    H BLOB)
"""

CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints (
    image_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB,
    FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)
"""

CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches (
    pair_id INTEGER PRIMARY KEY NOT NULL,
    rows INTEGER NOT NULL,
    cols INTEGER NOT NULL,
    data BLOB)"""

CREATE_NAME_INDEX = \
    "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)"

CREATE_ALL = "; ".join([
    CREATE_CAMERAS_TABLE,
    CREATE_IMAGES_TABLE,
    CREATE_KEYPOINTS_TABLE,
    CREATE_DESCRIPTORS_TABLE,
    CREATE_MATCHES_TABLE,
    CREATE_TWO_VIEW_GEOMETRIES_TABLE,
    CREATE_NAME_INDEX
])


def image_ids_to_pair_id(image_id1, image_id2):
    if image_id1 > image_id2:
        image_id1, image_id2 = image_id2, image_id1
    return image_id1 * MAX_IMAGE_ID + image_id2


def pair_id_to_image_ids(pair_id):
    image_id2 = pair_id % MAX_IMAGE_ID
    image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID
    return image_id1, image_id2


def array_to_blob(array):
    if IS_PYTHON3:
        return array.tostring()
    else:
        return np.getbuffer(array)


def blob_to_array(blob, dtype, shape=(-1,)):
    if IS_PYTHON3:
        return np.fromstring(blob, dtype=dtype).reshape(*shape)
    else:
        return np.frombuffer(blob, dtype=dtype).reshape(*shape)


class COLMAPDatabase(sqlite3.Connection):

    @staticmethod
    def connect(database_path):
        return sqlite3.connect(database_path, factory=COLMAPDatabase)


    def __init__(self, *args, **kwargs):
        super(COLMAPDatabase, self).__init__(*args, **kwargs)

        self.create_tables = lambda: self.executescript(CREATE_ALL)
        self.create_cameras_table = \
            lambda: self.executescript(CREATE_CAMERAS_TABLE)
        self.create_descriptors_table = \
            lambda: self.executescript(CREATE_DESCRIPTORS_TABLE)
        self.create_images_table = \
            lambda: self.executescript(CREATE_IMAGES_TABLE)
        self.create_two_view_geometries_table = \
            lambda: self.executescript(CREATE_TWO_VIEW_GEOMETRIES_TABLE)
        self.create_keypoints_table = \
            lambda: self.executescript(CREATE_KEYPOINTS_TABLE)
        self.create_matches_table = \
            lambda: self.executescript(CREATE_MATCHES_TABLE)
        self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX)

    def add_camera(self, model, width, height, params,
                   prior_focal_length=False, camera_id=None):
        params = np.asarray(params, np.float64)
        cursor = self.execute(
            "INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)",
            (camera_id, model, width, height, array_to_blob(params),
             prior_focal_length))
        return cursor.lastrowid

    def add_image(self, name, camera_id,
                  prior_q=np.zeros(4), prior_t=np.zeros(3), image_id=None):
        cursor = self.execute(
            "INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
            (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2],
             prior_q[3], prior_t[0], prior_t[1], prior_t[2]))
        return cursor.lastrowid

    def add_keypoints(self, image_id, keypoints):
        assert(len(keypoints.shape) == 2)
        assert(keypoints.shape[1] in [2, 4, 6])

        keypoints = np.asarray(keypoints, np.float32)
        self.execute(
            "INSERT INTO keypoints VALUES (?, ?, ?, ?)",
            (image_id,) + keypoints.shape + (array_to_blob(keypoints),))

    def add_descriptors(self, image_id, descriptors):
        descriptors = np.ascontiguousarray(descriptors, np.uint8)
        self.execute(
            "INSERT INTO descriptors VALUES (?, ?, ?, ?)",
            (image_id,) + descriptors.shape + (array_to_blob(descriptors),))

    def add_matches(self, image_id1, image_id2, matches):
        assert(len(matches.shape) == 2)
        assert(matches.shape[1] == 2)

        if image_id1 > image_id2:
            matches = matches[:,::-1]

        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        matches = np.asarray(matches, np.uint32)
        self.execute(
            "INSERT INTO matches VALUES (?, ?, ?, ?)",
            (pair_id,) + matches.shape + (array_to_blob(matches),))

    def add_two_view_geometry(self, image_id1, image_id2, matches,
                              F=np.eye(3), E=np.eye(3), H=np.eye(3), config=2):
        assert(len(matches.shape) == 2)
        assert(matches.shape[1] == 2)

        if image_id1 > image_id2:
            matches = matches[:,::-1]

        pair_id = image_ids_to_pair_id(image_id1, image_id2)
        matches = np.asarray(matches, np.uint32)
        F = np.asarray(F, dtype=np.float64)
        E = np.asarray(E, dtype=np.float64)
        H = np.asarray(H, dtype=np.float64)
        self.execute(
            "INSERT INTO two_view_geometries VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
            (pair_id,) + matches.shape + (array_to_blob(matches), config,
             array_to_blob(F), array_to_blob(E), array_to_blob(H)))

In [None]:
# Code to interface DISK with Colmap.
# Forked from https://github.com/cvlab-epfl/disk/blob/37f1f7e971cea3055bb5ccfc4cf28bfd643fa339/colmap/h5_to_db.py

#  Copyright [2020] [Michał Tyszkiewicz, Pascal Fua, Eduard Trulls]
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

import os, argparse, h5py, warnings
import numpy as np
from tqdm import tqdm
from PIL import Image, ExifTags


def get_focal(image_path, err_on_default=False):
    image         = Image.open(image_path)
    max_size      = max(image.size)

    exif = image.getexif()
    focal = None
    if exif is not None:
        focal_35mm = None
        # https://github.com/colmap/colmap/blob/d3a29e203ab69e91eda938d6e56e1c7339d62a99/src/util/bitmap.cc#L299
        for tag, value in exif.items():
            focal_35mm = None
            if ExifTags.TAGS.get(tag, None) == 'FocalLengthIn35mmFilm':
                focal_35mm = float(value)
                break

        if focal_35mm is not None:
            focal = focal_35mm / 35. * max_size
    
    if focal is None:
        if err_on_default:
            raise RuntimeError("Failed to find focal length")

        # failed to find it in exif, use prior
        FOCAL_PRIOR = 1.2
        focal = FOCAL_PRIOR * max_size

    return focal

def create_camera(db, image_path, camera_model):
    image         = Image.open(image_path)
    width, height = image.size

    focal = get_focal(image_path)

    if camera_model == 'simple-pinhole':
        model = 0 # simple pinhole
        param_arr = np.array([focal, width / 2, height / 2])
    if camera_model == 'pinhole':
        model = 1 # pinhole
        param_arr = np.array([focal, focal, width / 2, height / 2])
    elif camera_model == 'simple-radial':
        model = 2 # simple radial
        param_arr = np.array([focal, width / 2, height / 2, 0.1])
    elif camera_model == 'opencv':
        model = 4 # opencv
        param_arr = np.array([focal, focal, width / 2, height / 2, 0., 0., 0., 0.])
         
    return db.add_camera(model, width, height, param_arr)


def add_keypoints(db, h5_path, image_path, img_ext, camera_model, single_camera = True):
    keypoint_f = h5py.File(os.path.join(h5_path, 'keypoints.h5'), 'r')

    camera_id = None
    fname_to_id = {}
    for filename in tqdm(list(keypoint_f.keys())):
        keypoints = keypoint_f[filename][()]

        fname_with_ext = filename# + img_ext
        path = os.path.join(image_path, fname_with_ext)
        if not os.path.isfile(path):
            raise IOError(f'Invalid image path {path}')

        if camera_id is None or not single_camera:
            camera_id = create_camera(db, path, camera_model)
        image_id = db.add_image(fname_with_ext, camera_id)
        fname_to_id[filename] = image_id

        db.add_keypoints(image_id, keypoints)

    return fname_to_id

def add_matches(db, h5_path, fname_to_id):
    match_file = h5py.File(os.path.join(h5_path, 'matches.h5'), 'r')
    
    added = set()
    n_keys = len(match_file.keys())
    n_total = (n_keys * (n_keys - 1)) // 2

    with tqdm(total=n_total) as pbar:
        for key_1 in match_file.keys():
            group = match_file[key_1]
            for key_2 in group.keys():
                id_1 = fname_to_id[key_1]
                id_2 = fname_to_id[key_2]

                pair_id = image_ids_to_pair_id(id_1, id_2)
                if pair_id in added:
                    warnings.warn(f'Pair {pair_id} ({id_1}, {id_2}) already added!')
                    continue
            
                matches = group[key_2][()]
                db.add_matches(id_1, id_2, matches)

                added.add(pair_id)

                pbar.update(1)

In [None]:
# We will use ViT global descriptor to get matching shortlists.
def get_global_desc(fnames, model, model2,
                    device =  torch.device('cpu')):
    model = model.eval()
    model = model.to(device)
    model2 = model2.eval()
    model2 = model2.to(device)
    config = resolve_data_config({}, model=model)
    transform = transforms.Compose([
                transforms.Resize(600, interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.CenterCrop(600),
                transforms.ToTensor(),
                transforms.Normalize([0.4850, 0.4560, 0.4060], [0.2290, 0.2240, 0.2250]),])#create_transform(**config)
    global_descs_convnext=[]
    for i, img_fname_full in tqdm(enumerate(fnames),total= len(fnames)):
        key = os.path.splitext(os.path.basename(img_fname_full))[0]
        img = Image.open(img_fname_full).convert('RGB')
        timg = transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            desc = model.forward_features(timg.to(device)).mean(dim=(-1,2))#
            #descf = model.forward_features(timg.to(device).flip(3)).mean(dim=(-1,2))#
            desc2 = model2.forward_features(timg.to(device)).mean(dim=(-1,2))#
            #desc2f = model2.forward_features(timg.to(device).flip(3)).mean(dim=(-1,2))#
            #desc = desc+descf
            #desc2 = desc2+desc2f
            #print (desc.shape)
            desc = desc.view(1, -1)
            desc2 = desc2.view(1, -1)
            desc_norm = torch.cat([desc, desc2], dim=-1)
            desc_norm = F.normalize(desc_norm, dim=1, p=2)
        #print (desc_norm)
        global_descs_convnext.append(desc_norm.detach().cpu())
    global_descs_all = torch.cat(global_descs_convnext, dim=0)
    return global_descs_all


def get_img_pairs_exhaustive(img_fnames):
    index_pairs = []
    for i in range(len(img_fnames)):
        for j in range(i+1, len(img_fnames)):
            index_pairs.append((i,j))
    return index_pairs


def get_image_pairs_shortlist(fnames,
                              sim_th = 0.6, # should be strict
                              min_pairs = 20,
                              exhaustive_if_less = 20,
                              device=torch.device('cpu')):
    num_imgs = len(fnames)

    if num_imgs <= exhaustive_if_less:
        return get_img_pairs_exhaustive(fnames)

    #model = timm.create_model('tf_efficientnet_b7',
    #                          checkpoint_path='/kaggle/input/tf-efficientnet/pytorch/tf-efficientnet-b7/1/tf_efficientnet_b7_ra-6c08e654.pth')
    
    model.eval()
    descs = get_global_desc(fnames, model, device=device)
    dm = torch.cdist(descs, descs, p=2).detach().cpu().numpy()
    # removing half
    mask = dm <= sim_th
    total = 0
    matching_list = []
    ar = np.arange(num_imgs)
    already_there_set = []
    for st_idx in range(num_imgs-1):
        mask_idx = mask[st_idx]
        to_match = ar[mask_idx]
        if len(to_match) > min_pairs:
            to_match = np.argsort(dm[st_idx])[:min_pairs]  
        for idx in to_match:
            if st_idx == idx:
                continue
            if dm[st_idx, idx] < 1000:
                matching_list.append(tuple(sorted((st_idx, idx.item()))))
                total+=1
    matching_list = sorted(list(set(matching_list)))
    return matching_list


def get_image_pairs_shortlist_nn(fnames,
                              exhaustive_if_less = 20,
                              nneighbor=6,
                              device=torch.device('cpu')):
    num_imgs = len(fnames)
    if num_imgs <= exhaustive_if_less or num_imgs <= nneighbor :
        return get_img_pairs_exhaustive(fnames)

    model = timm.create_model('tf_efficientnet_b7',
                             checkpoint_path='/kaggle/input/tf-efficientnet/pytorch/tf-efficientnet-b7/1/tf_efficientnet_b7_ra-6c08e654.pth')
    
    model.eval()
    model2 = timm.create_model('tf_efficientnet_b6_ns',
                             checkpoint_path='/kaggle/input/imc-2023-timm-models/tf_efficientnet_b6_ns-51548356.pth')
    
    model2.eval()
    descs = get_global_desc(fnames, model, model2, device=device)
    dm = torch.einsum('bi,ki->bk', descs, descs).detach().cpu()
    print(f"sim shape: {dm.shape}")
    
    val, indices = torch.topk(dm, k=nneighbor, dim=1)
    if DEBUG_FLAG:
        print(f"sim: {val[:3]}")
        print(f"nn: {indices[:3]}")
    matching_list = []
    for st_idx in range(num_imgs-1):
        for t in indices[st_idx]:
            if t == st_idx:
                continue
            matching_list.append(tuple(sorted((st_idx, t.item()))))
    matching_list = sorted(list(set(matching_list)))
    return matching_list

In [None]:
# Making kornia local features loading w/o internet
class KeyNetAffNetHardNet(KF.LocalFeature):
    """Convenience module, which implements KeyNet detector + AffNet + HardNet descriptor.

    .. image:: _static/img/keynet_affnet.jpg
    """

    def __init__(
        self,
        num_features: int = 5000,
        upright: bool = False,
        device = torch.device('cpu'),
        scale_laf: float = 1.0,
    ):
        ori_module = KF.PassLAF() if upright else KF.LAFOrienter(angle_detector=KF.OriNet(False)).eval()
        if not upright:
            weights = torch.load('/kaggle/input/kornia-local-feature-weights/OriNet.pth')['state_dict']
            ori_module.angle_detector.load_state_dict(weights)
        detector = KF.KeyNetDetector(
            False, num_features=num_features, ori_module=ori_module, aff_module=KF.LAFAffNetShapeEstimator(False).eval()
        ).to(device)
        kn_weights = torch.load('/kaggle/input/kornia-local-feature-weights/keynet_pytorch.pth')['state_dict']
        detector.model.load_state_dict(kn_weights)
        affnet_weights = torch.load('/kaggle/input/kornia-local-feature-weights/AffNet.pth')['state_dict']
        detector.aff.load_state_dict(affnet_weights)
        
        hardnet = KF.HardNet(False).eval()
        hn_weights = torch.load('/kaggle/input/kornia-local-feature-weights/HardNetLib.pth')['state_dict']
        hardnet.load_state_dict(hn_weights)
        descriptor = KF.LAFDescriptor(hardnet, patch_size=32, grayscale_descriptor=True).to(device)
        super().__init__(detector, descriptor, scale_laf)

In [None]:
def detect_features(img_fnames,
                    num_feats = 2048,
                    upright = False,
                    device=torch.device('cpu'),
                    feature_dir = '.featureout',
                    resize_small_edge_to = 600):
    if LOCAL_FEATURE == 'DISK':
        # Load DISK from Kaggle models so it can run when the notebook is offline.
        disk = KF.DISK().to(device)
        pretrained_dict = torch.load('/kaggle/input/disk/pytorch/depth-supervision/1/loftr_outdoor.ckpt', map_location=device)
        disk.load_state_dict(pretrained_dict['extractor'])
        disk.eval()
    if LOCAL_FEATURE == 'KeyNetAffNetHardNet':
        feature = KeyNetAffNetHardNet(num_feats, upright, device).to(device).eval()
    if not os.path.isdir(feature_dir):
        os.makedirs(feature_dir)
    with h5py.File(f'{feature_dir}/lafs.h5', mode='w') as f_laf, \
         h5py.File(f'{feature_dir}/keypoints.h5', mode='w') as f_kp, \
         h5py.File(f'{feature_dir}/descriptors.h5', mode='w') as f_desc:
        for img_path in progress_bar(img_fnames):
            img_fname = img_path.split('/')[-1]
            key = img_fname
            with torch.inference_mode():
                timg = load_torch_image(img_path, device=device)
                H, W = timg.shape[2:]
                if resize_small_edge_to is None:
                    timg_resized = timg
                else:
                    timg_resized = K.geometry.resize(timg, resize_small_edge_to, antialias=True)
                    print(f'Resized {timg.shape} to {timg_resized.shape} (resize_small_edge_to={resize_small_edge_to})')
                h, w = timg_resized.shape[2:]
                if LOCAL_FEATURE == 'DISK':
                    features = disk(timg_resized, num_feats, pad_if_not_divisible=True)[0]
                    kps1, descs = features.keypoints, features.descriptors
                    
                    lafs = KF.laf_from_center_scale_ori(kps1[None], torch.ones(1, len(kps1), 1, 1, device=device))
                if LOCAL_FEATURE == 'KeyNetAffNetHardNet':
                    lafs, resps, descs = feature(K.color.rgb_to_grayscale(timg_resized))
                lafs[:,:,0,:] *= float(W) / float(w)
                lafs[:,:,1,:] *= float(H) / float(h)
                desc_dim = descs.shape[-1]
                kpts = KF.get_laf_center(lafs).reshape(-1, 2).detach().cpu().numpy()
                descs = descs.reshape(-1, desc_dim).detach().cpu().numpy()
                f_laf[key] = lafs.detach().cpu().numpy()
                f_kp[key] = kpts
                f_desc[key] = descs
    return

def get_unique_idxs(A, dim=0):
    # https://stackoverflow.com/questions/72001505/how-to-get-unique-elements-and-their-firstly-appeared-indices-of-a-pytorch-tenso
    unique, idx, counts = torch.unique(A, dim=dim, sorted=True, return_inverse=True, return_counts=True)
    _, ind_sorted = torch.sort(idx, stable=True)
    cum_sum = counts.cumsum(0)
    cum_sum = torch.cat((torch.tensor([0],device=cum_sum.device), cum_sum[:-1]))
    first_indices = ind_sorted[cum_sum]
    return first_indices

def match_features(img_fnames,
                   index_pairs,
                   feature_dir = '.featureout',
                   device=torch.device('cpu'),
                   min_matches=15, 
                   force_mutual = True,
                   matching_alg='smnn'
                  ):
    assert matching_alg in ['smnn', 'adalam']
    with h5py.File(f'{feature_dir}/lafs.h5', mode='r') as f_laf, \
         h5py.File(f'{feature_dir}/descriptors.h5', mode='r') as f_desc, \
        h5py.File(f'{feature_dir}/matches.h5', mode='w') as f_match:

        for pair_idx in progress_bar(index_pairs):
                    idx1, idx2 = pair_idx
                    fname1, fname2 = img_fnames[idx1], img_fnames[idx2]
                    key1, key2 = fname1.split('/')[-1], fname2.split('/')[-1]
                    lafs1 = torch.from_numpy(f_laf[key1][...]).to(device)
                    lafs2 = torch.from_numpy(f_laf[key2][...]).to(device)
                    desc1 = torch.from_numpy(f_desc[key1][...]).to(device)
                    desc2 = torch.from_numpy(f_desc[key2][...]).to(device)
                    if matching_alg == 'adalam':
                        img1, img2 = cv2.imread(fname1), cv2.imread(fname2)
                        hw1, hw2 = img1.shape[:2], img2.shape[:2]
                        adalam_config = KF.adalam.get_adalam_default_config()
                        #adalam_config['orientation_difference_threshold'] = None
                        #adalam_config['scale_rate_threshold'] = None
                        adalam_config['force_seed_mnn']= False
                        adalam_config['search_expansion'] = 16
                        adalam_config['ransac_iters'] = 128
                        adalam_config['device'] = device
                        dists, idxs = KF.match_adalam(desc1, desc2,
                                                      lafs1, lafs2, # Adalam takes into account also geometric information
                                                      hw1=hw1, hw2=hw2,
                                                      config=adalam_config) # Adalam also benefits from knowing image size
                    else:
                        dists, idxs = KF.match_smnn(desc1, desc2, 0.98)
                    if len(idxs)  == 0:
                        continue
                    # Force mutual nearest neighbors
                    if force_mutual:
                        first_indices = get_unique_idxs(idxs[:,1])
                        idxs = idxs[first_indices]
                        dists = dists[first_indices]
                    n_matches = len(idxs)
                    if False:
                        print (f'{key1}-{key2}: {n_matches} matches')
                    group  = f_match.require_group(key1)
                    if n_matches >= min_matches:
                         group.create_dataset(key2, data=idxs.detach().cpu().numpy().reshape(-1, 2))
    return

In [None]:
def load_torch_image(fname, device, resize=True, img_max_size=1280):
    img = cv2.imread(fname)
    img_sz = img.shape
    if img_max_size <=0:
        resize = False
    if resize:
        scale = img_max_size / max(img.shape[0], img.shape[1]) 
    else:
        scale = 1
    w = int(img.shape[1] * scale)
    h = int(img.shape[0] * scale)
    img = cv2.resize(img, ((w//8)*8, (h//8)*8))
    img = K.image_to_tensor(img, False).float() /255.
    img = K.color.bgr_to_rgb(img)
    return img.to(device), img_sz

In [None]:
DEVICE = device
IMG_MAX_SIZE = [840, 640]
DO_FLIP = False
MAX_NUM_PAIRS = 6000

In [None]:
def match_same_size(img_path0, img_path1, matcher, device=DEVICE, flip=False, img_max_size=1280):
    img0, source_img_shape = load_torch_image(img_path0, device, img_max_size=img_max_size)
    img1, target_img_shape = load_torch_image(img_path1, device, img_max_size=img_max_size)
    input_dict = {"image0": K.color.rgb_to_grayscale(img0).to(device), 
                  "image1": K.color.rgb_to_grayscale(img1).to(device)}
    if flip:
        input_dict["image0"] = torch.cat((input_dict["image0"], torch.flip(input_dict["image0"], dims=[3])), dim=0)
        input_dict["image1"] = torch.cat((input_dict["image1"], torch.flip(input_dict["image1"], dims=[3])), dim=0)
        
        
    with torch.no_grad():
        correspondences = matcher(input_dict)
    
    # flip coordinate
    # https://github.com/kornia/kornia/blob/master/kornia/feature/loftr/loftr.py
    keypoints0 = correspondences["keypoints0"].cpu().numpy()
    keypoints1 = correspondences["keypoints1"].cpu().numpy()
    confidence = correspondences["confidence"].cpu().numpy()
    batch_indexes = correspondences["batch_indexes"].cpu().numpy()
    if flip:
        select_fliped = (batch_indexes == 1)
        keypoints0[select_fliped, 0] = input_dict["image0"].shape[-1] - keypoints0[select_fliped, 0]
        keypoints1[select_fliped, 0] = input_dict["image1"].shape[-1] - keypoints1[select_fliped, 0]
    
    # scale
    new_source_img_shape = input_dict["image0"].shape[-2:]
    new_target_img_shape = input_dict["image1"].shape[-2:]
    sy, sx = source_img_shape[0] / new_source_img_shape[0], source_img_shape[1] / new_source_img_shape[1]
    keypoints0[:, 0] = keypoints0[:, 0] * sx
    keypoints0[:, 1] = keypoints0[:, 1] * sy

    sy, sx = target_img_shape[0] / new_target_img_shape[0], target_img_shape[1] / new_target_img_shape[1]
    keypoints1[:, 0] = keypoints1[:, 0] * sx
    keypoints1[:, 1] = keypoints1[:, 1] * sy
    torch.cuda.empty_cache()
    return {
        "keypoints0": keypoints0,
        "keypoints1":  keypoints1,
        "confidence": confidence
    }

In [None]:
def match_func(img_path0, img_path1, matcher, device=DEVICE):
    mkpts0 = []
    mkpts1 = []
    confidence = []
    for img_size in IMG_MAX_SIZE:
        preds =  match_same_size(img_path0, img_path1, matcher, device=DEVICE, flip=DO_FLIP, img_max_size=img_size)
        mkpts0.append(preds["keypoints0"])
        mkpts1.append(preds["keypoints1"])
        confidence.append(preds["confidence"])
    # concatenate
    mkpts0 = np.concatenate(mkpts0, axis=0)
    mkpts1 = np.concatenate(mkpts1, axis=0)
    confidence = np.concatenate(confidence, axis=0)
    
    # filter pairs
    ind = np.argsort(-confidence)
    ind = ind[:MAX_NUM_PAIRS]
    mkpts0 = mkpts0[ind]
    mkpts1 = mkpts1[ind]
    confidence = confidence[ind]
    return mkpts0, mkpts1

In [None]:
import sys
from copy import deepcopy
sys.path.append('/kaggle/input/imc2023-dependencies-data/DKM/')

In [None]:
!mkdir -p pretrained/checkpoints
!cp /kaggle/input/imc2023-dependencies-data/pretrained/dkm.pth pretrained/checkpoints/dkm_base_v11.pth

!pip install -f /kaggle/input/imc2023-dependencies-data/wheels --no-index einops
!cp -r /kaggle/input/imc2023-dependencies-data/DKM/ /kaggle/working/DKM/
!cd /kaggle/working/DKM/; pip install -f /kaggle/input/imc2023-dependencies-data/wheels -e . 

In [None]:
torch.hub.set_dir('/kaggle/working/pretrained/')
from dkm import dkm_base
DKM_matcher = dkm_base(pretrained=True, version="v11").to(device).eval()

In [None]:
def match_loftr(img_fnames,
                   index_pairs,
                   feature_dir = '.featureout_loftr',
                   device=torch.device('cpu'),
                   min_matches=15, resize_to_ = (640, 480)):
    LoFTR_cfg = {'backbone_type': 'ResNetFPN',
               'resolution': (8, 2),
               'fine_window_size': 5,
               'fine_concat_coarse_feat': True,
               'resnetfpn': {'initial_dim': 128, 'block_dims': [128, 196, 256]},
               'coarse': {'d_model': 256,
                          'd_ffn': 256,
                          'nhead': 8,
                          'layer_names': ['self',
                                          'cross',
                                          'self',
                                          'cross',
                                          'self',
                                          'cross',
                                          'self',
                                          'cross'],
                          'attention': 'linear',
                          'temp_bug_fix': False},
               'match_coarse': {'thr': 0.2,
                                'border_rm': 4,
                                'match_type': 'dual_softmax',
                                'dsmax_temperature': 0.1,
                                'skh_iters': 3,
                                'skh_init_bin_score': 1.0,
                                'skh_prefilter': True,
                                'train_coarse_percent': 0.4,
                                'train_pad_num_gt_min': 200},
               'fine': {'d_model': 128,
                        'd_ffn': 128,
                        'nhead': 8,
                        'layer_names': ['self', 'cross'],
                        'attention': 'linear'}}
    #matcher = KF.LoFTR(pretrained=None)
    #matcher = KF.LoFTR(pretrained=None, config=LoFTR_cfg)
    #matcher.load_state_dict(torch.load('/kaggle/input/loftr/pytorch/outdoor/1/loftr_outdoor.ckpt')['state_dict'])
    #matcher = matcher.to(device).eval()

    # First we do pairwise matching, and then extract "keypoints" from loftr matches.
    with h5py.File(f'{feature_dir}/matches_loftr.h5', mode='w') as f_match:
        for pair_idx in progress_bar(index_pairs):
            idx1, idx2 = pair_idx
            fname1, fname2 = img_fnames[idx1], img_fnames[idx2]
            key1, key2 = fname1.split('/')[-1], fname2.split('/')[-1]
            
            #mkpts0, mkpts1 = match_func(fname1, fname2, matcher)
            #torch.cuda.empty_cache()
            image_1, inp_1, scales_1 = read_image(fname1, device, [-1,], 0, resize_float)
            image_2, inp_2, scales_2 = read_image(fname2, device, [-1,], 0, resize_float)

            pred = super_glue_matcher({"image0": inp_1, "image1": inp_2})
            pred = {k: v[0].detach().cpu().numpy() for k, v in pred.items()}
            kpts1, kpts2 = pred["keypoints0"], pred["keypoints1"]
            matches, conf = pred["matches0"], pred["matching_scores0"]

            valid = matches > -1
            mkpts0 = kpts1[valid]
            mkpts1 = kpts2[matches[valid]]
            mconf = conf[valid]

            n_matches = len(mkpts1)
            group  = f_match.require_group(key1)
            if n_matches >= min_matches:
                 group.create_dataset(key2, data=np.concatenate([mkpts0, mkpts1], axis=1))

    # Let's find unique loftr pixels and group them together.
    kpts = defaultdict(list)
    match_indexes = defaultdict(dict)
    total_kpts=defaultdict(int)
    with h5py.File(f'{feature_dir}/matches_loftr.h5', mode='r') as f_match:
        for k1 in f_match.keys():
            group  = f_match[k1]
            for k2 in group.keys():
                matches = group[k2][...]
                total_kpts[k1]
                kpts[k1].append(matches[:, :2])
                kpts[k2].append(matches[:, 2:])
                current_match = torch.arange(len(matches)).reshape(-1, 1).repeat(1, 2)
                current_match[:, 0]+=total_kpts[k1]
                current_match[:, 1]+=total_kpts[k2]
                total_kpts[k1]+=len(matches)
                total_kpts[k2]+=len(matches)
                match_indexes[k1][k2]=current_match

    for k in kpts.keys():
        kpts[k] = np.round(np.concatenate(kpts[k], axis=0))
    unique_kpts = {}
    unique_match_idxs = {}
    out_match = defaultdict(dict)
    for k in kpts.keys():
        uniq_kps, uniq_reverse_idxs = torch.unique(torch.from_numpy(kpts[k]),dim=0, return_inverse=True)
        unique_match_idxs[k] = uniq_reverse_idxs
        unique_kpts[k] = uniq_kps.numpy()
    for k1, group in match_indexes.items():
        for k2, m in group.items():
            m2 = deepcopy(m)
            m2[:,0] = unique_match_idxs[k1][m2[:,0]]
            m2[:,1] = unique_match_idxs[k2][m2[:,1]]
            mkpts = np.concatenate([unique_kpts[k1][ m2[:,0]],
                                    unique_kpts[k2][  m2[:,1]],
                                   ],
                                   axis=1)
            unique_idxs_current = get_unique_idxs(torch.from_numpy(mkpts), dim=0)
            m2_semiclean = m2[unique_idxs_current]
            unique_idxs_current1 = get_unique_idxs(m2_semiclean[:, 0], dim=0)
            m2_semiclean = m2_semiclean[unique_idxs_current1]
            unique_idxs_current2 = get_unique_idxs(m2_semiclean[:, 1], dim=0)
            m2_semiclean2 = m2_semiclean[unique_idxs_current2]
            out_match[k1][k2] = m2_semiclean2.numpy()
    with h5py.File(f'{feature_dir}/keypoints.h5', mode='w') as f_kp:
        for k, kpts1 in unique_kpts.items():
            f_kp[k] = kpts1
    
    with h5py.File(f'{feature_dir}/matches.h5', mode='w') as f_match:
        for k1, gr in out_match.items():
            group  = f_match.require_group(k1)
            for k2, match in gr.items():
                group[k2] = match
    return

In [None]:
def import_into_colmap(img_dir,
                       feature_dir ='.featureout',
                       database_path = 'colmap.db',
                       img_ext='.jpg'):
    db = COLMAPDatabase.connect(database_path)
    db.create_tables()
    single_camera = False
    fname_to_id = add_keypoints(db, feature_dir, img_dir, img_ext, 'simple-radial', single_camera)
    add_matches(
        db,
        feature_dir,
        fname_to_id,
    )

    db.commit()
    return

In [None]:
def construct_local_validation():
    train_label_csv = "/kaggle/input/image-matching-challenge-2023/train/train_labels.csv"
    df = pd.read_csv(train_label_csv).head(100)
    df = df[["image_path", "dataset", "scene", "rotation_matrix", "translation_vector"]]
    df.to_csv("/kaggle/working/sample_submission.csv", index=False)

In [None]:
# from https://www.kaggle.com/code/eduardtrulls/imc2023-evaluation
# Evaluation metric.

@dataclass
class Camera:
    rotmat: np.array
    tvec: np.array

def quaternion_from_matrix(matrix):
    M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4]
    m00 = M[0, 0]
    m01 = M[0, 1]
    m02 = M[0, 2]
    m10 = M[1, 0]
    m11 = M[1, 1]
    m12 = M[1, 2]
    m20 = M[2, 0]
    m21 = M[2, 1]
    m22 = M[2, 2]

    # Symmetric matrix K.
    K = np.array([[m00 - m11 - m22, 0.0, 0.0, 0.0],
                  [m01 + m10, m11 - m00 - m22, 0.0, 0.0],
                  [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0],
                  [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22]])
    K /= 3.0

    # Quaternion is eigenvector of K that corresponds to largest eigenvalue.
    w, V = np.linalg.eigh(K)
    q = V[[3, 0, 1, 2], np.argmax(w)]

    if q[0] < 0.0:
        np.negative(q, q)
    return q

def evaluate_R_t(R_gt, t_gt, R, t, eps=1e-15):
    t = t.flatten()
    t_gt = t_gt.flatten()

    q_gt = quaternion_from_matrix(R_gt)
    q = quaternion_from_matrix(R)
    q = q / (np.linalg.norm(q) + eps)
    q_gt = q_gt / (np.linalg.norm(q_gt) + eps)
    loss_q = np.maximum(eps, (1.0 - np.sum(q * q_gt)**2))
    err_q = np.arccos(1 - 2 * loss_q)

    GT_SCALE = np.linalg.norm(t_gt)
    t = GT_SCALE * (t / (np.linalg.norm(t) + eps))
    err_t = min(np.linalg.norm(t_gt - t), np.linalg.norm(t_gt + t))
    
    return np.degrees(err_q), err_t

def compute_dR_dT(R1, T1, R2, T2):
    '''Given absolute (R, T) pairs for two cameras, compute the relative pose difference, from the first.'''
    
    dR = np.dot(R2, R1.T)
    dT = T2 - np.dot(dR, T1)
    return dR, dT

def compute_mAA(err_q, err_t, ths_q, ths_t):
    '''Compute the mean average accuracy over a set of thresholds. Additionally returns the metric only over rotation and translation.'''

    acc, acc_q, acc_t = [], [], []
    for th_q, th_t in zip(ths_q, ths_t):
        cur_acc_q = (err_q <= th_q)
        cur_acc_t = (err_t <= th_t)
        cur_acc = cur_acc_q & cur_acc_t
        
        acc.append(cur_acc.astype(np.float32).mean())
        acc_q.append(cur_acc_q.astype(np.float32).mean())
        acc_t.append(cur_acc_t.astype(np.float32).mean())
    return np.array(acc), np.array(acc_q), np.array(acc_t)

def dict_from_csv(csv_path, has_header):
    csv_dict = {}
    with open(csv_path, 'r') as f:
        for i, l in enumerate(f):
            if has_header and i == 0:
                continue
            if l:
                image, dataset, scene, R_str, T_str = l.strip().split(',')
                R = np.fromstring(R_str.strip(), sep=';').reshape(3, 3)
                T = np.fromstring(T_str.strip(), sep=';')
                if dataset not in csv_dict:
                    csv_dict[dataset] = {}
                if scene not in csv_dict[dataset]:
                    csv_dict[dataset][scene] = {}
                csv_dict[dataset][scene][image] = Camera(rotmat=R, tvec=T)
    return csv_dict

def eval_submission(submission_csv_path, ground_truth_csv_path, rotation_thresholds_degrees_dict, translation_thresholds_meters_dict, verbose=False):
    '''Compute final metric given submission and ground truth files. Thresholds are specified per dataset.'''

    submission_dict = dict_from_csv(submission_csv_path, has_header=True)
    gt_dict = dict_from_csv(ground_truth_csv_path, has_header=True)

    # Check that all necessary keys exist in the submission file
    for dataset in gt_dict:
        assert dataset in submission_dict, f'Unknown dataset: {dataset}'
        for scene in gt_dict[dataset]:
            assert scene in submission_dict[dataset], f'Unknown scene: {dataset}->{scene}'
            for image in gt_dict[dataset][scene]:
                assert image in submission_dict[dataset][scene], f'Unknown image: {dataset}->{scene}->{image}'

    # Iterate over all the scenes
    if verbose:
        t = time()
        print('*** METRICS ***')

    metrics_per_dataset = []
    for dataset in gt_dict:
        metrics_per_scene = []
        for scene in gt_dict[dataset]:
            err_q_all = []
            err_t_all = []
            images = [camera for camera in gt_dict[dataset][scene]]
            # Process all pairs in a scene
            for i in range(len(images)):
                for j in range(i + 1, len(images)):
                    gt_i = gt_dict[dataset][scene][images[i]]
                    gt_j = gt_dict[dataset][scene][images[j]]
                    dR_gt, dT_gt = compute_dR_dT(gt_i.rotmat, gt_i.tvec, gt_j.rotmat, gt_j.tvec)

                    pred_i = submission_dict[dataset][scene][images[i]]
                    pred_j = submission_dict[dataset][scene][images[j]]
                    dR_pred, dT_pred = compute_dR_dT(pred_i.rotmat, pred_i.tvec, pred_j.rotmat, pred_j.tvec)

                    err_q, err_t = evaluate_R_t(dR_gt, dT_gt, dR_pred, dT_pred)
                    err_q_all.append(err_q)
                    err_t_all.append(err_t)

            mAA, mAA_q, mAA_t = compute_mAA(err_q=err_q_all,
                                            err_t=err_t_all,
                                            ths_q=rotation_thresholds_degrees_dict[(dataset, scene)],
                                            ths_t=translation_thresholds_meters_dict[(dataset, scene)])
            if verbose:
                print(f'{dataset} / {scene} ({len(images)} images, {len(err_q_all)} pairs) -> mAA={np.mean(mAA):.06f}, mAA_q={np.mean(mAA_q):.06f}, mAA_t={np.mean(mAA_t):.06f}')
            metrics_per_scene.append(np.mean(mAA))

        metrics_per_dataset.append(np.mean(metrics_per_scene))
        if verbose:
            print(f'{dataset} -> mAA={np.mean(metrics_per_scene):.06f}')
            print()

    if verbose:
        print(f'Final metric -> mAA={np.mean(metrics_per_dataset):.06f} (t: {time() - t} sec.)')
        print()

    return np.mean(metrics_per_dataset)

In [None]:
# Set rotation thresholds per scene.
rotation_thresholds_degrees_dict = {
    **{('haiper', scene): np.linspace(1, 10, 10) for scene in ['bike', 'chairs', 'fountain']},
    **{('heritage', scene): np.linspace(1, 10, 10) for scene in ['cyprus', 'dioscuri']},
    **{('heritage', 'wall'): np.linspace(0.2, 10, 10)},
    **{('urban', 'kyiv-puppet-theater'): np.linspace(1, 10, 10)},
}

translation_thresholds_meters_dict = {
    **{('haiper', scene): np.geomspace(0.05, 0.5, 10) for scene in ['bike', 'chairs', 'fountain']},
    **{('heritage', scene): np.geomspace(0.1, 2, 10) for scene in ['cyprus', 'dioscuri']},
    **{('heritage', 'wall'): np.geomspace(0.05, 1, 10)},
    **{('urban', 'kyiv-puppet-theater'): np.geomspace(0.5, 5, 10)},
}

In [None]:
src = '/kaggle/input/image-matching-challenge-2023'
if DEBUG_FLAG:
    construct_local_validation()
    sample_submission_file = f'/kaggle/working/sample_submission.csv'
    img_base_dir = os.path.join(src, "train")
else:
    sample_submission_file = f'/kaggle/input/image-matching-challenge-2023/sample_submission.csv'
    img_base_dir = os.path.join(src, "test")

In [None]:
# Get data from csv.
data_dict = {}
with open(sample_submission_file, 'r') as f:
    for i, l in enumerate(f):
        # Skip header.
        if l and i > 0:
            image, dataset, scene, _, _ = l.strip().split(',')
            if dataset not in data_dict:
                data_dict[dataset] = {}
            if scene not in data_dict[dataset]:
                data_dict[dataset][scene] = []
            data_dict[dataset][scene].append(image)

In [None]:
for dataset in data_dict:
    for scene in data_dict[dataset]:
        print(f'{dataset} / {scene} -> {len(data_dict[dataset][scene])} images')

In [None]:
out_results = {}
timings = {"shortlisting":[],
           "feature_detection": [],
           "feature_matching":[],
           "RANSAC": [],
           "Reconstruction": []}

In [None]:
# Function to create a submission file.
def create_submission(out_results, data_dict):
    with open(f'submission.csv', 'w') as f:
        f.write('image_path,dataset,scene,rotation_matrix,translation_vector\n')
        for dataset in data_dict:
            if dataset in out_results:
                res = out_results[dataset]
            else:
                res = {}
            for scene in data_dict[dataset]:
                if scene in res:
                    scene_res = res[scene]
                else:
                    scene_res = {"R":{}, "t":{}}
                for image in data_dict[dataset][scene]:
                    if image in scene_res:
                        print (image)
                        R = scene_res[image]['R'].reshape(-1)
                        T = scene_res[image]['t'].reshape(-1)
                    else:
                        R = np.eye(3).reshape(-1)
                        T = np.zeros((3))
                    f.write(f'{image},{dataset},{scene},{arr_to_str(R)},{arr_to_str(T)}\n')

In [None]:
gc.collect()
datasets = []
for dataset in data_dict:
    datasets.append(dataset)

for dataset in datasets:
    print(dataset)
    if dataset not in out_results:
        out_results[dataset] = {}
    for scene in data_dict[dataset]:
        print(scene)
        # Fail gently if the notebook has not been submitted and the test data is not populated.
        # You may want to run this on the training data in that case?
        img_dir = f'{img_base_dir}/{dataset}/{scene}/images'
        if not os.path.exists(img_dir):
            continue
        # Wrap the meaty part in a try-except block.
        try:
            out_results[dataset][scene] = {}
            img_fnames = [f'{img_base_dir}/{x}' for x in data_dict[dataset][scene]]
            print (f"Got {len(img_fnames)} images")
            feature_dir = f'featureout/{dataset}_{scene}'
            if not os.path.isdir(feature_dir):
                os.makedirs(feature_dir, exist_ok=True)
            
            ## Pairs
            t=time()
            index_pairs =get_image_pairs_shortlist_nn(img_fnames,
                                                    nneighbor=20,              
                                                    exhaustive_if_less = 20,
                                                    device=device)
            t=time() -t 
            timings['shortlisting'].append(t)
            print (f'{len(index_pairs)}, pairs to match, {t:.4f} sec')
            gc.collect()
            
            # Desc & Matching
            t=time()
            if LOCAL_FEATURE != 'LoFTR':
                detect_features(img_fnames, 
                                2048,
                                feature_dir=feature_dir,
                                upright=True,
                                device=device,
                                resize_small_edge_to=600
                               )
                gc.collect()
                t=time() -t 
                timings['feature_detection'].append(t)
                print(f'Features detected in  {t:.4f} sec')
                t=time()
                match_features(img_fnames, index_pairs, feature_dir=feature_dir,device=device, matching_alg='adalam')
            else:
                match_loftr(img_fnames, index_pairs, feature_dir=feature_dir, device=device, resize_to_=(600, 800))
            t=time() -t 
            timings['feature_matching'].append(t)
            print(f'Features matched in  {t:.4f} sec')
            
            ## Colmap
            database_path = f'{feature_dir}/colmap.db'
            if os.path.isfile(database_path):
                os.remove(database_path)
            gc.collect()
            import_into_colmap(img_dir, feature_dir=feature_dir,database_path=database_path)
            output_path = f'{feature_dir}/colmap_rec_{LOCAL_FEATURE}'

            t=time()
            pycolmap.match_exhaustive(database_path)
            t=time() - t 
            timings['RANSAC'].append(t)
            print(f'RANSAC in  {t:.4f} sec')

            t=time()
            # By default colmap does not generate a reconstruction if less than 10 images are registered. Lower it to 3.
            mapper_options = pycolmap.IncrementalMapperOptions()
            mapper_options.min_model_size = 3
            os.makedirs(output_path, exist_ok=True)
            maps = pycolmap.incremental_mapping(database_path=database_path, image_path=img_dir, output_path=output_path, options=mapper_options)
            print(maps)
            #clear_output(wait=False)
            t=time() - t
            timings['Reconstruction'].append(t)
            print(f'Reconstruction done in  {t:.4f} sec')
            
            ## Export
            imgs_registered  = 0
            best_idx = None
            print ("Looking for the best reconstruction")
            if isinstance(maps, dict):
                for idx1, rec in maps.items():
                    print (idx1, rec.summary())
                    if len(rec.images) > imgs_registered:
                        imgs_registered = len(rec.images)
                        best_idx = idx1
                    for k, im in maps[idx1].images.items():
                        key1 = f'{dataset}/{scene}/images/{im.name}'
                        if key1 not in out_results[dataset][scene]:
                            out_results[dataset][scene][key1] = {}
                        out_results[dataset][scene][key1]["R"] = deepcopy(im.rotmat())
                        out_results[dataset][scene][key1]["t"] = deepcopy(np.array(im.tvec))
                    
            if best_idx is not None:
                print (maps[best_idx].summary())
                for k, im in maps[best_idx].images.items():
                    key1 = f'{dataset}/{scene}/images/{im.name}'
                    if key1 not in out_results[dataset][scene]:
                            out_results[dataset][scene][key1] = {}
                    out_results[dataset][scene][key1]["R"] = deepcopy(im.rotmat())
                    out_results[dataset][scene][key1]["t"] = deepcopy(np.array(im.tvec))
            print(f'Registered: {dataset} / {scene} -> {len(out_results[dataset][scene])} images')
            print(f'Total: {dataset} / {scene} -> {len(data_dict[dataset][scene])} images')
            create_submission(out_results, data_dict)
            gc.collect()
        except Exception as e:
            print("*"*20)
            print(e)

In [None]:
create_submission(out_results, data_dict)

In [None]:
if DEBUG_FLAG:
    eval_submission(submission_csv_path='submission.csv',
                ground_truth_csv_path='sample_submission.csv',
                rotation_thresholds_degrees_dict=rotation_thresholds_degrees_dict,
                translation_thresholds_meters_dict=translation_thresholds_meters_dict,
                verbose=True)